@@ -57,12 +57,75 @@ def toy_computation() -> None:
5757 x .block_until_ready ()
5858
5959
60+ def _is_default_profile_options (
61+ profiler_options : jax .profiler .ProfileOptions ,
62+ ) -> bool :
63+ if jax .version .__version_info__ < (0 , 9 , 2 ):
64+ return True
65+
66+ default_options = jax .profiler .ProfileOptions ()
67+ return (
68+ profiler_options .host_tracer_level == default_options .host_tracer_level
69+ and profiler_options .python_tracer_level
70+ == default_options .python_tracer_level
71+ and profiler_options .duration_ms == default_options .duration_ms
72+ and not getattr (profiler_options , "advanced_configuration" , None )
73+ )
74+
75+
6076def _create_profile_request (
6177 log_dir : os .PathLike [str ] | str ,
78+ profiler_options : jax .profiler .ProfileOptions | None = None ,
6279) -> Mapping [str , Any ]:
6380 """Creates a profile request mapping from the given options."""
64- profile_request = {}
65- profile_request ["traceLocation" ] = str (log_dir )
81+ profile_request : dict [str , Any ] = {
82+ "traceLocation" : str (log_dir ),
83+ }
84+
85+ if profiler_options is None or _is_default_profile_options (profiler_options ):
86+ return profile_request
87+
88+ advanced_config = None
89+ if getattr (profiler_options , "advanced_configuration" , None ):
90+ advanced_config = {}
91+ for k , v in getattr (profiler_options , "advanced_configuration" ).items ():
92+ # Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
93+ # json-compatible dict
94+ if isinstance (v , bool ):
95+ advanced_config [k ] = {"boolValue" : v }
96+ elif isinstance (v , int ):
97+ advanced_config [k ] = {"intValue" : v }
98+ elif isinstance (v , str ):
99+ advanced_config [k ] = {"stringValue" : v }
100+ else :
101+ raise ValueError (
102+ f"Unsupported advanced configuration value type: { type (v )} . "
103+ "Supported types are bool, int, and str."
104+ )
105+
106+ xprof_options : dict [str , Any ] = {
107+ "traceDirectory" : str (log_dir ),
108+ }
109+
110+ if profiler_options .host_tracer_level != 2 :
111+ xprof_options ["hostTraceLevel" ] = profiler_options .host_tracer_level
112+
113+ pw_trace_opts : dict [str , Any ] = {}
114+ if profiler_options .python_tracer_level :
115+ pw_trace_opts ["enablePythonTracer" ] = bool (
116+ profiler_options .python_tracer_level
117+ )
118+
119+ if advanced_config :
120+ pw_trace_opts ["advancedConfiguration" ] = advanced_config
121+
122+ if pw_trace_opts :
123+ xprof_options ["pwTraceOptions" ] = pw_trace_opts
124+
125+ profile_request ["xprofTraceOptions" ] = xprof_options
126+
127+ if profiler_options .duration_ms > 0 :
128+ profile_request ["maxDurationSecs" ] = profiler_options .duration_ms / 1000.0
66129
67130 return profile_request
68131
@@ -104,7 +167,7 @@ def start_trace(
104167 * ,
105168 create_perfetto_link : bool = False ,
106169 create_perfetto_trace : bool = False ,
107- profiler_options : jax .profiler .ProfileOptions | None = None , # pylint: disable=unused-argument
170+ profiler_options : jax .profiler .ProfileOptions | None = None ,
108171) -> None :
109172 """Starts a profiler trace.
110173
@@ -133,7 +196,6 @@ def start_trace(
133196 This feature is experimental for Pathways on Cloud and may not be fully
134197 supported.
135198 profiler_options: Profiler options to configure the profiler for collection.
136- Options are not currently supported and ignored.
137199 """
138200 if not str (log_dir ).startswith ("gs://" ):
139201 raise ValueError (f"log_dir must be a GCS bucket path, got { log_dir } " )
@@ -144,7 +206,18 @@ def start_trace(
144206 "features for Pathways on Cloud and may not be fully supported."
145207 )
146208
147- _start_pathways_trace_from_profile_request (_create_profile_request (log_dir ))
209+ if jax .version .__version_info__ < (0 , 9 , 2 ) and profiler_options is not None :
210+ _logger .warning (
211+ "ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
212+ "Some options can be specified via command line flags."
213+ )
214+ profiler_options = None
215+
216+ profile_request = _create_profile_request (log_dir , profiler_options )
217+
218+ _logger .debug ("Profile request: %s" , profile_request )
219+
220+ _start_pathways_trace_from_profile_request (profile_request )
148221
149222 _original_start_trace (
150223 log_dir = log_dir ,
0 commit comments