Skip to content

Commit 0d2ca8a

Browse files
committed
add addtional config to import libtpu flags
1 parent d5cbf3c commit 0d2ca8a

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ rope_attention_scaling: False # Scale the rotary embedding output
879879
compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle
880880
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
881881
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
882+
compile_libtpu_flags: "" # LIBTPU_INIT_ARGS for compilation only
882883

883884
decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature)
884885
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ class LayoutAndSharding(BaseModel):
840840
description="Allowed percentage of non-sharded parameters.",
841841
)
842842
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
843+
compile_libtpu_flags: str = Field("", description="LIBTPU_INIT_ARGS for compilation only.")
843844
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
844845
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
845846

@@ -2174,6 +2175,9 @@ def validate_and_set_hlo_dump_defaults():
21742175

21752176
# pylint: enable=access-member-before-definition
21762177

2178+
# Add LIBTPU FLAGS
2179+
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + self.compile_libtpu_flags
2180+
21772181
# Validate and initiate hlo dump related configs
21782182
validate_and_set_hlo_dump_defaults()
21792183

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def is_oom(argv: Sequence[str]) -> bool:
219219
def main(argv: Sequence[str]) -> None:
220220
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
221221
os.environ["LIBTPU_INIT_ARGS"] = (
222-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
222+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true "
223223
)
224224
print("Starting train_compile.py...", flush=True)
225225

0 commit comments

Comments
 (0)