Commit d2d5fda
committed
fix: cache modulate_index in QwenImageTransformer2DModel to avoid per-step DtoH sync
When zero_cond_t=True, the modulate_index tensor was being recreated on
every transformer forward pass (once per denoising step) using:
torch.tensor(list_comprehension, device=timestep.device, ...)
This triggers a Python list comprehension + torch.tensor() from a Python
list, which causes a cudaMemcpyAsync + cudaStreamSynchronize (DtoH sync)
that forces the CPU to wait for all pending GPU kernels.
Since img_shapes (which fully determines modulate_index) is fixed for the
entire inference run, the resulting tensor is identical across all steps.
We cache it in _modulate_index_cache keyed by (img_shapes, device), so
the tensor is built only on the first step and reused thereafter.
This eliminates N-1 unnecessary torch.tensor() constructions and DtoH
syncs during inference (where N = num_inference_steps).
This issue was identified in the profiling guide added in huggingface#13356 and
referenced in huggingface#13401.
Follows the same caching pattern as _compute_video_freqs in QwenEmbedRope.1 parent fbe8a75 commit d2d5fda
1 file changed
Lines changed: 20 additions & 5 deletions
Lines changed: 20 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
832 | 832 | | |
833 | 833 | | |
834 | 834 | | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
835 | 842 | | |
836 | 843 | | |
837 | 844 | | |
| |||
898 | 905 | | |
899 | 906 | | |
900 | 907 | | |
901 | | - | |
902 | | - | |
903 | | - | |
904 | | - | |
905 | | - | |
| 908 | + | |
| 909 | + | |
| 910 | + | |
| 911 | + | |
| 912 | + | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
| 917 | + | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
906 | 921 | | |
907 | 922 | | |
908 | 923 | | |
| |||
0 commit comments