Skip to content

Commit 4940b21

Browse files
committed
attnetion dispatcher support
1 parent 6c841e8 commit 4940b21

2 files changed

Lines changed: 52 additions & 37 deletions

File tree

src/diffusers/models/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020

2121
from ..utils import deprecate, logging
22-
from ..utils.import_utils import is_xformers_available
22+
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
2323
from ..utils.torch_utils import maybe_allow_in_graph
2424
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
2525
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27+
from ..attention_dispatch import dispatch_attention_fn
2728
from ..cache_utils import CacheMixin
2829
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
2930
from ..modeling_outputs import Transformer2DModelOutput
@@ -34,42 +35,44 @@
3435
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3536

3637

38+
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39+
# encoder_hidden_states is only passed for cross-attention
40+
if encoder_hidden_states is None:
41+
encoder_hidden_states = hidden_states
42+
43+
if attn.fused_projections:
44+
if attn.cross_attention_dim_head is None:
45+
# In self-attention layers, we can fuse the entire QKV projection into a single linear
46+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47+
else:
48+
# In cross-attention layers, we can only fuse the KV projections into a single linear
49+
query = attn.to_q(hidden_states)
50+
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51+
else:
52+
query = attn.to_q(hidden_states)
53+
key = attn.to_k(encoder_hidden_states)
54+
value = attn.to_v(encoder_hidden_states)
55+
return query, key, value
56+
57+
58+
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59+
if attn.fused_projections:
60+
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61+
else:
62+
key_img = attn.add_k_proj(encoder_hidden_states_img)
63+
value_img = attn.add_v_proj(encoder_hidden_states_img)
64+
return key_img, value_img
65+
66+
3767
class WanAttnProcessor:
68+
_attention_backend = None
69+
3870
def __init__(self):
3971
if not hasattr(F, "scaled_dot_product_attention"):
4072
raise ImportError(
4173
"WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
4274
)
4375

44-
def get_qkv_projections(
45-
self, attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
46-
):
47-
# encoder_hidden_states is only passed for cross-attention
48-
if encoder_hidden_states is None:
49-
encoder_hidden_states = hidden_states
50-
51-
if attn.fused_projections:
52-
if attn.cross_attention_dim_head is None:
53-
# In self-attention layers, we can fuse the entire QKV projection into a single linear
54-
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
55-
else:
56-
# In cross-attention layers, we can only fuse the KV projections into a single linear
57-
query = attn.to_q(hidden_states)
58-
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
59-
else:
60-
query = attn.to_q(hidden_states)
61-
key = attn.to_k(encoder_hidden_states)
62-
value = attn.to_v(encoder_hidden_states)
63-
return query, key, value
64-
65-
def get_added_kv_projections(self, attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
66-
if attn.fused_projections:
67-
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
68-
else:
69-
key_img = attn.add_k_proj(encoder_hidden_states_img)
70-
value_img = attn.add_v_proj(encoder_hidden_states_img)
71-
return key_img, value_img
72-
7376
def __call__(
7477
self,
7578
attn: "WanAttention",
@@ -85,7 +88,7 @@ def __call__(
8588
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
8689
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
8790

88-
query, key, value = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
91+
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
8992

9093
query = attn.norm_q(query)
9194
key = attn.norm_k(key)
@@ -116,20 +119,32 @@ def apply_rotary_emb(
116119
# I2V task
117120
hidden_states_img = None
118121
if encoder_hidden_states_img is not None:
119-
key_img, value_img = self.get_added_kv_projections(attn, encoder_hidden_states_img)
122+
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
120123
key_img = attn.norm_added_k(key_img)
121124

122125
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
123126
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
124127

125-
hidden_states_img = F.scaled_dot_product_attention(
126-
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
128+
hidden_states_img = dispatch_attention_fn(
129+
query,
130+
key_img,
131+
value_img,
132+
attn_mask=None,
133+
dropout_p=0.0,
134+
is_causal=False,
135+
backend=self._attention_backend,
127136
)
128137
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
129138
hidden_states_img = hidden_states_img.type_as(query)
130139

131-
hidden_states = F.scaled_dot_product_attention(
132-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
140+
hidden_states = dispatch_attention_fn(
141+
query,
142+
key,
143+
value,
144+
attn_mask=attention_mask,
145+
dropout_p=0.0,
146+
is_causal=False,
147+
backend=self._attention_backend,
133148
)
134149
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
135150
hidden_states = hidden_states.type_as(query)

0 commit comments

Comments
 (0)