Skip to content

Commit 63409ee

Browse files
Adds ability to pass chat_template_path argument into MaxText SFT, loading a separate chat_template from the tokenizer that's provided.
PiperOrigin-RevId: 900245511
1 parent 51c7f2b commit 63409ee

3 files changed

Lines changed: 57 additions & 0 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def preprocessing_pipeline(
249249
dataset = dataset.select_columns(data_column_names)
250250

251251
if use_sft:
252+
if not chat_template and chat_template_path:
253+
chat_template = instruction_data_processing.get_chat_template_from_path(chat_template_path)
252254
data_processing_utils.validate_and_configure_sft_columns(data_column_names, tokenizer, chat_template)
253255

254256
# convert instruction dataset to conversational format

src/maxtext/input_pipeline/instruction_data_processing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,34 @@ def load_template_from_file(template_path):
3333
return template_config
3434

3535

36+
def get_chat_template_from_path(chat_template_path):
37+
"""Loads a chat template from a file."""
38+
if not chat_template_path:
39+
return None
40+
41+
current_dir = os.path.dirname(os.path.abspath(__file__))
42+
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
43+
full_path = os.path.join(repo_root, chat_template_path)
44+
45+
if not os.path.isfile(full_path):
46+
return None
47+
48+
if full_path.endswith((".jinja", ".j2", ".txt")):
49+
with open(full_path, "r", encoding="utf-8") as f:
50+
return f.read()
51+
52+
if full_path.endswith(".json"):
53+
with open(full_path, "r", encoding="utf-8") as f:
54+
try:
55+
template_config = json.load(f)
56+
if isinstance(template_config, dict) and "chat_template" in template_config:
57+
return template_config["chat_template"]
58+
except json.JSONDecodeError:
59+
return None
60+
61+
return None
62+
63+
3664
def get_template_placeholders(template):
3765
"""Dynamically extracts the format keys (placeholders) from a template string."""
3866
# Finds all names inside {...}

tests/unit/instruction_data_processing_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
"""Instruction data processing test."""
1616

1717
import unittest
18+
import os
19+
import json
20+
import tempfile
21+
import shutil
1822

1923
from maxtext.input_pipeline import instruction_data_processing
2024

@@ -139,6 +143,29 @@ def test_map_qa_data_to_conversation_with_no_templates(self):
139143
if data["role"] == "assistant":
140144
self.assertEqual(data["content"], "The capital of Germany is Berlin.")
141145

146+
def test_get_chat_template_from_path(self):
147+
with tempfile.TemporaryDirectory() as tmpdir:
148+
# Test .jinja file
149+
jinja_path = os.path.join(tmpdir, "test.jinja")
150+
with open(jinja_path, "w") as f:
151+
f.write("test jinja template")
152+
self.assertEqual(instruction_data_processing.get_chat_template_from_path(jinja_path), "test jinja template")
153+
154+
# Test .json file with chat_template
155+
json_path = os.path.join(tmpdir, "test.json")
156+
with open(json_path, "w") as f:
157+
json.dump({"chat_template": "test json template"}, f)
158+
self.assertEqual(instruction_data_processing.get_chat_template_from_path(json_path), "test json template")
159+
160+
# Test .json file without chat_template
161+
json_no_key_path = os.path.join(tmpdir, "no_key.json")
162+
with open(json_no_key_path, "w") as f:
163+
json.dump({"other_key": "other_value"}, f)
164+
self.assertIsNone(instruction_data_processing.get_chat_template_from_path(json_no_key_path))
165+
166+
# Test non-existent file
167+
self.assertIsNone(instruction_data_processing.get_chat_template_from_path("non_existent.jinja"))
168+
142169

143170
if __name__ == "__main__":
144171
unittest.main()

0 commit comments

Comments
 (0)