2222import torch .multiprocessing as mp
2323
2424from diffusers .models ._modeling_parallel import ContextParallelConfig
25+ from diffusers .models .attention_dispatch import AttentionBackendName , _AttentionBackendRegistry
2526
2627from ...testing_utils import (
2728 is_context_parallel ,
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
160161@require_torch_multi_accelerator
161162class ContextParallelTesterMixin :
162163 @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ], ids = ["ulysses" , "ring" ])
163- def test_context_parallel_inference (self , cp_type ):
164+ def test_context_parallel_inference (self , cp_type , batch_size : int = 1 ):
164165 if not torch .distributed .is_available ():
165166 pytest .skip ("torch.distributed is not available." )
166167
167168 if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
168169 pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
169170
171+ if cp_type == "ring_degree" :
172+ active_backend , _ = _AttentionBackendRegistry .get_active_backend ()
173+ if active_backend == AttentionBackendName .NATIVE :
174+ pytest .skip ("Ring attention is not supported with the native attention backend." )
175+
170176 world_size = 2
171177 init_dict = self .get_init_dict ()
172- inputs_dict = self .get_dummy_inputs ()
178+ inputs_dict = self .get_dummy_inputs (batch_size = batch_size )
173179
174180 # Move all tensors to CPU for multiprocessing
175181 inputs_dict = {k : v .cpu () if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
@@ -194,6 +200,11 @@ def test_context_parallel_inference(self, cp_type):
194200 f"Context parallel inference failed: { return_dict .get ('error' , 'Unknown error' )} "
195201 )
196202
203+ @pytest .mark .xfail (reason = "Context parallel may not support batch_size > 1" )
204+ @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ], ids = ["ulysses" , "ring" ])
205+ def test_context_parallel_batch_inputs (self , cp_type ):
206+ self .test_context_parallel_inference (cp_type , batch_size = 2 )
207+
197208 @pytest .mark .parametrize (
198209 "cp_type,mesh_shape,mesh_dim_names" ,
199210 [
@@ -209,6 +220,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
209220 if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
210221 pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
211222
223+ if cp_type == "ring_degree" :
224+ active_backend , _ = _AttentionBackendRegistry .get_active_backend ()
225+ if active_backend == AttentionBackendName .NATIVE :
226+ pytest .skip ("Ring attention is not supported with the native attention backend." )
227+
212228 world_size = 2
213229 init_dict = self .get_init_dict ()
214230 inputs_dict = {k : v .cpu () if isinstance (v , torch .Tensor ) else v for k , v in self .get_dummy_inputs ().items ()}
0 commit comments