@@ -201,7 +201,7 @@ def test_unscan_train_state_params(self):
201201
202202
203203class TestGpuDistributedInitialization (unittest .TestCase ):
204- """Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
204+ """Tests using CUDA_VISIBLE_DEVICES / SLURM_STEP_GPUS for jax.distributed.initialize."""
205205
206206 @mock .patch .dict (
207207 os .environ ,
@@ -210,7 +210,7 @@ class TestGpuDistributedInitialization(unittest.TestCase):
210210 "JAX_COORDINATOR_PORT" : "1234" ,
211211 "NNODES" : "1" ,
212212 "NODE_RANK" : "0" ,
213- "CUDA_VISIBLE_DEVICES" : "0,2,3" , # Simulating Slurm/orchestrator assignment
213+ "CUDA_VISIBLE_DEVICES" : "0,2,3" ,
214214 },
215215 )
216216 @mock .patch ("jax.distributed.initialize" )
@@ -220,7 +220,6 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
220220 """Verifies that a comma-separated string of IDs is correctly parsed."""
221221 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
222222 max_utils .initialize_jax_for_gpu (raw_keys )
223- # Check that local_device_ids was passed correctly as a list of integers
224223 _ , kwargs = mock_init .call_args
225224 self .assertEqual (kwargs ["local_device_ids" ], [0 , 2 , 3 ])
226225 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
@@ -232,17 +231,16 @@ def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mo
232231 "JAX_COORDINATOR_PORT" : "1234" ,
233232 "NNODES" : "1" ,
234233 "NODE_RANK" : "0" ,
235- "CUDA_VISIBLE_DEVICES" : "GPU-8f2e3072-..." , # Invalid format for integer parsing
234+ "CUDA_VISIBLE_DEVICES" : "GPU-8f2e3072-..." ,
236235 },
237236 )
238237 @mock .patch ("jax.distributed.initialize" )
239238 @mock .patch ("jax.devices" )
240239 @mock .patch ("maxtext.utils.max_logging.log" )
241- def test_initialize_jax_for_gpu_invalid_devices (self , _mock_log , mock_devices , mock_init ):
240+ def test_initialize_jax_for_gpu_invalid_devices (self , _mock_log , _mock_devices , mock_init ):
242241 """Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
243242 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
244243 max_utils .initialize_jax_for_gpu (raw_keys )
245- # Check that it falls back to None (JAX auto-detection default) on error
246244 _ , kwargs = mock_init .call_args
247245 self .assertIsNone (kwargs .get ("local_device_ids" ))
248246 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
@@ -255,17 +253,69 @@ def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, m
255253 "NNODES" : "1" ,
256254 "NODE_RANK" : "0" ,
257255 },
256+ clear = True ,
258257 )
258+ @mock .patch ("jax.config.update" )
259259 @mock .patch ("jax.distributed.initialize" )
260260 @mock .patch ("jax.devices" )
261261 @mock .patch ("maxtext.utils.max_logging.log" )
262- def test_initialize_jax_for_gpu_no_devices (self , _mock_log , mock_devices , mock_init ):
263- """Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
262+ def test_initialize_jax_for_gpu_no_devices (self , _mock_log , _mock_devices , mock_init , mock_config_update ):
263+ """When coordinator env is set but neither CUDA_VISIBLE_DEVICES nor SLURM_STEP_GPUS is set, JAX uses all devices
264+ (config) and init gets no local ids.
265+ """
264266 raw_keys = {"jax_distributed_initialization_timeout" : 300 }
265267 max_utils .initialize_jax_for_gpu (raw_keys )
266268 _ , kwargs = mock_init .call_args
267269 self .assertIsNone (kwargs .get ("local_device_ids" ))
268270 self .assertEqual (kwargs ["coordinator_address" ], "10.0.0.1:1234" )
271+ mock_config_update .assert_has_calls (
272+ [
273+ mock .call ("jax_cuda_visible_devices" , "all" ),
274+ mock .call ("jax_rocm_visible_devices" , "all" ),
275+ ]
276+ )
277+
278+ @mock .patch ("jax.distributed.initialize" )
279+ @mock .patch ("jax.devices" )
280+ @mock .patch ("maxtext.utils.max_logging.log" )
281+ def test_initialize_jax_for_gpu_uses_slurm_when_cuda_unset (self , mock_log , _mock_devices , mock_init ):
282+ """Uses SLURM_STEP_GPUS when CUDA_VISIBLE_DEVICES is absent (loop over env_var_list)."""
283+ env = {
284+ "JAX_COORDINATOR_IP" : "10.0.0.1" ,
285+ "JAX_COORDINATOR_PORT" : "1234" ,
286+ "NNODES" : "1" ,
287+ "NODE_RANK" : "0" ,
288+ "SLURM_STEP_GPUS" : "1,3" ,
289+ }
290+ with mock .patch .dict (os .environ , env , clear = False ):
291+ os .environ .pop ("CUDA_VISIBLE_DEVICES" , None )
292+ raw_keys = {"jax_distributed_initialization_timeout" : 300 }
293+ max_utils .initialize_jax_for_gpu (raw_keys )
294+ _ , kwargs = mock_init .call_args
295+ self .assertEqual (kwargs ["local_device_ids" ], [1 , 3 ])
296+ mock_log .assert_any_call ("Using SLURM_STEP_GPUS to initialize JAX distributed system: 1,3" )
297+
298+ @mock .patch .dict (
299+ os .environ ,
300+ {
301+ "JAX_COORDINATOR_IP" : "10.0.0.1" ,
302+ "JAX_COORDINATOR_PORT" : "1234" ,
303+ "NNODES" : "1" ,
304+ "NODE_RANK" : "0" ,
305+ "CUDA_VISIBLE_DEVICES" : "0,2" ,
306+ "SLURM_STEP_GPUS" : "4,5,6" ,
307+ },
308+ )
309+ @mock .patch ("jax.distributed.initialize" )
310+ @mock .patch ("jax.devices" )
311+ @mock .patch ("maxtext.utils.max_logging.log" )
312+ def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop (self , mock_log , _mock_devices , mock_init ):
313+ """First matching env var in the list wins; CUDA_VISIBLE_DEVICES is checked before SLURM_STEP_GPUS."""
314+ raw_keys = {"jax_distributed_initialization_timeout" : 300 }
315+ max_utils .initialize_jax_for_gpu (raw_keys )
316+ _ , kwargs = mock_init .call_args
317+ self .assertEqual (kwargs ["local_device_ids" ], [0 , 2 ])
318+ mock_log .assert_any_call ("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2" )
269319
270320
271321if __name__ == "__main__" :
0 commit comments