Skip to content

Commit ade2859

Browse files
committed
adding support for None chat templates.
1 parent 44fc6d0 commit ade2859

3 files changed

Lines changed: 51 additions & 1 deletion

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def get_dataset(
142142
)
143143

144144
template_config = load_template_from_file(tmvp_config.chat_template_path)
145+
if template_config is None:
146+
max_logging.warning(
147+
f"Failed to load chat template from {tmvp_config.chat_template_path}. Proceeding without chat template."
148+
)
145149

146150
loaded_dataset = (
147151
grain.MapDataset.source(data)
@@ -339,6 +343,10 @@ def prepare_openinstructmath2_dataset(
339343
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
340344
splits = prepare_openinstructmath2_dataset(split=split_name)
341345
template_config = load_template_from_file(trainer_config.chat_template_path)
346+
if template_config is None:
347+
max_logging.warning(
348+
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
349+
)
342350

343351
train_dataset = (
344352
grain.MapDataset.source(splits["train"])
@@ -616,6 +624,10 @@ def _reward_fn(**kwargs):
616624
)
617625
# Instantiate the custom MaxText chat parser
618626
template_config = load_template_from_file(trainer_config.chat_template_path)
627+
if template_config is None:
628+
max_logging.warning(
629+
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
630+
)
619631
chat_parser = utils_rl.MaxTextChatParser(
620632
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
621633
)

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,11 @@ def make_optimizer(learning_rate):
526526
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
527527

528528

529-
def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
529+
def format_maxtext_messages(messages: list[str], template_config: dict | None, tmvp_config) -> list[dict[str, str]]:
530530
"""Helper to inject MaxText's system prompt into the input user messages."""
531+
if template_config is None:
532+
return [{"role": "user", "content": msg} for msg in messages]
533+
531534
formatted_messages = []
532535
for msg in messages:
533536
formatted_content = template_config["TEMPLATE"].format(

tests/post_training/unit/rl_utils_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,40 @@ def test_returns_optimizer_with_clipping(self):
370370
self.assertIn("learning_rate", state.hyperparams)
371371

372372

373+
class TestFormatMaxTextMessages(unittest.TestCase):
374+
"""Tests for utils_rl.format_maxtext_messages."""
375+
376+
def setUp(self):
377+
self.config = _make_config()
378+
self.template_config = {
379+
"SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. "
380+
+ "Solution between {solution_start_token} and {solution_end_token}.",
381+
"TEMPLATE": "system: {system_prompt}\nquestion: {question}",
382+
}
383+
384+
@pytest.mark.cpu_only
385+
def test_format_with_template(self):
386+
"""Test formatting when a template is provided."""
387+
messages = ["What is 2+2?"]
388+
formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config)
389+
self.assertEqual(len(formatted), 1)
390+
self.assertEqual(formatted[0]["role"], "user")
391+
expected_content = (
392+
"system: Reason between <reasoning> and </reasoning>. "
393+
"Solution between <answer> and </answer>.\n"
394+
"question: What is 2+2?"
395+
)
396+
self.assertEqual(formatted[0]["content"], expected_content)
397+
398+
@pytest.mark.cpu_only
399+
def test_format_without_template(self):
400+
"""Test formatting when template_config is None (the fix)."""
401+
messages = ["What is 2+2?"]
402+
formatted = utils_rl.format_maxtext_messages(messages, None, self.config)
403+
self.assertEqual(len(formatted), 1)
404+
self.assertEqual(formatted[0]["role"], "user")
405+
self.assertEqual(formatted[0]["content"], "What is 2+2?")
406+
407+
373408
if __name__ == "__main__":
374409
unittest.main()

0 commit comments

Comments
 (0)