Skip to content

Commit baec42f

Browse files
Merge pull request #3536 from AI-Hypercomputer:autocheckpoint-v7x
PiperOrigin-RevId: 892599489
2 parents 87b1861 + 8ff7694 commit baec42f

4 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def create_orbax_checkpoint_manager(
221221
enable_single_controller: bool = False,
222222
colocated_python_checkpointing: bool = False,
223223
enable_single_replica_ckpt_restoring: bool = False,
224+
enable_autocheckpoint: bool = False,
224225
):
225226
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
226227
if not enable_checkpointing:
@@ -248,11 +249,21 @@ def create_orbax_checkpoint_manager(
248249
# local storage checkpoint needs parent directory created
249250
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250251
if enable_continuous_checkpointing:
252+
max_logging.log("Enabling policy for continuous checkpointing.")
251253
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
254+
elif enable_autocheckpoint:
255+
max_logging.log("Enabling policy for autocheckpoint.")
256+
save_decision_policy = save_decision_policy_lib.AnySavePolicy(
257+
[
258+
save_decision_policy_lib.PreemptionCheckpointingPolicy(),
259+
save_decision_policy_lib.FixedIntervalPolicy(save_interval_steps),
260+
]
261+
)
253262
else:
263+
max_logging.log("Enabling policy for fixed interval checkpointing.")
254264
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
255-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
265+
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
266+
256267
async_options = None
257268
if enable_continuous_checkpointing:
258269
async_options = ocp.AsyncOptions(
@@ -752,6 +763,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
752763
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
753764
or (step % config.checkpoint_period == 0)
754765
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
766+
or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step))
755767
):
756768
blocking_until_ready_start = time.time()
757769
max_logging.log(f"Waiting for step {step} to finish before checkpoint...")

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ source_checkpoint_layout: "orbax"
8383

8484
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
8585
colocated_python_checkpointing: False
86+
87+
# enables autocheckpoint, which saves a checkpoint at the preemption step.
88+
enable_autocheckpoint: False
8689
############################### end checkpointing ##################################
8790

8891

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ class Checkpointing(BaseModel):
333333
False,
334334
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
335335
)
336+
enable_autocheckpoint: bool = Field(
337+
False, description="If True, enables autocheckpoint or preemption induced checkpointing."
338+
)
336339

337340

338341
class OrbaxStorage(BaseModel):

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting for training in MaxText. """
16+
"""Utils that are only interesting for training in MaxText."""
1717

1818
import os
1919
import jax
@@ -82,6 +82,7 @@ def create_training_tools(config, model, mesh):
8282
config.enable_single_controller,
8383
config.colocated_python_checkpointing,
8484
config.enable_single_replica_ckpt_restoring,
85+
config.enable_autocheckpoint,
8586
)
8687

8788
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)