There are widely used models on HF that ship with custom modeling code that relies on the original flash-attn package, either by calling transformers.utils.is_flash_attn_2_available() and/or relying that the AutoModel.from_pretrained() parameter attn_implementation is equal to flash_attention_2 (for example, Jina Embeddings v4: https://huggingface.co/jinaai/jina-embeddings-v4/blob/main/qwen2_5_vl.py).
I wanted to switch to kernels-community/flash-attn2 to avoid building flash-attn or using pre-built wheels from https://github.com/mjun0812/flash-attention-prebuild-wheels (has no wheels index, so you have to hardcode the path to a particular .whl file) and managed to do it with a few lines of code.
Have you thought about adding some documentation/example for this case. In my experience, model authors rarely update the included custom modeling code to support new stuff (e.g., Transformers v5 or kernels), so unless you want to create your own local modeling code, you're stuck with whatever the authors used when they published the model.
Here's the code that "exposes" kernels-community/flash-attn2 as the original flash-attn:
from pathlib import Path
import sys
from kernels import get_kernel
def setup_flash_attn_2() -> None:
flash_attn = get_kernel("kernels-community/flash-attn2", version=1)
sys.modules["flash_attn"] = flash_attn
dist_info = Path(flash_attn.__file__).parent / "flash_attn.dist-info"
dist_info.mkdir(parents=True, exist_ok=True)
(dist_info / "METADATA").write_text("Metadata-Version: 2.5\nName: flash-attn\nVersion: 2.8.3")
sys.path.append(str(dist_info.parent))
There are widely used models on HF that ship with custom modeling code that relies on the original
flash-attnpackage, either by callingtransformers.utils.is_flash_attn_2_available()and/or relying that theAutoModel.from_pretrained()parameterattn_implementationis equal toflash_attention_2(for example, Jina Embeddings v4: https://huggingface.co/jinaai/jina-embeddings-v4/blob/main/qwen2_5_vl.py).I wanted to switch to
kernels-community/flash-attn2to avoid buildingflash-attnor using pre-built wheels from https://github.com/mjun0812/flash-attention-prebuild-wheels (has no wheels index, so you have to hardcode the path to a particular .whl file) and managed to do it with a few lines of code.Have you thought about adding some documentation/example for this case. In my experience, model authors rarely update the included custom modeling code to support new stuff (e.g., Transformers v5 or
kernels), so unless you want to create your own local modeling code, you're stuck with whatever the authors used when they published the model.Here's the code that "exposes"
kernels-community/flash-attn2as the originalflash-attn: