1717import inspect
1818import math
1919from enum import Enum
20+ from functools import lru_cache
2021from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union
2122
2223import torch
3940from ..utils .constants import DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS
4041
4142
42- logger = get_logger (__name__ ) # pylint: disable=invalid-name
43-
4443_REQUIRED_FLASH_VERSION = "2.6.3"
4544_REQUIRED_SAGE_VERSION = "2.1.1"
4645_REQUIRED_FLEX_VERSION = "2.5.0"
7069 flash_attn_3_func = None
7170 flash_attn_3_varlen_func = None
7271
73- if is_kernels_available ():
74- from ..utils .kernels_utils import _get_fa3_from_hub
75-
76- flash_attn_interface_hub = _get_fa3_from_hub ()
77- if flash_attn_interface_hub is not None :
78- flash_attn_3_hub_func = flash_attn_interface_hub .flash_attn_func
79- flash_attn_3_varlen_hub_func = flash_attn_interface_hub .flash_attn_varlen_func
80- else :
81- flash_attn_3_hub_func = None
82- flash_attn_3_varlen_hub_func = None
83- else :
84- flash_attn_3_hub_func = None
85- flash_attn_3_varlen_hub_func = None
86-
8772
8873if _CAN_USE_SAGE_ATTN :
8974 from sageattention import (
@@ -148,6 +133,7 @@ def wrap(func):
148133 _custom_op = custom_op_no_op
149134 _register_fake = register_fake_no_op
150135
136+ logger = get_logger (__name__ ) # pylint: disable=invalid-name
151137
152138# TODO(aryan): Add support for the following:
153139# - Sage Attention++
@@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum):
169155 _FLASH_3 = "_flash_3"
170156 _FLASH_VARLEN_3 = "_flash_varlen_3"
171157 _FLASH_3_HUB = "_flash_3_hub"
172- _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
158+ # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
173159
174160 # PyTorch native
175161 FLEX = "flex"
@@ -224,6 +210,22 @@ def list_backends(cls):
224210 return list (cls ._backends .keys ())
225211
226212
213+ @lru_cache (maxsize = None )
214+ def _load_fa3_hub ():
215+ from ..utils .kernels_utils import _get_fa3_from_hub
216+
217+ fa3_hub = _get_fa3_from_hub () # won't re-download if already present
218+ if fa3_hub is None :
219+ raise RuntimeError (
220+ "Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
221+ )
222+ return fa3_hub
223+
224+
225+ def flash_attn_3_hub_func (* args , ** kwargs ):
226+ return _load_fa3_hub ().flash_attn_func (* args , ** kwargs )
227+
228+
227229@contextlib .contextmanager
228230def attention_backend (backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ):
229231 """
@@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
374376 raise RuntimeError (
375377 f"Flash Attention 3 Hub backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
376378 )
377- if flash_attn_3_hub_func is None :
378- raise RuntimeError (
379- "`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3."
380- )
381- elif backend in [AttentionBackendName ._FLASH_VARLEN_3_HUB ]:
382- raise NotImplementedError
383379
384380 elif backend in [
385381 AttentionBackendName .SAGE ,
@@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
544540 return torch .empty_like (query ), query .new_empty (lse_shape )
545541
546542
547- @_custom_op ("vllm_flash_attn3::_flash_attn_forward " , mutates_args = (), device_types = "cuda" )
543+ @_custom_op ("vllm_flash_attn3::flash_attn " , mutates_args = (), device_types = "cuda" )
548544def _wrapped_flash_attn_3_hub (
549545 query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
550546) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
553549 return out , lse
554550
555551
556- @_register_fake ("vllm_flash_attn3::_flash_attn_forward " )
552+ @_register_fake ("vllm_flash_attn3::flash_attn " )
557553def _ (query : torch .Tensor , key : torch .Tensor , value : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
558554 batch_size , seq_len , num_heads , head_dim = query .shape
559555 lse_shape = (batch_size , seq_len , num_heads )
0 commit comments