Skip to content

Commit 5dcbd1c

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix JaxRuntimeError during profiler stop_trace with profile options
PiperOrigin-RevId: 888268746
1 parent bc6b635 commit 5dcbd1c

File tree

2 files changed

+97
-36
lines changed

2 files changed

+97
-36
lines changed

pathwaysutils/profiling.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
"""Profiling Utilities."""
1515

1616
import asyncio
17+
from collections.abc import Mapping
1718
import dataclasses
1819
import json
1920
import logging
2021
import os
2122
import threading
22-
from typing import Any, Mapping
23+
from typing import Any
2324
import urllib.parse
2425

2526
import fastapi
@@ -35,14 +36,17 @@
3536

3637
class _ProfileState:
3738
executable: plugin_executable.PluginExecutable | None = None
39+
profile_request: Mapping[str, Any] | None = None
3840
lock: threading.Lock
3941

4042
def __init__(self) -> None:
4143
self.executable = None
44+
self.profile_request = None
4245
self.lock = threading.Lock()
4346

4447
def reset(self) -> None:
4548
self.executable = None
49+
self.profile_request = None
4650

4751

4852
_first_profile_start = True
@@ -153,6 +157,7 @@ def _start_pathways_trace_from_profile_request(
153157
_profile_state.executable = plugin_executable.PluginExecutable(
154158
json.dumps({"profileRequest": profile_request})
155159
)
160+
_profile_state.profile_request = profile_request
156161
try:
157162
_, result_future = _profile_state.executable.call()
158163
result_future.result()
@@ -233,7 +238,19 @@ def stop_trace() -> None:
233238
if _profile_state.executable is None:
234239
raise ValueError("stop_trace called before a trace is being taken!")
235240
try:
236-
_, result_future = _profile_state.executable.call()
241+
if (
242+
_profile_state.profile_request is not None
243+
and "xprofTraceOptions" in _profile_state.profile_request
244+
):
245+
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
246+
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
247+
else:
248+
out_avals = ()
249+
out_shardings = ()
250+
251+
_, result_future = _profile_state.executable.call(
252+
out_avals=out_avals, out_shardings=out_shardings
253+
)
237254
result_future.result()
238255
finally:
239256
_profile_state.reset()

pathwaysutils/test/profiling_test.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
import json
1616
import logging
17-
import unittest
1817
from unittest import mock
1918

2019
from absl.testing import absltest
2120
from absl.testing import parameterized
2221
import jax
22+
from jax import numpy as jnp
2323
from pathwaysutils import profiling
2424
import requests
2525

@@ -213,10 +213,13 @@ def test_lock_released_on_stop_failure(self):
213213
"""Tests that the lock is released if stop_trace fails."""
214214
profiling.start_trace("gs://test_bucket/test_dir3")
215215
self.assertFalse(profiling._profile_state.lock.locked())
216-
mock_result = (
217-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
216+
mock_result_fail = mock.MagicMock()
217+
mock_result_fail.result.side_effect = RuntimeError("stop failed")
218+
self.mock_plugin_executable_cls.return_value.call.return_value = (
219+
mock.MagicMock(),
220+
mock_result_fail,
218221
)
219-
mock_result.result.side_effect = RuntimeError("stop failed")
222+
self.mock_plugin_executable_cls.return_value.call.side_effect = None
220223
with self.assertRaisesRegex(RuntimeError, "stop failed"):
221224
profiling.stop_trace()
222225
self.assertFalse(profiling._profile_state.lock.locked())
@@ -277,6 +280,44 @@ def test_stop_trace_success(self):
277280
with self.subTest("executable_is_none"):
278281
self.assertIsNone(profiling._profile_state.executable)
279282

283+
@absltest.skipIf(
284+
jax.version.__version_info__ < (0, 9, 2),
285+
"ProfileOptions requires JAX 0.9.2 or newer",
286+
)
287+
def test_stop_trace_with_xprof_options_passes_out_avals(self):
288+
options = jax.profiler.ProfileOptions()
289+
options.duration_ms = 2000
290+
291+
with mock.patch.object(
292+
profiling, "_profile_state", autospec=True
293+
) as mock_profile_state:
294+
request = profiling._create_profile_request(
295+
"gs://test_bucket/test_dir", options
296+
)
297+
mock_profile_state.profile_request = request
298+
mock_profile_state.executable = (
299+
self.mock_plugin_executable_cls.return_value
300+
)
301+
mock_profile_state.lock = mock.MagicMock()
302+
mock_profile_state.lock.locked.return_value = True
303+
mock_profile_state.lock.__enter__.return_value = None
304+
mock_profile_state.lock.__exit__.return_value = None
305+
306+
profiling.stop_trace()
307+
308+
with self.subTest("plugin_executable_called"):
309+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
310+
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
311+
self.assertIn("out_avals", kwargs)
312+
self.assertIn("out_shardings", kwargs)
313+
314+
with self.subTest("out_avals_properties"):
315+
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
316+
self.assertLen(kwargs["out_avals"], 1)
317+
(out_aval,) = kwargs["out_avals"]
318+
self.assertEqual(out_aval.shape, (1,))
319+
self.assertEqual(out_aval.dtype, jnp.object_)
320+
280321
def test_stop_trace_before_start_error(self):
281322
with self.assertRaisesRegex(
282323
ValueError, "stop_trace called before a trace is being taken!"
@@ -406,7 +447,7 @@ def test_create_profile_request_default_options(self, profiler_options):
406447
},
407448
)
408449

409-
@unittest.skipIf(
450+
@absltest.skipIf(
410451
jax.version.__version_info__ < (0, 9, 2),
411452
"ProfileOptions requires JAX 0.9.2 or newer",
412453
)
@@ -444,41 +485,45 @@ def test_create_profile_request_with_options(self):
444485
},
445486
)
446487

447-
@unittest.skipIf(
488+
@absltest.skipIf(
448489
jax.version.__version_info__ < (0, 9, 2),
449490
"ProfileOptions requires JAX 0.9.2 or newer",
450491
)
451492
@parameterized.parameters(
452493
({"traceLocation": "gs://test_bucket/test_dir"},),
453-
({
454-
"traceLocation": "gs://test_bucket/test_dir",
455-
"blockUntilStart": True,
456-
"maxDurationSecs": 10.0,
457-
"devices": {"deviceIds": [1, 2]},
458-
"includeResourceManagers": True,
459-
"maxNumHosts": 5,
460-
"xprofTraceOptions": {
494+
(
495+
{
496+
"traceLocation": "gs://test_bucket/test_dir",
461497
"blockUntilStart": True,
462-
"traceDirectory": "gs://test_bucket/test_dir",
498+
"maxDurationSecs": 10.0,
499+
"devices": {"deviceIds": [1, 2]},
500+
"includeResourceManagers": True,
501+
"maxNumHosts": 5,
502+
"xprofTraceOptions": {
503+
"blockUntilStart": True,
504+
"traceDirectory": "gs://test_bucket/test_dir",
505+
},
463506
},
464-
},),
465-
({
466-
"traceLocation": "gs://bucket/dir",
467-
"xprofTraceOptions": {
468-
"hostTraceLevel": 0,
469-
"traceOptions": {
470-
"traceMode": "TRACE_COMPUTE",
471-
"numSparseCoresToTrace": 1,
472-
"numSparseCoreTilesToTrace": 2,
473-
"numChipsToProfilePerTask": 3,
474-
"powerTraceLevel": 4,
475-
"enableFwThrottleEvent": True,
476-
"enableFwPowerLevelEvent": True,
477-
"enableFwThermalEvent": True,
507+
),
508+
(
509+
{
510+
"traceLocation": "gs://bucket/dir",
511+
"xprofTraceOptions": {
512+
"hostTraceLevel": 0,
513+
"traceOptions": {
514+
"traceMode": "TRACE_COMPUTE",
515+
"numSparseCoresToTrace": 1,
516+
"numSparseCoreTilesToTrace": 2,
517+
"numChipsToProfilePerTask": 3,
518+
"powerTraceLevel": 4,
519+
"enableFwThrottleEvent": True,
520+
"enableFwPowerLevelEvent": True,
521+
"enableFwThermalEvent": True,
522+
},
523+
"traceDirectory": "gs://bucket/dir",
478524
},
479-
"traceDirectory": "gs://bucket/dir",
480525
},
481-
},),
526+
),
482527
)
483528

484529
def test_start_pathways_trace_from_profile_request(self, profile_request):
@@ -496,10 +541,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
496541
"""Tests that original_stop_trace is called if pathways stop_trace fails."""
497542
profiling.start_trace("gs://test_bucket/test_dir")
498543
self.assertFalse(profiling._profile_state.lock.locked())
499-
mock_result = (
500-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
544+
self.mock_plugin_executable_cls.return_value.call.side_effect = (
545+
RuntimeError("stop failed")
501546
)
502-
mock_result.result.side_effect = RuntimeError("stop failed")
503547
with self.assertRaisesRegex(RuntimeError, "stop failed"):
504548
profiling.stop_trace()
505549
self.mock_original_stop_trace.assert_called_once()

0 commit comments

Comments
 (0)