@@ -349,6 +349,14 @@ def __init__(
349349 self .quant = quant
350350 self .rngs = rngs
351351
352+ self .moe_expert_input_dim = getattr (self .config , "moe_expert_input_dim" , - 1 )
353+ if self .moe_expert_input_dim <= 0 :
354+ self .moe_expert_input_dim = self .config .emb_dim
355+
356+ self .moe_model_dim = getattr (self .config , "moe_model_dim" , - 1 )
357+ if self .moe_model_dim <= 0 :
358+ self .moe_model_dim = self .config .emb_dim
359+
352360 if self .config .shard_exp_on_fsdp :
353361 # special sharding for dsv3
354362 self .wi_kernel_axes = ("embed_moe" , None , "mlp_moe" )
@@ -374,7 +382,7 @@ def __init__(
374382 self ._expert_parallelism_name = "expert"
375383
376384 self .gate = GateLogit (
377- in_features_shape = self .config . emb_dim ,
385+ in_features_shape = self .moe_expert_input_dim ,
378386 out_features_shape = self .num_experts ,
379387 mesh = self .mesh ,
380388 model_name = self .config .model_name ,
@@ -400,14 +408,14 @@ def __init__(
400408 # During aqt convert state we delete kernel weight from params to save
401409 # memory. Instead they are retrieved from the tensors stored in the 'aqt'
402410 # collection.
403- self .wi_0 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
404- self .wi_1 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
405- self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config . emb_dim ))
411+ self .wi_0 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
412+ self .wi_1 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
413+ self .wo = jnp .zeros ((num_experts , intermediate_dim , self .moe_expert_input_dim ))
406414 else :
407415 self .wi_0 = nnx .Param (
408416 self .kernel_init (
409417 self .rngs .params (),
410- (num_experts , self .config . emb_dim , intermediate_dim ),
418+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
411419 weight_dtype ,
412420 kernel_in_axis ,
413421 kernel_out_axis ,
@@ -417,7 +425,7 @@ def __init__(
417425 self .wi_1 = nnx .Param (
418426 self .kernel_init (
419427 self .rngs .params (),
420- (num_experts , self .config . emb_dim , intermediate_dim ),
428+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
421429 weight_dtype ,
422430 kernel_in_axis ,
423431 kernel_out_axis ,
@@ -427,7 +435,7 @@ def __init__(
427435 self .wo = nnx .Param (
428436 self .kernel_init (
429437 self .rngs .params (),
430- (self .num_experts , self .intermediate_dim , self .config . emb_dim ),
438+ (self .num_experts , self .intermediate_dim , self .moe_expert_input_dim ),
431439 self .weight_dtype ,
432440 kernel_in_axis ,
433441 kernel_out_axis ,
@@ -439,7 +447,7 @@ def __init__(
439447 wi_bias_axes = ("exp" , "activation_mlp" )
440448 wo_bias_axes = ("exp" , "activation_embed" )
441449 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442- wo_bias_shape = (self .num_experts , self .config . emb_dim )
450+ wo_bias_shape = (self .num_experts , self .moe_expert_input_dim )
443451 self .wi_0_bias = nnx .Param (
444452 default_bias_init (self .rngs .params (), wi_bias_shape , self .weight_dtype ),
445453 sharding = wi_bias_axes ,
@@ -1172,7 +1180,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11721180 # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11731181 max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
11741182 buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1175- output_shape = jax .lax .empty ((buffer_size , self .config . emb_dim ), dtype = x .dtype )
1183+ output_shape = jax .lax .empty ((buffer_size , self .moe_model_dim ), dtype = x .dtype )
11761184
11771185 x = jax .lax .ragged_all_to_all (
11781186 x ,
@@ -1327,7 +1335,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13271335 )
13281336
13291337 # Sum up the partial outputs across the expert shards.
1330- output = jnp .reshape (output , (- 1 , sequence_length , self .config . emb_dim // self .get_tensor_parallelism_size ()))
1338+ output = jnp .reshape (output , (- 1 , sequence_length , self .moe_model_dim // self .get_tensor_parallelism_size ()))
13311339 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
13321340
13331341 else :
@@ -1338,7 +1346,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13381346 output_shape = jax .lax .empty (
13391347 (
13401348 original_inputs_first_dim ,
1341- self .config . emb_dim // self .get_tensor_parallelism_size (),
1349+ self .moe_model_dim // self .get_tensor_parallelism_size (),
13421350 ),
13431351 dtype = intermediate_output .dtype ,
13441352 )
@@ -2094,6 +2102,10 @@ def __init__(
20942102 self .dtype = dtype
20952103 self .quant = quant
20962104 self .rngs = rngs
2105+ self .moe_model_dim = getattr (self .config , "moe_model_dim" , - 1 )
2106+ if self .moe_model_dim <= 0 :
2107+ self .moe_model_dim = self .config .emb_dim
2108+
20972109 # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
20982110 # existing checkpoints for routed experts.
20992111 self .MoeBlock_0 = RoutedMoE (
@@ -2115,7 +2127,7 @@ def __init__(
21152127 )
21162128 self .shared_experts = linears .MlpBlock (
21172129 mesh = self .mesh ,
2118- in_features = self .config . emb_dim ,
2130+ in_features = self .moe_model_dim ,
21192131 intermediate_dim = self .config .shared_experts * shared_expert_mlp_dim ,
21202132 activations = self .config .mlp_activations ,
21212133 intermediate_dropout_rate = self .config .dropout_rate ,
0 commit comments