Skip to content

Commit 0cefd7e

Browse files
lukebaumanncopybara-github
authored andcommitted
Add a replica_resize decorator for fault tolerance in elastic Pathways.
PiperOrigin-RevId: 857246882
1 parent 83d0aa3 commit 0cefd7e

2 files changed

Lines changed: 192 additions & 11 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,14 @@ def get_active_slice_indices(
106106
A set of integers representing the indices of the active slices.
107107
"""
108108
if slice_to_devices is None:
109+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
109110
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
110111

112+
_logger.debug(
113+
"Getting active slice indices for slices: %s",
114+
sorted(list(slice_to_devices.keys())),
115+
)
116+
111117
active_slice_indices = set()
112118

113119
results = {
@@ -116,17 +122,19 @@ def get_active_slice_indices(
116122
}
117123

118124
for slice_index, x in results.items():
119-
_logger.info("Checking slice_index=%s", slice_index)
125+
_logger.debug("Checking slice_index=%s", slice_index)
120126
expected = (
121127
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
122128
+ _SIMPLE_EXECUTION_TEST_VALUE
123129
)
124130
try:
125131
with timing.Timer(f"Checking {slice_index=}"):
132+
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
126133
jax.block_until_ready(x)
134+
_logger.debug("Execution finished for slice_index=%s", slice_index)
127135
if np.allclose(x, expected):
128136
active_slice_indices.add(slice_index)
129-
_logger.info("slice_index=%s active", slice_index)
137+
_logger.debug("slice_index=%s active", slice_index)
130138
else:
131139
_logger.error(
132140
"Error with _simple_execution for slice_index=%s. "
@@ -139,11 +147,15 @@ def get_active_slice_indices(
139147
f"Error with _simple_execution for slice_index={slice_index}."
140148
)
141149
except jax.errors.JaxRuntimeError as error:
150+
_logger.debug(
151+
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
152+
)
142153
if not is_error_due_to_slice_down(error):
154+
_logger.info("Re-raising error for slice_index=%s", slice_index)
143155
raise
144-
_logger.info("slice_index=%s bad", slice_index)
156+
_logger.debug("slice_index=%s bad", slice_index)
145157

146-
_logger.info("active_slice_indices=%s", active_slice_indices)
158+
_logger.debug("active_slice_indices=%s", active_slice_indices)
147159

148160
return active_slice_indices
149161

@@ -174,22 +186,36 @@ def wait_for_slices(
174186
active.
175187
"""
176188
if slice_to_devices is None:
189+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
177190
slice_to_devices = get_slice_to_devices(jax.devices())
178191

192+
_logger.info(
193+
"Waiting for %s slices. Poll interval: %s, Timeout: %s",
194+
slice_count,
195+
poll_interval,
196+
timeout,
197+
)
179198
start_time = time.time()
180199

181200
while True:
182201
check_start_time = time.time()
183202

203+
_logger.debug("Checking active slices...")
184204
active_slice_indices = get_active_slice_indices(slice_to_devices)
185205
if len(active_slice_indices) >= slice_count:
186-
_logger.info("%s slices active.", len(active_slice_indices))
206+
_logger.info(
207+
"Sufficient slices active: %s >= %s. Active indices: %s",
208+
len(active_slice_indices),
209+
slice_count,
210+
active_slice_indices,
211+
)
187212
return active_slice_indices
188213

189214
_logger.info(
190-
"%s slices active. Wanting at least %s.",
215+
"%s slices active. Wanting at least %s. Active indices: %s",
191216
len(active_slice_indices),
192217
slice_count,
218+
active_slice_indices,
193219
)
194220

195221
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
@@ -206,7 +232,7 @@ def wait_for_slices(
206232
)
207233

208234
if time_to_sleep > 0:
209-
_logger.info("Sleeping for %.2f seconds.", time_to_sleep)
235+
_logger.debug("Sleeping for %.2f seconds.", time_to_sleep)
210236

211237
time.sleep(time_to_sleep)
212238

@@ -228,10 +254,14 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
228254
traceback_logging_level = logging.DEBUG
229255

230256
if isinstance(error, jax.errors.JaxRuntimeError):
257+
_logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error)
231258
if any(
232259
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
233260
):
234-
_logger.info("Caught an error due to slice down")
261+
_logger.debug(
262+
"Caught an error due to slice down (matched"
263+
" _ELASTIC_DOWN_ERROR_TYPES)"
264+
)
235265

236266
error_due_to_slice_down = True
237267

@@ -240,15 +270,16 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
240270
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
241271
):
242272
_logger.warning(
243-
"Caught an error due that may or may not be due to slice down. This"
244-
" error will be treated as due to slice down."
273+
"Caught an error that may or may not be due to slice down (matched"
274+
" _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated"
275+
" as due to slice down."
245276
)
246277
traceback_logging_level = logging.WARNING
247278

248279
error_due_to_slice_down = True
249280

250281
if not error_due_to_slice_down:
251-
_logger.info("Caught an error not due to slice down")
282+
_logger.debug("Caught an error not due to slice down")
252283

253284
_logger.log(traceback_logging_level, "Error details:", exc_info=True)
254285

pathwaysutils/elastic/manager.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
events. It also provides a utility for waiting for slices to become active.
1919
"""
2020

21+
import _thread
2122
from collections.abc import Callable, Mapping, Sequence
2223
import functools
2324
import logging
25+
import threading
2426
from typing import Any, TypeVar
2527

2628
import jax
@@ -34,6 +36,20 @@ class ElasticRuntimeError(RuntimeError):
3436
"""Error raised when elasticity cannot continue."""
3537

3638

39+
class ScaleUpError(RuntimeError):
40+
"""Signals that the workload is ready to scale up.
41+
42+
This exception should be raised by user code when it detects that new hardware
43+
is available and it wants to restart computation to make use of it.
44+
Raising this exception will interrupt the current computation and cause the
45+
elasticity manager to retry it with an updated slice configuration that
46+
includes the new hardware.
47+
"""
48+
49+
# For backward compatibility.
50+
NewSliceAvailableError = ScaleUpError
51+
52+
3753
_F = TypeVar("_F", bound=Callable[..., Any])
3854

3955

@@ -59,6 +75,7 @@ class Manager:
5975
_total_slice_count: int | None = None
6076
slice_to_devices: Mapping[int, Sequence[jax.Device]]
6177
active_slice_indices: set[int]
78+
new_slice_event: threading.Event
6279

6380
def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
6481
"""Initializes the manager.
@@ -70,9 +87,12 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
7087
devices = jax.devices()
7188
self.slice_to_devices = elastic.get_slice_to_devices(devices)
7289

90+
self.all_slice_indices = set(self.slice_to_devices.keys())
91+
7392
self.active_slice_indices = elastic.get_active_slice_indices(
7493
slice_to_devices=self.slice_to_devices
7594
)
95+
self.new_slice_event = threading.Event()
7696

7797
@property
7898
def total_slice_count(self) -> int:
@@ -97,6 +117,11 @@ def active_slice_count(self) -> int:
97117
"""Returns the number of slices."""
98118
return len(self.active_slice_indices)
99119

120+
@property
121+
def inactive_slice_indices(self) -> set[int]:
122+
"""Returns the set of inactive slice indices."""
123+
return self.all_slice_indices - self.active_slice_indices
124+
100125
def scale_by_active_slices(self, x: int | float) -> int | float:
101126
"""Scale x by the number of active slices."""
102127
if isinstance(x, int):
@@ -114,6 +139,20 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
114139
else:
115140
raise ValueError(f"Unsupported type: {type(x)=}")
116141

142+
def _cleanup_on_retry(self):
143+
"""Cleans up JAX caches and traces on retry."""
144+
try:
145+
_logger.debug("Cleaning up any ongoing traces")
146+
jax.profiler.stop_trace()
147+
except (RuntimeError, ValueError):
148+
_logger.debug("No ongoing traces to clean up")
149+
except Exception: # pylint: disable=broad-exception-caught
150+
_logger.exception("Error cleaning up ongoing traces")
151+
152+
jax.clear_caches()
153+
for array in jax.live_arrays():
154+
array.delete()
155+
117156
def _elasticity_retry_decorator(
118157
self,
119158
max_retries: int,
@@ -148,10 +187,23 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
148187

149188
with jax.default_device(self.default_device):
150189
return func(*args, **kwargs)
190+
except ScaleUpError:
191+
_logger.info("Scale up requested. Retrying.")
192+
_elastic_event_cleanup()
193+
194+
if on_elastic_event_callback is not None:
195+
on_elastic_event_callback()
151196
except jax.errors.JaxRuntimeError as error:
152197
if not elastic.is_error_due_to_slice_down(error):
153198
raise
154199

200+
if self.new_slice_event.is_set():
201+
_logger.info(
202+
"Slice down event and new slice available detected. Retrying."
203+
)
204+
else:
205+
_logger.info("Slice down event detected. Retrying.")
206+
155207
_elastic_event_cleanup()
156208

157209
if on_elastic_event_callback is not None:
@@ -162,6 +214,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
162214
)
163215

164216
return wrapper
217+
165218
return decorator
166219

167220
def pause_resume(
@@ -219,3 +272,100 @@ def internal_pre_callback():
219272
pre_callback=internal_pre_callback,
220273
on_elastic_event_callback=on_elastic_event_callback,
221274
)
275+
276+
def _monitor_new_slices(
277+
self, stop_event: threading.Event, poll_interval: float | int
278+
):
279+
"""Monitors for new slices and sets the `new_slice_event` if found."""
280+
while not stop_event.wait(poll_interval):
281+
try:
282+
if not self.inactive_slice_indices:
283+
_logger.debug("No inactive slices to check.")
284+
continue
285+
286+
_logger.debug(
287+
"Checking inactive slices: %s", self.inactive_slice_indices
288+
)
289+
inactive_slice_to_devices = {
290+
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
291+
}
292+
newly_active_indices = elastic.get_active_slice_indices(
293+
inactive_slice_to_devices
294+
)
295+
296+
if newly_active_indices:
297+
_logger.info(
298+
"New slices found: %s. Setting new slice event.",
299+
newly_active_indices,
300+
)
301+
self.new_slice_event.set()
302+
return
303+
304+
_logger.debug("No new slices found.")
305+
except Exception: # pylint: disable=broad-exception-caught
306+
_logger.exception("Error in monitor thread")
307+
308+
def replica_resize(
309+
self,
310+
max_resizes: int,
311+
poll_interval: float = 10,
312+
pre_callback: Callable[..., Any] | None = None,
313+
on_elastic_event_callback: Callable[..., Any] | None = None,
314+
) -> Callable[[_F], _F]:
315+
"""Retries a function with replica/resize fault tolerance.
316+
317+
Args:
318+
max_resizes: The maximum number of times to retry the function after
319+
resizing the replica count.
320+
poll_interval: The number of seconds to wait between active slice checks.
321+
Defaults to 10 seconds.
322+
pre_callback: A callback to call before the function is attempted.
323+
on_elastic_event_callback: A callback to call after an elastic failure
324+
occurs.
325+
326+
Returns:
327+
The result of the wrapped function.
328+
329+
Raises:
330+
ElasticRuntimeError: If all retry attempts fail.
331+
Exception: Any other exception raised by the wrapped function that is not
332+
due to a slice down event.
333+
"""
334+
335+
def internal_pre_callback():
336+
self.active_slice_indices = elastic.wait_for_slices(
337+
slice_count=1,
338+
slice_to_devices=self.slice_to_devices,
339+
poll_interval=poll_interval,
340+
)
341+
342+
if pre_callback is not None:
343+
pre_callback()
344+
345+
retry_decorator = self._elasticity_retry_decorator(
346+
max_retries=max_resizes,
347+
pre_callback=internal_pre_callback,
348+
on_elastic_event_callback=on_elastic_event_callback,
349+
)
350+
351+
def decorator(func):
352+
@functools.wraps(func)
353+
def wrapper(*args, **kwargs):
354+
self.new_slice_event.clear()
355+
stop_event = threading.Event()
356+
357+
monitor_thread = threading.Thread(
358+
target=self._monitor_new_slices,
359+
args=(stop_event, poll_interval),
360+
daemon=True,
361+
)
362+
monitor_thread.start()
363+
try:
364+
return func(*args, **kwargs)
365+
finally:
366+
stop_event.set()
367+
monitor_thread.join()
368+
369+
return retry_decorator(wrapper)
370+
371+
return decorator

0 commit comments

Comments
 (0)