Skip to content

Commit d2d5fda

Browse files
fix: cache modulate_index in QwenImageTransformer2DModel to avoid per-step DtoH sync
When zero_cond_t=True, the modulate_index tensor was being recreated on every transformer forward pass (once per denoising step) using: torch.tensor(list_comprehension, device=timestep.device, ...) This triggers a Python list comprehension + torch.tensor() from a Python list, which causes a cudaMemcpyAsync + cudaStreamSynchronize (DtoH sync) that forces the CPU to wait for all pending GPU kernels. Since img_shapes (which fully determines modulate_index) is fixed for the entire inference run, the resulting tensor is identical across all steps. We cache it in _modulate_index_cache keyed by (img_shapes, device), so the tensor is built only on the first step and reused thereafter. This eliminates N-1 unnecessary torch.tensor() constructions and DtoH syncs during inference (where N = num_inference_steps). This issue was identified in the profiling guide added in huggingface#13356 and referenced in huggingface#13401. Follows the same caching pattern as _compute_video_freqs in QwenEmbedRope.
1 parent fbe8a75 commit d2d5fda

1 file changed

Lines changed: 20 additions & 5 deletions

File tree

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,13 @@ def __init__(
832832
self.gradient_checkpointing = False
833833
self.zero_cond_t = zero_cond_t
834834

835+
# Cache for modulate_index tensor to avoid rebuilding it on every forward pass.
836+
# The tensor is determined solely by img_shapes (fixed during inference), so it
837+
# only needs to be computed once per unique (img_shapes, device) combination.
838+
# Without caching, every forward call triggers a Python list comprehension +
839+
# torch.tensor() construction which is visible as CPU overhead in profiling traces.
840+
self._modulate_index_cache: dict = {}
841+
835842
@apply_lora_scale("attention_kwargs")
836843
def forward(
837844
self,
@@ -898,11 +905,19 @@ def forward(
898905

899906
if self.zero_cond_t:
900907
timestep = torch.cat([timestep, timestep * 0], dim=0)
901-
modulate_index = torch.tensor(
902-
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
903-
device=timestep.device,
904-
dtype=torch.int,
905-
)
908+
# Cache modulate_index to avoid rebuilding it on every forward pass.
909+
# img_shapes is fixed during inference (same across all denoising steps),
910+
# so we can build the tensor once and reuse it, eliminating the CPU overhead
911+
# and implicit sync from torch.tensor() on each step.
912+
device = timestep.device
913+
cache_key = (tuple(tuple(s) for s in img_shapes), device)
914+
if cache_key not in self._modulate_index_cache:
915+
self._modulate_index_cache[cache_key] = torch.tensor(
916+
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
917+
device=device,
918+
dtype=torch.int,
919+
)
920+
modulate_index = self._modulate_index_cache[cache_key]
906921
else:
907922
modulate_index = None
908923

0 commit comments

Comments
 (0)