Skip to content

Commit 86b6433

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
Custom Qwen 30B-A3B
PiperOrigin-RevId: 896749554
1 parent 38112ca commit 86b6433

14 files changed

Lines changed: 368 additions & 24 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class DecoderBlockType(enum.Enum):
106106
QWEN2 = "qwen2"
107107
QWEN3 = "qwen3"
108108
QWEN3_MOE = "qwen3_moe"
109+
QWEN3_CUSTOM_MOE = "qwen3_custom_moe"
109110
QWEN3_NEXT = "qwen3_next"
110111
GPT3 = "gpt3"
111112
GPT_OSS = "gpt_oss"

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ base_emb_dim: 2048
160160
base_num_query_heads: 16
161161
base_num_kv_heads: 16
162162
base_mlp_dim: 7168
163+
dense_init_scale: 1.0
163164
base_num_decoder_layers: 16
164165
head_dim: 128
166+
attention_output_dim: -1
165167
# Those parameters are only used with global attention for Gemma4.
166168
global_head_dim: 0
167169
global_num_kv_heads: 0
@@ -195,6 +197,7 @@ num_experts_per_tok: 1
195197
megablox: true
196198
sparse_matmul: true
197199
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
200+
moe_expert_input_dim: -1 # feature dimension of the tokens entering the MoE expert blocks.
198201
load_balance_loss_weight: 0.0 # weight for the load balance loss
199202
use_random_routing: false # whether to use random routing for debug/test purpose
200203
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Model config for custom Qwen3-30B-A3B
16+
17+
# Core Architectural Parameters
18+
decoder_block: "qwen3_custom_moe"
19+
base_emb_dim: 2048
20+
base_mlp_dim: 2048
21+
base_num_query_heads: 16
22+
base_num_kv_heads: 2
23+
base_num_decoder_layers: 48
24+
head_dim: 256
25+
mlp_activations: ["silu", "linear"]
26+
vocab_size: 151936
27+
normalization_layer_epsilon: 1.0e-6
28+
use_qk_norm: True
29+
attention_output_dim: 768
30+
moe_expert_input_dim: 768
31+
32+
# MoE Specific Parameters
33+
num_experts: 128
34+
num_experts_per_tok: 8
35+
base_moe_mlp_dim: 2048
36+
norm_topk_prob: true
37+
38+
# RoPE Settings
39+
rope_max_timescale: 10_000_000
40+
41+
# General Model Settings
42+
enable_dropout: False

src/maxtext/configs/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class ProfilerType(str, Enum):
256256
"qwen3-480b-a35b",
257257
"qwen3-next-80b-a3b",
258258
"qwen3-omni-30b-a3b",
259+
"qwen3-custom-30b-a3b",
259260
"gpt3-175b",
260261
"gpt3-22b",
261262
"gpt3-6b",
@@ -267,6 +268,7 @@ class ProfilerType(str, Enum):
267268
"olmo3-7b",
268269
"olmo3-7b-pt",
269270
"olmo3-32b",
271+
"qwen3-custom-moe",
270272
]
271273

272274

@@ -447,11 +449,13 @@ class ModelArchitecture(BaseModel):
447449
base_num_query_heads: int = Field(16, description="Base number of query heads.")
448450
base_num_kv_heads: int = Field(16, description="Base number of key/value heads.")
449451
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
452+
dense_init_scale: float = Field(1.0, description="Initialization scale for dense layers")
450453
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
451454
head_dim: int = Field(
452455
128,
453456
description="Model query and key head dimension.",
454457
)
458+
attention_output_dim: int = Field(-1, description="Override output dimension for attention block if set to a positive value.")
455459
global_head_dim: int = Field(
456460
0,
457461
description="Model query and key head dimension for global attention layers.",
@@ -646,6 +650,7 @@ class MoEGeneral(BaseModel):
646650
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
647651
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
648652
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
653+
moe_expert_input_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.")
649654
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
650655
use_custom_sort_vjp: bool = Field(
651656
True,
@@ -2802,4 +2807,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
28022807
else:
28032808
self.constant_bound_config = []
28042809

2810+
if self.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE:
2811+
if self.moe_expert_input_dim != self.attention_output_dim:
2812+
raise ValueError(
2813+
f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) "
2814+
f"must be equal to attention_output_dim ({self.attention_output_dim})"
2815+
)
2816+
28052817
return self

src/maxtext/layers/decoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
olmo3,
5656
qwen2,
5757
qwen3,
58+
qwen3_custom,
5859
simple_layer,
5960
)
6061
from maxtext.multimodal import utils as mm_utils
@@ -475,6 +476,8 @@ def get_decoder_layers(self):
475476
return [qwen3.Qwen3DecoderLayerToLinen]
476477
case DecoderBlockType.QWEN3_MOE:
477478
return [qwen3.Qwen3MoeDecoderLayerToLinen]
479+
case DecoderBlockType.QWEN3_CUSTOM_MOE:
480+
return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen]
478481
case DecoderBlockType.QWEN3_NEXT:
479482
return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen]
480483
case DecoderBlockType.SIMPLE:
@@ -533,6 +536,7 @@ def get_norm_layer(self, num_features: int):
533536
DecoderBlockType.QWEN2,
534537
DecoderBlockType.QWEN3,
535538
DecoderBlockType.QWEN3_MOE,
539+
DecoderBlockType.QWEN3_CUSTOM_MOE,
536540
DecoderBlockType.GPT_OSS,
537541
DecoderBlockType.SIMPLE,
538542
DecoderBlockType.SIMPLE_MLP,

src/maxtext/layers/moe.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ def __init__(
349349
self.quant = quant
350350
self.rngs = rngs
351351

352+
self.moe_expert_input_dim = getattr(self.config, "moe_expert_input_dim", -1)
353+
if self.moe_expert_input_dim <= 0:
354+
self.moe_expert_input_dim = self.config.emb_dim
355+
352356
if self.config.shard_exp_on_fsdp:
353357
# special sharding for dsv3
354358
self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp")
@@ -374,7 +378,7 @@ def __init__(
374378
self._expert_parallelism_name = "expert"
375379

376380
self.gate = GateLogit(
377-
in_features_shape=self.config.emb_dim,
381+
in_features_shape=self.moe_expert_input_dim,
378382
out_features_shape=self.num_experts,
379383
mesh=self.mesh,
380384
model_name=self.config.model_name,
@@ -400,14 +404,14 @@ def __init__(
400404
# During aqt convert state we delete kernel weight from params to save
401405
# memory. Instead they are retrieved from the tensors stored in the 'aqt'
402406
# collection.
403-
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
404-
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
405-
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
407+
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
408+
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
409+
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
406410
else:
407411
self.wi_0 = nnx.Param(
408412
self.kernel_init(
409413
self.rngs.params(),
410-
(num_experts, self.config.emb_dim, intermediate_dim),
414+
(num_experts, self.moe_expert_input_dim, intermediate_dim),
411415
weight_dtype,
412416
kernel_in_axis,
413417
kernel_out_axis,
@@ -417,7 +421,7 @@ def __init__(
417421
self.wi_1 = nnx.Param(
418422
self.kernel_init(
419423
self.rngs.params(),
420-
(num_experts, self.config.emb_dim, intermediate_dim),
424+
(num_experts, self.moe_expert_input_dim, intermediate_dim),
421425
weight_dtype,
422426
kernel_in_axis,
423427
kernel_out_axis,
@@ -427,7 +431,7 @@ def __init__(
427431
self.wo = nnx.Param(
428432
self.kernel_init(
429433
self.rngs.params(),
430-
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
434+
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
431435
self.weight_dtype,
432436
kernel_in_axis,
433437
kernel_out_axis,
@@ -439,7 +443,7 @@ def __init__(
439443
wi_bias_axes = ("exp", "activation_mlp")
440444
wo_bias_axes = ("exp", "activation_embed_moe")
441445
wi_bias_shape = (self.num_experts, self.intermediate_dim)
442-
wo_bias_shape = (self.num_experts, self.config.emb_dim)
446+
wo_bias_shape = (self.num_experts, self.moe_expert_input_dim)
443447
self.wi_0_bias = nnx.Param(
444448
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
445449
sharding=wi_bias_axes,
@@ -1182,7 +1186,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11821186
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11831187
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
11841188
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1185-
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
1189+
output_shape = jax.lax.empty((buffer_size, self.moe_model_dim), dtype=x.dtype)
11861190

11871191
x = jax.lax.ragged_all_to_all(
11881192
x,
@@ -1337,7 +1341,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13371341
)
13381342

13391343
# Sum up the partial outputs across the expert shards.
1340-
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
1344+
output = jnp.reshape(output, (-1, sequence_length, self.moe_model_dim // self.get_tensor_parallelism_size()))
13411345
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
13421346

13431347
else:
@@ -1348,7 +1352,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13481352
output_shape = jax.lax.empty(
13491353
(
13501354
original_inputs_first_dim,
1351-
self.config.emb_dim // self.get_tensor_parallelism_size(),
1355+
self.moe_model_dim // self.get_tensor_parallelism_size(),
13521356
),
13531357
dtype=intermediate_output.dtype,
13541358
)
@@ -2095,6 +2099,10 @@ def __init__(
20952099
self.dtype = dtype
20962100
self.quant = quant
20972101
self.rngs = rngs
2102+
self.moe_model_dim = getattr(self.config, "moe_model_dim", -1)
2103+
if self.moe_model_dim <= 0:
2104+
self.moe_model_dim = self.config.emb_dim
2105+
20982106
# NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
20992107
# existing checkpoints for routed experts.
21002108
self.MoeBlock_0 = RoutedMoE(
@@ -2116,7 +2124,7 @@ def __init__(
21162124
)
21172125
self.shared_experts = linears.MlpBlock(
21182126
mesh=self.mesh,
2119-
in_features=self.config.emb_dim,
2127+
in_features=self.moe_model_dim,
21202128
intermediate_dim=self.config.shared_experts * shared_expert_mlp_dim,
21212129
activations=self.config.mlp_activations,
21222130
intermediate_dropout_rate=self.config.dropout_rate,

src/maxtext/models/deepseek.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def __init__(
415415
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
416416
config=self.config,
417417
mesh=mesh,
418-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
418+
kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"),
419419
kernel_axes=("embed", None),
420420
dtype=self.config.dtype,
421421
weight_dtype=self.config.weight_dtype,

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.moe_block = moe.RoutedAndSharedMoE(
7171
config=config,
7272
mesh=mesh,
73-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
73+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
7474
kernel_axes=("embed", None),
7575
weight_dtype=config.weight_dtype,
7676
dtype=config.dtype,

src/maxtext/models/gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
num_experts=config.num_experts,
122122
num_experts_per_tok=config.num_experts_per_tok,
123123
mesh=mesh,
124-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
124+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
125125
kernel_axes=("embed", None),
126126
intermediate_dim=config.mlp_dim,
127127
dtype=config.dtype,

src/maxtext/models/llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def __init__(
403403
self.Llama4MoEBlock_0 = RoutedAndSharedMoE(
404404
config=config,
405405
mesh=self.mesh,
406-
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
406+
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
407407
kernel_axes=("embed", None),
408408
dtype=config.dtype,
409409
weight_dtype=config.weight_dtype,

0 commit comments

Comments
 (0)