1919import functools
2020import inspect
2121import warnings
22+ import dataclasses
2223from typing import Any
2324
2425import jax
@@ -472,13 +473,7 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
472473 def _create_scanned_layers (
473474 self , decoder_layer_class , length : int , metadata_axis_name : str , rngs : nnx .Rngs , ** layer_kwargs
474475 ):
475- """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization.
476-
477- Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
478- With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
479- With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
480- which prevents OOM on memory-constrained devices like TPU v6e-4.
481- """
476+ """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization."""
482477 scan_axis = self .config .param_scan_axis
483478
484479 # Fork rngs to get per-layer RNG states for scanning
@@ -489,10 +484,6 @@ def _create_scanned_layers(
489484
490485 rngs_graphdef , rngs_state = nnx .split (forked_rngs )
491486
492- # Create a reference layer to capture the module graph structure (graphdef).
493- # This layer's params are discarded — only the structure is kept.
494- # Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the
495- # graphdef has the same number of RNG state leaves as the scan-created layers.
496487 first_rng_state = jax .tree .map (lambda x : x [0 ], rngs_state )
497488 ref_rngs = nnx .merge (rngs_graphdef , first_rng_state )
498489 ref_layer = decoder_layer_class (
@@ -501,9 +492,6 @@ def _create_scanned_layers(
501492 layer_graphdef , _ , _ = nnx .split (ref_layer , nnx .Param , ...)
502493 del ref_layer
503494
504- # Sequentially create each layer's parameters via jax.lax.scan.
505- # The scan body is traced once; XLA executes it N times with different RNG keys,
506- # keeping only one layer's intermediate state alive at a time.
507495 def scan_body (carry , rng_state_slice ):
508496 layer_rngs = nnx .merge (rngs_graphdef , rng_state_slice )
509497 layer = decoder_layer_class (
@@ -519,47 +507,40 @@ def scan_body(carry, rng_state_slice):
519507
520508 _ , (stacked_params , stacked_rest ) = jax .lax .scan (scan_body , None , rngs_state )
521509
522- # jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
523510 if scan_axis != 0 :
524511 stacked_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), stacked_params )
525512
526- # Add partition metadata that nnx.vmap's transform_metadata would normally set.
527- # This metadata is read by variable_to_logically_partitioned() in initializers.py
528- # and by nnx.get_partition_spec() to produce correct sharding specs.
529513 def _add_scan_metadata (state , axis ):
530514 def _update_leaf (leaf ):
531- if isinstance (leaf , nnx .VariableState ):
532- metadata = leaf .get_metadata ()
533- metadata [nnx .PARTITION_NAME ] = metadata_axis_name
534- metadata ["param_scan_axis" ] = axis
535-
536- # Patch all sharding configurations in metadata so that nnx.get_partition_spec
537- # returns a 3D spec matching the actual 3D tensor rank, instead of the original 2D.
538- for key in ["out_sharding" , "sharding" , "kernel_axes" ]:
539- if key in metadata and metadata [key ] is not None :
540- val = metadata [key ]
515+ if hasattr (leaf , "replace" ) and hasattr (leaf , "value" ):
516+ replace_kwargs = {}
517+ if hasattr (leaf , "get_metadata" ):
518+ replace_kwargs .update (leaf .get_metadata ())
519+
520+ replace_kwargs [nnx .PARTITION_NAME ] = metadata_axis_name
521+ replace_kwargs ["param_scan_axis" ] = axis
522+
523+ for key in ["sharding" , "out_sharding" , "kernel_axes" , "sharding_names" ]:
524+ val = getattr (leaf , key , None )
525+ if val is None and key in replace_kwargs :
526+ val = replace_kwargs [key ]
527+
528+ if val is not None :
541529 if isinstance (val , str ):
542530 val = (val ,)
543531 if isinstance (val , tuple ):
544- sharding_list = list (val )
545- sharding_list .insert (axis , metadata_axis_name )
546- metadata [key ] = tuple (sharding_list )
547-
548- # Ensure the native 'sharding' property is also updated if it exists separately
549- replace_kwargs = dict (metadata )
550- if hasattr (leaf , "sharding" ) and leaf .sharding is not None :
551- val = leaf .sharding
552- if isinstance (val , str ):
553- val = (val ,)
554- if isinstance (val , tuple ):
555- sharding_list = list (val )
556- sharding_list .insert (axis , metadata_axis_name )
557- replace_kwargs ["sharding" ] = tuple (sharding_list )
532+ l = list (val )
533+ # Safely insert the scan axis into the logical axes string
534+ if metadata_axis_name not in l :
535+ insert_idx = min (axis , len (l ))
536+ l .insert (insert_idx , metadata_axis_name )
537+ replace_kwargs [key ] = tuple (l )
558538
559539 return leaf .replace (** replace_kwargs )
560540 return leaf
561541
562- return jax .tree .map (_update_leaf , state , is_leaf = lambda x : isinstance (x , nnx .VariableState ))
542+ # We must use a custom is_leaf to catch the VariableState instances
543+ return jax .tree .map (_update_leaf , state , is_leaf = lambda x : hasattr (x , "replace" ) and hasattr (x , "value" ))
563544
564545 stacked_params = _add_scan_metadata (stacked_params , scan_axis )
565546 stacked_rest = _add_scan_metadata (stacked_rest , 0 )
@@ -811,7 +792,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs):
811792 )
812793 elif self .config .decoder_block == DecoderBlockType .QWEN3_NEXT :
813794 return functools .partial (
814- normalizations .Qwen3NextRMSNorm , num_features = num_features , shard_mode = self .config .shard_mode , rngs = rngs
795+ normalizations .RMSNorm , num_features = num_features , shard_mode = self .config .shard_mode , rngs = rngs
815796 )
816797 else :
817798 raise ValueError (f"Incorrect decoder_block name { self .config .decoder_block .value = } " )
0 commit comments