@@ -221,6 +221,7 @@ def create_orbax_checkpoint_manager(
221221 enable_single_controller : bool = False ,
222222 colocated_python_checkpointing : bool = False ,
223223 enable_single_replica_ckpt_restoring : bool = False ,
224+ enable_autocheckpoint : bool = False ,
224225):
225226 """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
226227 if not enable_checkpointing :
@@ -248,11 +249,21 @@ def create_orbax_checkpoint_manager(
248249 # local storage checkpoint needs parent directory created
249250 p = gcs_utils .mkdir_and_check_permissions (checkpoint_dir )
250251 if enable_continuous_checkpointing :
252+ max_logging .log ("Enabling policy for continuous checkpointing." )
251253 save_decision_policy = save_decision_policy_lib .ContinuousCheckpointingPolicy ()
252- preservation_policy = preservation_policy_lib .LatestN (max_num_checkpoints_to_keep )
254+ elif enable_autocheckpoint :
255+ max_logging .log ("Enabling policy for autocheckpoint." )
256+ save_decision_policy = save_decision_policy_lib .AnySavePolicy (
257+ [
258+ save_decision_policy_lib .PreemptionCheckpointingPolicy (),
259+ save_decision_policy_lib .FixedIntervalPolicy (save_interval_steps ),
260+ ]
261+ )
253262 else :
263+ max_logging .log ("Enabling policy for fixed interval checkpointing." )
254264 save_decision_policy = save_decision_policy_lib .FixedIntervalPolicy (interval = save_interval_steps )
255- preservation_policy = preservation_policy_lib .LatestN (max_num_checkpoints_to_keep )
265+ preservation_policy = preservation_policy_lib .LatestN (max_num_checkpoints_to_keep )
266+
256267 async_options = None
257268 if enable_continuous_checkpointing :
258269 async_options = ocp .AsyncOptions (
@@ -752,6 +763,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
752763 or (step % config .checkpoint_period == 0 and not config .enable_continuous_checkpointing )
753764 or (step % config .checkpoint_period == 0 )
754765 or (config .enable_emergency_checkpoint and step % config .local_checkpoint_period == 0 )
766+ or (config .enable_autocheckpoint and checkpoint_manager .reached_preemption (step ))
755767 ):
756768 blocking_until_ready_start = time .time ()
757769 max_logging .log (f"Waiting for step { step } to finish before checkpoint..." )
0 commit comments