Skip to content

Commit 5213b90

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Add type hints to pathwaysutils
This change adds type annotations to various functions, methods, and variables across the pathwaysutils package, improving code clarity and maintainability. It also includes minor adjustments to existing type hints, such as using `|` for unions and casting where necessary. PiperOrigin-RevId: 881566030
1 parent 1cda610 commit 5213b90

10 files changed

Lines changed: 50 additions & 44 deletions

File tree

pathwaysutils/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Package of Pathways-on-Cloud utilities."""
15+
from collections.abc import Callable
1516
from pathwaysutils import _initialize
1617

17-
initialize = _initialize.initialize
18-
is_pathways_backend_used = _initialize.is_pathways_backend_used
18+
initialize: Callable[[], None] = _initialize.initialize
19+
is_pathways_backend_used: Callable[[], bool] = _initialize.is_pathways_backend_used
1920

2021
del _initialize
2122

2223
# When changing this, also update the CHANGELOG.md.
23-
__version__ = "v0.1.5"
24+
__version__: str = "v0.1.5"

pathwaysutils/collect_profile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_logger.setLevel(logging.INFO)
2727

2828

29-
_DESCRIPTION = """
29+
_DESCRIPTION: str = """
3030
To profile running JAX programs, you first need to start the profiler server
3131
in the program of interest. You can do this via
3232
`jax.profiler.start_server(<port>)`. Once the program is running and the
@@ -36,7 +36,7 @@
3636
"""
3737

3838

39-
def _get_parser():
39+
def _get_parser() -> argparse.ArgumentParser:
4040
"""Returns an argument parser for the collect_profile script."""
4141
parser = argparse.ArgumentParser(description=_DESCRIPTION)
4242
parser.add_argument(
@@ -62,7 +62,7 @@ def _get_parser():
6262
return parser
6363

6464

65-
def main():
65+
def main() -> None:
6666
parser = _get_parser()
6767
args = parser.parse_args()
6868

pathwaysutils/experimental/reshard.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
"""Experimental resharding API for elastic device sets."""
1515

1616
import base64
17-
import collections
18-
from collections.abc import Mapping
17+
from collections.abc import Callable, Mapping, Sequence
1918
import json
2019
import logging
2120
import math
2221
import operator
23-
from typing import Any, Callable, Dict, Mapping, Sequence
22+
from typing import Any
2423

2524
import jax
2625
from pathwaysutils import jax as pw_jax
@@ -57,7 +56,7 @@ def __init__(
5756
):
5857
def ifrt_hlo_sharding(
5958
aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding
60-
) -> Dict[str, Any]:
59+
) -> dict[str, Any]:
6160
result = {
6261
"devices": {
6362
"device_ids": [

pathwaysutils/experimental/split_by_mesh_axis.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414
"""Experimental split by mesh axis API."""
1515

16-
from typing import Any, Sequence
16+
from collections.abc import Sequence
17+
from typing import Any, cast
1718

1819
import jax
1920
from pathwaysutils import jax as pw_jax
@@ -167,7 +168,8 @@ def split_by_mesh_axis(
167168
mesh_axis_sizes=mesh.axis_sizes,
168169
mesh_axis_idx=mesh_axis_idx,
169170
mesh_axis_sections=mesh_axis_sections,
170-
submesh_shardings=submesh_shardings,
171+
# TODO: b/491156211 - Remove cast once type mismatch is fixed.
172+
submesh_shardings=cast(Any, submesh_shardings),
171173
donate=donate,
172174
)
173175

pathwaysutils/lru_cache.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
# limitations under the License.
1414
"""An LRU cache that will be cleared when JAX clears its internal cache."""
1515

16+
from collections.abc import Callable
1617
import functools
17-
from typing import Any, Callable
18+
from typing import Any, TypeVar
1819

1920
from jax.extend import backend
2021

2122

23+
_F = TypeVar("_F", bound=Callable[..., Any])
24+
25+
2226
def lru_cache(
2327
maxsize: int = 4096,
24-
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
28+
) -> Callable[[_F], _F]:
2529
"""An LRU cache that will be cleared when JAX clears its internal cache.
2630
2731
Args:
@@ -32,7 +36,7 @@ def lru_cache(
3236
A function that can be used to decorate a function to cache its results.
3337
"""
3438

35-
def wrap(f):
39+
def wrap(f: _F) -> _F:
3640
cached = functools.lru_cache(maxsize=maxsize)(f)
3741
wrapper = functools.wraps(f)(cached)
3842

pathwaysutils/persistence/helper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
"""Helper functions for persistence."""
1515

1616
import base64
17+
from collections.abc import Sequence
1718
import concurrent.futures
1819
import datetime
1920
import json
20-
from typing import Any, Sequence, Tuple, Union
21+
from typing import Any
2122

2223
import jax
2324
from jax import core
@@ -93,7 +94,7 @@ def get_hlo_sharding_string(
9394
def get_shape_info(
9495
dtype: np.dtype,
9596
dimensions: Sequence[int],
96-
) -> dict[str, Union[Sequence[int], str]]:
97+
) -> dict[str, Sequence[int] | str]:
9798
"""Returns shape info in the format expected by read requests."""
9899
return {
99100
"xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype),
@@ -107,7 +108,7 @@ def get_write_request(
107108
jax_array: jax.Array,
108109
timeout: datetime.timedelta,
109110
return_dict: bool = False,
110-
) -> Union[str, dict[str, Any]]:
111+
) -> str | dict[str, Any]:
111112
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
112113
sharding = jax_array.sharding
113114
assert isinstance(sharding, jax.sharding.Sharding), sharding
@@ -171,7 +172,7 @@ def get_read_request(
171172
devices: Sequence[jax.Device],
172173
timeout: datetime.timedelta,
173174
return_dict: bool = False,
174-
) -> Union[str, dict[str, Any]]:
175+
) -> str | dict[str, Any]:
175176
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
176177
if not isinstance(devices, np.ndarray):
177178
devices = np.array(devices)
@@ -256,9 +257,9 @@ def read_one_array(
256257
dtype: np.dtype,
257258
shape: Sequence[int],
258259
shardings: jax.sharding.Sharding,
259-
devices: Union[Sequence[jax.Device], np.ndarray],
260+
devices: Sequence[jax.Device] | np.ndarray,
260261
timeout: datetime.timedelta,
261-
):
262+
) -> jax.Array:
262263
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
263264
read_request = get_read_request(
264265
location,
@@ -284,9 +285,9 @@ def read_arrays(
284285
dtypes: Sequence[np.dtype],
285286
shapes: Sequence[Sequence[int]],
286287
shardings: Sequence[jax.sharding.Sharding],
287-
devices: Union[Sequence[jax.Device], np.ndarray],
288+
devices: Sequence[jax.Device] | np.ndarray,
288289
timeout: datetime.timedelta,
289-
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
290+
) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
290291
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
291292

292293
bulk_read_request = get_bulk_read_request(

pathwaysutils/plugin_executable.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
# limitations under the License.
1414
"""PluginExecutable is a class for executing plugin programs."""
1515

16+
from collections.abc import Sequence
1617
import concurrent.futures
1718
import threading
18-
from typing import List, Sequence, Tuple, Union
19-
2019
import jax
2120
from jax.extend import ifrt_programs
2221
from jax.interpreters import pxla
@@ -36,11 +35,11 @@ def __init__(self, prog_str: str):
3635

3736
def call(
3837
self,
39-
in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (),
38+
in_arr: Sequence[jax.Array | Sequence[jax.Array]] = (),
4039
out_shardings: Sequence[jax.sharding.Sharding] = (),
4140
out_avals: Sequence[jax.core.ShapedArray] = (),
4241
out_committed: bool = True,
43-
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
42+
) -> tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
4443
"""Runs the compiled IFRT program and returns the result and a future."""
4544
results_with_token = self.compiled.execute_sharded(in_arr, with_tokens=True)
4645

pathwaysutils/profiling.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import threading
22-
import time
2322
from typing import Any
2423
import urllib.parse
2524

@@ -38,11 +37,11 @@ class _ProfileState:
3837
executable: plugin_executable.PluginExecutable | None = None
3938
lock: threading.Lock
4039

41-
def __init__(self):
40+
def __init__(self) -> None:
4241
self.executable = None
4342
self.lock = threading.Lock()
4443

45-
def reset(self):
44+
def reset(self) -> None:
4645
self.executable = None
4746

4847

@@ -52,7 +51,7 @@ def reset(self):
5251
_original_stop_trace = jax.profiler.stop_trace
5352

5453

55-
def toy_computation():
54+
def toy_computation() -> None:
5655
"""A toy computation to run before the first profile."""
5756
x = jax.jit(lambda x: x + 1)(jnp.array(1))
5857
x.block_until_ready()
@@ -154,7 +153,7 @@ def start_trace(
154153
)
155154

156155

157-
def stop_trace():
156+
def stop_trace() -> None:
158157
"""Stops the currently-running profiler trace."""
159158
try:
160159
with _profile_state.lock:
@@ -172,7 +171,7 @@ def stop_trace():
172171
_profiler_thread: threading.Thread | None = None
173172

174173

175-
def start_server(port: int):
174+
def start_server(port: int) -> None:
176175
"""Starts the profiling server on port `port`.
177176
178177
The signature is slightly different from `jax.profiler.start_server`
@@ -192,7 +191,7 @@ class ProfilingConfig:
192191
repository_path: str
193192

194193
@app.post("/profiling")
195-
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
194+
async def profiling(pc: ProfilingConfig) -> dict[str, str]: # pylint: disable=unused-variable
196195
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
197196
_logger.debug("Writing profiling data to %s", pc.repository_path)
198197
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)
@@ -210,7 +209,7 @@ async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
210209
_profiler_thread.start()
211210

212211

213-
def stop_server():
212+
def stop_server() -> None:
214213
"""Raises an error if there is no active profiler server.
215214
216215
Pathways profiling servers are not stoppable at this time.
@@ -257,7 +256,7 @@ def collect_profile(
257256
return True
258257

259258

260-
def monkey_patch_jax():
259+
def monkey_patch_jax() -> None:
261260
"""Monkey patches JAX with Pathways versions of functions.
262261
263262
The signatures in patched functions should match the original.
@@ -279,7 +278,7 @@ def start_trace_patch(
279278
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
280279
) -> None:
281280
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
282-
return start_trace(
281+
start_trace(
283282
log_dir,
284283
create_perfetto_link=create_perfetto_link,
285284
create_perfetto_trace=create_perfetto_trace,
@@ -291,21 +290,21 @@ def start_trace_patch(
291290

292291
def stop_trace_patch() -> None:
293292
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
294-
return stop_trace()
293+
stop_trace()
295294

296295
jax.profiler.stop_trace = stop_trace_patch
297296
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access
298297

299-
def start_server_patch(port: int):
298+
def start_server_patch(port: int) -> None:
300299
_logger.debug(
301300
"jax.profile.start_server patched with pathways' start_server"
302301
)
303-
return start_server(port)
302+
start_server(port)
304303

305304
jax.profiler.start_server = start_server_patch
306305

307-
def stop_server_patch():
306+
def stop_server_patch() -> None:
308307
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
309-
return stop_server()
308+
stop_server()
310309

311310
jax.profiler.stop_server = stop_server_patch

pathwaysutils/proxy_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pathwaysutils import jax as pw_jax
1919

2020

21-
def register_backend_factory():
21+
def register_backend_factory() -> None:
2222
backend.register_backend_factory(
2323
"proxy",
2424
lambda: pw_jax.ifrt_proxy.get_client(

pathwaysutils/reshard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"""Resharding API using the IFRT RemapArray API."""
1515

1616
import collections
17-
from typing import Any, Callable, Mapping, Sequence
17+
from collections.abc import Callable, Mapping, Sequence
18+
from typing import Any
1819

1920
import jax
2021
import pathwaysutils.jax

0 commit comments

Comments
 (0)