Skip to content

Commit de67fbc

Browse files
lukebaumanncopybara-github
authored andcommitted
Require JAX>=0.8.0
---- Directly use jax.extend.ifrt_proxy. This change updates pathwaysutils to import and use `jax.extend.ifrt_proxy.ifrt_proxy` directly. The re-export of this function from `pathwaysutils.jax` is removed, along with version-specific compatibility code for older JAX versions. PiperOrigin-RevId: 852927594
1 parent 83d0aa3 commit de67fbc

File tree

3 files changed

+6
-22
lines changed

3 files changed

+6
-22
lines changed

pathwaysutils/jax/__init__.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
`pathwaysutils`'s compatibility window.
1818
"""
1919

20-
import functools
2120

22-
import jax
21+
import functools
2322

2423

2524
class _FakeJaxFunction:
@@ -46,20 +45,6 @@ def __call__(self, *args, **kwargs):
4645
raise ImportError(self.error_message)
4746

4847

49-
try:
50-
# jax>=0.7.1
51-
from jax.extend import backend # pylint: disable=g-import-not-at-top
52-
53-
ifrt_proxy = backend.ifrt_proxy
54-
del backend
55-
except AttributeError:
56-
# jax<0.7.1
57-
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top
58-
59-
ifrt_proxy = xla_extension.ifrt_proxy
60-
del xla_extension
61-
62-
6348
try:
6449
# jax>=0.8.0
6550
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
@@ -112,6 +97,5 @@ def ifrt_reshard_available() -> bool:
11297
del jax
11398

11499

115-
del jax
116100
del _FakeJaxFunction
117101
del functools

pathwaysutils/proxy_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
import jax
1717
from jax.extend import backend
18-
from pathwaysutils import jax as pw_jax
18+
from jax.extend.backend import ifrt_proxy
1919

2020

2121
def register_backend_factory() -> None:
2222
backend.register_backend_factory(
2323
"proxy",
24-
lambda: pw_jax.ifrt_proxy.get_client(
24+
lambda: ifrt_proxy.get_client(
2525
jax.config.read("jax_backend_target"),
26-
pw_jax.ifrt_proxy.ClientConnectionOptions(),
26+
ifrt_proxy.ClientConnectionOptions(),
2727
),
2828
priority=-1,
2929
)

pathwaysutils/test/proxy_backend_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from absl.testing import absltest
1919
import jax
2020
from jax.extend import backend
21-
from pathwaysutils import jax as pw_jax
21+
from jax.extend.backend import ifrt_proxy
2222
from pathwaysutils import proxy_backend
2323

2424

@@ -46,7 +46,7 @@ def test_no_proxy_backend_registration_raises_error(self):
4646
def test_proxy_backend_registration(self):
4747
self.enter_context(
4848
mock.patch.object(
49-
pw_jax.ifrt_proxy,
49+
ifrt_proxy,
5050
"get_client",
5151
return_value=mock.MagicMock(),
5252
)

0 commit comments

Comments
 (0)