Skip to content

Commit b5cffaf

Browse files
lukebaumanncopybara-github
authored andcommitted
Enable Pathways profiling with jax.profiler.ProfileOptions.
This change allows users to configure Pathways profiling by passing a jax.profiler.ProfileOptions object to the start_trace function. The options are translated into the Pathways profile request, enabling control over a subset of parameters. Explicitly, `start_timestamp_ms`, `duration_ms`, `host_tracer_level`, `advanced_configuration`, and `python_tracer_level`. PiperOrigin-RevId: 885730249
1 parent a57c2a0 commit b5cffaf

2 files changed

Lines changed: 97 additions & 9 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,22 @@ def toy_computation() -> None:
5959

6060
def _create_profile_request(
6161
log_dir: os.PathLike[str] | str,
62+
profiler_options: jax.profiler.ProfileOptions | None = None,
6263
) -> Mapping[str, Any]:
6364
"""Creates a profile request mapping from the given options."""
64-
profile_request = {}
65-
profile_request["traceLocation"] = str(log_dir)
65+
if profiler_options is None:
66+
profiler_options = jax.profiler.ProfileOptions()
67+
68+
profile_request = {
69+
"traceLocation": str(log_dir),
70+
"profilingStartTimeNs": profiler_options.start_timestamp_ns,
71+
"profilingDurationMs": profiler_options.duration_ms,
72+
"hostTraceLevel": profiler_options.host_tracer_level,
73+
"pwTraceOptions": {
74+
"advancedConfiguration": profiler_options.advanced_configuration,
75+
"enablePythonTracer": bool(profiler_options.python_tracer_level),
76+
},
77+
}
6678

6779
return profile_request
6880

@@ -104,7 +116,7 @@ def start_trace(
104116
*,
105117
create_perfetto_link: bool = False,
106118
create_perfetto_trace: bool = False,
107-
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
119+
profiler_options: jax.profiler.ProfileOptions | None = None,
108120
) -> None:
109121
"""Starts a profiler trace.
110122
@@ -133,7 +145,6 @@ def start_trace(
133145
This feature is experimental for Pathways on Cloud and may not be fully
134146
supported.
135147
profiler_options: Profiler options to configure the profiler for collection.
136-
Options are not currently supported and ignored.
137148
"""
138149
if not str(log_dir).startswith("gs://"):
139150
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -144,7 +155,11 @@ def start_trace(
144155
"features for Pathways on Cloud and may not be fully supported."
145156
)
146157

147-
_start_pathways_trace_from_profile_request(_create_profile_request(log_dir))
158+
profile_request = _create_profile_request(log_dir, profiler_options)
159+
160+
_logger.debug("Profile request: %s", profile_request)
161+
162+
_start_pathways_trace_from_profile_request(profile_request)
148163

149164
_original_start_trace(
150165
log_dir=log_dir,

pathwaysutils/test/profiling_test.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,18 @@ def test_start_trace_success(self):
225225

226226
self.mock_toy_computation.assert_called_once()
227227
self.mock_plugin_executable_cls.assert_called_once_with(
228-
json.dumps(
229-
{"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}}
230-
)
228+
json.dumps({
229+
"profileRequest": {
230+
"traceLocation": "gs://test_bucket/test_dir",
231+
"profilingStartTimeNs": 0,
232+
"profilingDurationMs": 0,
233+
"hostTraceLevel": 2,
234+
"pwTraceOptions": {
235+
"advancedConfiguration": {},
236+
"enablePythonTracer": True,
237+
},
238+
}
239+
})
231240
)
232241
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
233242
self.mock_original_start_trace.assert_called_once_with(
@@ -393,7 +402,71 @@ def test_monkey_patched_stop_server(self):
393402

394403
def test_create_profile_request_no_options(self):
395404
request = profiling._create_profile_request("gs://bucket/dir")
396-
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
405+
self.assertEqual(
406+
request,
407+
{
408+
"traceLocation": "gs://bucket/dir",
409+
"profilingStartTimeNs": 0,
410+
"profilingDurationMs": 0,
411+
"hostTraceLevel": 2,
412+
"pwTraceOptions": {
413+
"advancedConfiguration": {},
414+
"enablePythonTracer": True,
415+
},
416+
},
417+
)
418+
419+
def test_create_profile_request_default_options(self):
420+
options = jax.profiler.ProfileOptions()
421+
request = profiling._create_profile_request(
422+
"gs://bucket/dir", profiler_options=options
423+
)
424+
self.assertEqual(
425+
request,
426+
{
427+
"traceLocation": "gs://bucket/dir",
428+
"profilingStartTimeNs": 0,
429+
"profilingDurationMs": 0,
430+
"hostTraceLevel": 2,
431+
"pwTraceOptions": {
432+
"advancedConfiguration": {},
433+
"enablePythonTracer": True,
434+
},
435+
},
436+
)
437+
438+
def test_create_profile_request_with_options(self):
439+
options = jax.profiler.ProfileOptions()
440+
options.host_tracer_level = 2
441+
options.python_tracer_level = 1
442+
options.duration_ms = 2000
443+
options.start_timestamp_ns = 123456789
444+
options.advanced_configuration = {
445+
"tpu_num_chips_to_profile_per_task": 3,
446+
"tpu_num_sparse_core_tiles_to_trace": 5,
447+
"tpu_trace_mode": "TRACE_COMPUTE",
448+
}
449+
450+
request = profiling._create_profile_request(
451+
"gs://bucket/dir", profiler_options=options
452+
)
453+
self.assertEqual(
454+
request,
455+
{
456+
"traceLocation": "gs://bucket/dir",
457+
"hostTraceLevel": 2,
458+
"profilingDurationMs": 2000,
459+
"profilingStartTimeNs": 123456789,
460+
"pwTraceOptions": {
461+
"enablePythonTracer": True,
462+
"advancedConfiguration": {
463+
"tpu_num_chips_to_profile_per_task": 3,
464+
"tpu_num_sparse_core_tiles_to_trace": 5,
465+
"tpu_trace_mode": "TRACE_COMPUTE",
466+
},
467+
},
468+
},
469+
)
397470

398471
@parameterized.parameters(
399472
({"traceLocation": "gs://test_bucket/test_dir"},),

0 commit comments

Comments
 (0)