2222
2323from flax import linen as nn
2424from flax .linen import partitioning as nn_partitioning
25- from flax .training import train_state
25+ from flax .training . train_state import TrainState
2626
2727import numpy as np
2828
29- from jax .experimental import mesh_utils
30- from jax .experimental .serialize_executable import deserialize_and_load
31- from jax .sharding import AxisType , Mesh
32-
3329import jax
3430import jax .numpy as jnp
31+ from jax .sharding import AxisType , Mesh , NamedSharding , PartitionSpec
32+ from jax .experimental import mesh_utils
33+ from jax .experimental .serialize_executable import deserialize_and_load
3534
3635import optax
37-
3836import orbax .checkpoint .experimental .emergency .checkpoint_manager as emergency_checkpoint_manager
3937import orbax .checkpoint .experimental .emergency .replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
4038
4846from maxtext .utils import max_logging
4947from maxtext .utils import max_utils
5048from maxtext .utils import sharding
49+ from maxtext .utils import maxtext_utils_nnx
5150
5251OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
5352
@@ -994,15 +993,15 @@ def _apply_update(path, param):
994993 return state .replace (params = new_params )
995994
996995
997- def init_decode_state (apply_fn , params ) -> train_state . TrainState :
996+ def init_decode_state (apply_fn , params ) -> TrainState :
998997 """Init train state with null opt state for decode."""
999- state = train_state . TrainState (step = 0 , apply_fn = apply_fn , params = params , tx = None , opt_state = {}) # type: ignore
998+ state = TrainState (step = 0 , apply_fn = apply_fn , params = params , tx = None , opt_state = {}) # type: ignore
1000999 return state
10011000
10021001
10031002def init_training_state (apply_fn , params , tx ):
10041003 """Init train state with null opt state for decode."""
1005- state = train_state . TrainState .create (apply_fn = apply_fn , params = params , tx = tx )
1004+ state = TrainState .create (apply_fn = apply_fn , params = params , tx = tx )
10061005 return state
10071006
10081007
@@ -1124,7 +1123,7 @@ def setup_initial_state(
11241123 is_training: True to initialize training state, False for decode state
11251124
11261125 Returns:
1127- state : the initialized train state
1126+ train_state : the initialized train state. For NNX, this is a TrainStateNNX instance
11281127 state_mesh_annotations: the mesh annotations for the train state
11291128 """
11301129
@@ -1163,19 +1162,32 @@ def setup_initial_state(
11631162 else :
11641163 # The update of data_iterator state happens in place, no need to assign explicitly
11651164 state = restored ["items" ]
1165+
1166+ # TODO: For NNX, convert the pure dict to nnx.State.
11661167 else :
11671168 init_state_partial = init_state_fn
11681169 init_state_partial .__name__ = "initialize_state"
1169- # pylint: disable=not-callable
1170- state = jax .jit (
1171- init_state_partial ,
1172- in_shardings = None ,
1173- out_shardings = state_mesh_shardings ,
1174- )()
1170+ if config .pure_nnx :
1171+ state = jax .jit (
1172+ lambda : nnx .state (init_state_partial ()), # Get state only, mapping to out_sharding structure
1173+ in_shardings = None ,
1174+ out_shardings = state_mesh_shardings ,
1175+ )()
1176+ else :
1177+ # pylint: disable=not-callable
1178+ state = jax .jit (
1179+ init_state_partial ,
1180+ in_shardings = None ,
1181+ out_shardings = state_mesh_shardings ,
1182+ )()
11751183 if raw_params : # If we loaded a partial state, we need to merge it.
1176- state = state .replace (params = raw_params )
1177-
1178- state = max_utils .unbox_logicallypartioned (state )
1184+ if config .pure_nnx :
1185+ # raw_params should have the same sharding info as in the model
1186+ nnx .update (state .model , raw_params )
1187+ else :
1188+ state = state .replace (params = raw_params )
1189+ if not config .pure_nnx :
1190+ state = max_utils .unbox_logicallypartioned (state )
11791191
11801192 return state , state_mesh_annotations , state_mesh_shardings , data_iterator
11811193
@@ -1191,6 +1203,9 @@ def get_logical_annotations(config, mesh, init_state_fn):
11911203
11921204def get_abstract_state (config , mesh , init_state_fn , is_training = True ):
11931205 """Get a shaped abstraction of the state (including optimizer)"""
1206+ if config .pure_nnx :
1207+ return get_abstract_state_nnx (config , mesh , init_state_fn , is_training )
1208+
11941209 init_state_partial = init_state_fn
11951210
11961211 with nn_partitioning .axis_rules (config .logical_axis_rules ):
@@ -1234,6 +1249,148 @@ def move(path, x):
12341249 )
12351250
12361251
1252+ def get_nnx_named_sharding_with_scan_axis (abs_var_state : nnx .State , mesh ) -> nnx .State :
1253+ """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.
1254+
1255+ Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
1256+ inserts the partition_name axis at the correct scan_axis position for parameters created by
1257+ _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
1258+ 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.
1259+
1260+ Args:
1261+ abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
1262+ mesh: JAX physical mesh.
1263+
1264+ Returns:
1265+ Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding.
1266+ """
1267+
1268+ def _make_named_sharding (v ):
1269+ val = v .get_value ()
1270+ if not hasattr (val , "shape" ):
1271+ # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
1272+ # as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
1273+ return v
1274+ metadata = v .get_metadata ()
1275+ out_sharding = metadata .get ("out_sharding" ) or metadata .get ("sharding_names" ) or metadata .get ("sharding" )
1276+ if not out_sharding :
1277+ pspec = PartitionSpec ()
1278+ else :
1279+ # Insert the scan axis for parameters created by _create_scanned_layers.
1280+ # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the
1281+ # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these.
1282+ if nnx .PARTITION_NAME in metadata :
1283+ partition_name = metadata [nnx .PARTITION_NAME ]
1284+ # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits
1285+ # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode
1286+ # scan_axis=0 for non-Param types. stacked_rest non-Param variables have
1287+ # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct.
1288+ scan_axis = metadata .get ("param_scan_axis" , 0 )
1289+ out_sharding = [out_sharding ] if isinstance (out_sharding , str ) else list (out_sharding )
1290+ # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames
1291+ # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted
1292+ # the scan axis. Only insert if not already present.
1293+ if partition_name not in out_sharding :
1294+ out_sharding .insert (scan_axis , partition_name )
1295+ out_sharding = tuple (out_sharding )
1296+ # Convert logical axis names to physical mesh axes using current context rules.
1297+ context_rules = get_logical_axis_rules ()
1298+ local_rules = metadata .get ("sharding_rules" , ())
1299+ if context_rules or local_rules :
1300+ rules = composite_rules (context_rules , local_rules )
1301+ pspec = PartitionSpec (* from_sharding_rules (out_sharding , rules ))
1302+ else :
1303+ pspec = PartitionSpec (* out_sharding )
1304+ return v .replace (NamedSharding (mesh , pspec ))
1305+
1306+ return jax .tree .map (_make_named_sharding , abs_var_state , is_leaf = lambda x : isinstance (x , nnx .Variable ))
1307+
1308+
1309+ def get_abstract_state_nnx (config , mesh , nnx_init_trainstate_fn , is_training = True ):
1310+ """Calculates the abstract sharded state and memory placement for an NNX TrainState.
1311+
1312+ This function performs an abstract trace of the NNX model and optimizer using
1313+ `nnx.get_abstract_model`. It resolves logical sharding annotations into physical
1314+ JAX shardings and applies memory placement optimizations such as optimizer
1315+ sharding and host memory offloading (pinning to CPU RAM).
1316+
1317+ Args:
1318+ config: Configuration object containing sharding and offloading hyperparameters
1319+ (e.g., shard_optimizer_over_data, optimizer_memory_host_offload).
1320+ mesh: JAX physical mesh used to resolve logical axis names to physical devices.
1321+ nnx_init_trainstate_fn: A zero-argument factory function that produces a
1322+ TrainStateNNX instance during the abstract trace.
1323+ is_training: Boolean indicating if the state is for training. If True,
1324+ optimizer state is processed and memory offloading strategies are applied.
1325+
1326+ Returns:
1327+ A tuple containing (abstract_sharded_state, None, state_mesh_shardings):
1328+ abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with
1329+ fully resolved physical sharding and memory_kind metadata.
1330+ state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec
1331+ objects corresponding to each parameter/variable.
1332+ state_mesh_shardings: An nnx.State tree consisting of the raw JAX
1333+ Sharding objects corresponding to each parameter/variable.
1334+ """
1335+ assert nnx_init_trainstate_fn is not None , "get_abstract_state_nnx: init function must be given."
1336+
1337+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
1338+ # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply
1339+ # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers
1340+ # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally
1341+ # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers,
1342+ # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor.
1343+ # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls
1344+ # var.shape for every variable when a global mesh is active, but masked optimizer
1345+ # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode()
1346+ # which has no .shape and would raise AttributeError. We handle sharding
1347+ # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not
1348+ # needed here.
1349+ abs_model = nnx .eval_shape (nnx_init_trainstate_fn )
1350+ _ , abs_var_state = nnx .split (abs_model )
1351+ named_sharding_state = get_nnx_named_sharding_with_scan_axis (abs_var_state , mesh )
1352+ abstract_state = jax .tree .map (
1353+ lambda a , s : jax .ShapeDtypeStruct (a .shape , a .dtype , sharding = s ),
1354+ abs_var_state ,
1355+ named_sharding_state ,
1356+ )
1357+
1358+ state_mesh_shardings = maxtext_utils_nnx .get_named_sharding_nnx (abstract_state )
1359+
1360+ if is_training and config .shard_optimizer_over_data :
1361+ # Add data to sharding for optimizer state
1362+ optimizer_sharding = jax .tree_util .tree_map_with_path (
1363+ functools .partial (sharding .add_data_to_sharding , mesh ),
1364+ abstract_state .optimizer ,
1365+ state_mesh_shardings .optimizer ,
1366+ )
1367+ state_mesh_shardings .optimizer = optimizer_sharding
1368+ if is_training and config .optimizer_memory_host_offload :
1369+ optimizer_sharding = jax .tree_util .tree_map_with_path (
1370+ maxtext_utils_nnx .move_memory_to_host ,
1371+ state_mesh_shardings .optimizer ,
1372+ is_leaf = lambda x : isinstance (x , NamedSharding ),
1373+ )
1374+ state_mesh_shardings .optimizer = optimizer_sharding
1375+ if is_training and config .parameter_memory_host_offload :
1376+ assert config .param_scan_axis == 0 , "You must set the scan axis 0 to enable parameter offloading."
1377+ _ , state_params , _ = nnx .split (state_mesh_shardings , nnx .Param , ...)
1378+ state_params = jax .tree_util .tree_map_with_path (
1379+ maxtext_utils_nnx .move_memory_to_host ,
1380+ state_params ,
1381+ is_leaf = lambda x : isinstance (x , NamedSharding ),
1382+ )
1383+ nnx .update (state_mesh_shardings , state_params )
1384+
1385+ abstract_sharded_state = maxtext_utils_nnx .set_named_sharding_nnx (abstract_state , state_mesh_shardings )
1386+ state_mesh_annotations = maxtext_utils_nnx .get_partition_spec_nnx (state_mesh_shardings )
1387+ return (
1388+ abstract_sharded_state ,
1389+ state_mesh_annotations ,
1390+ state_mesh_shardings ,
1391+ )
1392+
1393+
12371394def get_prefill_kv_cache_annotations (model , config , rng , mesh , page_state : None | PageState = None ):
12381395 """Get a shaped abstraction of the state (including optimizer)"""
12391396
0 commit comments