1414
1515import json
1616import logging
17- import unittest
1817from unittest import mock
1918
2019from absl .testing import absltest
2120from absl .testing import parameterized
2221import jax
22+ from jax import numpy as jnp
2323from pathwaysutils import profiling
2424import requests
2525
@@ -213,10 +213,13 @@ def test_lock_released_on_stop_failure(self):
213213 """Tests that the lock is released if stop_trace fails."""
214214 profiling .start_trace ("gs://test_bucket/test_dir3" )
215215 self .assertFalse (profiling ._profile_state .lock .locked ())
216- mock_result = (
217- self .mock_plugin_executable_cls .return_value .call .return_value [1 ]
216+ mock_result_fail = mock .MagicMock ()
217+ mock_result_fail .result .side_effect = RuntimeError ("stop failed" )
218+ self .mock_plugin_executable_cls .return_value .call .return_value = (
219+ mock .MagicMock (),
220+ mock_result_fail ,
218221 )
219- mock_result . result . side_effect = RuntimeError ( "stop failed" )
222+ self . mock_plugin_executable_cls . return_value . call . side_effect = None
220223 with self .assertRaisesRegex (RuntimeError , "stop failed" ):
221224 profiling .stop_trace ()
222225 self .assertFalse (profiling ._profile_state .lock .locked ())
@@ -277,6 +280,44 @@ def test_stop_trace_success(self):
277280 with self .subTest ("executable_is_none" ):
278281 self .assertIsNone (profiling ._profile_state .executable )
279282
283+ @absltest .skipIf (
284+ jax .version .__version_info__ < (0 , 9 , 2 ),
285+ "ProfileOptions requires JAX 0.9.2 or newer" ,
286+ )
287+ def test_stop_trace_with_xprof_options_passes_out_avals (self ):
288+ options = jax .profiler .ProfileOptions ()
289+ options .duration_ms = 2000
290+
291+ with mock .patch .object (
292+ profiling , "_profile_state" , autospec = True
293+ ) as mock_profile_state :
294+ request = profiling ._create_profile_request (
295+ "gs://test_bucket/test_dir" , options
296+ )
297+ mock_profile_state .profile_request = request
298+ mock_profile_state .executable = (
299+ self .mock_plugin_executable_cls .return_value
300+ )
301+ mock_profile_state .lock = mock .MagicMock ()
302+ mock_profile_state .lock .locked .return_value = True
303+ mock_profile_state .lock .__enter__ .return_value = None
304+ mock_profile_state .lock .__exit__ .return_value = None
305+
306+ profiling .stop_trace ()
307+
308+ with self .subTest ("plugin_executable_called" ):
309+ self .mock_plugin_executable_cls .return_value .call .assert_called_once ()
310+ _ , kwargs = self .mock_plugin_executable_cls .return_value .call .call_args
311+ self .assertIn ("out_avals" , kwargs )
312+ self .assertIn ("out_shardings" , kwargs )
313+
314+ with self .subTest ("out_avals_properties" ):
315+ _ , kwargs = self .mock_plugin_executable_cls .return_value .call .call_args
316+ self .assertLen (kwargs ["out_avals" ], 1 )
317+ (out_aval ,) = kwargs ["out_avals" ]
318+ self .assertEqual (out_aval .shape , (1 ,))
319+ self .assertEqual (out_aval .dtype , jnp .object_ )
320+
280321 def test_stop_trace_before_start_error (self ):
281322 with self .assertRaisesRegex (
282323 ValueError , "stop_trace called before a trace is being taken!"
@@ -406,7 +447,7 @@ def test_create_profile_request_default_options(self, profiler_options):
406447 },
407448 )
408449
409- @unittest .skipIf (
450+ @absltest .skipIf (
410451 jax .version .__version_info__ < (0 , 9 , 2 ),
411452 "ProfileOptions requires JAX 0.9.2 or newer" ,
412453 )
@@ -444,41 +485,45 @@ def test_create_profile_request_with_options(self):
444485 },
445486 )
446487
447- @unittest .skipIf (
488+ @absltest .skipIf (
448489 jax .version .__version_info__ < (0 , 9 , 2 ),
449490 "ProfileOptions requires JAX 0.9.2 or newer" ,
450491 )
451492 @parameterized .parameters (
452493 ({"traceLocation" : "gs://test_bucket/test_dir" },),
453- ({
454- "traceLocation" : "gs://test_bucket/test_dir" ,
455- "blockUntilStart" : True ,
456- "maxDurationSecs" : 10.0 ,
457- "devices" : {"deviceIds" : [1 , 2 ]},
458- "includeResourceManagers" : True ,
459- "maxNumHosts" : 5 ,
460- "xprofTraceOptions" : {
494+ (
495+ {
496+ "traceLocation" : "gs://test_bucket/test_dir" ,
461497 "blockUntilStart" : True ,
462- "traceDirectory" : "gs://test_bucket/test_dir" ,
498+ "maxDurationSecs" : 10.0 ,
499+ "devices" : {"deviceIds" : [1 , 2 ]},
500+ "includeResourceManagers" : True ,
501+ "maxNumHosts" : 5 ,
502+ "xprofTraceOptions" : {
503+ "blockUntilStart" : True ,
504+ "traceDirectory" : "gs://test_bucket/test_dir" ,
505+ },
463506 },
464- },),
465- ({
466- "traceLocation" : "gs://bucket/dir" ,
467- "xprofTraceOptions" : {
468- "hostTraceLevel" : 0 ,
469- "traceOptions" : {
470- "traceMode" : "TRACE_COMPUTE" ,
471- "numSparseCoresToTrace" : 1 ,
472- "numSparseCoreTilesToTrace" : 2 ,
473- "numChipsToProfilePerTask" : 3 ,
474- "powerTraceLevel" : 4 ,
475- "enableFwThrottleEvent" : True ,
476- "enableFwPowerLevelEvent" : True ,
477- "enableFwThermalEvent" : True ,
507+ ),
508+ (
509+ {
510+ "traceLocation" : "gs://bucket/dir" ,
511+ "xprofTraceOptions" : {
512+ "hostTraceLevel" : 0 ,
513+ "traceOptions" : {
514+ "traceMode" : "TRACE_COMPUTE" ,
515+ "numSparseCoresToTrace" : 1 ,
516+ "numSparseCoreTilesToTrace" : 2 ,
517+ "numChipsToProfilePerTask" : 3 ,
518+ "powerTraceLevel" : 4 ,
519+ "enableFwThrottleEvent" : True ,
520+ "enableFwPowerLevelEvent" : True ,
521+ "enableFwThermalEvent" : True ,
522+ },
523+ "traceDirectory" : "gs://bucket/dir" ,
478524 },
479- "traceDirectory" : "gs://bucket/dir" ,
480525 },
481- }, ),
526+ ),
482527 )
483528
484529 def test_start_pathways_trace_from_profile_request (self , profile_request ):
@@ -496,10 +541,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
496541 """Tests that original_stop_trace is called if pathways stop_trace fails."""
497542 profiling .start_trace ("gs://test_bucket/test_dir" )
498543 self .assertFalse (profiling ._profile_state .lock .locked ())
499- mock_result = (
500- self . mock_plugin_executable_cls . return_value . call . return_value [ 1 ]
544+ self . mock_plugin_executable_cls . return_value . call . side_effect = (
545+ RuntimeError ( "stop failed" )
501546 )
502- mock_result .result .side_effect = RuntimeError ("stop failed" )
503547 with self .assertRaisesRegex (RuntimeError , "stop failed" ):
504548 profiling .stop_trace ()
505549 self .mock_original_stop_trace .assert_called_once ()
0 commit comments