Skip to content

Commit 0f94640

Browse files
Fix unit test
1 parent 12531e6 commit 0f94640

3 files changed

Lines changed: 28 additions & 49 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,8 @@ position_id_per_seconds: 25
11261126
subslice_shape: ""
11271127

11281128
# NNX
1129-
enable_nnx: False
1130-
pure_nnx_decoder: False
1129+
enable_nnx: True
1130+
pure_nnx_decoder: True
11311131
pure_nnx: False
11321132

11331133
################################## Qwen3-Next Specific Configs ##################################

src/maxtext/layers/nnx_decoders.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
import inspect
2121
import warnings
22+
import dataclasses
2223
from typing import Any
2324

2425
import jax
@@ -472,13 +473,7 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
472473
def _create_scanned_layers(
473474
self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs
474475
):
475-
"""Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization.
476-
477-
Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
478-
With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
479-
With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
480-
which prevents OOM on memory-constrained devices like TPU v6e-4.
481-
"""
476+
"""Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization."""
482477
scan_axis = self.config.param_scan_axis
483478

484479
# Fork rngs to get per-layer RNG states for scanning
@@ -489,10 +484,6 @@ def _create_scanned_layers(
489484

490485
rngs_graphdef, rngs_state = nnx.split(forked_rngs)
491486

492-
# Create a reference layer to capture the module graph structure (graphdef).
493-
# This layer's params are discarded — only the structure is kept.
494-
# Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the
495-
# graphdef has the same number of RNG state leaves as the scan-created layers.
496487
first_rng_state = jax.tree.map(lambda x: x[0], rngs_state)
497488
ref_rngs = nnx.merge(rngs_graphdef, first_rng_state)
498489
ref_layer = decoder_layer_class(
@@ -501,9 +492,6 @@ def _create_scanned_layers(
501492
layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...)
502493
del ref_layer
503494

504-
# Sequentially create each layer's parameters via jax.lax.scan.
505-
# The scan body is traced once; XLA executes it N times with different RNG keys,
506-
# keeping only one layer's intermediate state alive at a time.
507495
def scan_body(carry, rng_state_slice):
508496
layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice)
509497
layer = decoder_layer_class(
@@ -519,47 +507,40 @@ def scan_body(carry, rng_state_slice):
519507

520508
_, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state)
521509

522-
# jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
523510
if scan_axis != 0:
524511
stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params)
525512

526-
# Add partition metadata that nnx.vmap's transform_metadata would normally set.
527-
# This metadata is read by variable_to_logically_partitioned() in initializers.py
528-
# and by nnx.get_partition_spec() to produce correct sharding specs.
529513
def _add_scan_metadata(state, axis):
530514
def _update_leaf(leaf):
531-
if isinstance(leaf, nnx.VariableState):
532-
metadata = leaf.get_metadata()
533-
metadata[nnx.PARTITION_NAME] = metadata_axis_name
534-
metadata["param_scan_axis"] = axis
535-
536-
# Patch all sharding configurations in metadata so that nnx.get_partition_spec
537-
# returns a 3D spec matching the actual 3D tensor rank, instead of the original 2D.
538-
for key in ["out_sharding", "sharding", "kernel_axes"]:
539-
if key in metadata and metadata[key] is not None:
540-
val = metadata[key]
515+
if hasattr(leaf, "replace") and hasattr(leaf, "value"):
516+
replace_kwargs = {}
517+
if hasattr(leaf, "get_metadata"):
518+
replace_kwargs.update(leaf.get_metadata())
519+
520+
replace_kwargs[nnx.PARTITION_NAME] = metadata_axis_name
521+
replace_kwargs["param_scan_axis"] = axis
522+
523+
for key in ["sharding", "out_sharding", "kernel_axes", "sharding_names"]:
524+
val = getattr(leaf, key, None)
525+
if val is None and key in replace_kwargs:
526+
val = replace_kwargs[key]
527+
528+
if val is not None:
541529
if isinstance(val, str):
542530
val = (val,)
543531
if isinstance(val, tuple):
544-
sharding_list = list(val)
545-
sharding_list.insert(axis, metadata_axis_name)
546-
metadata[key] = tuple(sharding_list)
547-
548-
# Ensure the native 'sharding' property is also updated if it exists separately
549-
replace_kwargs = dict(metadata)
550-
if hasattr(leaf, "sharding") and leaf.sharding is not None:
551-
val = leaf.sharding
552-
if isinstance(val, str):
553-
val = (val,)
554-
if isinstance(val, tuple):
555-
sharding_list = list(val)
556-
sharding_list.insert(axis, metadata_axis_name)
557-
replace_kwargs["sharding"] = tuple(sharding_list)
532+
l = list(val)
533+
# Safely insert the scan axis into the logical axes string
534+
if metadata_axis_name not in l:
535+
insert_idx = min(axis, len(l))
536+
l.insert(insert_idx, metadata_axis_name)
537+
replace_kwargs[key] = tuple(l)
558538

559539
return leaf.replace(**replace_kwargs)
560540
return leaf
561541

562-
return jax.tree.map(_update_leaf, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
542+
# We must use a custom is_leaf to catch the VariableState instances
543+
return jax.tree.map(_update_leaf, state, is_leaf=lambda x: hasattr(x, "replace") and hasattr(x, "value"))
563544

564545
stacked_params = _add_scan_metadata(stacked_params, scan_axis)
565546
stacked_rest = _add_scan_metadata(stacked_rest, 0)
@@ -811,7 +792,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs):
811792
)
812793
elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT:
813794
return functools.partial(
814-
normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs
795+
normalizations.RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs
815796
)
816797
else:
817798
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")

tests/unit/train_compile_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,6 @@ def test_moe_deepseek_pipeline_subset(self):
636636
"pipeline_parallel_layers=56",
637637
"ici_expert_parallelism=16",
638638
"dcn_pipeline_parallelism=8",
639-
"first_num_dense_layers=8",
640-
"base_num_decoder_layers=72",
641639
)
642640
)
643641

@@ -655,7 +653,7 @@ def test_pipeline_subset(self):
655653
"per_device_batch_size=1",
656654
"max_target_length=1024",
657655
"pipeline_parallel_layers=56",
658-
"base_num_decoder_layers=64", # Must be divisible by dcn_pipeline_parallelism=8 in NNX scan path.
656+
"base_num_decoder_layers=61", # Remainder of 5 will fail when sharded incorrectly.
659657
"ici_expert_parallelism=16",
660658
"dcn_pipeline_parallelism=8",
661659
)

0 commit comments

Comments
 (0)