@@ -349,6 +349,10 @@ def __init__(
349349 self .quant = quant
350350 self .rngs = rngs
351351
352+ self .moe_expert_input_dim = getattr (self .config , "moe_expert_input_dim" , - 1 )
353+ if self .moe_expert_input_dim <= 0 :
354+ self .moe_expert_input_dim = self .config .emb_dim
355+
352356 if self .config .shard_exp_on_fsdp :
353357 # special sharding for dsv3
354358 self .wi_kernel_axes = ("embed_no_exp_moe" , None , "mlp" )
@@ -374,7 +378,7 @@ def __init__(
374378 self ._expert_parallelism_name = "expert"
375379
376380 self .gate = GateLogit (
377- in_features_shape = self .config . emb_dim ,
381+ in_features_shape = self .moe_expert_input_dim ,
378382 out_features_shape = self .num_experts ,
379383 mesh = self .mesh ,
380384 model_name = self .config .model_name ,
@@ -400,14 +404,14 @@ def __init__(
400404 # During aqt convert state we delete kernel weight from params to save
401405 # memory. Instead they are retrieved from the tensors stored in the 'aqt'
402406 # collection.
403- self .wi_0 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
404- self .wi_1 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
405- self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config . emb_dim ))
407+ self .wi_0 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
408+ self .wi_1 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
409+ self .wo = jnp .zeros ((num_experts , intermediate_dim , self .moe_expert_input_dim ))
406410 else :
407411 self .wi_0 = nnx .Param (
408412 self .kernel_init (
409413 self .rngs .params (),
410- (num_experts , self .config . emb_dim , intermediate_dim ),
414+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
411415 weight_dtype ,
412416 kernel_in_axis ,
413417 kernel_out_axis ,
@@ -417,7 +421,7 @@ def __init__(
417421 self .wi_1 = nnx .Param (
418422 self .kernel_init (
419423 self .rngs .params (),
420- (num_experts , self .config . emb_dim , intermediate_dim ),
424+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
421425 weight_dtype ,
422426 kernel_in_axis ,
423427 kernel_out_axis ,
@@ -427,7 +431,7 @@ def __init__(
427431 self .wo = nnx .Param (
428432 self .kernel_init (
429433 self .rngs .params (),
430- (self .num_experts , self .intermediate_dim , self .config . emb_dim ),
434+ (self .num_experts , self .intermediate_dim , self .moe_expert_input_dim ),
431435 self .weight_dtype ,
432436 kernel_in_axis ,
433437 kernel_out_axis ,
@@ -439,7 +443,7 @@ def __init__(
439443 wi_bias_axes = ("exp" , "activation_mlp" )
440444 wo_bias_axes = ("exp" , "activation_embed_moe" )
441445 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442- wo_bias_shape = (self .num_experts , self .config . emb_dim )
446+ wo_bias_shape = (self .num_experts , self .moe_expert_input_dim )
443447 self .wi_0_bias = nnx .Param (
444448 default_bias_init (self .rngs .params (), wi_bias_shape , self .weight_dtype ),
445449 sharding = wi_bias_axes ,
@@ -1182,7 +1186,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11821186 # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11831187 max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
11841188 buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1185- output_shape = jax .lax .empty ((buffer_size , self .config . emb_dim ), dtype = x .dtype )
1189+ output_shape = jax .lax .empty ((buffer_size , self .moe_model_dim ), dtype = x .dtype )
11861190
11871191 x = jax .lax .ragged_all_to_all (
11881192 x ,
@@ -1337,7 +1341,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13371341 )
13381342
13391343 # Sum up the partial outputs across the expert shards.
1340- output = jnp .reshape (output , (- 1 , sequence_length , self .config . emb_dim // self .get_tensor_parallelism_size ()))
1344+ output = jnp .reshape (output , (- 1 , sequence_length , self .moe_model_dim // self .get_tensor_parallelism_size ()))
13411345 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
13421346
13431347 else :
@@ -1348,7 +1352,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13481352 output_shape = jax .lax .empty (
13491353 (
13501354 original_inputs_first_dim ,
1351- self .config . emb_dim // self .get_tensor_parallelism_size (),
1355+ self .moe_model_dim // self .get_tensor_parallelism_size (),
13521356 ),
13531357 dtype = intermediate_output .dtype ,
13541358 )
@@ -2095,6 +2099,10 @@ def __init__(
20952099 self .dtype = dtype
20962100 self .quant = quant
20972101 self .rngs = rngs
2102+ self .moe_model_dim = getattr (self .config , "moe_model_dim" , - 1 )
2103+ if self .moe_model_dim <= 0 :
2104+ self .moe_model_dim = self .config .emb_dim
2105+
20982106 # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
20992107 # existing checkpoints for routed experts.
21002108 self .MoeBlock_0 = RoutedMoE (
@@ -2116,7 +2124,7 @@ def __init__(
21162124 )
21172125 self .shared_experts = linears .MlpBlock (
21182126 mesh = self .mesh ,
2119- in_features = self .config . emb_dim ,
2127+ in_features = self .moe_model_dim ,
21202128 intermediate_dim = self .config .shared_experts * shared_expert_mlp_dim ,
21212129 activations = self .config .mlp_activations ,
21222130 intermediate_dropout_rate = self .config .dropout_rate ,
0 commit comments