Skip to content

Commit 37e6225

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor elastic retry decorators into a single elastic_retry method.
This change unifies the `pause_resume` and `replica_resize` functionalities into a single `elastic_retry` decorator. The new decorator uses a `minimum_slice_count` parameter to control whether to wait for all slices (defaulting to pause/resume behavior) or a smaller subset (enabling replica/resize behavior). The old `pause_resume` and `replica_resize` methods are now deprecated and act as wrappers around `elastic_retry`. PiperOrigin-RevId: 885122724
1 parent a57c2a0 commit 37e6225

1 file changed

Lines changed: 131 additions & 112 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 131 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import logging
2424
import threading
2525
from typing import Any, TypeVar
26+
import warnings
2627

2728
import jax
2829
from pathwaysutils.elastic import elastic
@@ -156,40 +157,136 @@ def _cleanup_on_retry(self):
156157
for array in jax.live_arrays():
157158
array.delete()
158159

159-
def _elasticity_retry_decorator(
160+
def _monitor_new_slices(
161+
self, stop_event: threading.Event, poll_interval: float | int
162+
):
163+
"""Monitors for new slices and sets the `new_slice_event` if found."""
164+
while not stop_event.wait(poll_interval):
165+
try:
166+
if not self.inactive_slice_indices:
167+
_logger.debug("No inactive slices to check.")
168+
continue
169+
170+
_logger.debug(
171+
"Checking inactive slices: %s", self.inactive_slice_indices
172+
)
173+
inactive_slice_to_devices = {
174+
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
175+
}
176+
newly_active_indices = elastic.get_active_slice_indices(
177+
inactive_slice_to_devices
178+
)
179+
180+
if newly_active_indices:
181+
_logger.info(
182+
"New slices found: %s. Setting new slice event.",
183+
newly_active_indices,
184+
)
185+
self.new_slice_event.set()
186+
return
187+
188+
_logger.debug("No new slices found.")
189+
except Exception: # pylint: disable=broad-exception-caught
190+
_logger.exception("Error in monitor thread")
191+
192+
def elastic_retry(
160193
self,
161194
max_retries: int,
195+
minimum_slice_count: int | None = None,
196+
poll_interval: float | int = 10,
197+
timeout: float | None = None,
162198
pre_callback: Callable[..., Any] | None = None,
163199
on_elastic_event_callback: Callable[..., Any] | None = None,
164200
) -> Callable[[_F], _F]:
165201
"""Retries a function with elasticity fault tolerance.
166202
203+
This decorator wraps a function to automatically retry execution in case of
204+
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
205+
active slices before each attempt and cleans up JAX caches on failure.
206+
207+
If `minimum_slice_count` is not met, the function will wait until at least
208+
`minimum_slice_count` slices are active before execution. If
209+
`minimum_slice_count` is None, it defaults to the total number of slices
210+
(i.e., it waits for all slices to be active).
211+
212+
When `minimum_slice_count` is less than the total number of slices, a
213+
background thread will monitor for new slices becoming available and trigger
214+
a retry if they do.
215+
216+
Often, the function will dispatch JAX operations and wait for them to
217+
complete while creating a log message. If using Python logging, it is
218+
recommended to set `logging.raiseExceptions=True` to ensure that the
219+
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
220+
call.
221+
167222
Args:
168223
max_retries: The maximum number of times to retry the function.
169-
pre_callback: A callback to call before each attempt of the wrapped
170-
function.
224+
minimum_slice_count: The minimum number of slices required to run the
225+
function. If None, defaults to the total number of slices.
226+
poll_interval: The number of seconds to wait between activity checks.
227+
Defaults to 10 seconds.
228+
timeout: The maximum number of seconds to wait for slices to become
229+
active before each retry attempt. If None, there is no timeout.
230+
pre_callback: A callback to call before the function is attempted.
171231
on_elastic_event_callback: A callback to call after an elastic failure
172232
occurs.
173233
174234
Returns:
175-
A function decorator.
235+
A decorator that retries the wrapped function.
236+
237+
Raises:
238+
ElasticRuntimeError: If all retry attempts fail.
239+
Exception: Any other exception raised by the wrapped function that is not
240+
due to a slice down event.
176241
"""
242+
if minimum_slice_count is None:
243+
target_slice_count = self.total_slice_count
244+
else:
245+
target_slice_count = minimum_slice_count
177246

178247
if max_retries <= 0:
179248
raise ValueError("max_retries must be positive.")
249+
180250
def decorator(func: _F) -> _F:
181251
@functools.wraps(func)
182252
def wrapper(*args: Any, **kwargs: Any) -> Any:
183-
for retry_index in range(max_retries):
184-
try:
185-
_logger.info(
186-
"Elastic attempt %d out of %d", retry_index + 1, max_retries
187-
)
188-
if pre_callback is not None:
189-
pre_callback()
253+
def attempt_execution(retry_index: int):
254+
_logger.info(
255+
"Elastic attempt %d out of %d", retry_index + 1, max_retries
256+
)
257+
self.active_slice_indices = elastic.wait_for_slices(
258+
slice_count=target_slice_count,
259+
slice_to_devices=self.slice_to_devices,
260+
poll_interval=poll_interval,
261+
timeout=timeout,
262+
)
263+
if pre_callback is not None:
264+
pre_callback()
265+
266+
with jax.default_device(self.default_device):
267+
self.new_slice_event.clear()
268+
stop_event = threading.Event()
269+
270+
if target_slice_count < self.total_slice_count:
271+
monitor_thread = threading.Thread(
272+
target=self._monitor_new_slices,
273+
args=(stop_event, poll_interval),
274+
daemon=True,
275+
)
276+
monitor_thread.start()
277+
else:
278+
monitor_thread = None
190279

191-
with jax.default_device(self.default_device):
280+
try:
192281
return func(*args, **kwargs)
282+
finally:
283+
stop_event.set()
284+
if monitor_thread is not None:
285+
monitor_thread.join()
286+
287+
for retry_index in range(max_retries):
288+
try:
289+
return attempt_execution(retry_index)
193290
except ScaleUpSignalError:
194291
_logger.info("Scale up requested. Retrying.")
195292
_elastic_event_cleanup()
@@ -230,17 +327,7 @@ def pause_resume(
230327
) -> Callable[[_F], _F]:
231328
"""Retries a function with pause/resume fault tolerance.
232329
233-
This decorator wraps a function to automatically retry execution in case of
234-
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
235-
active slices before each attempt and cleans up JAX caches on failure.
236-
The function will not be attempted (or reattempted) until all of the slices
237-
are active.
238-
239-
Often, the function will dispatch JAX operations and wait for them to
240-
complete while creating a log message. If using Python logging, it is
241-
recommended to set `logging.raiseExceptions=True` to ensure that the
242-
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
243-
call.
330+
DEPRECATED: Use `elastic_retry` instead.
244331
245332
Args:
246333
max_retries: The maximum number of times to retry the function.
@@ -254,60 +341,21 @@ def pause_resume(
254341
255342
Returns:
256343
A decorator that retries the wrapped function.
257-
258-
Raises:
259-
ElasticRuntimeError: If all retry attempts fail.
260-
Exception: Any other exception raised by the wrapped function that is not
261-
due to a slice down event.
262344
"""
263-
def internal_pre_callback():
264-
self.active_slice_indices = elastic.wait_for_slices(
265-
slice_count=self.total_slice_count,
266-
slice_to_devices=self.slice_to_devices,
267-
poll_interval=poll_interval,
268-
timeout=timeout,
269-
)
270-
if pre_callback is not None:
271-
pre_callback()
272-
273-
return self._elasticity_retry_decorator(
345+
warnings.warn(
346+
"`pause_resume` is deprecated. Please use `elastic_retry` instead.",
347+
DeprecationWarning,
348+
stacklevel=2,
349+
)
350+
return self.elastic_retry(
274351
max_retries=max_retries,
275-
pre_callback=internal_pre_callback,
352+
minimum_slice_count=None,
353+
poll_interval=poll_interval,
354+
timeout=timeout,
355+
pre_callback=pre_callback,
276356
on_elastic_event_callback=on_elastic_event_callback,
277357
)
278358

279-
def _monitor_new_slices(
280-
self, stop_event: threading.Event, poll_interval: float | int
281-
):
282-
"""Monitors for new slices and sets the `new_slice_event` if found."""
283-
while not stop_event.wait(poll_interval):
284-
try:
285-
if not self.inactive_slice_indices:
286-
_logger.debug("No inactive slices to check.")
287-
continue
288-
289-
_logger.debug(
290-
"Checking inactive slices: %s", self.inactive_slice_indices
291-
)
292-
inactive_slice_to_devices = {
293-
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
294-
}
295-
newly_active_indices = elastic.get_active_slice_indices(
296-
inactive_slice_to_devices
297-
)
298-
299-
if newly_active_indices:
300-
_logger.info(
301-
"New slices found: %s. Setting new slice event.",
302-
newly_active_indices,
303-
)
304-
self.new_slice_event.set()
305-
return
306-
307-
_logger.debug("No new slices found.")
308-
except Exception: # pylint: disable=broad-exception-caught
309-
_logger.exception("Error in monitor thread")
310-
311359
def replica_resize(
312360
self,
313361
max_resizes: int,
@@ -317,6 +365,8 @@ def replica_resize(
317365
) -> Callable[[_F], _F]:
318366
"""Retries a function with replica/resize fault tolerance.
319367
368+
DEPRECATED: Use `elastic_retry` instead.
369+
320370
Args:
321371
max_resizes: The maximum number of times to retry the function after
322372
resizing the replica count.
@@ -328,47 +378,16 @@ def replica_resize(
328378
329379
Returns:
330380
A decorator that retries the wrapped function.
331-
332-
Raises:
333-
ElasticRuntimeError: If all retry attempts fail.
334-
Exception: Any other exception raised by the wrapped function that is not
335-
due to a slice down event.
336381
"""
337-
338-
def internal_pre_callback():
339-
self.active_slice_indices = elastic.wait_for_slices(
340-
slice_count=1,
341-
slice_to_devices=self.slice_to_devices,
342-
poll_interval=poll_interval,
343-
)
344-
345-
if pre_callback is not None:
346-
pre_callback()
347-
348-
retry_decorator = self._elasticity_retry_decorator(
382+
warnings.warn(
383+
"`replica_resize` is deprecated. Please use `elastic_retry` instead.",
384+
DeprecationWarning,
385+
stacklevel=2,
386+
)
387+
return self.elastic_retry(
349388
max_retries=max_resizes,
350-
pre_callback=internal_pre_callback,
389+
minimum_slice_count=1,
390+
poll_interval=poll_interval,
391+
pre_callback=pre_callback,
351392
on_elastic_event_callback=on_elastic_event_callback,
352393
)
353-
354-
def decorator(func):
355-
@functools.wraps(func)
356-
def wrapper(*args, **kwargs):
357-
self.new_slice_event.clear()
358-
stop_event = threading.Event()
359-
360-
monitor_thread = threading.Thread(
361-
target=self._monitor_new_slices,
362-
args=(stop_event, poll_interval),
363-
daemon=True,
364-
)
365-
monitor_thread.start()
366-
try:
367-
return func(*args, **kwargs)
368-
finally:
369-
stop_event.set()
370-
monitor_thread.join()
371-
372-
return retry_decorator(wrapper)
373-
374-
return decorator

0 commit comments

Comments
 (0)