Skip to content

Commit da4e047

Browse files
Adds ability to pass chat_template_path argument into MaxText SFT, loading a separate chat_template from the tokenizer that's provided.
PiperOrigin-RevId: 900245511
1 parent 51c7f2b commit da4e047

3 files changed

Lines changed: 75 additions & 0 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ def preprocessing_pipeline(
249249
dataset = dataset.select_columns(data_column_names)
250250

251251
if use_sft:
252+
if not chat_template and chat_template_path:
253+
chat_template = instruction_data_processing.get_chat_template_from_path(
254+
chat_template_path
255+
)
252256
data_processing_utils.validate_and_configure_sft_columns(data_column_names, tokenizer, chat_template)
253257

254258
# convert instruction dataset to conversational format

src/maxtext/input_pipeline/instruction_data_processing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,37 @@ def load_template_from_file(template_path):
3333
return template_config
3434

3535

36+
def get_chat_template_from_path(chat_template_path):
37+
"""Loads a chat template from a file."""
38+
if not chat_template_path:
39+
return None
40+
41+
current_dir = os.path.dirname(os.path.abspath(__file__))
42+
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
43+
full_path = os.path.join(repo_root, chat_template_path)
44+
45+
if not os.path.isfile(full_path):
46+
return None
47+
48+
if full_path.endswith((".jinja", ".j2", ".txt")):
49+
with open(full_path, "r", encoding="utf-8") as f:
50+
return f.read()
51+
52+
if full_path.endswith(".json"):
53+
with open(full_path, "r", encoding="utf-8") as f:
54+
try:
55+
template_config = json.load(f)
56+
if (
57+
isinstance(template_config, dict)
58+
and "chat_template" in template_config
59+
):
60+
return template_config["chat_template"]
61+
except json.JSONDecodeError:
62+
return None
63+
64+
return None
65+
66+
3667
def get_template_placeholders(template):
3768
"""Dynamically extracts the format keys (placeholders) from a template string."""
3869
# Finds all names inside {...}

tests/unit/instruction_data_processing_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
"""Instruction data processing test."""
1616

17+
import json
18+
import os
19+
import tempfile
1720
import unittest
1821

1922
from maxtext.input_pipeline import instruction_data_processing
@@ -139,6 +142,43 @@ def test_map_qa_data_to_conversation_with_no_templates(self):
139142
if data["role"] == "assistant":
140143
self.assertEqual(data["content"], "The capital of Germany is Berlin.")
141144

145+
def test_get_chat_template_from_path(self):
146+
with tempfile.TemporaryDirectory() as tmpdir:
147+
# Test .jinja file
148+
jinja_path = os.path.join(tmpdir, "test.jinja")
149+
with open(jinja_path, "w", encoding="utf-8") as f:
150+
f.write("test jinja template")
151+
self.assertEqual(
152+
instruction_data_processing.get_chat_template_from_path(jinja_path),
153+
"test jinja template",
154+
)
155+
156+
# Test .json file with chat_template
157+
json_path = os.path.join(tmpdir, "test.json")
158+
with open(json_path, "w", encoding="utf-8") as f:
159+
json.dump({"chat_template": "test json template"}, f)
160+
self.assertEqual(
161+
instruction_data_processing.get_chat_template_from_path(json_path),
162+
"test json template",
163+
)
164+
165+
# Test .json file without chat_template
166+
json_no_key_path = os.path.join(tmpdir, "no_key.json")
167+
with open(json_no_key_path, "w", encoding="utf-8") as f:
168+
json.dump({"other_key": "other_value"}, f)
169+
self.assertIsNone(
170+
instruction_data_processing.get_chat_template_from_path(
171+
json_no_key_path
172+
)
173+
)
174+
175+
# Test non-existent file
176+
self.assertIsNone(
177+
instruction_data_processing.get_chat_template_from_path(
178+
"non_existent.jinja"
179+
)
180+
)
181+
142182

143183
if __name__ == "__main__":
144184
unittest.main()

0 commit comments

Comments
 (0)