Skip to content

Commit 9d4c9dc

Browse files
authored
change QwenImageTransformer UT to batch inputs (#13312)
* UT expands to batch inputs * update according to suggestion * update according to suggestion 2 * fix CI * update according to suggestion 3 * clean line
1 parent ef309a1 commit 9d4c9dc

4 files changed

Lines changed: 30 additions & 14 deletions

File tree

tests/models/testing_utils/parallelism.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.multiprocessing as mp
2323

2424
from diffusers.models._modeling_parallel import ContextParallelConfig
25+
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
2526

2627
from ...testing_utils import (
2728
is_context_parallel,
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
160161
@require_torch_multi_accelerator
161162
class ContextParallelTesterMixin:
162163
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
163-
def test_context_parallel_inference(self, cp_type):
164+
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
164165
if not torch.distributed.is_available():
165166
pytest.skip("torch.distributed is not available.")
166167

167168
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
168169
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
169170

171+
if cp_type == "ring_degree":
172+
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
173+
if active_backend == AttentionBackendName.NATIVE:
174+
pytest.skip("Ring attention is not supported with the native attention backend.")
175+
170176
world_size = 2
171177
init_dict = self.get_init_dict()
172-
inputs_dict = self.get_dummy_inputs()
178+
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
173179

174180
# Move all tensors to CPU for multiprocessing
175181
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
@@ -194,6 +200,11 @@ def test_context_parallel_inference(self, cp_type):
194200
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
195201
)
196202

203+
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
204+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
205+
def test_context_parallel_batch_inputs(self, cp_type):
206+
self.test_context_parallel_inference(cp_type, batch_size=2)
207+
197208
@pytest.mark.parametrize(
198209
"cp_type,mesh_shape,mesh_dim_names",
199210
[
@@ -209,6 +220,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
209220
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
210221
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
211222

223+
if cp_type == "ring_degree":
224+
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
225+
if active_backend == AttentionBackendName.NATIVE:
226+
pytest.skip("Ring attention is not supported with the native attention backend.")
227+
212228
world_size = 2
213229
init_dict = self.get_init_dict()
214230
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
150150
"axes_dims_rope": [4, 4, 8],
151151
}
152152

153-
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
154-
batch_size = 1
153+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
155154
height = width = 4
156155
num_latent_channels = 4
157156
num_image_channels = 3

tests/models/transformers/test_models_transformer_flux2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
9090
"axes_dims_rope": [4, 4, 4, 4],
9191
}
9292

93-
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
94-
batch_size = 1
93+
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
9594
num_latent_channels = 4
9695
sequence_length = 48
9796
embedding_dim = 32

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import warnings
1616

17+
import pytest
1718
import torch
1819

1920
from diffusers import QwenImageTransformer2DModel
@@ -77,8 +78,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
7778
"axes_dims_rope": (8, 4, 4),
7879
}
7980

80-
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
81-
batch_size = 1
81+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
8282
num_latent_channels = embedding_dim = 16
8383
height = width = 4
8484
sequence_length = 8
@@ -106,9 +106,10 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
106106

107107

108108
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
109-
def test_infers_text_seq_len_from_mask(self):
109+
@pytest.mark.parametrize("batch_size", [1, 2])
110+
def test_infers_text_seq_len_from_mask(self, batch_size):
110111
init_dict = self.get_init_dict()
111-
inputs = self.get_dummy_inputs()
112+
inputs = self.get_dummy_inputs(batch_size=batch_size)
112113
model = self.model_class(**init_dict).to(torch_device)
113114

114115
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -122,7 +123,7 @@ def test_infers_text_seq_len_from_mask(self):
122123
assert isinstance(per_sample_len, torch.Tensor)
123124
assert int(per_sample_len.max().item()) == 2
124125
assert normalized_mask.dtype == torch.bool
125-
assert normalized_mask.sum().item() == 2
126+
assert normalized_mask.sum().item() == 2 * batch_size
126127
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
127128

128129
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -139,7 +140,7 @@ def test_infers_text_seq_len_from_mask(self):
139140
)
140141

141142
assert int(per_sample_len2.max().item()) == 8
142-
assert normalized_mask2.sum().item() == 5
143+
assert normalized_mask2.sum().item() == 5 * batch_size
143144

144145
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
145146
inputs["encoder_hidden_states"], None
@@ -149,9 +150,10 @@ def test_infers_text_seq_len_from_mask(self):
149150
assert per_sample_len_none is None
150151
assert normalized_mask_none is None
151152

152-
def test_non_contiguous_attention_mask(self):
153+
@pytest.mark.parametrize("batch_size", [1, 2])
154+
def test_non_contiguous_attention_mask(self, batch_size):
153155
init_dict = self.get_init_dict()
154-
inputs = self.get_dummy_inputs()
156+
inputs = self.get_dummy_inputs(batch_size=batch_size)
155157
model = self.model_class(**init_dict).to(torch_device)
156158

157159
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()

0 commit comments

Comments
 (0)