2424from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
2525from ...utils .torch_utils import maybe_allow_in_graph
2626from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
27+ from ..attention_dispatch import dispatch_attention_fn
2728from ..cache_utils import CacheMixin
2829from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
2930from ..modeling_outputs import Transformer2DModelOutput
3435logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3536
3637
38+ def _get_qkv_projections (attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
39+ # encoder_hidden_states is only passed for cross-attention
40+ if encoder_hidden_states is None :
41+ encoder_hidden_states = hidden_states
42+
43+ if attn .fused_projections :
44+ if attn .cross_attention_dim_head is None :
45+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
47+ else :
48+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49+ query = attn .to_q (hidden_states )
50+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
51+ else :
52+ query = attn .to_q (hidden_states )
53+ key = attn .to_k (encoder_hidden_states )
54+ value = attn .to_v (encoder_hidden_states )
55+ return query , key , value
56+
57+
58+ def _get_added_kv_projections (attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
59+ if attn .fused_projections :
60+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
61+ else :
62+ key_img = attn .add_k_proj (encoder_hidden_states_img )
63+ value_img = attn .add_v_proj (encoder_hidden_states_img )
64+ return key_img , value_img
65+
66+
3767class WanAttnProcessor :
68+ _attention_backend = None
69+
3870 def __init__ (self ):
3971 if not hasattr (F , "scaled_dot_product_attention" ):
4072 raise ImportError (
4173 "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
4274 )
4375
44- def get_qkv_projections (
45- self , attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor
46- ):
47- # encoder_hidden_states is only passed for cross-attention
48- if encoder_hidden_states is None :
49- encoder_hidden_states = hidden_states
50-
51- if attn .fused_projections :
52- if attn .cross_attention_dim_head is None :
53- # In self-attention layers, we can fuse the entire QKV projection into a single linear
54- query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
55- else :
56- # In cross-attention layers, we can only fuse the KV projections into a single linear
57- query = attn .to_q (hidden_states )
58- key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
59- else :
60- query = attn .to_q (hidden_states )
61- key = attn .to_k (encoder_hidden_states )
62- value = attn .to_v (encoder_hidden_states )
63- return query , key , value
64-
65- def get_added_kv_projections (self , attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
66- if attn .fused_projections :
67- key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
68- else :
69- key_img = attn .add_k_proj (encoder_hidden_states_img )
70- value_img = attn .add_v_proj (encoder_hidden_states_img )
71- return key_img , value_img
72-
7376 def __call__ (
7477 self ,
7578 attn : "WanAttention" ,
@@ -85,7 +88,7 @@ def __call__(
8588 encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
8689 encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
8790
88- query , key , value = self . get_qkv_projections (attn , hidden_states , encoder_hidden_states )
91+ query , key , value = _get_qkv_projections (attn , hidden_states , encoder_hidden_states )
8992
9093 query = attn .norm_q (query )
9194 key = attn .norm_k (key )
@@ -116,20 +119,32 @@ def apply_rotary_emb(
116119 # I2V task
117120 hidden_states_img = None
118121 if encoder_hidden_states_img is not None :
119- key_img , value_img = self . get_added_kv_projections (attn , encoder_hidden_states_img )
122+ key_img , value_img = _get_added_kv_projections (attn , encoder_hidden_states_img )
120123 key_img = attn .norm_added_k (key_img )
121124
122125 key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
123126 value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
124127
125- hidden_states_img = F .scaled_dot_product_attention (
126- query , key_img , value_img , attn_mask = None , dropout_p = 0.0 , is_causal = False
128+ hidden_states_img = dispatch_attention_fn (
129+ query ,
130+ key_img ,
131+ value_img ,
132+ attn_mask = None ,
133+ dropout_p = 0.0 ,
134+ is_causal = False ,
135+ backend = self ._attention_backend ,
127136 )
128137 hidden_states_img = hidden_states_img .transpose (1 , 2 ).flatten (2 , 3 )
129138 hidden_states_img = hidden_states_img .type_as (query )
130139
131- hidden_states = F .scaled_dot_product_attention (
132- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
140+ hidden_states = dispatch_attention_fn (
141+ query ,
142+ key ,
143+ value ,
144+ attn_mask = attention_mask ,
145+ dropout_p = 0.0 ,
146+ is_causal = False ,
147+ backend = self ._attention_backend ,
133148 )
134149 hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
135150 hidden_states = hidden_states .type_as (query )
0 commit comments