2323import logging
2424import threading
2525from typing import Any , TypeVar
26+ import warnings
2627
2728import jax
2829from pathwaysutils .elastic import elastic
@@ -54,7 +55,7 @@ def _elastic_event_cleanup() -> None:
5455 try :
5556 _logger .info ("Cleaning up any ongoing traces" )
5657 jax .profiler .stop_trace ()
57- except (RuntimeError , ValueError ) as e :
58+ except (RuntimeError , ValueError ):
5859 _logger .info ("No ongoing traces to clean up" )
5960 except Exception :
6061 _logger .exception ("Error cleaning up ongoing traces" )
@@ -156,40 +157,141 @@ def _cleanup_on_retry(self):
156157 for array in jax .live_arrays ():
157158 array .delete ()
158159
159- def _elasticity_retry_decorator (
160+ def _monitor_new_slices (
161+ self , stop_event : threading .Event , poll_interval : float | int
162+ ) -> None :
163+ """Monitors for new slices and sets the `new_slice_event` if found."""
164+ while not stop_event .wait (poll_interval ):
165+ try :
166+ if not self .inactive_slice_indices :
167+ _logger .debug ("No inactive slices to check." )
168+ continue
169+
170+ _logger .debug (
171+ "Checking inactive slices: %s" , self .inactive_slice_indices
172+ )
173+ inactive_slice_to_devices = {
174+ i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
175+ }
176+ newly_active_indices = elastic .get_active_slice_indices (
177+ inactive_slice_to_devices
178+ )
179+
180+ if newly_active_indices :
181+ _logger .info (
182+ "New slices found: %s. Setting new slice event." ,
183+ newly_active_indices ,
184+ )
185+ self .new_slice_event .set ()
186+ return
187+
188+ _logger .debug ("No new slices found." )
189+ except Exception : # pylint: disable=broad-exception-caught
190+ _logger .exception ("Error in monitor thread" )
191+
192+ def elastic_retry (
160193 self ,
161194 max_retries : int ,
195+ minimum_slice_count : int | None = None ,
196+ poll_interval : float | int = 10 ,
197+ timeout : float | None = None ,
162198 pre_callback : Callable [..., Any ] | None = None ,
163199 on_elastic_event_callback : Callable [..., Any ] | None = None ,
164200 ) -> Callable [[_F ], _F ]:
165201 """Retries a function with elasticity fault tolerance.
166202
203+ This decorator wraps a function to automatically retry execution in case of
204+ `jax.errors.JaxRuntimeError` caused by slice down events. It waits for
205+ `minimum_slice_count` active slices before each attempt and cleans up JAX
206+ caches on failure.
207+
208+ If `minimum_slice_count` is not met, the function will wait until at least
209+ `minimum_slice_count` slices are active before execution. If
210+ `minimum_slice_count` is None, it defaults to the total number of slices
211+ (i.e., it waits for all slices to be active).
212+
213+ When `minimum_slice_count` is less than the total number of slices, a
214+ background thread will monitor for new slices becoming available and set
215+ `self.new_slice_event`. The user code can then poll this event and raise
216+ a `ScaleUpSignalError` to gracefully interrupt the current execution and
217+ trigger a retry.
218+
219+ Often, the function will dispatch JAX operations and wait for them to
220+ complete while creating a log message. If using Python logging, it is
221+ recommended to set `logging.raiseExceptions=True` to ensure that the
222+ `jax.errors.JaxRuntimeError` is not silently ignored within the logging
223+ call.
224+
167225 Args:
168226 max_retries: The maximum number of times to retry the function.
169- pre_callback: A callback to call before each attempt of the wrapped
170- function.
227+ minimum_slice_count: The minimum number of slices required to run the
228+ function. If None, defaults to the total number of slices.
229+ poll_interval: The number of seconds to wait between activity checks.
230+ Defaults to 10 seconds.
231+ timeout: The maximum number of seconds to wait for slices to become active
232+ before each retry attempt. If None, there is no timeout.
233+ pre_callback: A callback to call before the function is attempted.
171234 on_elastic_event_callback: A callback to call after an elastic failure
172235 occurs.
173236
174237 Returns:
175- A function decorator.
238+ A decorator that retries the wrapped function.
239+
240+ Raises:
241+ ElasticRuntimeError: If all retry attempts fail.
242+ Exception: Any other exception raised by the wrapped function that is not
243+ due to a slice down event.
176244 """
245+ target_slice_count = (
246+ self .total_slice_count
247+ if minimum_slice_count is None
248+ else minimum_slice_count
249+ )
177250
178251 if max_retries <= 0 :
179252 raise ValueError ("max_retries must be positive." )
253+
180254 def decorator (func : _F ) -> _F :
181255 @functools .wraps (func )
182256 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
183- for retry_index in range (max_retries ):
184- try :
185- _logger .info (
186- "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
187- )
188- if pre_callback is not None :
189- pre_callback ()
190257
191- with jax .default_device (self .default_device ):
258+ def attempt_execution (retry_index : int ) -> Any :
259+ _logger .info (
260+ "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
261+ )
262+ self .active_slice_indices = elastic .wait_for_slices (
263+ slice_count = target_slice_count ,
264+ slice_to_devices = self .slice_to_devices ,
265+ poll_interval = poll_interval ,
266+ timeout = timeout ,
267+ )
268+ if pre_callback is not None :
269+ pre_callback ()
270+
271+ with jax .default_device (self .default_device ):
272+ self .new_slice_event .clear ()
273+ stop_event = threading .Event ()
274+
275+ if target_slice_count < self .total_slice_count :
276+ monitor_thread = threading .Thread (
277+ target = self ._monitor_new_slices ,
278+ args = (stop_event , poll_interval ),
279+ daemon = True ,
280+ )
281+ monitor_thread .start ()
282+ else :
283+ monitor_thread = None
284+
285+ try :
192286 return func (* args , ** kwargs )
287+ finally :
288+ stop_event .set ()
289+ if monitor_thread is not None :
290+ monitor_thread .join ()
291+
292+ for retry_index in range (max_retries ):
293+ try :
294+ return attempt_execution (retry_index )
193295 except ScaleUpSignalError :
194296 _logger .info ("Scale up requested. Retrying." )
195297 _elastic_event_cleanup ()
@@ -230,84 +332,35 @@ def pause_resume(
230332 ) -> Callable [[_F ], _F ]:
231333 """Retries a function with pause/resume fault tolerance.
232334
233- This decorator wraps a function to automatically retry execution in case of
234- `jax.errors.JaxRuntimeError` caused by slice down events. It waits for
235- active slices before each attempt and cleans up JAX caches on failure.
236- The function will not be attempted (or reattempted) until all of the slices
237- are active.
238-
239- Often, the function will dispatch JAX operations and wait for them to
240- complete while creating a log message. If using Python logging, it is
241- recommended to set `logging.raiseExceptions=True` to ensure that the
242- `jax.errors.JaxRuntimeError` is not silently ignored within the logging
243- call.
335+ DEPRECATED: Use `elastic_retry` instead.
244336
245337 Args:
246338 max_retries: The maximum number of times to retry the function.
247339 poll_interval: The number of seconds to wait between activity checks.
248340 Defaults to 10 seconds.
249- timeout: The maximum number of seconds to wait for slices to become
250- active before each retry attempt. If None, there is no timeout.
341+ timeout: The maximum number of seconds to wait for slices to become active
342+ before each retry attempt. If None, there is no timeout.
251343 pre_callback: A callback to call before the function is attempted.
252344 on_elastic_event_callback: A callback to call after an elastic failure
253345 occurs.
254346
255347 Returns:
256348 A decorator that retries the wrapped function.
257-
258- Raises:
259- ElasticRuntimeError: If all retry attempts fail.
260- Exception: Any other exception raised by the wrapped function that is not
261- due to a slice down event.
262349 """
263- def internal_pre_callback ():
264- self .active_slice_indices = elastic .wait_for_slices (
265- slice_count = self .total_slice_count ,
266- slice_to_devices = self .slice_to_devices ,
267- poll_interval = poll_interval ,
268- timeout = timeout ,
269- )
270- if pre_callback is not None :
271- pre_callback ()
272-
273- return self ._elasticity_retry_decorator (
350+ warnings .warn (
351+ "`pause_resume` is deprecated. Please use `elastic_retry` instead." ,
352+ DeprecationWarning ,
353+ stacklevel = 2 ,
354+ )
355+ return self .elastic_retry (
274356 max_retries = max_retries ,
275- pre_callback = internal_pre_callback ,
357+ minimum_slice_count = None ,
358+ poll_interval = poll_interval ,
359+ timeout = timeout ,
360+ pre_callback = pre_callback ,
276361 on_elastic_event_callback = on_elastic_event_callback ,
277362 )
278363
279- def _monitor_new_slices (
280- self , stop_event : threading .Event , poll_interval : float | int
281- ):
282- """Monitors for new slices and sets the `new_slice_event` if found."""
283- while not stop_event .wait (poll_interval ):
284- try :
285- if not self .inactive_slice_indices :
286- _logger .debug ("No inactive slices to check." )
287- continue
288-
289- _logger .debug (
290- "Checking inactive slices: %s" , self .inactive_slice_indices
291- )
292- inactive_slice_to_devices = {
293- i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
294- }
295- newly_active_indices = elastic .get_active_slice_indices (
296- inactive_slice_to_devices
297- )
298-
299- if newly_active_indices :
300- _logger .info (
301- "New slices found: %s. Setting new slice event." ,
302- newly_active_indices ,
303- )
304- self .new_slice_event .set ()
305- return
306-
307- _logger .debug ("No new slices found." )
308- except Exception : # pylint: disable=broad-exception-caught
309- _logger .exception ("Error in monitor thread" )
310-
311364 def replica_resize (
312365 self ,
313366 max_resizes : int ,
@@ -317,6 +370,8 @@ def replica_resize(
317370 ) -> Callable [[_F ], _F ]:
318371 """Retries a function with replica/resize fault tolerance.
319372
373+ DEPRECATED: Use `elastic_retry` instead.
374+
320375 Args:
321376 max_resizes: The maximum number of times to retry the function after
322377 resizing the replica count.
@@ -328,47 +383,16 @@ def replica_resize(
328383
329384 Returns:
330385 A decorator that retries the wrapped function.
331-
332- Raises:
333- ElasticRuntimeError: If all retry attempts fail.
334- Exception: Any other exception raised by the wrapped function that is not
335- due to a slice down event.
336386 """
337-
338- def internal_pre_callback ():
339- self .active_slice_indices = elastic .wait_for_slices (
340- slice_count = 1 ,
341- slice_to_devices = self .slice_to_devices ,
342- poll_interval = poll_interval ,
343- )
344-
345- if pre_callback is not None :
346- pre_callback ()
347-
348- retry_decorator = self ._elasticity_retry_decorator (
387+ warnings .warn (
388+ "`replica_resize` is deprecated. Please use `elastic_retry` instead." ,
389+ DeprecationWarning ,
390+ stacklevel = 2 ,
391+ )
392+ return self .elastic_retry (
349393 max_retries = max_resizes ,
350- pre_callback = internal_pre_callback ,
394+ minimum_slice_count = 1 ,
395+ poll_interval = poll_interval ,
396+ pre_callback = pre_callback ,
351397 on_elastic_event_callback = on_elastic_event_callback ,
352398 )
353-
354- def decorator (func ):
355- @functools .wraps (func )
356- def wrapper (* args , ** kwargs ):
357- self .new_slice_event .clear ()
358- stop_event = threading .Event ()
359-
360- monitor_thread = threading .Thread (
361- target = self ._monitor_new_slices ,
362- args = (stop_event , poll_interval ),
363- daemon = True ,
364- )
365- monitor_thread .start ()
366- try :
367- return func (* args , ** kwargs )
368- finally :
369- stop_event .set ()
370- monitor_thread .join ()
371-
372- return retry_decorator (wrapper )
373-
374- return decorator
0 commit comments