1414"""Helper functions for persistence."""
1515
1616import base64
17+ from collections .abc import Sequence
1718import concurrent .futures
1819import datetime
1920import json
20- from typing import Any , Sequence , Tuple , Union
21+ from typing import Any
2122
2223import jax
2324from jax import core
@@ -93,7 +94,7 @@ def get_hlo_sharding_string(
9394def get_shape_info (
9495 dtype : np .dtype ,
9596 dimensions : Sequence [int ],
96- ) -> dict [str , Union [ Sequence [int ], str ] ]:
97+ ) -> dict [str , Sequence [int ] | str ]:
9798 """Returns shape info in the format expected by read requests."""
9899 return {
99100 "xla_primitive_type_str" : dtype_to_xla_primitive_type_str (dtype ),
@@ -107,7 +108,7 @@ def get_write_request(
107108 jax_array : jax .Array ,
108109 timeout : datetime .timedelta ,
109110 return_dict : bool = False ,
110- ) -> Union [ str , dict [str , Any ] ]:
111+ ) -> str | dict [str , Any ]:
111112 """Returns a string representation of the plugin program which writes the given jax_array to the given location."""
112113 sharding = jax_array .sharding
113114 assert isinstance (sharding , jax .sharding .Sharding ), sharding
@@ -171,7 +172,7 @@ def get_read_request(
171172 devices : Sequence [jax .Device ],
172173 timeout : datetime .timedelta ,
173174 return_dict : bool = False ,
174- ) -> Union [ str , dict [str , Any ] ]:
175+ ) -> str | dict [str , Any ]:
175176 """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
176177 if not isinstance (devices , np .ndarray ):
177178 devices = np .array (devices )
@@ -256,9 +257,9 @@ def read_one_array(
256257 dtype : np .dtype ,
257258 shape : Sequence [int ],
258259 shardings : jax .sharding .Sharding ,
259- devices : Union [ Sequence [jax .Device ], np .ndarray ] ,
260+ devices : Sequence [jax .Device ] | np .ndarray ,
260261 timeout : datetime .timedelta ,
261- ):
262+ ) -> jax . Array :
262263 """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
263264 read_request = get_read_request (
264265 location ,
@@ -284,9 +285,9 @@ def read_arrays(
284285 dtypes : Sequence [np .dtype ],
285286 shapes : Sequence [Sequence [int ]],
286287 shardings : Sequence [jax .sharding .Sharding ],
287- devices : Union [ Sequence [jax .Device ], np .ndarray ] ,
288+ devices : Sequence [jax .Device ] | np .ndarray ,
288289 timeout : datetime .timedelta ,
289- ) -> Tuple [Sequence [jax .Array ], concurrent .futures .Future [None ]]:
290+ ) -> tuple [Sequence [jax .Array ], concurrent .futures .Future [None ]]:
290291 """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
291292
292293 bulk_read_request = get_bulk_read_request (
0 commit comments