2323import logging
2424import threading
2525from typing import Any , TypeVar
26+ import warnings
2627
2728import jax
2829from pathwaysutils .elastic import elastic
@@ -156,40 +157,136 @@ def _cleanup_on_retry(self):
156157 for array in jax .live_arrays ():
157158 array .delete ()
158159
159- def _elasticity_retry_decorator (
160+ def _monitor_new_slices (
161+ self , stop_event : threading .Event , poll_interval : float | int
162+ ):
163+ """Monitors for new slices and sets the `new_slice_event` if found."""
164+ while not stop_event .wait (poll_interval ):
165+ try :
166+ if not self .inactive_slice_indices :
167+ _logger .debug ("No inactive slices to check." )
168+ continue
169+
170+ _logger .debug (
171+ "Checking inactive slices: %s" , self .inactive_slice_indices
172+ )
173+ inactive_slice_to_devices = {
174+ i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
175+ }
176+ newly_active_indices = elastic .get_active_slice_indices (
177+ inactive_slice_to_devices
178+ )
179+
180+ if newly_active_indices :
181+ _logger .info (
182+ "New slices found: %s. Setting new slice event." ,
183+ newly_active_indices ,
184+ )
185+ self .new_slice_event .set ()
186+ return
187+
188+ _logger .debug ("No new slices found." )
189+ except Exception : # pylint: disable=broad-exception-caught
190+ _logger .exception ("Error in monitor thread" )
191+
192+ def elastic_retry (
160193 self ,
161194 max_retries : int ,
195+ minimum_slice_count : int | None = None ,
196+ poll_interval : float | int = 10 ,
197+ timeout : float | None = None ,
162198 pre_callback : Callable [..., Any ] | None = None ,
163199 on_elastic_event_callback : Callable [..., Any ] | None = None ,
164200 ) -> Callable [[_F ], _F ]:
165201 """Retries a function with elasticity fault tolerance.
166202
203+ This decorator wraps a function to automatically retry execution in case of
204+ `jax.errors.JaxRuntimeError` caused by slice down events. It waits for
205+ active slices before each attempt and cleans up JAX caches on failure.
206+
207+ If `minimum_slice_count` is not met, the function will wait until at least
208+ `minimum_slice_count` slices are active before execution. If
209+ `minimum_slice_count` is None, it defaults to the total number of slices
210+ (i.e., it waits for all slices to be active).
211+
212+ When `minimum_slice_count` is less than the total number of slices, a
213+ background thread will monitor for new slices becoming available and trigger
214+ a retry if they do.
215+
216+ Often, the function will dispatch JAX operations and wait for them to
217+ complete while creating a log message. If using Python logging, it is
218+ recommended to set `logging.raiseExceptions=True` to ensure that the
219+ `jax.errors.JaxRuntimeError` is not silently ignored within the logging
220+ call.
221+
167222 Args:
168223 max_retries: The maximum number of times to retry the function.
169- pre_callback: A callback to call before each attempt of the wrapped
170- function.
224+ minimum_slice_count: The minimum number of slices required to run the
225+ function. If None, defaults to the total number of slices.
226+ poll_interval: The number of seconds to wait between activity checks.
227+ Defaults to 10 seconds.
228+ timeout: The maximum number of seconds to wait for slices to become
229+ active before each retry attempt. If None, there is no timeout.
230+ pre_callback: A callback to call before the function is attempted.
171231 on_elastic_event_callback: A callback to call after an elastic failure
172232 occurs.
173233
174234 Returns:
175- A function decorator.
235+ A decorator that retries the wrapped function.
236+
237+ Raises:
238+ ElasticRuntimeError: If all retry attempts fail.
239+ Exception: Any other exception raised by the wrapped function that is not
240+ due to a slice down event.
176241 """
242+ if minimum_slice_count is None :
243+ target_slice_count = self .total_slice_count
244+ else :
245+ target_slice_count = minimum_slice_count
177246
178247 if max_retries <= 0 :
179248 raise ValueError ("max_retries must be positive." )
249+
180250 def decorator (func : _F ) -> _F :
181251 @functools .wraps (func )
182252 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
183- for retry_index in range (max_retries ):
184- try :
185- _logger .info (
186- "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
187- )
188- if pre_callback is not None :
189- pre_callback ()
253+ def attempt_execution (retry_index : int ):
254+ _logger .info (
255+ "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
256+ )
257+ self .active_slice_indices = elastic .wait_for_slices (
258+ slice_count = target_slice_count ,
259+ slice_to_devices = self .slice_to_devices ,
260+ poll_interval = poll_interval ,
261+ timeout = timeout ,
262+ )
263+ if pre_callback is not None :
264+ pre_callback ()
265+
266+ with jax .default_device (self .default_device ):
267+ self .new_slice_event .clear ()
268+ stop_event = threading .Event ()
269+
270+ if target_slice_count < self .total_slice_count :
271+ monitor_thread = threading .Thread (
272+ target = self ._monitor_new_slices ,
273+ args = (stop_event , poll_interval ),
274+ daemon = True ,
275+ )
276+ monitor_thread .start ()
277+ else :
278+ monitor_thread = None
190279
191- with jax . default_device ( self . default_device ) :
280+ try :
192281 return func (* args , ** kwargs )
282+ finally :
283+ stop_event .set ()
284+ if monitor_thread is not None :
285+ monitor_thread .join ()
286+
287+ for retry_index in range (max_retries ):
288+ try :
289+ return attempt_execution (retry_index )
193290 except ScaleUpSignalError :
194291 _logger .info ("Scale up requested. Retrying." )
195292 _elastic_event_cleanup ()
@@ -230,17 +327,7 @@ def pause_resume(
230327 ) -> Callable [[_F ], _F ]:
231328 """Retries a function with pause/resume fault tolerance.
232329
233- This decorator wraps a function to automatically retry execution in case of
234- `jax.errors.JaxRuntimeError` caused by slice down events. It waits for
235- active slices before each attempt and cleans up JAX caches on failure.
236- The function will not be attempted (or reattempted) until all of the slices
237- are active.
238-
239- Often, the function will dispatch JAX operations and wait for them to
240- complete while creating a log message. If using Python logging, it is
241- recommended to set `logging.raiseExceptions=True` to ensure that the
242- `jax.errors.JaxRuntimeError` is not silently ignored within the logging
243- call.
330+ DEPRECATED: Use `elastic_retry` instead.
244331
245332 Args:
246333 max_retries: The maximum number of times to retry the function.
@@ -254,60 +341,21 @@ def pause_resume(
254341
255342 Returns:
256343 A decorator that retries the wrapped function.
257-
258- Raises:
259- ElasticRuntimeError: If all retry attempts fail.
260- Exception: Any other exception raised by the wrapped function that is not
261- due to a slice down event.
262344 """
263- def internal_pre_callback ():
264- self .active_slice_indices = elastic .wait_for_slices (
265- slice_count = self .total_slice_count ,
266- slice_to_devices = self .slice_to_devices ,
267- poll_interval = poll_interval ,
268- timeout = timeout ,
269- )
270- if pre_callback is not None :
271- pre_callback ()
272-
273- return self ._elasticity_retry_decorator (
345+ warnings .warn (
346+ "`pause_resume` is deprecated. Please use `elastic_retry` instead." ,
347+ DeprecationWarning ,
348+ stacklevel = 2 ,
349+ )
350+ return self .elastic_retry (
274351 max_retries = max_retries ,
275- pre_callback = internal_pre_callback ,
352+ minimum_slice_count = None ,
353+ poll_interval = poll_interval ,
354+ timeout = timeout ,
355+ pre_callback = pre_callback ,
276356 on_elastic_event_callback = on_elastic_event_callback ,
277357 )
278358
279- def _monitor_new_slices (
280- self , stop_event : threading .Event , poll_interval : float | int
281- ):
282- """Monitors for new slices and sets the `new_slice_event` if found."""
283- while not stop_event .wait (poll_interval ):
284- try :
285- if not self .inactive_slice_indices :
286- _logger .debug ("No inactive slices to check." )
287- continue
288-
289- _logger .debug (
290- "Checking inactive slices: %s" , self .inactive_slice_indices
291- )
292- inactive_slice_to_devices = {
293- i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
294- }
295- newly_active_indices = elastic .get_active_slice_indices (
296- inactive_slice_to_devices
297- )
298-
299- if newly_active_indices :
300- _logger .info (
301- "New slices found: %s. Setting new slice event." ,
302- newly_active_indices ,
303- )
304- self .new_slice_event .set ()
305- return
306-
307- _logger .debug ("No new slices found." )
308- except Exception : # pylint: disable=broad-exception-caught
309- _logger .exception ("Error in monitor thread" )
310-
311359 def replica_resize (
312360 self ,
313361 max_resizes : int ,
@@ -317,6 +365,8 @@ def replica_resize(
317365 ) -> Callable [[_F ], _F ]:
318366 """Retries a function with replica/resize fault tolerance.
319367
368+ DEPRECATED: Use `elastic_retry` instead.
369+
320370 Args:
321371 max_resizes: The maximum number of times to retry the function after
322372 resizing the replica count.
@@ -328,47 +378,16 @@ def replica_resize(
328378
329379 Returns:
330380 A decorator that retries the wrapped function.
331-
332- Raises:
333- ElasticRuntimeError: If all retry attempts fail.
334- Exception: Any other exception raised by the wrapped function that is not
335- due to a slice down event.
336381 """
337-
338- def internal_pre_callback ():
339- self .active_slice_indices = elastic .wait_for_slices (
340- slice_count = 1 ,
341- slice_to_devices = self .slice_to_devices ,
342- poll_interval = poll_interval ,
343- )
344-
345- if pre_callback is not None :
346- pre_callback ()
347-
348- retry_decorator = self ._elasticity_retry_decorator (
382+ warnings .warn (
383+ "`replica_resize` is deprecated. Please use `elastic_retry` instead." ,
384+ DeprecationWarning ,
385+ stacklevel = 2 ,
386+ )
387+ return self .elastic_retry (
349388 max_retries = max_resizes ,
350- pre_callback = internal_pre_callback ,
389+ minimum_slice_count = 1 ,
390+ poll_interval = poll_interval ,
391+ pre_callback = pre_callback ,
351392 on_elastic_event_callback = on_elastic_event_callback ,
352393 )
353-
354- def decorator (func ):
355- @functools .wraps (func )
356- def wrapper (* args , ** kwargs ):
357- self .new_slice_event .clear ()
358- stop_event = threading .Event ()
359-
360- monitor_thread = threading .Thread (
361- target = self ._monitor_new_slices ,
362- args = (stop_event , poll_interval ),
363- daemon = True ,
364- )
365- monitor_thread .start ()
366- try :
367- return func (* args , ** kwargs )
368- finally :
369- stop_event .set ()
370- monitor_thread .join ()
371-
372- return retry_decorator (wrapper )
373-
374- return decorator
0 commit comments