2121from collections .abc import Callable , Mapping , Sequence , Set
2222import functools
2323import logging
24+ import threading
2425from typing import Any , TypeVar
2526
2627import jax
@@ -34,6 +35,17 @@ class ElasticRuntimeError(RuntimeError):
3435 """Error raised when elasticity cannot continue."""
3536
3637
38+ class ScaleUpSignalError (Exception ):
39+ """Signals that the workload is ready to scale up.
40+
41+ This exception should be raised by user code when it detects that new hardware
42+ is available and it wants to restart computation to make use of it.
43+ Raising this exception will interrupt the current computation and cause the
44+ elasticity manager to retry it with an updated slice configuration that
45+ includes the new hardware.
46+ """
47+
48+
3749_F = TypeVar ("_F" , bound = Callable [..., Any ])
3850
3951
@@ -54,11 +66,21 @@ def _elastic_event_cleanup() -> None:
5466
5567
5668class Manager :
57- """Utility class for elastic training."""
69+ """Utility class for elastic training.
70+
71+ Attributes:
72+ slice_to_devices: A mapping from slice index to a sequence of `jax.Device`
73+ objects for that slice.
74+ all_slice_indices: A set of all possible slice indices.
75+ active_slice_indices: A set of indices of the currently active slices.
76+ new_slice_event: A `threading.Event` that is set when new slices become
77+ available during replica/resize mode.
78+ """
5879
59- _total_slice_count : int | None = None
6080 slice_to_devices : Mapping [int , Sequence [jax .Device ]]
81+ all_slice_indices : Set [int ]
6182 active_slice_indices : Set [int ]
83+ new_slice_event : threading .Event
6284
6385 def __init__ (self , devices : Sequence [jax .Device ] | None = None ) -> None :
6486 """Initializes the manager.
@@ -70,20 +92,21 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
7092 devices = jax .devices ()
7193 self .slice_to_devices = elastic .get_slice_to_devices (devices )
7294
95+ self .all_slice_indices = set (self .slice_to_devices .keys ())
96+
7397 self .active_slice_indices = elastic .get_active_slice_indices (
7498 slice_to_devices = self .slice_to_devices
7599 )
100+ self .new_slice_event = threading .Event ()
76101
77- @property
102+ @functools . cached_property
78103 def total_slice_count (self ) -> int :
79- """Returns the total number of slices."""
80- if self ._total_slice_count is None :
81- self ._total_slice_count = len (self .slice_to_devices )
82- return self ._total_slice_count
104+ """The total number of slices."""
105+ return len (self .slice_to_devices )
83106
84107 @property
85108 def default_device (self ) -> jax .Device :
86- """Returns the device that should be set to the default device.
109+ """The device that should be set to the default device.
87110
88111 This will be from one of the slices in `active_slice_indices`.
89112 """
@@ -94,9 +117,14 @@ def default_device(self) -> jax.Device:
94117
95118 @property
96119 def active_slice_count (self ) -> int :
97- """Returns the number of slices."""
120+ """The number of active slices."""
98121 return len (self .active_slice_indices )
99122
123+ @property
124+ def inactive_slice_indices (self ) -> set [int ]:
125+ """The set of inactive slice indices."""
126+ return self .all_slice_indices - self .active_slice_indices
127+
100128 def scale_by_active_slices (self , x : int | float ) -> int | float :
101129 """Scale x by the number of active slices."""
102130 if isinstance (x , int ):
@@ -114,6 +142,20 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
114142 else :
115143 raise ValueError (f"Unsupported type: { type (x )= } " )
116144
145+ def _cleanup_on_retry (self ):
146+ """Cleans up JAX caches and traces on retry."""
147+ try :
148+ _logger .debug ("Cleaning up any ongoing traces" )
149+ jax .profiler .stop_trace ()
150+ except (RuntimeError , ValueError ):
151+ _logger .debug ("No ongoing traces to clean up" )
152+ except Exception : # pylint: disable=broad-exception-caught
153+ _logger .exception ("Error cleaning up ongoing traces" )
154+
155+ jax .clear_caches ()
156+ for array in jax .live_arrays ():
157+ array .delete ()
158+
117159 def _elasticity_retry_decorator (
118160 self ,
119161 max_retries : int ,
@@ -148,10 +190,23 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
148190
149191 with jax .default_device (self .default_device ):
150192 return func (* args , ** kwargs )
193+ except ScaleUpSignalError :
194+ _logger .info ("Scale up requested. Retrying." )
195+ _elastic_event_cleanup ()
196+
197+ if on_elastic_event_callback is not None :
198+ on_elastic_event_callback ()
151199 except jax .errors .JaxRuntimeError as error :
152200 if not elastic .is_error_due_to_slice_down (error ):
153201 raise
154202
203+ if self .new_slice_event .is_set ():
204+ _logger .info (
205+ "Slice down event and new slice available detected. Retrying."
206+ )
207+ else :
208+ _logger .info ("Slice down event detected. Retrying." )
209+
155210 _elastic_event_cleanup ()
156211
157212 if on_elastic_event_callback is not None :
@@ -162,6 +217,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
162217 )
163218
164219 return wrapper
220+
165221 return decorator
166222
167223 def pause_resume (
@@ -197,7 +253,7 @@ def pause_resume(
197253 occurs.
198254
199255 Returns:
200- The result of the wrapped function.
256+ A decorator that retries the wrapped function.
201257
202258 Raises:
203259 ElasticRuntimeError: If all retry attempts fail.
@@ -219,3 +275,100 @@ def internal_pre_callback():
219275 pre_callback = internal_pre_callback ,
220276 on_elastic_event_callback = on_elastic_event_callback ,
221277 )
278+
279+ def _monitor_new_slices (
280+ self , stop_event : threading .Event , poll_interval : float | int
281+ ):
282+ """Monitors for new slices and sets the `new_slice_event` if found."""
283+ while not stop_event .wait (poll_interval ):
284+ try :
285+ if not self .inactive_slice_indices :
286+ _logger .debug ("No inactive slices to check." )
287+ continue
288+
289+ _logger .debug (
290+ "Checking inactive slices: %s" , self .inactive_slice_indices
291+ )
292+ inactive_slice_to_devices = {
293+ i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
294+ }
295+ newly_active_indices = elastic .get_active_slice_indices (
296+ inactive_slice_to_devices
297+ )
298+
299+ if newly_active_indices :
300+ _logger .info (
301+ "New slices found: %s. Setting new slice event." ,
302+ newly_active_indices ,
303+ )
304+ self .new_slice_event .set ()
305+ return
306+
307+ _logger .debug ("No new slices found." )
308+ except Exception : # pylint: disable=broad-exception-caught
309+ _logger .exception ("Error in monitor thread" )
310+
311+ def replica_resize (
312+ self ,
313+ max_resizes : int ,
314+ poll_interval : float = 10 ,
315+ pre_callback : Callable [..., Any ] | None = None ,
316+ on_elastic_event_callback : Callable [..., Any ] | None = None ,
317+ ) -> Callable [[_F ], _F ]:
318+ """Retries a function with replica/resize fault tolerance.
319+
320+ Args:
321+ max_resizes: The maximum number of times to retry the function after
322+ resizing the replica count.
323+ poll_interval: The number of seconds to wait between active slice checks.
324+ Defaults to 10 seconds.
325+ pre_callback: A callback to call before the function is attempted.
326+ on_elastic_event_callback: A callback to call after an elastic failure
327+ occurs.
328+
329+ Returns:
330+ A decorator that retries the wrapped function.
331+
332+ Raises:
333+ ElasticRuntimeError: If all retry attempts fail.
334+ Exception: Any other exception raised by the wrapped function that is not
335+ due to a slice down event.
336+ """
337+
338+ def internal_pre_callback ():
339+ self .active_slice_indices = elastic .wait_for_slices (
340+ slice_count = 1 ,
341+ slice_to_devices = self .slice_to_devices ,
342+ poll_interval = poll_interval ,
343+ )
344+
345+ if pre_callback is not None :
346+ pre_callback ()
347+
348+ retry_decorator = self ._elasticity_retry_decorator (
349+ max_retries = max_resizes ,
350+ pre_callback = internal_pre_callback ,
351+ on_elastic_event_callback = on_elastic_event_callback ,
352+ )
353+
354+ def decorator (func ):
355+ @functools .wraps (func )
356+ def wrapper (* args , ** kwargs ):
357+ self .new_slice_event .clear ()
358+ stop_event = threading .Event ()
359+
360+ monitor_thread = threading .Thread (
361+ target = self ._monitor_new_slices ,
362+ args = (stop_event , poll_interval ),
363+ daemon = True ,
364+ )
365+ monitor_thread .start ()
366+ try :
367+ return func (* args , ** kwargs )
368+ finally :
369+ stop_event .set ()
370+ monitor_thread .join ()
371+
372+ return retry_decorator (wrapper )
373+
374+ return decorator
0 commit comments