Skip to content

Commit da43a48

Browse files
lukebaumanncopybara-github
authored andcommitted
Add a replica_resize decorator for fault tolerance in elastic Pathways.
PiperOrigin-RevId: 857246882
1 parent 2c8593c commit da43a48

2 files changed

Lines changed: 205 additions & 21 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,14 @@ def get_active_slice_indices(
106106
A set of integers representing the indices of the active slices.
107107
"""
108108
if slice_to_devices is None:
109+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
109110
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
110111

112+
_logger.debug(
113+
"Getting active slice indices for slices: %s",
114+
sorted(list(slice_to_devices.keys())),
115+
)
116+
111117
active_slice_indices = set()
112118

113119
results = {
@@ -116,17 +122,19 @@ def get_active_slice_indices(
116122
}
117123

118124
for slice_index, x in results.items():
119-
_logger.info("Checking slice_index=%s", slice_index)
125+
_logger.debug("Checking slice_index=%s", slice_index)
120126
expected = (
121127
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
122128
+ _SIMPLE_EXECUTION_TEST_VALUE
123129
)
124130
try:
125131
with timing.Timer(f"Checking {slice_index=}"):
132+
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
126133
jax.block_until_ready(x)
134+
_logger.debug("Execution finished for slice_index=%s", slice_index)
127135
if np.allclose(x, expected):
128136
active_slice_indices.add(slice_index)
129-
_logger.info("slice_index=%s active", slice_index)
137+
_logger.debug("slice_index=%s active", slice_index)
130138
else:
131139
_logger.error(
132140
"Error with _simple_execution for slice_index=%s. "
@@ -139,11 +147,15 @@ def get_active_slice_indices(
139147
f"Error with _simple_execution for slice_index={slice_index}."
140148
)
141149
except jax.errors.JaxRuntimeError as error:
150+
_logger.debug(
151+
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
152+
)
142153
if not is_error_due_to_slice_down(error):
154+
_logger.info("Re-raising error for slice_index=%s", slice_index)
143155
raise
144-
_logger.info("slice_index=%s bad", slice_index)
156+
_logger.debug("slice_index=%s bad", slice_index)
145157

146-
_logger.info("active_slice_indices=%s", active_slice_indices)
158+
_logger.debug("active_slice_indices=%s", active_slice_indices)
147159

148160
return active_slice_indices
149161

@@ -174,22 +186,36 @@ def wait_for_slices(
174186
active.
175187
"""
176188
if slice_to_devices is None:
189+
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
177190
slice_to_devices = get_slice_to_devices(jax.devices())
178191

192+
_logger.info(
193+
"Waiting for %s slices. Poll interval: %s, Timeout: %s",
194+
slice_count,
195+
poll_interval,
196+
timeout,
197+
)
179198
start_time = time.time()
180199

181200
while True:
182201
check_start_time = time.time()
183202

203+
_logger.debug("Checking active slices...")
184204
active_slice_indices = get_active_slice_indices(slice_to_devices)
185205
if len(active_slice_indices) >= slice_count:
186-
_logger.info("%s slices active.", len(active_slice_indices))
206+
_logger.info(
207+
"Sufficient slices active: %s >= %s. Active indices: %s",
208+
len(active_slice_indices),
209+
slice_count,
210+
active_slice_indices,
211+
)
187212
return active_slice_indices
188213

189214
_logger.info(
190-
"%s slices active. Wanting at least %s.",
215+
"%s slices active. Wanting at least %s. Active indices: %s",
191216
len(active_slice_indices),
192217
slice_count,
218+
active_slice_indices,
193219
)
194220

195221
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
@@ -206,7 +232,7 @@ def wait_for_slices(
206232
)
207233

208234
if time_to_sleep > 0:
209-
_logger.info("Sleeping for %.2f seconds.", time_to_sleep)
235+
_logger.debug("Sleeping for %.2f seconds.", time_to_sleep)
210236

211237
time.sleep(time_to_sleep)
212238

@@ -228,10 +254,14 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
228254
traceback_logging_level = logging.DEBUG
229255

230256
if isinstance(error, jax.errors.JaxRuntimeError):
257+
_logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error)
231258
if any(
232259
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
233260
):
234-
_logger.info("Caught an error due to slice down")
261+
_logger.debug(
262+
"Caught an error due to slice down (matched"
263+
" _ELASTIC_DOWN_ERROR_TYPES)"
264+
)
235265

236266
error_due_to_slice_down = True
237267

@@ -240,15 +270,16 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
240270
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
241271
):
242272
_logger.warning(
243-
"Caught an error due that may or may not be due to slice down. This"
244-
" error will be treated as due to slice down."
273+
"Caught an error that may or may not be due to slice down (matched"
274+
" _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated"
275+
" as due to slice down."
245276
)
246277
traceback_logging_level = logging.WARNING
247278

248279
error_due_to_slice_down = True
249280

250281
if not error_due_to_slice_down:
251-
_logger.info("Caught an error not due to slice down")
282+
_logger.debug("Caught an error not due to slice down")
252283

253284
_logger.log(traceback_logging_level, "Error details:", exc_info=True)
254285

pathwaysutils/elastic/manager.py

Lines changed: 163 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections.abc import Callable, Mapping, Sequence, Set
2222
import functools
2323
import logging
24+
import threading
2425
from typing import Any, TypeVar
2526

2627
import jax
@@ -34,6 +35,17 @@ class ElasticRuntimeError(RuntimeError):
3435
"""Error raised when elasticity cannot continue."""
3536

3637

38+
class ScaleUpSignalError(Exception):
39+
"""Signals that the workload is ready to scale up.
40+
41+
This exception should be raised by user code when it detects that new hardware
42+
is available and it wants to restart computation to make use of it.
43+
Raising this exception will interrupt the current computation and cause the
44+
elasticity manager to retry it with an updated slice configuration that
45+
includes the new hardware.
46+
"""
47+
48+
3749
_F = TypeVar("_F", bound=Callable[..., Any])
3850

3951

@@ -54,11 +66,21 @@ def _elastic_event_cleanup() -> None:
5466

5567

5668
class Manager:
57-
"""Utility class for elastic training."""
69+
"""Utility class for elastic training.
70+
71+
Attributes:
72+
slice_to_devices: A mapping from slice index to a sequence of `jax.Device`
73+
objects for that slice.
74+
all_slice_indices: A set of all possible slice indices.
75+
active_slice_indices: A set of indices of the currently active slices.
76+
new_slice_event: A `threading.Event` that is set when new slices become
77+
available during replica/resize mode.
78+
"""
5879

59-
_total_slice_count: int | None = None
6080
slice_to_devices: Mapping[int, Sequence[jax.Device]]
81+
all_slice_indices: Set[int]
6182
active_slice_indices: Set[int]
83+
new_slice_event: threading.Event
6284

6385
def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
6486
"""Initializes the manager.
@@ -70,20 +92,21 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
7092
devices = jax.devices()
7193
self.slice_to_devices = elastic.get_slice_to_devices(devices)
7294

95+
self.all_slice_indices = set(self.slice_to_devices.keys())
96+
7397
self.active_slice_indices = elastic.get_active_slice_indices(
7498
slice_to_devices=self.slice_to_devices
7599
)
100+
self.new_slice_event = threading.Event()
76101

77-
@property
102+
@functools.cached_property
78103
def total_slice_count(self) -> int:
79-
"""Returns the total number of slices."""
80-
if self._total_slice_count is None:
81-
self._total_slice_count = len(self.slice_to_devices)
82-
return self._total_slice_count
104+
"""The total number of slices."""
105+
return len(self.slice_to_devices)
83106

84107
@property
85108
def default_device(self) -> jax.Device:
86-
"""Returns the device that should be set to the default device.
109+
"""The device that should be set to the default device.
87110
88111
This will be from one of the slices in `active_slice_indices`.
89112
"""
@@ -94,9 +117,14 @@ def default_device(self) -> jax.Device:
94117

95118
@property
96119
def active_slice_count(self) -> int:
97-
"""Returns the number of slices."""
120+
"""The number of active slices."""
98121
return len(self.active_slice_indices)
99122

123+
@property
124+
def inactive_slice_indices(self) -> set[int]:
125+
"""The set of inactive slice indices."""
126+
return self.all_slice_indices - self.active_slice_indices
127+
100128
def scale_by_active_slices(self, x: int | float) -> int | float:
101129
"""Scale x by the number of active slices."""
102130
if isinstance(x, int):
@@ -114,6 +142,20 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
114142
else:
115143
raise ValueError(f"Unsupported type: {type(x)=}")
116144

145+
def _cleanup_on_retry(self):
146+
"""Cleans up JAX caches and traces on retry."""
147+
try:
148+
_logger.debug("Cleaning up any ongoing traces")
149+
jax.profiler.stop_trace()
150+
except (RuntimeError, ValueError):
151+
_logger.debug("No ongoing traces to clean up")
152+
except Exception: # pylint: disable=broad-exception-caught
153+
_logger.exception("Error cleaning up ongoing traces")
154+
155+
jax.clear_caches()
156+
for array in jax.live_arrays():
157+
array.delete()
158+
117159
def _elasticity_retry_decorator(
118160
self,
119161
max_retries: int,
@@ -148,10 +190,23 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
148190

149191
with jax.default_device(self.default_device):
150192
return func(*args, **kwargs)
193+
except ScaleUpSignalError:
194+
_logger.info("Scale up requested. Retrying.")
195+
_elastic_event_cleanup()
196+
197+
if on_elastic_event_callback is not None:
198+
on_elastic_event_callback()
151199
except jax.errors.JaxRuntimeError as error:
152200
if not elastic.is_error_due_to_slice_down(error):
153201
raise
154202

203+
if self.new_slice_event.is_set():
204+
_logger.info(
205+
"Slice down event and new slice available detected. Retrying."
206+
)
207+
else:
208+
_logger.info("Slice down event detected. Retrying.")
209+
155210
_elastic_event_cleanup()
156211

157212
if on_elastic_event_callback is not None:
@@ -162,6 +217,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
162217
)
163218

164219
return wrapper
220+
165221
return decorator
166222

167223
def pause_resume(
@@ -197,7 +253,7 @@ def pause_resume(
197253
occurs.
198254
199255
Returns:
200-
The result of the wrapped function.
256+
A decorator that retries the wrapped function.
201257
202258
Raises:
203259
ElasticRuntimeError: If all retry attempts fail.
@@ -219,3 +275,100 @@ def internal_pre_callback():
219275
pre_callback=internal_pre_callback,
220276
on_elastic_event_callback=on_elastic_event_callback,
221277
)
278+
279+
def _monitor_new_slices(
280+
self, stop_event: threading.Event, poll_interval: float | int
281+
):
282+
"""Monitors for new slices and sets the `new_slice_event` if found."""
283+
while not stop_event.wait(poll_interval):
284+
try:
285+
if not self.inactive_slice_indices:
286+
_logger.debug("No inactive slices to check.")
287+
continue
288+
289+
_logger.debug(
290+
"Checking inactive slices: %s", self.inactive_slice_indices
291+
)
292+
inactive_slice_to_devices = {
293+
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
294+
}
295+
newly_active_indices = elastic.get_active_slice_indices(
296+
inactive_slice_to_devices
297+
)
298+
299+
if newly_active_indices:
300+
_logger.info(
301+
"New slices found: %s. Setting new slice event.",
302+
newly_active_indices,
303+
)
304+
self.new_slice_event.set()
305+
return
306+
307+
_logger.debug("No new slices found.")
308+
except Exception: # pylint: disable=broad-exception-caught
309+
_logger.exception("Error in monitor thread")
310+
311+
def replica_resize(
312+
self,
313+
max_resizes: int,
314+
poll_interval: float = 10,
315+
pre_callback: Callable[..., Any] | None = None,
316+
on_elastic_event_callback: Callable[..., Any] | None = None,
317+
) -> Callable[[_F], _F]:
318+
"""Retries a function with replica/resize fault tolerance.
319+
320+
Args:
321+
max_resizes: The maximum number of times to retry the function after
322+
resizing the replica count.
323+
poll_interval: The number of seconds to wait between active slice checks.
324+
Defaults to 10 seconds.
325+
pre_callback: A callback to call before the function is attempted.
326+
on_elastic_event_callback: A callback to call after an elastic failure
327+
occurs.
328+
329+
Returns:
330+
A decorator that retries the wrapped function.
331+
332+
Raises:
333+
ElasticRuntimeError: If all retry attempts fail.
334+
Exception: Any other exception raised by the wrapped function that is not
335+
due to a slice down event.
336+
"""
337+
338+
def internal_pre_callback():
339+
self.active_slice_indices = elastic.wait_for_slices(
340+
slice_count=1,
341+
slice_to_devices=self.slice_to_devices,
342+
poll_interval=poll_interval,
343+
)
344+
345+
if pre_callback is not None:
346+
pre_callback()
347+
348+
retry_decorator = self._elasticity_retry_decorator(
349+
max_retries=max_resizes,
350+
pre_callback=internal_pre_callback,
351+
on_elastic_event_callback=on_elastic_event_callback,
352+
)
353+
354+
def decorator(func):
355+
@functools.wraps(func)
356+
def wrapper(*args, **kwargs):
357+
self.new_slice_event.clear()
358+
stop_event = threading.Event()
359+
360+
monitor_thread = threading.Thread(
361+
target=self._monitor_new_slices,
362+
args=(stop_event, poll_interval),
363+
daemon=True,
364+
)
365+
monitor_thread.start()
366+
try:
367+
return func(*args, **kwargs)
368+
finally:
369+
stop_event.set()
370+
monitor_thread.join()
371+
372+
return retry_decorator(wrapper)
373+
374+
return decorator

0 commit comments

Comments
 (0)