Skip to content

Commit 4bae533

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent d8dd362 commit 4bae533

10 files changed

Lines changed: 1203 additions & 239 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -521,7 +522,7 @@ def load_state_if_possible(
521522
load_parameters_from_path: str,
522523
load_full_state_from_path: str,
523524
checkpoint_storage_concurrent_gb: int,
524-
abstract_unboxed_pre_state: train_state.TrainState,
525+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
525526
enable_single_replica_ckpt_restoring: bool | None = False,
526527
dataset_type: str | None = "tfds",
527528
step: int = -1, # -1 means latest
@@ -625,9 +626,14 @@ def map_to_pspec(data):
625626
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
626627

627628
if load_parameters_from_path != "":
629+
if isinstance(abstract_unboxed_pre_state, nnx.State):
630+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
631+
else:
632+
params = abstract_unboxed_pre_state.params
633+
628634
restored_params = load_params_from_path(
629635
load_parameters_from_path,
630-
abstract_unboxed_pre_state.params,
636+
params,
631637
checkpoint_storage_concurrent_gb,
632638
use_ocdbt=use_ocdbt,
633639
use_zarr3=use_zarr3,

src/maxtext/trainers/pre_train/train.py

Lines changed: 248 additions & 154 deletions
Large diffs are not rendered by default.

src/maxtext/utils/maxtext_utils.py

Lines changed: 176 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,17 @@
2222

2323
from flax import linen as nn
2424
from flax.linen import partitioning as nn_partitioning
25-
from flax.training import train_state
25+
from flax.training.train_state import TrainState
2626

2727
import numpy as np
2828

29-
from jax.experimental import mesh_utils
30-
from jax.experimental.serialize_executable import deserialize_and_load
31-
from jax.sharding import AxisType, Mesh
32-
3329
import jax
3430
import jax.numpy as jnp
31+
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec
32+
from jax.experimental import mesh_utils
33+
from jax.experimental.serialize_executable import deserialize_and_load
3534

3635
import optax
37-
3836
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3937
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
4038

@@ -48,6 +46,7 @@
4846
from maxtext.utils import max_logging
4947
from maxtext.utils import max_utils
5048
from maxtext.utils import sharding
49+
from maxtext.utils import maxtext_utils_nnx
5150

5251
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
5352

@@ -994,15 +993,15 @@ def _apply_update(path, param):
994993
return state.replace(params=new_params)
995994

996995

997-
def init_decode_state(apply_fn, params) -> train_state.TrainState:
996+
def init_decode_state(apply_fn, params) -> TrainState:
998997
"""Init train state with null opt state for decode."""
999-
state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore
998+
state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore
1000999
return state
10011000

10021001

10031002
def init_training_state(apply_fn, params, tx):
10041003
"""Init train state with null opt state for decode."""
1005-
state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx)
1004+
state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx)
10061005
return state
10071006

10081007

@@ -1124,7 +1123,7 @@ def setup_initial_state(
11241123
is_training: True to initialize training state, False for decode state
11251124
11261125
Returns:
1127-
state: the initialized train state
1126+
train_state: the initialized train state. For NNX, this is a TrainStateNNX instance
11281127
state_mesh_annotations: the mesh annotations for the train state
11291128
"""
11301129

@@ -1163,19 +1162,32 @@ def setup_initial_state(
11631162
else:
11641163
# The update of data_iterator state happens in place, no need to assign explicitly
11651164
state = restored["items"]
1165+
1166+
# TODO: For NNX, convert the pure dict to nnx.State.
11661167
else:
11671168
init_state_partial = init_state_fn
11681169
init_state_partial.__name__ = "initialize_state"
1169-
# pylint: disable=not-callable
1170-
state = jax.jit(
1171-
init_state_partial,
1172-
in_shardings=None,
1173-
out_shardings=state_mesh_shardings,
1174-
)()
1170+
if config.pure_nnx:
1171+
state = jax.jit(
1172+
lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure
1173+
in_shardings=None,
1174+
out_shardings=state_mesh_shardings,
1175+
)()
1176+
else:
1177+
# pylint: disable=not-callable
1178+
state = jax.jit(
1179+
init_state_partial,
1180+
in_shardings=None,
1181+
out_shardings=state_mesh_shardings,
1182+
)()
11751183
if raw_params: # If we loaded a partial state, we need to merge it.
1176-
state = state.replace(params=raw_params)
1177-
1178-
state = max_utils.unbox_logicallypartioned(state)
1184+
if config.pure_nnx:
1185+
# raw_params should have the same sharding info as in the model
1186+
nnx.update(state.model, raw_params)
1187+
else:
1188+
state = state.replace(params=raw_params)
1189+
if not config.pure_nnx:
1190+
state = max_utils.unbox_logicallypartioned(state)
11791191

11801192
return state, state_mesh_annotations, state_mesh_shardings, data_iterator
11811193

@@ -1191,6 +1203,9 @@ def get_logical_annotations(config, mesh, init_state_fn):
11911203

11921204
def get_abstract_state(config, mesh, init_state_fn, is_training=True):
11931205
"""Get a shaped abstraction of the state (including optimizer)"""
1206+
if config.pure_nnx:
1207+
return get_abstract_state_nnx(config, mesh, init_state_fn, is_training)
1208+
11941209
init_state_partial = init_state_fn
11951210

11961211
with nn_partitioning.axis_rules(config.logical_axis_rules):
@@ -1234,6 +1249,148 @@ def move(path, x):
12341249
)
12351250

12361251

1252+
def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State:
1253+
"""Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.
1254+
1255+
Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
1256+
inserts the partition_name axis at the correct scan_axis position for parameters created by
1257+
_create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
1258+
3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.
1259+
1260+
Args:
1261+
abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
1262+
mesh: JAX physical mesh.
1263+
1264+
Returns:
1265+
Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding.
1266+
"""
1267+
1268+
def _make_named_sharding(v):
1269+
val = v.get_value()
1270+
if not hasattr(val, "shape"):
1271+
# Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
1272+
# as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
1273+
return v
1274+
metadata = v.get_metadata()
1275+
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
1276+
if not out_sharding:
1277+
pspec = PartitionSpec()
1278+
else:
1279+
# Insert the scan axis for parameters created by _create_scanned_layers.
1280+
# _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the
1281+
# axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these.
1282+
if nnx.PARTITION_NAME in metadata:
1283+
partition_name = metadata[nnx.PARTITION_NAME]
1284+
# Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits
1285+
# param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode
1286+
# scan_axis=0 for non-Param types. stacked_rest non-Param variables have
1287+
# param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct.
1288+
scan_axis = metadata.get("param_scan_axis", 0)
1289+
out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
1290+
# Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames
1291+
# 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted
1292+
# the scan axis. Only insert if not already present.
1293+
if partition_name not in out_sharding:
1294+
out_sharding.insert(scan_axis, partition_name)
1295+
out_sharding = tuple(out_sharding)
1296+
# Convert logical axis names to physical mesh axes using current context rules.
1297+
context_rules = get_logical_axis_rules()
1298+
local_rules = metadata.get("sharding_rules", ())
1299+
if context_rules or local_rules:
1300+
rules = composite_rules(context_rules, local_rules)
1301+
pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules))
1302+
else:
1303+
pspec = PartitionSpec(*out_sharding)
1304+
return v.replace(NamedSharding(mesh, pspec))
1305+
1306+
return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable))
1307+
1308+
1309+
def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True):
1310+
"""Calculates the abstract sharded state and memory placement for an NNX TrainState.
1311+
1312+
This function performs an abstract trace of the NNX model and optimizer using
1313+
`nnx.get_abstract_model`. It resolves logical sharding annotations into physical
1314+
JAX shardings and applies memory placement optimizations such as optimizer
1315+
sharding and host memory offloading (pinning to CPU RAM).
1316+
1317+
Args:
1318+
config: Configuration object containing sharding and offloading hyperparameters
1319+
(e.g., shard_optimizer_over_data, optimizer_memory_host_offload).
1320+
mesh: JAX physical mesh used to resolve logical axis names to physical devices.
1321+
nnx_init_trainstate_fn: A zero-argument factory function that produces a
1322+
TrainStateNNX instance during the abstract trace.
1323+
is_training: Boolean indicating if the state is for training. If True,
1324+
optimizer state is processed and memory offloading strategies are applied.
1325+
1326+
Returns:
1327+
A tuple containing (abstract_sharded_state, None, state_mesh_shardings):
1328+
abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with
1329+
fully resolved physical sharding and memory_kind metadata.
1330+
state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec
1331+
objects corresponding to each parameter/variable.
1332+
state_mesh_shardings: An nnx.State tree consisting of the raw JAX
1333+
Sharding objects corresponding to each parameter/variable.
1334+
"""
1335+
assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given."
1336+
1337+
with nn_partitioning.axis_rules(config.logical_axis_rules):
1338+
# Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply
1339+
# get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers
1340+
# axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally
1341+
# which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers,
1342+
# causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor.
1343+
# Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls
1344+
# var.shape for every variable when a global mesh is active, but masked optimizer
1345+
# state variables (e.g. from trainable_parameters_mask) have value=MaskedNode()
1346+
# which has no .shape and would raise AttributeError. We handle sharding
1347+
# ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not
1348+
# needed here.
1349+
abs_model = nnx.eval_shape(nnx_init_trainstate_fn)
1350+
_, abs_var_state = nnx.split(abs_model)
1351+
named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh)
1352+
abstract_state = jax.tree.map(
1353+
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
1354+
abs_var_state,
1355+
named_sharding_state,
1356+
)
1357+
1358+
state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
1359+
1360+
if is_training and config.shard_optimizer_over_data:
1361+
# Add data to sharding for optimizer state
1362+
optimizer_sharding = jax.tree_util.tree_map_with_path(
1363+
functools.partial(sharding.add_data_to_sharding, mesh),
1364+
abstract_state.optimizer,
1365+
state_mesh_shardings.optimizer,
1366+
)
1367+
state_mesh_shardings.optimizer = optimizer_sharding
1368+
if is_training and config.optimizer_memory_host_offload:
1369+
optimizer_sharding = jax.tree_util.tree_map_with_path(
1370+
maxtext_utils_nnx.move_memory_to_host,
1371+
state_mesh_shardings.optimizer,
1372+
is_leaf=lambda x: isinstance(x, NamedSharding),
1373+
)
1374+
state_mesh_shardings.optimizer = optimizer_sharding
1375+
if is_training and config.parameter_memory_host_offload:
1376+
assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading."
1377+
_, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...)
1378+
state_params = jax.tree_util.tree_map_with_path(
1379+
maxtext_utils_nnx.move_memory_to_host,
1380+
state_params,
1381+
is_leaf=lambda x: isinstance(x, NamedSharding),
1382+
)
1383+
nnx.update(state_mesh_shardings, state_params)
1384+
1385+
abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings)
1386+
state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
1387+
return (
1388+
abstract_sharded_state,
1389+
state_mesh_annotations,
1390+
state_mesh_shardings,
1391+
)
1392+
1393+
12371394
def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None):
12381395
"""Get a shaped abstraction of the state (including optimizer)"""
12391396

0 commit comments

Comments
 (0)