Skip to content

Commit 2b9cebb

Browse files
Merge pull request #3577 from ROCm:gw_check_slurm_gpus
PiperOrigin-RevId: 896643586
2 parents 18e9fa1 + f765297 commit 2b9cebb

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

src/maxtext/utils/max_utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Common Max Utils needed by multiple modules.
16-
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file."""
16+
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file.
17+
"""
1718

1819
import collections
1920
from collections.abc import Sequence
@@ -288,12 +289,21 @@ def initialize_jax_for_gpu(raw_keys):
288289
if os.environ.get("JAX_COORDINATOR_IP") is not None:
289290
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
290291
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
291-
devices = os.getenv("CUDA_VISIBLE_DEVICES")
292+
env_var_list = ["CUDA_VISIBLE_DEVICES", "SLURM_STEP_GPUS"]
293+
for env_var in env_var_list:
294+
devices = os.getenv(env_var)
295+
if devices is not None:
296+
max_logging.log(f"Using {env_var} to initialize JAX distributed system: {devices}")
297+
break
298+
if devices is None:
299+
jax.config.update("jax_cuda_visible_devices", "all")
300+
jax.config.update("jax_rocm_visible_devices", "all")
301+
292302
if devices is not None:
293303
try:
294304
devices = [int(x) for x in devices.split(",")]
295305
except (ValueError, TypeError) as e:
296-
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
306+
max_logging.log(f"Error parsing {env_var}: {e}")
297307
devices = None
298308

299309
jax.distributed.initialize(
@@ -886,7 +896,14 @@ def reorder_causal_load_balanced(batch, cp_size):
886896
cp_size=cp_size,
887897
)
888898
if key
889-
in ["inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation"]
899+
in [
900+
"inputs",
901+
"targets",
902+
"inputs_position",
903+
"targets_position",
904+
"inputs_segmentation",
905+
"targets_segmentation",
906+
]
890907
else value
891908
for key, value in batch.items()
892909
}

tests/unit/max_utils_test.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_unscan_train_state_params(self):
201201

202202

203203
class TestGpuDistributedInitialization(unittest.TestCase):
204-
"""Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
204+
"""Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize."""
205205

206206
@mock.patch.dict(
207207
os.environ,
@@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase):
210210
"JAX_COORDINATOR_PORT": "1234",
211211
"NNODES": "1",
212212
"NODE_RANK": "0",
213-
"CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment
213+
"CUDA_VISIBLE_DEVICES": "0,2,3",
214214
},
215215
)
216216
@mock.patch("jax.distributed.initialize")
@@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
220220
"""Verifies that a comma-separated string of IDs is correctly parsed."""
221221
raw_keys = {"jax_distributed_initialization_timeout": 300}
222222
max_utils.initialize_jax_for_gpu(raw_keys)
223-
# Check that local_device_ids was passed correctly as a list of integers
224223
_, kwargs = mock_init.call_args
225224
self.assertEqual(kwargs["local_device_ids"], [0, 2, 3])
226225
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
@@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
232231
"JAX_COORDINATOR_PORT": "1234",
233232
"NNODES": "1",
234233
"NODE_RANK": "0",
235-
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing
234+
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...",
236235
},
237236
)
238237
@mock.patch("jax.distributed.initialize")
239238
@mock.patch("jax.devices")
240239
@mock.patch("maxtext.utils.max_logging.log")
241-
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init):
240+
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, _mock_devices, mock_init):
242241
"""Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
243242
raw_keys = {"jax_distributed_initialization_timeout": 300}
244243
max_utils.initialize_jax_for_gpu(raw_keys)
245-
# Check that it falls back to None (JAX auto-detection default) on error
246244
_, kwargs = mock_init.call_args
247245
self.assertIsNone(kwargs.get("local_device_ids"))
248246
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
@@ -255,17 +253,69 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m
255253
"NNODES": "1",
256254
"NODE_RANK": "0",
257255
},
256+
clear=True,
258257
)
258+
@mock.patch("jax.config.update")
259259
@mock.patch("jax.distributed.initialize")
260260
@mock.patch("jax.devices")
261261
@mock.patch("maxtext.utils.max_logging.log")
262-
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init):
263-
"""Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
262+
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, _mock_devices, mock_init, mock_config_update):
263+
"""When coordinator env is set but neither CUDA_VISIBLE_DEVICES nor SLURM_STEP_GPUS is set, JAX uses all devices
264+
(config) and init gets no local ids.
265+
"""
264266
raw_keys = {"jax_distributed_initialization_timeout": 300}
265267
max_utils.initialize_jax_for_gpu(raw_keys)
266268
_, kwargs = mock_init.call_args
267269
self.assertIsNone(kwargs.get("local_device_ids"))
268270
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
271+
mock_config_update.assert_has_calls(
272+
[
273+
mock.call("jax_cuda_visible_devices", "all"),
274+
mock.call("jax_rocm_visible_devices", "all"),
275+
]
276+
)
277+
278+
@mock.patch("jax.distributed.initialize")
279+
@mock.patch("jax.devices")
280+
@mock.patch("maxtext.utils.max_logging.log")
281+
def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset(self, mock_log, _mock_devices, mock_init):
282+
"""Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list)."""
283+
env = {
284+
"JAX_COORDINATOR_IP": "10.0.0.1",
285+
"JAX_COORDINATOR_PORT": "1234",
286+
"NNODES": "1",
287+
"NODE_RANK": "0",
288+
"SLURM_STEP_GPUS": "1,3",
289+
}
290+
with mock.patch.dict(os.environ, env, clear=False):
291+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
292+
raw_keys = {"jax_distributed_initialization_timeout": 300}
293+
max_utils.initialize_jax_for_gpu(raw_keys)
294+
_, kwargs = mock_init.call_args
295+
self.assertEqual(kwargs["local_device_ids"], [1, 3])
296+
mock_log.assert_any_call("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3")
297+
298+
@mock.patch.dict(
299+
os.environ,
300+
{
301+
"JAX_COORDINATOR_IP": "10.0.0.1",
302+
"JAX_COORDINATOR_PORT": "1234",
303+
"NNODES": "1",
304+
"NODE_RANK": "0",
305+
"CUDA_VISIBLE_DEVICES": "0,2",
306+
"SLURM_STEP_GPUS": "4,5,6",
307+
},
308+
)
309+
@mock.patch("jax.distributed.initialize")
310+
@mock.patch("jax.devices")
311+
@mock.patch("maxtext.utils.max_logging.log")
312+
def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop(self, mock_log, _mock_devices, mock_init):
313+
"""First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS."""
314+
raw_keys = {"jax_distributed_initialization_timeout": 300}
315+
max_utils.initialize_jax_for_gpu(raw_keys)
316+
_, kwargs = mock_init.call_args
317+
self.assertEqual(kwargs["local_device_ids"], [0, 2])
318+
mock_log.assert_any_call("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2")
269319

270320

271321
if __name__ == "__main__":

0 commit comments

Comments
 (0)