Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
af96109
add a profiling worflow.
sayakpaul Mar 26, 2026
eddef12
fix
sayakpaul Mar 27, 2026
e4d6293
fix
sayakpaul Mar 27, 2026
b2b6330
more clarification
sayakpaul Mar 27, 2026
60d4148
add points.
sayakpaul Mar 27, 2026
179fa51
up
sayakpaul Mar 27, 2026
96506c8
cache hooks
sayakpaul Mar 27, 2026
6a23a77
improve readme.
sayakpaul Mar 27, 2026
bf5131f
propagate deletion.
sayakpaul Mar 27, 2026
bfbaf07
up
sayakpaul Mar 27, 2026
a410b49
up
sayakpaul Mar 27, 2026
35437a8
wan fixes.
sayakpaul Mar 27, 2026
142f417
more
sayakpaul Mar 27, 2026
9ba98a2
up
sayakpaul Mar 27, 2026
12ba8be
add more traces.
sayakpaul Mar 28, 2026
43e16fb
up
sayakpaul Mar 28, 2026
e26d5c6
better title
sayakpaul Mar 28, 2026
1131acd
cuda graphs.
sayakpaul Mar 29, 2026
c642cd0
up
sayakpaul Mar 30, 2026
ed8241a
Apply suggestions from code review
sayakpaul Mar 31, 2026
3ae7d9b
add torch.compile link.
sayakpaul Mar 31, 2026
bfb19af
approach -> How the tooling works
sayakpaul Mar 31, 2026
40a525e
table
sayakpaul Mar 31, 2026
3bdd529
Merge branch 'main' into profiling-workflow
sayakpaul Mar 31, 2026
6cf1429
unavoidable gaps.
sayakpaul Mar 31, 2026
fb6afa6
make important
sayakpaul Mar 31, 2026
40c330a
note on regional compilation
sayakpaul Mar 31, 2026
3fc1a04
Merge branch 'main' into profiling-workflow
sayakpaul Mar 31, 2026
131831f
Apply suggestions from code review
sayakpaul Mar 31, 2026
6d8e371
Merge branch 'main' into profiling-workflow
sayakpaul Apr 3, 2026
1a3ffdd
make regional compilation note clearer.
sayakpaul Apr 3, 2026
ed2ef83
Apply suggestions from code review
sayakpaul Apr 3, 2026
7c0a8b5
clarify scheduler related changes.
sayakpaul Apr 3, 2026
c29324d
Apply suggestions from code review
sayakpaul Apr 3, 2026
4269a8b
Update examples/profiling/README.md
sayakpaul Apr 3, 2026
f071cbe
up
sayakpaul Apr 3, 2026
aeccb6b
formatting
sayakpaul Apr 3, 2026
a58ebbd
benchmarking runtime
sayakpaul Apr 3, 2026
55af3b6
up
sayakpaul Apr 3, 2026
9f6fe28
up
sayakpaul Apr 3, 2026
a73c89f
up
sayakpaul Apr 3, 2026
f115a0d
up
sayakpaul Apr 3, 2026
5c0c562
Update examples/profiling/README.md
sayakpaul Apr 3, 2026
0a74391
Merge branch 'main' into profiling-workflow
sayakpaul Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 320 additions & 0 deletions examples/profiling/README.md

Large diffs are not rendered by default.

181 changes: 181 additions & 0 deletions examples/profiling/profiling_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
Profile diffusers pipelines with torch.profiler.

Usage:
python profiling/profiling_pipelines.py --pipeline flux --mode eager
python profiling/profiling_pipelines.py --pipeline flux --mode compile
python profiling/profiling_pipelines.py --pipeline flux --mode both
python profiling/profiling_pipelines.py --pipeline all --mode eager
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
"""

import argparse
import copy
import logging

import torch
from profiling_utils import PipelineProfiler, PipelineProfilingConfig


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)

PROMPT = "A cat holding a sign that says hello world"


def build_registry():
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline

return {
"flux": PipelineProfilingConfig(
name="flux",
pipeline_cls=FluxPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"flux2": PipelineProfilingConfig(
name="flux2",
pipeline_cls=Flux2KleinPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"wan": PipelineProfilingConfig(
name="wan",
pipeline_cls=WanPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
"height": 480,
"width": 832,
"num_frames": 81,
"num_inference_steps": 4,
"output_type": "latent",
},
),
"ltx2": PipelineProfilingConfig(
name="ltx2",
pipeline_cls=LTX2Pipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Lightricks/LTX-2",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"height": 512,
"width": 768,
"num_frames": 121,
"num_inference_steps": 4,
"guidance_scale": 4.0,
"output_type": "latent",
},
),
"qwenimage": PipelineProfilingConfig(
name="qwenimage",
pipeline_cls=QwenImagePipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": " ",
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"true_cfg_scale": 4.0,
"output_type": "latent",
},
),
}


def main():
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
parser.add_argument(
"--pipeline",
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
required=True,
help="Which pipeline to profile",
)
parser.add_argument(
"--mode",
choices=["eager", "compile", "both"],
default="eager",
help="Run in eager mode, compile mode, or both",
)
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
parser.add_argument(
"--compile_mode",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode",
)
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
parser.add_argument(
"--compile_regional",
action="store_true",
help="Use compile_repeated_blocks() instead of full model compile",
)
args = parser.parse_args()

registry = build_registry()

pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]

for pipeline_name in pipeline_names:
for mode in modes:
config = copy.deepcopy(registry[pipeline_name])

# Apply overrides
if args.num_steps is not None:
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
if args.full_decode:
config.pipeline_call_kwargs["output_type"] = "pil"
if mode == "compile":
config.compile_kwargs = {
"fullgraph": args.compile_fullgraph,
"mode": args.compile_mode,
}
config.compile_regional = args.compile_regional

logger.info(f"Profiling {pipeline_name} in {mode} mode...")
profiler = PipelineProfiler(config, args.output_dir)
try:
trace_file = profiler.run()
logger.info(f"Done: {trace_file}")
except Exception as e:
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")


if __name__ == "__main__":
main()
148 changes: 148 additions & 0 deletions examples/profiling/profiling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Any

import torch
import torch.profiler


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)


def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)

return wrapper


def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.

Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
("vae", "decode", "vae_decode"),
("vae", "encode", "vae_encode"),
("scheduler", "step", "scheduler_step"),
]

# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
if component is None:
continue
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))

# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")


def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()


@dataclass
class PipelineProfilingConfig:
name: str
pipeline_cls: Any
pipeline_init_kwargs: dict[str, Any]
pipeline_call_kwargs: dict[str, Any]
compile_kwargs: dict[str, Any] | None = field(default=None)
compile_regional: bool = False


class PipelineProfiler:
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
self.config = config
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)

def setup_pipeline(self):
"""Load the pipeline from pretrained, optionally compile, and annotate."""
logger.info(f"Loading pipeline: {self.config.name}")
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
pipe.to("cuda")

if self.config.compile_kwargs:
if self.config.compile_regional:
logger.info(
f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}"
)
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
else:
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile(**self.config.compile_kwargs)

# Disable tqdm progress bar to avoid CPU overhead / IO between steps
pipe.set_progress_bar_config(disable=True)

annotate_pipeline(pipe)
return pipe

def run(self):
"""Execute the profiling run: warmup, then profile one pipeline call."""
pipe = self.setup_pipeline()
flush()

mode = "compile" if self.config.compile_kwargs else "eager"
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")

# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
logger.info("Running warmup...")
pipe(**self.config.pipeline_call_kwargs)
flush()

# Profile
logger.info("Running profiled iteration...")
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("pipeline_call"):
pipe(**self.config.pipeline_call_kwargs)

# Export trace
prof.export_chrome_trace(trace_file)
logger.info(f"Chrome trace saved to: {trace_file}")

# Print summary
print("\n" + "=" * 80)
print(f"Profile summary: {self.config.name} ({mode})")
print("=" * 80)
print(
prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
)
)

# Cleanup
pipe.to("cpu")
del pipe
flush()

return trace_file
46 changes: 46 additions & 0 deletions examples/profiling/run_profiling.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash
# Run profiling across all pipelines in eager and compile (regional) modes.
#
# Usage:
# bash profiling/run_profiling.sh
# bash profiling/run_profiling.sh --output_dir my_results

set -euo pipefail

OUTPUT_DIR="profiling_results"
while [[ $# -gt 0 ]]; do
case "$1" in
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
*) echo "Unknown arg: $1"; exit 1 ;;
esac
done
NUM_STEPS=2
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
PIPELINES=("wan")
MODES=("eager" "compile")

for pipeline in "${PIPELINES[@]}"; do
for mode in "${MODES[@]}"; do
echo "============================================================"
echo "Profiling: ${pipeline} | mode: ${mode}"
echo "============================================================"

COMPILE_ARGS=""
if [ "$mode" = "compile" ]; then
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
fi

python profiling/profiling_pipelines.py \
--pipeline "$pipeline" \
--mode "$mode" \
--output_dir "$OUTPUT_DIR" \
--num_steps "$NUM_STEPS" \
$COMPILE_ARGS

echo ""
done
done

echo "============================================================"
echo "All traces saved to: ${OUTPUT_DIR}/"
echo "============================================================"
Loading
Loading