Skip to content

Commit 6c56169

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor elastic retry decorators into a single elastic_retry method.
This change unifies the `pause_resume` and `replica_resize` functionalities into a single `elastic_retry` decorator. The new decorator uses a `minimum_slice_count` parameter to control whether to wait for all slices (defaulting to pause/resume behavior) or a smaller subset (enabling replica/resize behavior). The old `pause_resume` and `replica_resize` methods are now deprecated and act as wrappers around `elastic_retry`. PiperOrigin-RevId: 890535020
1 parent 44d0853 commit 6c56169

1 file changed

Lines changed: 139 additions & 115 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 139 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import logging
2424
import threading
2525
from typing import Any, TypeVar
26+
import warnings
2627

2728
import jax
2829
from pathwaysutils.elastic import elastic
@@ -54,7 +55,7 @@ def _elastic_event_cleanup() -> None:
5455
try:
5556
_logger.info("Cleaning up any ongoing traces")
5657
jax.profiler.stop_trace()
57-
except (RuntimeError, ValueError) as e:
58+
except (RuntimeError, ValueError):
5859
_logger.info("No ongoing traces to clean up")
5960
except Exception:
6061
_logger.exception("Error cleaning up ongoing traces")
@@ -156,40 +157,141 @@ def _cleanup_on_retry(self):
156157
for array in jax.live_arrays():
157158
array.delete()
158159

159-
def _elasticity_retry_decorator(
160+
def _monitor_new_slices(
161+
self, stop_event: threading.Event, poll_interval: float | int
162+
) -> None:
163+
"""Monitors for new slices and sets the `new_slice_event` if found."""
164+
while not stop_event.wait(poll_interval):
165+
try:
166+
if not self.inactive_slice_indices:
167+
_logger.debug("No inactive slices to check.")
168+
continue
169+
170+
_logger.debug(
171+
"Checking inactive slices: %s", self.inactive_slice_indices
172+
)
173+
inactive_slice_to_devices = {
174+
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
175+
}
176+
newly_active_indices = elastic.get_active_slice_indices(
177+
inactive_slice_to_devices
178+
)
179+
180+
if newly_active_indices:
181+
_logger.info(
182+
"New slices found: %s. Setting new slice event.",
183+
newly_active_indices,
184+
)
185+
self.new_slice_event.set()
186+
return
187+
188+
_logger.debug("No new slices found.")
189+
except Exception: # pylint: disable=broad-exception-caught
190+
_logger.exception("Error in monitor thread")
191+
192+
def elastic_retry(
160193
self,
161194
max_retries: int,
195+
minimum_slice_count: int | None = None,
196+
poll_interval: float | int = 10,
197+
timeout: float | None = None,
162198
pre_callback: Callable[..., Any] | None = None,
163199
on_elastic_event_callback: Callable[..., Any] | None = None,
164200
) -> Callable[[_F], _F]:
165201
"""Retries a function with elasticity fault tolerance.
166202
203+
This decorator wraps a function to automatically retry execution in case of
204+
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
205+
`minimum_slice_count` active slices before each attempt and cleans up JAX
206+
caches on failure.
207+
208+
If `minimum_slice_count` is not met, the function will wait until at least
209+
`minimum_slice_count` slices are active before execution. If
210+
`minimum_slice_count` is None, it defaults to the total number of slices
211+
(i.e., it waits for all slices to be active).
212+
213+
When `minimum_slice_count` is less than the total number of slices, a
214+
background thread will monitor for new slices becoming available and set
215+
`self.new_slice_event`. The user code can then poll this event and raise
216+
a `ScaleUpSignalError` to gracefully interrupt the current execution and
217+
trigger a retry.
218+
219+
Often, the function will dispatch JAX operations and wait for them to
220+
complete while creating a log message. If using Python logging, it is
221+
recommended to set `logging.raiseExceptions=True` to ensure that the
222+
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
223+
call.
224+
167225
Args:
168226
max_retries: The maximum number of times to retry the function.
169-
pre_callback: A callback to call before each attempt of the wrapped
170-
function.
227+
minimum_slice_count: The minimum number of slices required to run the
228+
function. If None, defaults to the total number of slices.
229+
poll_interval: The number of seconds to wait between activity checks.
230+
Defaults to 10 seconds.
231+
timeout: The maximum number of seconds to wait for slices to become active
232+
before each retry attempt. If None, there is no timeout.
233+
pre_callback: A callback to call before the function is attempted.
171234
on_elastic_event_callback: A callback to call after an elastic failure
172235
occurs.
173236
174237
Returns:
175-
A function decorator.
238+
A decorator that retries the wrapped function.
239+
240+
Raises:
241+
ElasticRuntimeError: If all retry attempts fail.
242+
Exception: Any other exception raised by the wrapped function that is not
243+
due to a slice down event.
176244
"""
245+
target_slice_count = (
246+
self.total_slice_count
247+
if minimum_slice_count is None
248+
else minimum_slice_count
249+
)
177250

178251
if max_retries <= 0:
179252
raise ValueError("max_retries must be positive.")
253+
180254
def decorator(func: _F) -> _F:
181255
@functools.wraps(func)
182256
def wrapper(*args: Any, **kwargs: Any) -> Any:
183-
for retry_index in range(max_retries):
184-
try:
185-
_logger.info(
186-
"Elastic attempt %d out of %d", retry_index + 1, max_retries
187-
)
188-
if pre_callback is not None:
189-
pre_callback()
190257

191-
with jax.default_device(self.default_device):
258+
def attempt_execution(retry_index: int) -> Any:
259+
_logger.info(
260+
"Elastic attempt %d out of %d", retry_index + 1, max_retries
261+
)
262+
self.active_slice_indices = elastic.wait_for_slices(
263+
slice_count=target_slice_count,
264+
slice_to_devices=self.slice_to_devices,
265+
poll_interval=poll_interval,
266+
timeout=timeout,
267+
)
268+
if pre_callback is not None:
269+
pre_callback()
270+
271+
with jax.default_device(self.default_device):
272+
self.new_slice_event.clear()
273+
stop_event = threading.Event()
274+
275+
if target_slice_count < self.total_slice_count:
276+
monitor_thread = threading.Thread(
277+
target=self._monitor_new_slices,
278+
args=(stop_event, poll_interval),
279+
daemon=True,
280+
)
281+
monitor_thread.start()
282+
else:
283+
monitor_thread = None
284+
285+
try:
192286
return func(*args, **kwargs)
287+
finally:
288+
stop_event.set()
289+
if monitor_thread is not None:
290+
monitor_thread.join()
291+
292+
for retry_index in range(max_retries):
293+
try:
294+
return attempt_execution(retry_index)
193295
except ScaleUpSignalError:
194296
_logger.info("Scale up requested. Retrying.")
195297
_elastic_event_cleanup()
@@ -230,84 +332,35 @@ def pause_resume(
230332
) -> Callable[[_F], _F]:
231333
"""Retries a function with pause/resume fault tolerance.
232334
233-
This decorator wraps a function to automatically retry execution in case of
234-
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
235-
active slices before each attempt and cleans up JAX caches on failure.
236-
The function will not be attempted (or reattempted) until all of the slices
237-
are active.
238-
239-
Often, the function will dispatch JAX operations and wait for them to
240-
complete while creating a log message. If using Python logging, it is
241-
recommended to set `logging.raiseExceptions=True` to ensure that the
242-
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
243-
call.
335+
DEPRECATED: Use `elastic_retry` instead.
244336
245337
Args:
246338
max_retries: The maximum number of times to retry the function.
247339
poll_interval: The number of seconds to wait between activity checks.
248340
Defaults to 10 seconds.
249-
timeout: The maximum number of seconds to wait for slices to become
250-
active before each retry attempt. If None, there is no timeout.
341+
timeout: The maximum number of seconds to wait for slices to become active
342+
before each retry attempt. If None, there is no timeout.
251343
pre_callback: A callback to call before the function is attempted.
252344
on_elastic_event_callback: A callback to call after an elastic failure
253345
occurs.
254346
255347
Returns:
256348
A decorator that retries the wrapped function.
257-
258-
Raises:
259-
ElasticRuntimeError: If all retry attempts fail.
260-
Exception: Any other exception raised by the wrapped function that is not
261-
due to a slice down event.
262349
"""
263-
def internal_pre_callback():
264-
self.active_slice_indices = elastic.wait_for_slices(
265-
slice_count=self.total_slice_count,
266-
slice_to_devices=self.slice_to_devices,
267-
poll_interval=poll_interval,
268-
timeout=timeout,
269-
)
270-
if pre_callback is not None:
271-
pre_callback()
272-
273-
return self._elasticity_retry_decorator(
350+
warnings.warn(
351+
"`pause_resume` is deprecated. Please use `elastic_retry` instead.",
352+
DeprecationWarning,
353+
stacklevel=2,
354+
)
355+
return self.elastic_retry(
274356
max_retries=max_retries,
275-
pre_callback=internal_pre_callback,
357+
minimum_slice_count=None,
358+
poll_interval=poll_interval,
359+
timeout=timeout,
360+
pre_callback=pre_callback,
276361
on_elastic_event_callback=on_elastic_event_callback,
277362
)
278363

279-
def _monitor_new_slices(
280-
self, stop_event: threading.Event, poll_interval: float | int
281-
):
282-
"""Monitors for new slices and sets the `new_slice_event` if found."""
283-
while not stop_event.wait(poll_interval):
284-
try:
285-
if not self.inactive_slice_indices:
286-
_logger.debug("No inactive slices to check.")
287-
continue
288-
289-
_logger.debug(
290-
"Checking inactive slices: %s", self.inactive_slice_indices
291-
)
292-
inactive_slice_to_devices = {
293-
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
294-
}
295-
newly_active_indices = elastic.get_active_slice_indices(
296-
inactive_slice_to_devices
297-
)
298-
299-
if newly_active_indices:
300-
_logger.info(
301-
"New slices found: %s. Setting new slice event.",
302-
newly_active_indices,
303-
)
304-
self.new_slice_event.set()
305-
return
306-
307-
_logger.debug("No new slices found.")
308-
except Exception: # pylint: disable=broad-exception-caught
309-
_logger.exception("Error in monitor thread")
310-
311364
def replica_resize(
312365
self,
313366
max_resizes: int,
@@ -317,6 +370,8 @@ def replica_resize(
317370
) -> Callable[[_F], _F]:
318371
"""Retries a function with replica/resize fault tolerance.
319372
373+
DEPRECATED: Use `elastic_retry` instead.
374+
320375
Args:
321376
max_resizes: The maximum number of times to retry the function after
322377
resizing the replica count.
@@ -328,47 +383,16 @@ def replica_resize(
328383
329384
Returns:
330385
A decorator that retries the wrapped function.
331-
332-
Raises:
333-
ElasticRuntimeError: If all retry attempts fail.
334-
Exception: Any other exception raised by the wrapped function that is not
335-
due to a slice down event.
336386
"""
337-
338-
def internal_pre_callback():
339-
self.active_slice_indices = elastic.wait_for_slices(
340-
slice_count=1,
341-
slice_to_devices=self.slice_to_devices,
342-
poll_interval=poll_interval,
343-
)
344-
345-
if pre_callback is not None:
346-
pre_callback()
347-
348-
retry_decorator = self._elasticity_retry_decorator(
387+
warnings.warn(
388+
"`replica_resize` is deprecated. Please use `elastic_retry` instead.",
389+
DeprecationWarning,
390+
stacklevel=2,
391+
)
392+
return self.elastic_retry(
349393
max_retries=max_resizes,
350-
pre_callback=internal_pre_callback,
394+
minimum_slice_count=1,
395+
poll_interval=poll_interval,
396+
pre_callback=pre_callback,
351397
on_elastic_event_callback=on_elastic_event_callback,
352398
)
353-
354-
def decorator(func):
355-
@functools.wraps(func)
356-
def wrapper(*args, **kwargs):
357-
self.new_slice_event.clear()
358-
stop_event = threading.Event()
359-
360-
monitor_thread = threading.Thread(
361-
target=self._monitor_new_slices,
362-
args=(stop_event, poll_interval),
363-
daemon=True,
364-
)
365-
monitor_thread.start()
366-
try:
367-
return func(*args, **kwargs)
368-
finally:
369-
stop_event.set()
370-
monitor_thread.join()
371-
372-
return retry_decorator(wrapper)
373-
374-
return decorator

0 commit comments

Comments
 (0)