1818events. It also provides a utility for waiting for slices to become active.
1919"""
2020
21- from collections .abc import Mapping , Sequence
21+ from collections .abc import Callable , Mapping , Sequence
2222import functools
2323import logging
24- from typing import Any
24+ from typing import Any , TypeVar
2525
2626import jax
2727from pathwaysutils .elastic import elastic
@@ -34,6 +34,25 @@ class ElasticRuntimeError(RuntimeError):
3434 """Error raised when elasticity cannot continue."""
3535
3636
37+ _F = TypeVar ("_F" , bound = Callable [..., Any ])
38+
39+
40+ def _elastic_event_cleanup () -> None :
41+ """Cleans up JAX profiles, caches, and live arrays."""
42+ try :
43+ _logger .info ("Cleaning up any ongoing traces" )
44+ jax .profiler .stop_trace ()
45+ except (RuntimeError , ValueError ) as e :
46+ _logger .info ("No ongoing traces to clean up" )
47+ except Exception :
48+ _logger .exception ("Error cleaning up ongoing traces" )
49+ raise
50+
51+ jax .clear_caches ()
52+ for array in jax .live_arrays ():
53+ array .delete ()
54+
55+
3756class Manager :
3857 """Utility class for elastic training."""
3958
@@ -95,12 +114,64 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
95114 else :
96115 raise ValueError (f"Unsupported type: { type (x )= } " )
97116
117+ def _elasticity_retry_decorator (
118+ self ,
119+ max_retries : int ,
120+ pre_callback : Callable [..., Any ] | None = None ,
121+ on_elastic_event_callback : Callable [..., Any ] | None = None ,
122+ ) -> Callable [[_F ], _F ]:
123+ """Retries a function with elasticity fault tolerance.
124+
125+ Args:
126+ max_retries: The maximum number of times to retry the function.
127+ pre_callback: A callback to call before each attempt of the wrapped
128+ function.
129+ on_elastic_event_callback: A callback to call after an elastic failure
130+ occurs.
131+
132+ Returns:
133+ A function decorator.
134+ """
135+
136+ if max_retries <= 0 :
137+ raise ValueError ("max_retries must be positive." )
138+ def decorator (func : _F ) -> _F :
139+ @functools .wraps (func )
140+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
141+ for retry_index in range (max_retries ):
142+ try :
143+ _logger .info (
144+ "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
145+ )
146+ if pre_callback is not None :
147+ pre_callback ()
148+
149+ with jax .default_device (self .default_device ):
150+ return func (* args , ** kwargs )
151+ except jax .errors .JaxRuntimeError as error :
152+ if not elastic .is_error_due_to_slice_down (error ):
153+ raise
154+
155+ _elastic_event_cleanup ()
156+
157+ if on_elastic_event_callback is not None :
158+ on_elastic_event_callback ()
159+ else :
160+ raise ElasticRuntimeError (
161+ f"Elastic attempt { max_retries } out of { max_retries } failed."
162+ )
163+
164+ return wrapper
165+ return decorator
166+
98167 def pause_resume (
99168 self ,
100169 max_retries : int ,
101170 poll_interval : float | int = 10 ,
102171 timeout : float | None = None ,
103- ) -> Any :
172+ pre_callback : Callable [..., Any ] | None = None ,
173+ on_elastic_event_callback : Callable [..., Any ] | None = None ,
174+ ) -> Callable [[_F ], _F ]:
104175 """Retries a function with pause/resume fault tolerance.
105176
106177 This decorator wraps a function to automatically retry execution in case of
@@ -121,6 +192,9 @@ def pause_resume(
121192 Defaults to 10 seconds.
122193 timeout: The maximum number of seconds to wait for slices to become
123194 active before each retry attempt. If None, there is no timeout.
195+ pre_callback: A callback to call before the function is attempted.
196+ on_elastic_event_callback: A callback to call after an elastic failure
197+ occurs.
124198
125199 Returns:
126200 The result of the wrapped function.
@@ -130,42 +204,18 @@ def pause_resume(
130204 Exception: Any other exception raised by the wrapped function that is not
131205 due to a slice down event.
132206 """
133- def decorator (func ):
134- @functools .wraps (func )
135- def wrapper (* args , ** kwargs ):
136- for retry_index in range (max_retries ):
137- try :
138- _logger .info (
139- "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
140- )
141-
142- self .active_slice_indices = elastic .wait_for_slices (
143- slice_count = self .total_slice_count ,
144- slice_to_devices = self .slice_to_devices ,
145- poll_interval = poll_interval ,
146- timeout = timeout ,
147- )
148-
149- return func (* args , ** kwargs )
150- except jax .errors .JaxRuntimeError as error :
151- if not elastic .is_error_due_to_slice_down (error ):
152- raise
153-
154- try :
155- _logger .info ("Cleaning up any ongoing traces" )
156- jax .profiler .stop_trace ()
157- except (RuntimeError , ValueError ) as e :
158- _logger .info ("No ongoing traces to clean up" )
159- except Exception :
160- _logger .exception ("Error cleaning up ongoing traces" )
161- raise
162-
163- jax .clear_caches ()
164- for array in jax .live_arrays ():
165- array .delete ()
166- raise ElasticRuntimeError (
167- f"Elastic attempt { max_retries } out of { max_retries } failed."
168- )
207+ def internal_pre_callback ():
208+ self .active_slice_indices = elastic .wait_for_slices (
209+ slice_count = self .total_slice_count ,
210+ slice_to_devices = self .slice_to_devices ,
211+ poll_interval = poll_interval ,
212+ timeout = timeout ,
213+ )
214+ if pre_callback is not None :
215+ pre_callback ()
169216
170- return wrapper
171- return decorator
217+ return self ._elasticity_retry_decorator (
218+ max_retries = max_retries ,
219+ pre_callback = internal_pre_callback ,
220+ on_elastic_event_callback = on_elastic_event_callback ,
221+ )
0 commit comments