Skip to content

Commit a1e1faf

Browse files
committed
up
1 parent 93c3eb9 commit a1e1faf

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def wrap(func):
141141
return wrap if fn is None else fn
142142

143143
_custom_op = custom_op_no_op
144+
_register_fake = register_fake_no_op
144145

145146

146147
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -366,6 +367,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
366367

367368
# TODO: add support Hub variant of FA3 varlen later
368369
elif backend in [AttentionBackendName._FLASH_3_HUB]:
370+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
371+
raise RuntimeError(
372+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it `DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
373+
)
369374
if not is_kernels_available():
370375
raise RuntimeError(
371376
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -708,9 +713,11 @@ def _flash_attention_3_hub(
708713
pack_gqa=None,
709714
deterministic=deterministic,
710715
sm_margin=0,
716+
return_attn_probs=return_attn_probs,
711717
)
712-
lse = None
713-
return (out, lse) if return_attn_probs else out
718+
# When `return_attn_probs` is True, the above returns a tuple of
719+
# actual outputs and lse.
720+
return (out[0], out[1]) if return_attn_probs else out
714721

715722

716723
@_AttentionBackendRegistry.register(

src/diffusers/utils/kernels_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _get_fa3_from_hub():
1616

1717
try:
1818
# TODO: temporary revision for now. Remove when merged upstream into `main`.
19-
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops")
19+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-prob")
2020
return flash_attn_3_hub
2121
except Exception as e:
2222
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")

0 commit comments

Comments
 (0)