Skip to content

Commit b9787f8

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor elasticity retry logic into a reusable private method.
PiperOrigin-RevId: 878735893
1 parent cf87743 commit b9787f8

1 file changed

Lines changed: 91 additions & 41 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
events. It also provides a utility for waiting for slices to become active.
1919
"""
2020

21-
from collections.abc import Mapping, Sequence
21+
from collections.abc import Callable, Mapping, Sequence
2222
import functools
2323
import logging
24-
from typing import Any
24+
from typing import Any, TypeVar
2525

2626
import jax
2727
from pathwaysutils.elastic import elastic
@@ -34,6 +34,25 @@ class ElasticRuntimeError(RuntimeError):
3434
"""Error raised when elasticity cannot continue."""
3535

3636

37+
_F = TypeVar("_F", bound=Callable[..., Any])
38+
39+
40+
def _elastic_event_cleanup() -> None:
41+
"""Cleans up JAX profiles, caches, and live arrays."""
42+
try:
43+
_logger.info("Cleaning up any ongoing traces")
44+
jax.profiler.stop_trace()
45+
except (RuntimeError, ValueError) as e:
46+
_logger.info("No ongoing traces to clean up")
47+
except Exception:
48+
_logger.exception("Error cleaning up ongoing traces")
49+
raise
50+
51+
jax.clear_caches()
52+
for array in jax.live_arrays():
53+
array.delete()
54+
55+
3756
class Manager:
3857
"""Utility class for elastic training."""
3958

@@ -95,12 +114,64 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
95114
else:
96115
raise ValueError(f"Unsupported type: {type(x)=}")
97116

117+
def _elasticity_retry_decorator(
118+
self,
119+
max_retries: int,
120+
pre_callback: Callable[..., Any] | None = None,
121+
on_elastic_event_callback: Callable[..., Any] | None = None,
122+
) -> Callable[[_F], _F]:
123+
"""Retries a function with elasticity fault tolerance.
124+
125+
Args:
126+
max_retries: The maximum number of times to retry the function.
127+
pre_callback: A callback to call before each attempt of the wrapped
128+
function.
129+
on_elastic_event_callback: A callback to call after an elastic failure
130+
occurs.
131+
132+
Returns:
133+
A function decorator.
134+
"""
135+
136+
if max_retries <= 0:
137+
raise ValueError("max_retries must be positive.")
138+
def decorator(func: _F) -> _F:
139+
@functools.wraps(func)
140+
def wrapper(*args: Any, **kwargs: Any) -> Any:
141+
for retry_index in range(max_retries):
142+
try:
143+
_logger.info(
144+
"Elastic attempt %d out of %d", retry_index + 1, max_retries
145+
)
146+
if pre_callback is not None:
147+
pre_callback()
148+
149+
with jax.default_device(self.default_device):
150+
return func(*args, **kwargs)
151+
except jax.errors.JaxRuntimeError as error:
152+
if not elastic.is_error_due_to_slice_down(error):
153+
raise
154+
155+
_elastic_event_cleanup()
156+
157+
if on_elastic_event_callback is not None:
158+
on_elastic_event_callback()
159+
else:
160+
raise ElasticRuntimeError(
161+
f"Elastic attempt {max_retries} out of {max_retries} failed."
162+
)
163+
164+
return wrapper
165+
return decorator
166+
98167
def pause_resume(
99168
self,
100169
max_retries: int,
101170
poll_interval: float | int = 10,
102171
timeout: float | None = None,
103-
) -> Any:
172+
pre_callback: Callable[..., Any] | None = None,
173+
on_elastic_event_callback: Callable[..., Any] | None = None,
174+
) -> Callable[[_F], _F]:
104175
"""Retries a function with pause/resume fault tolerance.
105176
106177
This decorator wraps a function to automatically retry execution in case of
@@ -121,6 +192,9 @@ def pause_resume(
121192
Defaults to 10 seconds.
122193
timeout: The maximum number of seconds to wait for slices to become
123194
active before each retry attempt. If None, there is no timeout.
195+
pre_callback: A callback to call before the function is attempted.
196+
on_elastic_event_callback: A callback to call after an elastic failure
197+
occurs.
124198
125199
Returns:
126200
The result of the wrapped function.
@@ -130,42 +204,18 @@ def pause_resume(
130204
Exception: Any other exception raised by the wrapped function that is not
131205
due to a slice down event.
132206
"""
133-
def decorator(func):
134-
@functools.wraps(func)
135-
def wrapper(*args, **kwargs):
136-
for retry_index in range(max_retries):
137-
try:
138-
_logger.info(
139-
"Elastic attempt %d out of %d", retry_index + 1, max_retries
140-
)
141-
142-
self.active_slice_indices = elastic.wait_for_slices(
143-
slice_count=self.total_slice_count,
144-
slice_to_devices=self.slice_to_devices,
145-
poll_interval=poll_interval,
146-
timeout=timeout,
147-
)
148-
149-
return func(*args, **kwargs)
150-
except jax.errors.JaxRuntimeError as error:
151-
if not elastic.is_error_due_to_slice_down(error):
152-
raise
153-
154-
try:
155-
_logger.info("Cleaning up any ongoing traces")
156-
jax.profiler.stop_trace()
157-
except (RuntimeError, ValueError) as e:
158-
_logger.info("No ongoing traces to clean up")
159-
except Exception:
160-
_logger.exception("Error cleaning up ongoing traces")
161-
raise
162-
163-
jax.clear_caches()
164-
for array in jax.live_arrays():
165-
array.delete()
166-
raise ElasticRuntimeError(
167-
f"Elastic attempt {max_retries} out of {max_retries} failed."
168-
)
207+
def internal_pre_callback():
208+
self.active_slice_indices = elastic.wait_for_slices(
209+
slice_count=self.total_slice_count,
210+
slice_to_devices=self.slice_to_devices,
211+
poll_interval=poll_interval,
212+
timeout=timeout,
213+
)
214+
if pre_callback is not None:
215+
pre_callback()
169216

170-
return wrapper
171-
return decorator
217+
return self._elasticity_retry_decorator(
218+
max_retries=max_retries,
219+
pre_callback=internal_pre_callback,
220+
on_elastic_event_callback=on_elastic_event_callback,
221+
)

0 commit comments

Comments
 (0)