@@ -349,6 +349,10 @@ def __init__(
349349 self .quant = quant
350350 self .rngs = rngs
351351
352+ self .moe_expert_input_dim = (
353+ self .config .emb_dim if self .config .moe_expert_input_dim <= 0 else self .config .moe_expert_input_dim
354+ )
355+
352356 if self .config .shard_exp_on_fsdp :
353357 # special sharding for dsv3
354358 self .wi_kernel_axes = ("embed_moe" , None , "mlp_moe" )
@@ -374,7 +378,7 @@ def __init__(
374378 self ._expert_parallelism_name = "expert"
375379
376380 self .gate = GateLogit (
377- in_features_shape = self .config . emb_dim ,
381+ in_features_shape = self .moe_expert_input_dim ,
378382 out_features_shape = self .num_experts ,
379383 mesh = self .mesh ,
380384 model_name = self .config .model_name ,
@@ -400,14 +404,14 @@ def __init__(
400404 # During aqt convert state we delete kernel weight from params to save
401405 # memory. Instead they are retrieved from the tensors stored in the 'aqt'
402406 # collection.
403- self .wi_0 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
404- self .wi_1 = jnp .zeros ((num_experts , self .config . emb_dim , intermediate_dim ))
405- self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config . emb_dim ))
407+ self .wi_0 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
408+ self .wi_1 = jnp .zeros ((num_experts , self .moe_expert_input_dim , intermediate_dim ))
409+ self .wo = jnp .zeros ((num_experts , intermediate_dim , self .moe_expert_input_dim ))
406410 else :
407411 self .wi_0 = nnx .Param (
408412 self .kernel_init (
409413 self .rngs .params (),
410- (num_experts , self .config . emb_dim , intermediate_dim ),
414+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
411415 weight_dtype ,
412416 kernel_in_axis ,
413417 kernel_out_axis ,
@@ -417,7 +421,7 @@ def __init__(
417421 self .wi_1 = nnx .Param (
418422 self .kernel_init (
419423 self .rngs .params (),
420- (num_experts , self .config . emb_dim , intermediate_dim ),
424+ (num_experts , self .moe_expert_input_dim , intermediate_dim ),
421425 weight_dtype ,
422426 kernel_in_axis ,
423427 kernel_out_axis ,
@@ -427,7 +431,7 @@ def __init__(
427431 self .wo = nnx .Param (
428432 self .kernel_init (
429433 self .rngs .params (),
430- (self .num_experts , self .intermediate_dim , self .config . emb_dim ),
434+ (self .num_experts , self .intermediate_dim , self .moe_expert_input_dim ),
431435 self .weight_dtype ,
432436 kernel_in_axis ,
433437 kernel_out_axis ,
@@ -439,7 +443,7 @@ def __init__(
439443 wi_bias_axes = ("exp" , "activation_mlp" )
440444 wo_bias_axes = ("exp" , "activation_embed" )
441445 wi_bias_shape = (self .num_experts , self .intermediate_dim )
442- wo_bias_shape = (self .num_experts , self .config . emb_dim )
446+ wo_bias_shape = (self .num_experts , self .moe_expert_input_dim )
443447 self .wi_0_bias = nnx .Param (
444448 default_bias_init (self .rngs .params (), wi_bias_shape , self .weight_dtype ),
445449 sharding = wi_bias_axes ,
@@ -1208,7 +1212,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
12081212 self .config .num_experts_per_tok ,
12091213 self .config .ragged_buffer_factor ,
12101214 )
1211- output_shape = jax .lax .empty ((buffer_size , self .config . emb_dim ), dtype = x .dtype )
1215+ output_shape = jax .lax .empty ((buffer_size , self .moe_expert_input_dim ), dtype = x .dtype )
12121216
12131217 x = jax .lax .ragged_all_to_all (
12141218 x ,
@@ -1345,7 +1349,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13451349 )
13461350
13471351 # Sum up the partial outputs across the expert shards.
1348- output = jnp .reshape (output , (- 1 , sequence_length , self .config .emb_dim // self .get_tensor_parallelism_size ()))
1352+ output = jnp .reshape (
1353+ output , (- 1 , sequence_length , self .moe_expert_input_dim // self .get_tensor_parallelism_size ())
1354+ )
13491355 output = jax .lax .psum_scatter (output , self ._expert_parallelism_name , scatter_dimension = 0 , tiled = True )
13501356
13511357 else :
@@ -1356,7 +1362,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13561362 output_shape = jax .lax .empty (
13571363 (
13581364 original_inputs_first_dim ,
1359- self .config . emb_dim // self .get_tensor_parallelism_size (),
1365+ self .moe_expert_input_dim // self .get_tensor_parallelism_size (),
13601366 ),
13611367 dtype = intermediate_output .dtype ,
13621368 )
@@ -2112,14 +2118,18 @@ def __init__(
21122118 self .dtype = dtype
21132119 self .quant = quant
21142120 self .rngs = rngs
2121+ self .moe_expert_input_dim = (
2122+ self .config .emb_dim if self .config .moe_expert_input_dim <= 0 else self .config .moe_expert_input_dim
2123+ )
2124+
21152125 # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
21162126 # existing checkpoints for routed experts.
21172127 self .MoeBlock_0 = RoutedMoE (
21182128 config = self .config ,
21192129 num_experts = self .config .num_experts ,
21202130 num_experts_per_tok = self .config .num_experts_per_tok ,
21212131 mesh = self .mesh ,
2122- kernel_init = nd_dense_init ( 1.0 , "fan_in" , "truncated_normal" ) ,
2132+ kernel_init = self . kernel_init ,
21232133 kernel_axes = ("embed_moe" , None ),
21242134 intermediate_dim = self .config .moe_mlp_dim ,
21252135 dtype = self .config .dtype ,
@@ -2133,9 +2143,10 @@ def __init__(
21332143 )
21342144 self .shared_experts = linears .MlpBlock (
21352145 mesh = self .mesh ,
2136- in_features = self .config . emb_dim ,
2146+ in_features = self .moe_expert_input_dim ,
21372147 intermediate_dim = self .config .shared_experts * shared_expert_mlp_dim ,
21382148 activations = self .config .mlp_activations ,
2149+ kernel_init = self .kernel_init ,
21392150 intermediate_dropout_rate = self .config .dropout_rate ,
21402151 dtype = self .config .dtype ,
21412152 weight_dtype = self .config .weight_dtype ,
0 commit comments