Skip to content

Commit d26d2d7

Browse files
Merge pull request #3649 from AI-Hypercomputer:shuningjin-override
PiperOrigin-RevId: 899661073
2 parents 84e02f1 + e23d7e3 commit d26d2d7

6 files changed

Lines changed: 51 additions & 12 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
run_name: ""
2222

2323
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
24-
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
24+
override_model_config: False # When set to true allows overriding model parameters via CLI (or kwargs or env vars) for the purpose of debugging/testing.
2525
override_logical_axis_rules: False # When set overrides logical axis rules instead of merging them.
2626
debug:
2727
rl: False # RL-specific debugging

src/maxtext/configs/pyconfig.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ def yaml_key_to_env_key(s: str) -> str:
9393
return _MAX_PREFIX + s.upper()
9494

9595

96+
def validate_no_keys_overridden_twice(keys1: list[str], keys2: list[str]):
97+
overridden_keys = [k for k in keys1 if k in keys2]
98+
if overridden_keys:
99+
raise ValueError(
100+
f"Keys {overridden_keys} are overridden by both model config and CLI/kwargs."
101+
"This is not allowed, unless setting `override_model_config=True`."
102+
)
103+
104+
96105
def resolve_config_path(param: str) -> str:
97106
"""Resolve config path to auto rewrite to use new src folder."""
98107
if os.path.isfile(param):
@@ -330,6 +339,8 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
330339
model_cfg = {k: v for k, v in model_loaded_cfg.items() if k not in overrides_cfg}
331340
else:
332341
model_cfg = model_loaded_cfg
342+
# Validate that no keys are overridden by both model config and CLI/kwargs
343+
validate_no_keys_overridden_twice(model_loaded_cfg.keys(), overrides_cfg.keys())
333344
else:
334345
logger.warning("Model config for '%s' not found at %s", model_name, model_config_path)
335346

@@ -368,10 +379,17 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
368379
for k in tuple(raw_keys_dict.keys()):
369380
env_key = yaml_key_to_env_key(k)
370381
if env_key in os.environ:
382+
# Validate that no keys are overridden by both CLI/kwargs and environment variable
371383
if k in cli_keys or k in kwargs_keys:
372384
raise ValueError(
373385
f"Key '{k}' is overridden by both CLI/kwargs and environment variable '{env_key}'. This is not allowed."
374386
)
387+
# Validate that no keys are overridden by both model config and environment variable
388+
if not temp_cfg.get("override_model_config") and k in model_cfg.keys():
389+
raise ValueError(
390+
f"Key '{k}' is overridden by both model config and environment variable '{env_key}'."
391+
"This is not allowed, unless setting `override_model_config=True`."
392+
)
375393

376394
new_proposal = os.environ.get(env_key)
377395
original_value = raw_keys_dict.get(k)

tests/unit/mhc_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,15 @@ def setUp(self):
101101
per_device_batch_size=4,
102102
max_target_length=7,
103103
max_prefill_predict_length=7,
104+
attention="dot_product",
105+
routed_bias_update_rate=0.01,
106+
load_balance_loss_weight=0.02,
107+
# override
108+
override_model_config=True,
104109
base_emb_dim=self.dim,
105110
mhc_expansion_rate=3,
106111
num_experts=4,
107112
num_experts_per_tok=2,
108-
attention="dot_product",
109-
routed_bias_update_rate=0.01,
110-
load_balance_loss_weight=0.02,
111113
engram_layers=[],
112114
)
113115
devices_array = maxtext_utils.create_device_mesh(self.config)

tests/unit/nnx_decoders_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _make_config(**overrides):
7070
**_BASE_CONFIG,
7171
**extra_args,
7272
**overrides,
73+
override_model_config=True,
7374
)
7475

7576

tests/unit/pyconfig_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,19 @@ def test_overriding_model(self):
8383
base_emb_dim=1024, # Defined as 3072 in gemma-7b
8484
)
8585

86-
self.assertEqual(config.base_emb_dim, 1024)
87-
self.assertEqual(config.base_mlp_dim, 24576)
86+
self.assertEqual(config.base_emb_dim, 1024) # override
87+
self.assertEqual(config.base_mlp_dim, 24576) # unchanged
88+
89+
def test_overriding_model_raises_error(self):
90+
"""Test that overriding a model config with override_model_config=False raises an error."""
91+
with self.assertRaises(ValueError):
92+
pyconfig.initialize(
93+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
94+
skip_jax_distributed_system=True,
95+
model_name="gemma-7b",
96+
override_model_config=False,
97+
base_emb_dim=1024, # Defined as 3072 in gemma-7b
98+
)
8899

89100
def test_overriding_model_in_sft(self):
90101
# TODO: Update MAXTEXT_PKG_DIR after repo restructuring is complete.
@@ -93,10 +104,11 @@ def test_overriding_model_in_sft(self):
93104
skip_jax_distributed_system=True,
94105
model_name="llama3.1-8b",
95106
override_model_config=True,
107+
base_emb_dim=1024, # Defined as 4096 in llama3.1-8b
96108
)
97109

98-
self.assertEqual(config.base_emb_dim, 4096)
99-
self.assertEqual(config.base_mlp_dim, 14336)
110+
self.assertEqual(config.base_emb_dim, 1024) # override
111+
self.assertEqual(config.base_mlp_dim, 14336) # unchanged
100112

101113
def test_resolve_config_path(self):
102114
self.assertEqual(resolve_config_path("foo"), os.path.join("src", "foo"))
@@ -121,7 +133,7 @@ def test_config_file_mapping(self):
121133
self.assertTrue(os.path.isfile(full_path), f"Default config for '{module}' not found at {full_path}")
122134

123135
def test_module_from_path(self):
124-
import maxtext.trainers.pre_train.train as train_module
136+
import maxtext.trainers.pre_train.train as train_module # pylint: disable=import-outside-toplevel
125137

126138
module_file = train_module.__file__
127139
result = _module_from_path(module_file)

tests/unit/train_compile_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ def test_indexer_dense_warmup(self):
817817
"max_target_length=1024",
818818
"attention=flash",
819819
"use_tokamax_splash=True",
820+
# override
821+
"override_model_config=True",
820822
"engram_layers=[]",
821823
# dense warmup
822824
"indexer_sparse_training=False",
@@ -842,6 +844,8 @@ def test_indexer_sparse_training(self):
842844
"max_target_length=1024",
843845
"attention=flash",
844846
"use_tokamax_splash=True",
847+
# override
848+
"override_model_config=True",
845849
"engram_layers=[]",
846850
# sparse training
847851
"indexer_sparse_training=True",
@@ -869,7 +873,7 @@ def test_olmo3_7b(self):
869873

870874
@pytest.mark.cpu_only
871875
def test_mhc_integration(self):
872-
"""AOT test for Manifold-onstrained Hyper Connection implementation"""
876+
"""AOT test for Manifold-constrained Hyper Connection implementation"""
873877
compiled_trainstep_file = "/tmp/test_mhc_integration"
874878
train_compile_main(
875879
(
@@ -881,10 +885,12 @@ def test_mhc_integration(self):
881885
"model_name=deepseek-custom",
882886
"per_device_batch_size=4",
883887
"scan_layers=True",
884-
"max_target_length=1024",
885-
"mhc_expansion_rate=4",
886888
"attention=flash",
887889
"use_tokamax_splash=True",
890+
"max_target_length=1024",
891+
# override
892+
"override_model_config=True",
893+
"mhc_expansion_rate=4",
888894
"engram_layers=[]",
889895
)
890896
)

0 commit comments

Comments
 (0)