@@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
887887 )
888888 return input_offsets , send_sizes , output_offsets , recv_sizes
889889
890+ def get_ragged_buffer_size (self , local_expert_size , local_batch ):
891+ if self .config .ragged_buffer_factor > 0.0 :
892+ balanced_size = local_batch
893+ return int (balanced_size * self .config .ragged_buffer_factor )
894+ else :
895+ max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
896+ return int (local_batch * max_local_experts_per_tok )
897+
890898 def transform_bias (self , experts_index , * biases ):
891899 """Selects bias values for a variable number of bias tensors based on chosen experts."""
892900 return tuple (bias [experts_index ] for bias in biases )
@@ -1181,7 +1189,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11811189 # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
11821190 # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11831191 max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
1184- buffer_size = int ( num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1192+ buffer_size = self . get_ragged_buffer_size ( local_expert_size , jnp . shape ( x )[ 0 ] )
11851193 output_shape = jax .lax .empty ((buffer_size , self .config .emb_dim ), dtype = x .dtype )
11861194
11871195 x = jax .lax .ragged_all_to_all (
0 commit comments