@@ -201,7 +201,7 @@ def test_unscan_train_state_params(self):
201201
202202
203203class TestGpuDistributedInitialization (unittest .TestCase ):
204- """Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
204+ """Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize."""
205205
206206 @mock .patch .dict (
207207 os .environ ,
@@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase):
210210 "JAX_COORDINATOR_PORT" : "1234" ,
211211 "NNODES" : "1" ,
212212 "NODE_RANK" : "0" ,
213- "CUDA_VISIBLE_DEVICES" : "0,2,3" , # Simulating Slurm/orchestrator assignment
213+ "CUDA_VISIBLE_DEVICES" : "0,2,3" ,
214214 },
215215 )
216216 @mock .patch ("jax.distributed.initialize" )
@@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
220220 """Verifies that a comma-separated string of IDs is correctly parsed."""
221221 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
222222 max_utils .initialize_jax_for_gpu (raw_keys )
223- # Check that local_device_ids was passed correctly as a list of integers
224223 _ , kwargs = mock_init .call_args
225224 self .assertEqual (kwargs ["local_device_ids" ], [0 , 2 , 3 ])
226225 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
@@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
232231 "JAX_COORDINATOR_PORT" : "1234" ,
233232 "NNODES" : "1" ,
234233 "NODE_RANK" : "0" ,
235- "CUDA_VISIBLE_DEVICES" : "GPU-8f2e3072-..." , # Invalid format for integer parsing
234+ "CUDA_VISIBLE_DEVICES" : "GPU-8f2e3072-..." ,
236235 },
237236 )
238237 @mock .patch ("jax.distributed.initialize" )
239238 @mock .patch ("jax.devices" )
240239 @mock .patch ("maxtext.utils.max_logging.log" )
241- def test_initialize_jax_for_gpu_invalid_devices (self , _mock_log , mock_devices , mock_init ):
240+ def test_initialize_jax_for_gpu_invalid_devices (self , _mock_log , _mock_devices , mock_init ):
242241 """Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
243242 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
244243 max_utils .initialize_jax_for_gpu (raw_keys )
245- # Check that it falls back to None (JAX auto-detection default) on error
246244 _ , kwargs = mock_init .call_args
247245 self .assertIsNone (kwargs .get ("local_device_ids" ))
248246 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
@@ -259,14 +257,56 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m
259257 @mock .patch ("jax.distributed.initialize" )
260258 @mock .patch ("jax.devices" )
261259 @mock .patch ("maxtext.utils.max_logging.log" )
262- def test_initialize_jax_for_gpu_no_devices (self , _mock_log , mock_devices , mock_init ):
263- """Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
260+ def test_initialize_jax_for_gpu_no_devices (self , _mock_log , _mock_devices , mock_init ):
261+ """Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set. """
264262 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
265263 max_utils .initialize_jax_for_gpu (raw_keys )
266264 _ , kwargs = mock_init .call_args
267265 self .assertIsNone (kwargs .get ("local_device_ids" ))
268266 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
269267
268+ @mock .patch ("jax.distributed.initialize" )
269+ @mock .patch ("jax.devices" )
270+ @mock .patch ("maxtext.utils.max_logging.log" )
271+ def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset (self , mock_log , _mock_devices , mock_init ):
272+ """Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list)."""
273+ env = {
274+ "JAX_COORDINATOR_IP" : "10.0.0.1" ,
275+ "JAX_COORDINATOR_PORT" : "1234" ,
276+ "NNODES" : "1" ,
277+ "NODE_RANK" : "0" ,
278+ "SLURM_STEP_GPUS" : "1,3" ,
279+ }
280+ with mock .patch .dict (os .environ , env , clear = False ):
281+ os .environ .pop ("CUDA_VISIBLE_DEVICES" , None )
282+ raw_keys = {"jax_distributed_initialization_timeout" : 300 }
283+ max_utils .initialize_jax_for_gpu (raw_keys )
284+ _ , kwargs = mock_init .call_args
285+ self .assertEqual (kwargs ["local_device_ids" ], [1 , 3 ])
286+ mock_log .assert_any_call ("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3" )
287+
288+ @mock .patch .dict (
289+ os .environ ,
290+ {
291+ "JAX_COORDINATOR_IP" : "10.0.0.1" ,
292+ "JAX_COORDINATOR_PORT" : "1234" ,
293+ "NNODES" : "1" ,
294+ "NODE_RANK" : "0" ,
295+ "CUDA_VISIBLE_DEVICES" : "0,2" ,
296+ "SLURM_STEP_GPUS" : "4,5,6" ,
297+ },
298+ )
299+ @mock .patch ("jax.distributed.initialize" )
300+ @mock .patch ("jax.devices" )
301+ @mock .patch ("maxtext.utils.max_logging.log" )
302+ def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop (self , mock_log , _mock_devices , mock_init ):
303+ """First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS."""
304+ raw_keys = {"jax_distributed_initialization_timeout" : 300 }
305+ max_utils .initialize_jax_for_gpu (raw_keys )
306+ _ , kwargs = mock_init .call_args
307+ self .assertEqual (kwargs ["local_device_ids" ], [0 , 2 ])
308+ mock_log .assert_any_call ("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2" )
309+
270310
271311if __name__ == "__main__" :
272312 unittest .main ()
0 commit comments