Skip to content

Commit b9e5220

Browse files
committed
adding support for None chat templates.
1 parent 44fc6d0 commit b9e5220

3 files changed

Lines changed: 52 additions & 1 deletion

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def get_dataset(
142142
)
143143

144144
template_config = load_template_from_file(tmvp_config.chat_template_path)
145+
if template_config is None:
146+
max_logging.warning(
147+
f"Failed to load chat template from {tmvp_config.chat_template_path}. Proceeding without chat template."
148+
)
145149

146150
loaded_dataset = (
147151
grain.MapDataset.source(data)
@@ -339,6 +343,10 @@ def prepare_openinstructmath2_dataset(
339343
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
340344
splits = prepare_openinstructmath2_dataset(split=split_name)
341345
template_config = load_template_from_file(trainer_config.chat_template_path)
346+
if template_config is None:
347+
max_logging.warning(
348+
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
349+
)
342350

343351
train_dataset = (
344352
grain.MapDataset.source(splits["train"])
@@ -616,6 +624,10 @@ def _reward_fn(**kwargs):
616624
)
617625
# Instantiate the custom MaxText chat parser
618626
template_config = load_template_from_file(trainer_config.chat_template_path)
627+
if template_config is None:
628+
max_logging.warning(
629+
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
630+
)
619631
chat_parser = utils_rl.MaxTextChatParser(
620632
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
621633
)

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,13 @@ def make_optimizer(learning_rate):
526526
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
527527

528528

529-
def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
529+
def format_maxtext_messages(
530+
messages: list[dict[str, str]], template_config: dict | None, tmvp_config
531+
) -> list[dict[str, str]]:
530532
"""Helper to inject MaxText's system prompt into the input user messages."""
533+
if template_config is None:
534+
return [{"role": "user", "content": msg} for msg in messages]
535+
531536
formatted_messages = []
532537
for msg in messages:
533538
formatted_content = template_config["TEMPLATE"].format(

tests/post_training/unit/rl_utils_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,39 @@ def test_returns_optimizer_with_clipping(self):
370370
self.assertIn("learning_rate", state.hyperparams)
371371

372372

373+
class TestFormatMaxTextMessages(unittest.TestCase):
374+
"""Tests for utils_rl.format_maxtext_messages."""
375+
376+
def setUp(self):
377+
self.config = _make_config()
378+
self.template_config = {
379+
"SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. Solution between {solution_start_token} and {solution_end_token}.",
380+
"TEMPLATE": "system: {system_prompt}\nquestion: {question}",
381+
}
382+
383+
@pytest.mark.cpu_only
384+
def test_format_with_template(self):
385+
"""Test formatting when a template is provided."""
386+
messages = ["What is 2+2?"]
387+
formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config)
388+
self.assertEqual(len(formatted), 1)
389+
self.assertEqual(formatted[0]["role"], "user")
390+
expected_content = (
391+
"system: Reason between <reasoning> and </reasoning>. "
392+
"Solution between <answer> and </answer>.\n"
393+
"question: What is 2+2?"
394+
)
395+
self.assertEqual(formatted[0]["content"], expected_content)
396+
397+
@pytest.mark.cpu_only
398+
def test_format_without_template(self):
399+
"""Test formatting when template_config is None (the fix)."""
400+
messages = ["What is 2+2?"]
401+
formatted = utils_rl.format_maxtext_messages(messages, None, self.config)
402+
self.assertEqual(len(formatted), 1)
403+
self.assertEqual(formatted[0]["role"], "user")
404+
self.assertEqual(formatted[0]["content"], "What is 2+2?")
405+
406+
373407
if __name__ == "__main__":
374408
unittest.main()

0 commit comments

Comments
 (0)