@@ -349,6 +349,10 @@ def __init__(
349349 self .quant = quant
350350 self .rngs = rngs
351351
352+ self .moe_expert_input_dim = getattr (self .config , "moe_expert_input_dim" , - 1 )
353+ if self .moe_expert_input_dim <= 0 :
354+ self .moe_expert_input_dim = self .config .emb_dim
355+
352356 if self .config .shard_exp_on_fsdp :
353357 # special sharding for dsv3
354358 self .wi_kernel_axes = ("embed_moe" , None , "mlp_moe" )
@@ -374,7 +378,7 @@ def __init__(
374378 self ._expert_parallelism_name = "expert"
375379
376380 self .gate = GateLogit (
377- in_features_shape = self .config . emb_dim ,
381+ in_features_shape = self .moe_expert_input_dim ,
378382 out_features_shape = self .num_experts ,
379383 mesh = self .mesh ,
380384 model_name = self .config .model_name ,
@@ -400,14 +404,14 @@ def __init__(
400404 # During aqt convert state we delete kernel weight from params to save
401405 # memory. Instead they are retrieved from the tensors stored in the 'aqt'
402406 # collection.
403- self .wi_0 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
404- self .wi_1 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
405- self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config . emb_dim ))
407+ self .wi_0 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
408+ self .wi_1 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
409+ self .wo = jnp .zeros ((num_experts , intermediate_dim , self .moe_expert_input_dim ))
406410 else :
407411 self .wi_0 = nnx .Param (
408412 self .kernel_init (
409413 self .rngs .params (),
410- (num_experts , self .config . emb_dim , intermediate_dim ),
414+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
411415 weight_dtype ,
412416 kernel_in_axis ,
413417 kernel_out_axis ,
@@ -417,7 +421,7 @@ def __init__(
417421 self .wi_1 = nnx .Param (
418422 self .kernel_init (
419423 self .rngs .params (),
420- (num_experts , self .config . emb_dim , intermediate_dim ),
424+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
421425 weight_dtype ,
422426 kernel_in_axis ,
423427 kernel_out_axis ,
@@ -427,7 +431,7 @@ def __init__(
427431 self .wo = nnx .Param (
428432 self .kernel_init (
429433 self .rngs .params (),
430- (self .num_experts , self .intermediate_dim , self .config . emb_dim ),
434+ (self .num_experts , self .intermediate_dim , self .moe_expert_input_dim ),
431435 self .weight_dtype ,
432436 kernel_in_axis ,
433437 kernel_out_axis ,
@@ -439,7 +443,7 @@ def __init__(
439443 wi_bias_axes = ("exp" , "activation_mlp" )
440444 wo_bias_axes = ("exp" , "activation_embed" )
441445 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442- wo_bias_shape = (self .num_experts , self .config . emb_dim )
446+ wo_bias_shape = (self .num_experts , self .moe_expert_input_dim )
443447 self .wi_0_bias = nnx .Param (
444448 default_bias_init (self .rngs .params (), wi_bias_shape , self .weight_dtype ),
445449 sharding = wi_bias_axes ,
@@ -1172,7 +1176,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11721176 # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11731177 max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
11741178 buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1175- output_shape = jax .lax .empty ((buffer_size , self .config .emb_dim ), dtype = x .dtype )
1179+ output_shape = jax .lax .empty ((buffer_size , self .config .moe_model_dim ), dtype = x .dtype )
11761180
11771181 x = jax .lax .ragged_all_to_all (
11781182 x ,
@@ -1327,7 +1331,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13271331 )
13281332
13291333 # Sum up the partial outputs across the expert shards.
1330- output = jnp .reshape (output , (- 1 , sequence_length , self .config . emb_dim // self .get_tensor_parallelism_size ()))
1334+ output = jnp .reshape (output , (- 1 , sequence_length , self .moe_model_dim // self .get_tensor_parallelism_size ()))
13311335 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
13321336
13331337 else :
@@ -1338,7 +1342,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13381342 output_shape = jax .lax .empty (
13391343 (
13401344 original_inputs_first_dim ,
1341- self .config . emb_dim // self .get_tensor_parallelism_size (),
1345+ self .moe_model_dim // self .get_tensor_parallelism_size (),
13421346 ),
13431347 dtype = intermediate_output .dtype ,
13441348 )
@@ -2094,6 +2098,10 @@ def __init__(
20942098 self .dtype = dtype
20952099 self .quant = quant
20962100 self .rngs = rngs
2101+ self .moe_model_dim = getattr (self .config , "moe_model_dim" , - 1 )
2102+ if self .moe_model_dim <= 0 :
2103+ self .moe_model_dim = self .config .emb_dim
2104+
20972105 # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
20982106 # existing checkpoints for routed experts.
20992107 self .MoeBlock_0 = RoutedMoE (
@@ -2115,7 +2123,7 @@ def __init__(
21152123 )
21162124 self .shared_experts = linears .MlpBlock (
21172125 mesh = self .mesh ,
2118- in_features = self .config . emb_dim ,
2126+ in_features = self .moe_model_dim ,
21192127 intermediate_dim = self .config .shared_experts * shared_expert_mlp_dim ,
21202128 activations = self .config .mlp_activations ,
21212129 intermediate_dropout_rate = self .config .dropout_rate ,
0 commit comments