Skip to content

Commit 9350e8a

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
Add custom Qwen3 model with configurable attention and latentMoE.
Specifically, this introduces: * `attention_output_dim` and `moe_expert_input_dim` to allow the attention block output and the MoE expert input to have different dimensionalities than the base embedding dimension. * A `dense_init_scale` config to allow configuring the initialization scale for dense layers across all models (replacing the hardcoded 1.0). PiperOrigin-RevId: 896749554
1 parent e06745f commit 9350e8a

16 files changed

Lines changed: 442 additions & 57 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class DecoderBlockType(enum.Enum):
9999
QWEN2 = "qwen2"
100100
QWEN3 = "qwen3"
101101
QWEN3_MOE = "qwen3_moe"
102+
QWEN3_CUSTOM_MOE = "qwen3_custom_moe"
102103
QWEN3_NEXT = "qwen3_next"
103104
GPT3 = "gpt3"
104105
GPT_OSS = "gpt_oss"

src/maxtext/configs/base.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ base_emb_dim: 2048
160160
base_num_query_heads: 16
161161
base_num_kv_heads: 16
162162
base_mlp_dim: 7168
163+
dense_init_scale: 1.0
163164
base_num_decoder_layers: 16
164165
head_dim: 128
166+
attention_output_dim: -1
165167
# Those parameters are only used with global attention for Gemma4.
166168
global_head_dim: 0
167169
global_num_kv_heads: 0
@@ -200,6 +202,8 @@ ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer
200202
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
201203
# a size larger than this then tokens are dropped.
202204
# In general if ragged_buffer_factor > 0, the ragged_buffer_size is balanced_size * ragged_buffer_factor.
205+
moe_expert_input_dim: -1 # feature dimension of the tokens entering the MoE expert blocks.
206+
base_moe_mlp_dim: -1 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
203207
load_balance_loss_weight: 0.0 # weight for the load balance loss
204208
use_random_routing: false # whether to use random routing for debug/test purpose
205209
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
@@ -241,7 +245,6 @@ shard_exp_on_fsdp: False
241245
use_2d_fsdp_sharding: False
242246

243247
# deepseek moe
244-
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
245248
first_num_dense_layers: 0 # number of initial dense layers in the model
246249
shared_experts: 0
247250
routed_scaling_factor: 1.0 # scaling factor for routing scores
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Model config for custom Qwen3-30B-A3B
16+
17+
# Core Architectural Parameters
18+
decoder_block: "qwen3_custom_moe"
19+
base_emb_dim: 2048
20+
base_num_query_heads: 16
21+
base_num_kv_heads: 2
22+
base_num_decoder_layers: 48
23+
head_dim: 256
24+
mlp_activations: ["silu", "linear"]
25+
vocab_size: 151936
26+
normalization_layer_epsilon: 1.0e-6
27+
use_qk_norm: True
28+
attention_output_dim: 768
29+
moe_expert_input_dim: 768
30+
31+
# MoE Specific Parameters
32+
num_experts: 128
33+
num_experts_per_tok: 8
34+
base_moe_mlp_dim: 2048
35+
base_mlp_dim: 2048
36+
norm_topk_prob: true
37+
38+
# RoPE Settings
39+
rope_max_timescale: 10_000_000
40+
41+
# General Model Settings
42+
enable_dropout: False

src/maxtext/configs/types.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class ProfilerType(str, Enum):
256256
"qwen3-480b-a35b",
257257
"qwen3-next-80b-a3b",
258258
"qwen3-omni-30b-a3b",
259+
"qwen3-custom-30b-a3b",
259260
"gpt3-175b",
260261
"gpt3-22b",
261262
"gpt3-6b",
@@ -447,11 +448,16 @@ class ModelArchitecture(BaseModel):
447448
base_num_query_heads: int = Field(16, description="Base number of query heads.")
448449
base_num_kv_heads: int = Field(16, description="Base number of key/value heads.")
449450
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
451+
dense_init_scale: float = Field(1.0, description="Initialization scale for dense layers")
450452
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
451453
head_dim: int = Field(
452454
128,
453455
description="Model query and key head dimension.",
454456
)
457+
attention_output_dim: int = Field(
458+
-1,
459+
description="Override output dimension for attention block if set to a positive value.",
460+
)
455461
global_head_dim: int = Field(
456462
0,
457463
description="Model query and key head dimension for global attention layers.",
@@ -647,6 +653,11 @@ class MoEGeneral(BaseModel):
647653
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
648654
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
649655
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
656+
moe_expert_input_dim: int = Field(
657+
-1,
658+
description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.",
659+
)
660+
base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.")
650661
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
651662
use_custom_sort_vjp: bool = Field(
652663
True,
@@ -737,7 +748,6 @@ class MoEKernels(BaseModel):
737748
class DeepSeekMoE(BaseModel):
738749
"""Configuration specific to DeepSeek-style MoE layers."""
739750

740-
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
741751
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
742752
shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.")
743753
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
@@ -2557,6 +2567,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25572567
f"but got {self.engram_vocab_bases}."
25582568
)
25592569
if self.num_experts > 1:
2570+
if self.moe_mlp_dim <= 0:
2571+
raise ValueError("moe_mlp_dim must be positive for MoE models (num_experts > 1)")
25602572
is_fully_moe = (
25612573
self.interleave_moe_layer_step == 1
25622574
and self.first_num_dense_layers == 0
@@ -2814,4 +2826,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
28142826
else:
28152827
self.constant_bound_config = []
28162828

2829+
if self.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE:
2830+
if self.moe_expert_input_dim != self.attention_output_dim:
2831+
raise ValueError(
2832+
f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) "
2833+
f"must be equal to attention_output_dim ({self.attention_output_dim})"
2834+
)
2835+
28172836
return self

src/maxtext/layers/decoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
olmo3,
5757
qwen2,
5858
qwen3,
59+
qwen3_custom,
5960
simple_layer,
6061
)
6162
from maxtext.multimodal import utils as mm_utils
@@ -476,6 +477,8 @@ def get_decoder_layers(self):
476477
return [qwen3.Qwen3DecoderLayerToLinen]
477478
case DecoderBlockType.QWEN3_MOE:
478479
return [qwen3.Qwen3MoeDecoderLayerToLinen]
480+
case DecoderBlockType.QWEN3_CUSTOM_MOE:
481+
return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen]
479482
case DecoderBlockType.QWEN3_NEXT:
480483
return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen]
481484
case DecoderBlockType.SIMPLE:
@@ -534,6 +537,7 @@ def get_norm_layer(self, num_features: int):
534537
DecoderBlockType.QWEN2,
535538
DecoderBlockType.QWEN3,
536539
DecoderBlockType.QWEN3_MOE,
540+
DecoderBlockType.QWEN3_CUSTOM_MOE,
537541
DecoderBlockType.GPT_OSS,
538542
DecoderBlockType.SIMPLE,
539543
DecoderBlockType.SIMPLE_MLP,

src/maxtext/layers/moe.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ def __init__(
349349
self.quant = quant
350350
self.rngs = rngs
351351

352+
self.moe_expert_input_dim = (
353+
self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim
354+
)
355+
352356
if self.config.shard_exp_on_fsdp:
353357
# special sharding for dsv3
354358
self.wi_kernel_axes = ("embed_moe", None, "mlp_moe")
@@ -374,7 +378,7 @@ def __init__(
374378
self._expert_parallelism_name = "expert"
375379

376380
self.gate = GateLogit(
377-
in_features_shape=self.config.emb_dim,
381+
in_features_shape=self.moe_expert_input_dim,
378382
out_features_shape=self.num_experts,
379383
mesh=self.mesh,
380384
model_name=self.config.model_name,
@@ -400,14 +404,14 @@ def __init__(
400404
# During aqt convert state we delete kernel weight from params to save
401405
# memory. Instead they are retrieved from the tensors stored in the 'aqt'
402406
# collection.
403-
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
404-
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
405-
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
407+
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
408+
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
409+
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
406410
else:
407411
self.wi_0 = nnx.Param(
408412
self.kernel_init(
409413
self.rngs.params(),
410-
(num_experts, self.config.emb_dim, intermediate_dim),
414+
(num_experts, self.moe_expert_input_dim, intermediate_dim),
411415
weight_dtype,
412416
kernel_in_axis,
413417
kernel_out_axis,
@@ -417,7 +421,7 @@ def __init__(
417421
self.wi_1 = nnx.Param(
418422
self.kernel_init(
419423
self.rngs.params(),
420-
(num_experts, self.config.emb_dim, intermediate_dim),
424+
(num_experts, self.moe_expert_input_dim, intermediate_dim),
421425
weight_dtype,
422426
kernel_in_axis,
423427
kernel_out_axis,
@@ -427,7 +431,7 @@ def __init__(
427431
self.wo = nnx.Param(
428432
self.kernel_init(
429433
self.rngs.params(),
430-
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
434+
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
431435
self.weight_dtype,
432436
kernel_in_axis,
433437
kernel_out_axis,
@@ -439,7 +443,7 @@ def __init__(
439443
wi_bias_axes = ("exp", "activation_mlp")
440444
wo_bias_axes = ("exp", "activation_embed")
441445
wi_bias_shape = (self.num_experts, self.intermediate_dim)
442-
wo_bias_shape = (self.num_experts, self.config.emb_dim)
446+
wo_bias_shape = (self.num_experts, self.moe_expert_input_dim)
443447
self.wi_0_bias = nnx.Param(
444448
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
445449
sharding=wi_bias_axes,
@@ -1208,7 +1212,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
12081212
self.config.num_experts_per_tok,
12091213
self.config.ragged_buffer_factor,
12101214
)
1211-
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
1215+
output_shape = jax.lax.empty((buffer_size, self.moe_expert_input_dim), dtype=x.dtype)
12121216

12131217
x = jax.lax.ragged_all_to_all(
12141218
x,
@@ -1345,7 +1349,9 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13451349
)
13461350

13471351
# Sum up the partial outputs across the expert shards.
1348-
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
1352+
output = jnp.reshape(
1353+
output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size())
1354+
)
13491355
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
13501356

13511357
else:
@@ -1356,7 +1362,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13561362
output_shape = jax.lax.empty(
13571363
(
13581364
original_inputs_first_dim,
1359-
self.config.emb_dim // self.get_tensor_parallelism_size(),
1365+
self.moe_expert_input_dim // self.get_tensor_parallelism_size(),
13601366
),
13611367
dtype=intermediate_output.dtype,
13621368
)
@@ -2112,14 +2118,18 @@ def __init__(
21122118
self.dtype = dtype
21132119
self.quant = quant
21142120
self.rngs = rngs
2121+
self.moe_expert_input_dim = (
2122+
self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim
2123+
)
2124+
21152125
# NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
21162126
# existing checkpoints for routed experts.
21172127
self.MoeBlock_0 = RoutedMoE(
21182128
config=self.config,
21192129
num_experts=self.config.num_experts,
21202130
num_experts_per_tok=self.config.num_experts_per_tok,
21212131
mesh=self.mesh,
2122-
kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"),
2132+
kernel_init=self.kernel_init,
21232133
kernel_axes=("embed_moe", None),
21242134
intermediate_dim=self.config.moe_mlp_dim,
21252135
dtype=self.config.dtype,
@@ -2133,9 +2143,10 @@ def __init__(
21332143
)
21342144
self.shared_experts = linears.MlpBlock(
21352145
mesh=self.mesh,
2136-
in_features=self.config.emb_dim,
2146+
in_features=self.moe_expert_input_dim,
21372147
intermediate_dim=self.config.shared_experts * shared_expert_mlp_dim,
21382148
activations=self.config.mlp_activations,
2149+
kernel_init=self.kernel_init,
21392150
intermediate_dropout_rate=self.config.dropout_rate,
21402151
dtype=self.config.dtype,
21412152
weight_dtype=self.config.weight_dtype,

src/maxtext/models/deepseek.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def __init__(
419419
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
420420
config=self.config,
421421
mesh=mesh,
422-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
422+
kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"),
423423
kernel_axes=("embed", None),
424424
dtype=self.config.dtype,
425425
weight_dtype=self.config.weight_dtype,

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.moe_block = moe.RoutedAndSharedMoE(
7171
config=config,
7272
mesh=mesh,
73-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
73+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
7474
kernel_axes=("embed", None),
7575
weight_dtype=config.weight_dtype,
7676
dtype=config.dtype,

src/maxtext/models/gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
num_experts=config.num_experts,
122122
num_experts_per_tok=config.num_experts_per_tok,
123123
mesh=mesh,
124-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
124+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
125125
kernel_axes=("embed", None),
126126
intermediate_dim=config.mlp_dim,
127127
dtype=config.dtype,

src/maxtext/models/llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def __init__(
403403
self.Llama4MoEBlock_0 = RoutedAndSharedMoE(
404404
config=config,
405405
mesh=self.mesh,
406-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
406+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
407407
kernel_axes=("embed", None),
408408
dtype=config.dtype,
409409
weight_dtype=config.weight_dtype,

0 commit comments

Comments
 (0)