Skip to content

Commit d54669a

Browse files
kashifsayakpaul
andauthored
[Qwen] avoid creating attention masks when there is no padding (#12987)
* avoid creating attention masks when there is no padding * make fix-copies * torch compile tests * set all ones mask to none * fix positional encoding from becoming > 4096 * fix from review * slice freqs_cis to match the input sequence length * keep only attenton masking change --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 22ac6fa commit d54669a

10 files changed

Lines changed: 98 additions & 0 deletions

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def encode_prompt(
262262
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
263263
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
264264

265+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
266+
prompt_embeds_mask = None
267+
265268
return prompt_embeds, prompt_embeds_mask
266269

267270
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ def encode_prompt(
324324
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
325325
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
326326

327+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
328+
prompt_embeds_mask = None
329+
327330
return prompt_embeds, prompt_embeds_mask
328331

329332
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ def encode_prompt(
305305
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
306306
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
307307

308+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
309+
prompt_embeds_mask = None
310+
308311
return prompt_embeds, prompt_embeds_mask
309312

310313
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ def encode_prompt(
309309
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
310310
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
311311

312+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
313+
prompt_embeds_mask = None
314+
312315
return prompt_embeds, prompt_embeds_mask
313316

314317
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def encode_prompt(
321321
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
322322
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
323323

324+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
325+
prompt_embeds_mask = None
326+
324327
return prompt_embeds, prompt_embeds_mask
325328

326329
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def encode_prompt(
323323
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
324324
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
325325

326+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
327+
prompt_embeds_mask = None
328+
326329
return prompt_embeds, prompt_embeds_mask
327330

328331
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ def encode_prompt(
305305
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
306306
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
307307

308+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
309+
prompt_embeds_mask = None
310+
308311
return prompt_embeds, prompt_embeds_mask
309312

310313
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def encode_prompt(
316316
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
317317
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
318318

319+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
320+
prompt_embeds_mask = None
321+
319322
return prompt_embeds, prompt_embeds_mask
320323

321324
def check_inputs(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def encode_prompt(
328328
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
329329
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
330330

331+
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
332+
prompt_embeds_mask = None
333+
331334
return prompt_embeds, prompt_embeds_mask
332335

333336
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,74 @@ def prepare_dummy_input(self, height, width):
276276

277277
def test_torch_compile_recompilation_and_graph_break(self):
278278
super().test_torch_compile_recompilation_and_graph_break()
279+
280+
def test_torch_compile_with_and_without_mask(self):
281+
"""Test that torch.compile works with both None mask and padding mask."""
282+
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
283+
model = self.model_class(**init_dict).to(torch_device)
284+
model.eval()
285+
model.compile(mode="default", fullgraph=True)
286+
287+
# Test 1: Run with None mask (no padding, all tokens are valid)
288+
inputs_no_mask = inputs.copy()
289+
inputs_no_mask["encoder_hidden_states_mask"] = None
290+
291+
# First run to allow compilation
292+
with torch.no_grad():
293+
output_no_mask = model(**inputs_no_mask)
294+
295+
# Second run to verify no recompilation
296+
with (
297+
torch._inductor.utils.fresh_inductor_cache(),
298+
torch._dynamo.config.patch(error_on_recompile=True),
299+
torch.no_grad(),
300+
):
301+
output_no_mask_2 = model(**inputs_no_mask)
302+
303+
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
304+
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
305+
306+
# Test 2: Run with all-ones mask (should behave like None)
307+
inputs_all_ones = inputs.copy()
308+
# Keep the all-ones mask
309+
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
310+
311+
# First run to allow compilation
312+
with torch.no_grad():
313+
output_all_ones = model(**inputs_all_ones)
314+
315+
# Second run to verify no recompilation
316+
with (
317+
torch._inductor.utils.fresh_inductor_cache(),
318+
torch._dynamo.config.patch(error_on_recompile=True),
319+
torch.no_grad(),
320+
):
321+
output_all_ones_2 = model(**inputs_all_ones)
322+
323+
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
324+
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
325+
326+
# Test 3: Run with actual padding mask (has zeros)
327+
inputs_with_padding = inputs.copy()
328+
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
329+
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
330+
331+
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
332+
333+
# First run to allow compilation
334+
with torch.no_grad():
335+
output_with_padding = model(**inputs_with_padding)
336+
337+
# Second run to verify no recompilation
338+
with (
339+
torch._inductor.utils.fresh_inductor_cache(),
340+
torch._dynamo.config.patch(error_on_recompile=True),
341+
torch.no_grad(),
342+
):
343+
output_with_padding_2 = model(**inputs_with_padding)
344+
345+
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
346+
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
347+
348+
# Verify that outputs are different (mask should affect results)
349+
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))

0 commit comments

Comments
 (0)