Skip to content

Commit 4a0b8cb

Browse files
Merge pull request #3535 from AI-Hypercomputer:aireen/tiktoken
PiperOrigin-RevId: 892656253
2 parents f2216e2 + b47abba commit 4a0b8cb

4 files changed

Lines changed: 41 additions & 10 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,7 @@ num_vocab_tiling: 1
597597
# Tokenizer
598598
vocab_size: 32_000 # powers of 2 for sharding
599599
tokenizer_path: ""
600-
# tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken
601-
# grain pipeline supports tokenizer_type: sentencepiece, huggingface
600+
# grain and tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken
602601
# hf pipeline only supports huggingface type, and will ignore tokenizer_type flag
603602
tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface"
604603
use_chat_template: False

src/maxtext/configs/types.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,13 +2580,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25802580
raise ValueError("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path")
25812581
if self.eval_interval > 0 and not self.grain_eval_files:
25822582
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
2583-
if self.tokenizer_type not in (
2584-
TokenizerType.SENTENCEPIECE,
2585-
TokenizerType.HUGGINGFACE,
2586-
):
2587-
raise ValueError(
2588-
f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got {self.tokenizer_type}"
2589-
)
25902583
elif self.dataset_type == DatasetType.TFDS:
25912584
if not self.dataset_name:
25922585
raise ValueError("dataset_name can't be empty when dataset_type=tfds")

src/maxtext/input_pipeline/grain_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TokenizerTransformBase:
3030
# pylint: disable=attribute-defined-outside-init
3131
feature_names: str | Sequence[str]
3232
sequence_length: int | Sequence[int]
33-
tokenizer: tokenizer.SentencePieceTokenizer | tokenizer.HFTokenizer
33+
tokenizer: tokenizer.SentencePieceTokenizer | tokenizer.HFTokenizer | tokenizer.TikTokenTokenizer
3434

3535
def __post_init__(self):
3636
self._processor = None

tests/unit/grain_data_processing_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,45 @@ def test_for_loop_repeatable(self):
251251
super().test_for_loop_repeatable()
252252

253253

254+
class GrainArrayRecordTiktokenTest(GrainArrayRecordProcessingTest):
255+
"""Test grain data processing with best_fit packing strategy."""
256+
257+
def setUp(self):
258+
super().setUp()
259+
self.config = self._make_config(
260+
tokenizer_type="tiktoken",
261+
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
262+
)
263+
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
264+
265+
# Only runs test_train_ds from parent class, skip other tests
266+
@pytest.mark.skip(reason="skip for tokenizer testing")
267+
def test_batch_determinism(self):
268+
pass
269+
270+
@pytest.mark.skip(reason="skip for tokenizer testing")
271+
def test_for_loop_repeatable(self):
272+
pass
273+
274+
275+
class GrainArrayRecordHFTokenizerTest(GrainArrayRecordProcessingTest):
276+
"""Test grain data processing with best_fit packing strategy."""
277+
278+
def setUp(self):
279+
super().setUp()
280+
self.config = self._make_config(tokenizer_type="huggingface", tokenizer_path="deepseek-ai/DeepSeek-V3")
281+
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
282+
283+
# Only runs test_train_ds from parent class, skip other tests
284+
@pytest.mark.skip(reason="skip for tokenizer testing")
285+
def test_batch_determinism(self):
286+
pass
287+
288+
@pytest.mark.skip(reason="skip for tokenizer testing")
289+
def test_for_loop_repeatable(self):
290+
pass
291+
292+
254293
class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest):
255294
"""Test grain data processing with best_fit packing strategy."""
256295

0 commit comments

Comments
 (0)