Skip to content

Commit 9f04db3

Browse files
guptaakacopybara-github
authored andcommitted
Update environment variables for JAX backend
PiperOrigin-RevId: 886935376
1 parent a57c2a0 commit 9f04db3

1 file changed

Lines changed: 28 additions & 0 deletions

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def __init__(
163163
self._proxy_port = None
164164
self.proxy_server_image = proxy_server_image
165165
self.proxy_options = proxy_options or ProxyOptions()
166+
self._old_jax_platforms = None
167+
self._old_jax_backend_target = None
166168

167169
def __repr__(self):
168170
return (
@@ -198,12 +200,21 @@ def __enter__(self):
198200
gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT)
199201
)
200202

203+
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY)
204+
self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY)
205+
201206
# Update the JAX backend to use the proxy.
207+
os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY
208+
os.environ[
209+
_JAX_BACKEND_TARGET_KEY
210+
] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}"
211+
202212
jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY)
203213
jax.config.update(
204214
_JAX_BACKEND_TARGET_KEY,
205215
f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}",
206216
)
217+
207218
pathwaysutils.initialize()
208219
_logger.info(
209220
"Interactive supercomputing proxy client ready for cluster '%s'.",
@@ -221,6 +232,17 @@ def __exit__(self, exc_type, exc_value, traceback):
221232
_logger.info("Exiting ISCPathways context.")
222233
self._cleanup()
223234

235+
def _restore_env_var(self, key: str, original_value: str | None):
236+
"""Restores an environment variable to its original value or unsets it."""
237+
if original_value is None:
238+
_logger.info("Unsetting environment variable: %s", key)
239+
os.environ.pop(key, None)
240+
else:
241+
_logger.info(
242+
"Restoring environment variable '%s' to '%s'", key, original_value
243+
)
244+
os.environ[key] = original_value
245+
224246
def _cleanup(self):
225247
"""Cleans up resources created by the ISCPathways context."""
226248
# 1. Clear JAX caches and run garbage collection.
@@ -248,6 +270,12 @@ def _cleanup(self):
248270
gke_utils.delete_gke_job(self._proxy_job_name)
249271
_logger.info("Pathways proxy GKE job deletion complete.")
250272

273+
# 4. Restore environment variables.
274+
_logger.info("Restoring environment variables.")
275+
self._restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms)
276+
self._restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target)
277+
_logger.info("Environment variables restored.")
278+
251279

252280
@contextlib.contextmanager
253281
def connect(

0 commit comments

Comments
 (0)