1818events. It also provides a utility for waiting for slices to become active.
1919"""
2020
21+ import _thread
2122from collections .abc import Callable , Mapping , Sequence
2223import functools
2324import logging
25+ import threading
2426from typing import Any , TypeVar
2527
2628import jax
@@ -34,6 +36,20 @@ class ElasticRuntimeError(RuntimeError):
3436 """Error raised when elasticity cannot continue."""
3537
3638
39+ class ScaleUpError (RuntimeError ):
40+ """Signals that the workload is ready to scale up.
41+
42+ This exception should be raised by user code when it detects that new hardware
43+ is available and it wants to restart computation to make use of it.
44+ Raising this exception will interrupt the current computation and cause the
45+ elasticity manager to retry it with an updated slice configuration that
46+ includes the new hardware.
47+ """
48+
49+ # For backward compatibility.
50+ NewSliceAvailableError = ScaleUpError
51+
52+
3753_F = TypeVar ("_F" , bound = Callable [..., Any ])
3854
3955
@@ -59,6 +75,7 @@ class Manager:
5975 _total_slice_count : int | None = None
6076 slice_to_devices : Mapping [int , Sequence [jax .Device ]]
6177 active_slice_indices : set [int ]
78+ new_slice_event : threading .Event
6279
6380 def __init__ (self , devices : Sequence [jax .Device ] | None = None ) -> None :
6481 """Initializes the manager.
@@ -70,9 +87,12 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
7087 devices = jax .devices ()
7188 self .slice_to_devices = elastic .get_slice_to_devices (devices )
7289
90+ self .all_slice_indices = set (self .slice_to_devices .keys ())
91+
7392 self .active_slice_indices = elastic .get_active_slice_indices (
7493 slice_to_devices = self .slice_to_devices
7594 )
95+ self .new_slice_event = threading .Event ()
7696
7797 @property
7898 def total_slice_count (self ) -> int :
@@ -97,6 +117,11 @@ def active_slice_count(self) -> int:
97117 """Returns the number of slices."""
98118 return len (self .active_slice_indices )
99119
120+ @property
121+ def inactive_slice_indices (self ) -> set [int ]:
122+ """Returns the set of inactive slice indices."""
123+ return self .all_slice_indices - self .active_slice_indices
124+
100125 def scale_by_active_slices (self , x : int | float ) -> int | float :
101126 """Scale x by the number of active slices."""
102127 if isinstance (x , int ):
@@ -114,6 +139,20 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
114139 else :
115140 raise ValueError (f"Unsupported type: { type (x )= } " )
116141
142+ def _cleanup_on_retry (self ):
143+ """Cleans up JAX caches and traces on retry."""
144+ try :
145+ _logger .debug ("Cleaning up any ongoing traces" )
146+ jax .profiler .stop_trace ()
147+ except (RuntimeError , ValueError ):
148+ _logger .debug ("No ongoing traces to clean up" )
149+ except Exception : # pylint: disable=broad-exception-caught
150+ _logger .exception ("Error cleaning up ongoing traces" )
151+
152+ jax .clear_caches ()
153+ for array in jax .live_arrays ():
154+ array .delete ()
155+
117156 def _elasticity_retry_decorator (
118157 self ,
119158 max_retries : int ,
@@ -148,10 +187,23 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
148187
149188 with jax .default_device (self .default_device ):
150189 return func (* args , ** kwargs )
190+ except ScaleUpError :
191+ _logger .info ("Scale up requested. Retrying." )
192+ _elastic_event_cleanup ()
193+
194+ if on_elastic_event_callback is not None :
195+ on_elastic_event_callback ()
151196 except jax .errors .JaxRuntimeError as error :
152197 if not elastic .is_error_due_to_slice_down (error ):
153198 raise
154199
200+ if self .new_slice_event .is_set ():
201+ _logger .info (
202+ "Slice down event and new slice available detected. Retrying."
203+ )
204+ else :
205+ _logger .info ("Slice down event detected. Retrying." )
206+
155207 _elastic_event_cleanup ()
156208
157209 if on_elastic_event_callback is not None :
@@ -162,6 +214,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
162214 )
163215
164216 return wrapper
217+
165218 return decorator
166219
167220 def pause_resume (
@@ -219,3 +272,100 @@ def internal_pre_callback():
219272 pre_callback = internal_pre_callback ,
220273 on_elastic_event_callback = on_elastic_event_callback ,
221274 )
275+
276+ def _monitor_new_slices (
277+ self , stop_event : threading .Event , poll_interval : float | int
278+ ):
279+ """Monitors for new slices and sets the `new_slice_event` if found."""
280+ while not stop_event .wait (poll_interval ):
281+ try :
282+ if not self .inactive_slice_indices :
283+ _logger .debug ("No inactive slices to check." )
284+ continue
285+
286+ _logger .debug (
287+ "Checking inactive slices: %s" , self .inactive_slice_indices
288+ )
289+ inactive_slice_to_devices = {
290+ i : self .slice_to_devices [i ] for i in self .inactive_slice_indices
291+ }
292+ newly_active_indices = elastic .get_active_slice_indices (
293+ inactive_slice_to_devices
294+ )
295+
296+ if newly_active_indices :
297+ _logger .info (
298+ "New slices found: %s. Setting new slice event." ,
299+ newly_active_indices ,
300+ )
301+ self .new_slice_event .set ()
302+ return
303+
304+ _logger .debug ("No new slices found." )
305+ except Exception : # pylint: disable=broad-exception-caught
306+ _logger .exception ("Error in monitor thread" )
307+
308+ def replica_resize (
309+ self ,
310+ max_resizes : int ,
311+ poll_interval : float = 10 ,
312+ pre_callback : Callable [..., Any ] | None = None ,
313+ on_elastic_event_callback : Callable [..., Any ] | None = None ,
314+ ) -> Callable [[_F ], _F ]:
315+ """Retries a function with replica/resize fault tolerance.
316+
317+ Args:
318+ max_resizes: The maximum number of times to retry the function after
319+ resizing the replica count.
320+ poll_interval: The number of seconds to wait between active slice checks.
321+ Defaults to 10 seconds.
322+ pre_callback: A callback to call before the function is attempted.
323+ on_elastic_event_callback: A callback to call after an elastic failure
324+ occurs.
325+
326+ Returns:
327+ The result of the wrapped function.
328+
329+ Raises:
330+ ElasticRuntimeError: If all retry attempts fail.
331+ Exception: Any other exception raised by the wrapped function that is not
332+ due to a slice down event.
333+ """
334+
335+ def internal_pre_callback ():
336+ self .active_slice_indices = elastic .wait_for_slices (
337+ slice_count = 1 ,
338+ slice_to_devices = self .slice_to_devices ,
339+ poll_interval = poll_interval ,
340+ )
341+
342+ if pre_callback is not None :
343+ pre_callback ()
344+
345+ retry_decorator = self ._elasticity_retry_decorator (
346+ max_retries = max_resizes ,
347+ pre_callback = internal_pre_callback ,
348+ on_elastic_event_callback = on_elastic_event_callback ,
349+ )
350+
351+ def decorator (func ):
352+ @functools .wraps (func )
353+ def wrapper (* args , ** kwargs ):
354+ self .new_slice_event .clear ()
355+ stop_event = threading .Event ()
356+
357+ monitor_thread = threading .Thread (
358+ target = self ._monitor_new_slices ,
359+ args = (stop_event , poll_interval ),
360+ daemon = True ,
361+ )
362+ monitor_thread .start ()
363+ try :
364+ return func (* args , ** kwargs )
365+ finally :
366+ stop_event .set ()
367+ monitor_thread .join ()
368+
369+ return retry_decorator (wrapper )
370+
371+ return decorator
0 commit comments