@@ -163,6 +163,8 @@ def __init__(
163163 self ._proxy_port = None
164164 self .proxy_server_image = proxy_server_image
165165 self .proxy_options = proxy_options or ProxyOptions ()
166+ self ._old_jax_platforms = None
167+ self ._old_jax_backend_target = None
166168
167169 def __repr__ (self ):
168170 return (
@@ -198,12 +200,21 @@ def __enter__(self):
198200 gke_utils .enable_port_forwarding (proxy_pod , PROXY_SERVER_PORT )
199201 )
200202
203+ self ._old_jax_platforms = os .environ .get (_JAX_PLATFORMS_KEY )
204+ self ._old_jax_backend_target = os .environ .get (_JAX_BACKEND_TARGET_KEY )
205+
201206 # Update the JAX backend to use the proxy.
207+ os .environ [_JAX_PLATFORMS_KEY ] = _JAX_PLATFORM_PROXY
208+ os .environ [
209+ _JAX_BACKEND_TARGET_KEY
210+ ] = f"{ _JAX_BACKEND_TARGET_HOSTNAME } :{ self ._proxy_port } "
211+
202212 jax .config .update (_JAX_PLATFORMS_KEY , _JAX_PLATFORM_PROXY )
203213 jax .config .update (
204214 _JAX_BACKEND_TARGET_KEY ,
205215 f"{ _JAX_BACKEND_TARGET_HOSTNAME } :{ self ._proxy_port } " ,
206216 )
217+
207218 pathwaysutils .initialize ()
208219 _logger .info (
209220 "Interactive supercomputing proxy client ready for cluster '%s'." ,
@@ -221,6 +232,17 @@ def __exit__(self, exc_type, exc_value, traceback):
221232 _logger .info ("Exiting ISCPathways context." )
222233 self ._cleanup ()
223234
235+ def _restore_env_var (self , key : str , original_value : str | None ):
236+ """Restores an environment variable to its original value or unsets it."""
237+ if original_value is None :
238+ _logger .info ("Unsetting environment variable: %s" , key )
239+ os .environ .pop (key , None )
240+ else :
241+ _logger .info (
242+ "Restoring environment variable '%s' to '%s'" , key , original_value
243+ )
244+ os .environ [key ] = original_value
245+
224246 def _cleanup (self ):
225247 """Cleans up resources created by the ISCPathways context."""
226248 # 1. Clear JAX caches and run garbage collection.
@@ -248,6 +270,12 @@ def _cleanup(self):
248270 gke_utils .delete_gke_job (self ._proxy_job_name )
249271 _logger .info ("Pathways proxy GKE job deletion complete." )
250272
273+ # 4. Restore environment variables.
274+ _logger .info ("Restoring environment variables." )
275+ self ._restore_env_var (_JAX_PLATFORMS_KEY , self ._old_jax_platforms )
276+ self ._restore_env_var (_JAX_BACKEND_TARGET_KEY , self ._old_jax_backend_target )
277+ _logger .info ("Environment variables restored." )
278+
251279
252280@contextlib .contextmanager
253281def connect (
0 commit comments