Skip to content

Commit 557d60b

Browse files
committed
Try to use SLURM_STEP_GPUS for device list if CUDA_VISIBLE_DEVICES is not set
1 parent 9777a4c commit 557d60b

2 files changed

Lines changed: 66 additions & 12 deletions

File tree

src/maxtext/utils/max_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Common Max Utils needed by multiple modules.
16-
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file."""
16+
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file.
17+
"""
1718

1819
import collections
1920
from collections.abc import Sequence
@@ -250,12 +251,18 @@ def initialize_jax_for_gpu(raw_keys):
250251
if os.environ.get("JAX_COORDINATOR_IP") is not None:
251252
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
252253
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
253-
devices = os.getenv("CUDA_VISIBLE_DEVICES")
254+
env_var_list = ["CUDA_VISIBLE_DEVICES", "SLURM_STEP_GPUS"]
255+
for env_var in env_var_list:
256+
devices = os.getenv(env_var)
257+
if devices is not None:
258+
max_logging.log(f"Using {env_var} to initialize JAX distributed system: {devices}")
259+
break
260+
254261
if devices is not None:
255262
try:
256263
devices = [int(x) for x in devices.split(",")]
257264
except (ValueError, TypeError) as e:
258-
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
265+
max_logging.log(f"Error parsing {env_var}: {e}")
259266
devices = None
260267

261268
jax.distributed.initialize(
@@ -848,7 +855,14 @@ def reorder_causal_load_balanced(batch, cp_size):
848855
cp_size=cp_size,
849856
)
850857
if key
851-
in ["inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation"]
858+
in [
859+
"inputs",
860+
"targets",
861+
"inputs_position",
862+
"targets_position",
863+
"inputs_segmentation",
864+
"targets_segmentation",
865+
]
852866
else value
853867
for key, value in batch.items()
854868
}

tests/unit/max_utils_test.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_unscan_train_state_params(self):
201201

202202

203203
class TestGpuDistributedInitialization(unittest.TestCase):
204-
"""Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
204+
"""Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize."""
205205

206206
@mock.patch.dict(
207207
os.environ,
@@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase):
210210
"JAX_COORDINATOR_PORT": "1234",
211211
"NNODES": "1",
212212
"NODE_RANK": "0",
213-
"CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment
213+
"CUDA_VISIBLE_DEVICES": "0,2,3",
214214
},
215215
)
216216
@mock.patch("jax.distributed.initialize")
@@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
220220
"""Verifies that a comma-separated string of IDs is correctly parsed."""
221221
raw_keys = {"jax_distributed_initialization_timeout": 300}
222222
max_utils.initialize_jax_for_gpu(raw_keys)
223-
# Check that local_device_ids was passed correctly as a list of integers
224223
_, kwargs = mock_init.call_args
225224
self.assertEqual(kwargs["local_device_ids"], [0, 2, 3])
226225
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
@@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
232231
"JAX_COORDINATOR_PORT": "1234",
233232
"NNODES": "1",
234233
"NODE_RANK": "0",
235-
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing
234+
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...",
236235
},
237236
)
238237
@mock.patch("jax.distributed.initialize")
239238
@mock.patch("jax.devices")
240239
@mock.patch("maxtext.utils.max_logging.log")
241-
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init):
240+
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, _mock_devices, mock_init):
242241
"""Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
243242
raw_keys = {"jax_distributed_initialization_timeout": 300}
244243
max_utils.initialize_jax_for_gpu(raw_keys)
245-
# Check that it falls back to None (JAX auto-detection default) on error
246244
_, kwargs = mock_init.call_args
247245
self.assertIsNone(kwargs.get("local_device_ids"))
248246
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
@@ -259,14 +257,56 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m
259257
@mock.patch("jax.distributed.initialize")
260258
@mock.patch("jax.devices")
261259
@mock.patch("maxtext.utils.max_logging.log")
262-
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init):
263-
"""Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
260+
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, _mock_devices, mock_init):
261+
"""Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set."""
264262
raw_keys = {"jax_distributed_initialization_timeout": 300}
265263
max_utils.initialize_jax_for_gpu(raw_keys)
266264
_, kwargs = mock_init.call_args
267265
self.assertIsNone(kwargs.get("local_device_ids"))
268266
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
269267

268+
@mock.patch("jax.distributed.initialize")
269+
@mock.patch("jax.devices")
270+
@mock.patch("maxtext.utils.max_logging.log")
271+
def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset(self, mock_log, _mock_devices, mock_init):
272+
"""Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list)."""
273+
env = {
274+
"JAX_COORDINATOR_IP": "10.0.0.1",
275+
"JAX_COORDINATOR_PORT": "1234",
276+
"NNODES": "1",
277+
"NODE_RANK": "0",
278+
"SLURM_STEP_GPUS": "1,3",
279+
}
280+
with mock.patch.dict(os.environ, env, clear=False):
281+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
282+
raw_keys = {"jax_distributed_initialization_timeout": 300}
283+
max_utils.initialize_jax_for_gpu(raw_keys)
284+
_, kwargs = mock_init.call_args
285+
self.assertEqual(kwargs["local_device_ids"], [1, 3])
286+
mock_log.assert_any_call("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3")
287+
288+
@mock.patch.dict(
289+
os.environ,
290+
{
291+
"JAX_COORDINATOR_IP": "10.0.0.1",
292+
"JAX_COORDINATOR_PORT": "1234",
293+
"NNODES": "1",
294+
"NODE_RANK": "0",
295+
"CUDA_VISIBLE_DEVICES": "0,2",
296+
"SLURM_STEP_GPUS": "4,5,6",
297+
},
298+
)
299+
@mock.patch("jax.distributed.initialize")
300+
@mock.patch("jax.devices")
301+
@mock.patch("maxtext.utils.max_logging.log")
302+
def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop(self, mock_log, _mock_devices, mock_init):
303+
"""First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS."""
304+
raw_keys = {"jax_distributed_initialization_timeout": 300}
305+
max_utils.initialize_jax_for_gpu(raw_keys)
306+
_, kwargs = mock_init.call_args
307+
self.assertEqual(kwargs["local_device_ids"], [0, 2])
308+
mock_log.assert_any_call("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2")
309+
270310

271311
if __name__ == "__main__":
272312
unittest.main()

0 commit comments

Comments
 (0)