Skip to content

Commit 12531e6

Browse files
Implement and update the following models in NNX decoder: DeepSeek/Gemma3/Llama4
1 parent b37fee0 commit 12531e6

4 files changed

Lines changed: 433 additions & 211 deletions

File tree

src/maxtext/layers/initializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9494
out_sharding = metadata["sharding"]
9595

9696
if out_sharding is not None:
97+
if nnx.PARTITION_NAME in metadata:
98+
partition_name = metadata[nnx.PARTITION_NAME]
99+
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
100+
101+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
102+
if partition_name not in sharding_list:
103+
sharding_list.insert(scan_axis, partition_name)
104+
105+
out_sharding = tuple(sharding_list)
106+
97107
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
98108
variable.value,
99109
out_sharding, # type: ignore[arg-type]

0 commit comments

Comments
 (0)