Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def _wn_conv_transpose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


def _normalize_vae_strides(c_mults: list[int], strides: list[int] | None = None) -> list[int]:
default_strides = [2, 4, 4, 8, 8]
num_blocks = len(c_mults) - 1
if strides is None:
strides = default_strides
strides = list(strides)
if len(strides) < num_blocks:
strides.extend([strides[-1] if strides else 2] * (num_blocks - len(strides)))
else:
strides = strides[:num_blocks]
return strides


class Snake1d(nn.Module):
def __init__(self, channels: int, alpha_logscale: bool = True):
super().__init__()
Expand Down Expand Up @@ -200,11 +213,7 @@ def __init__(
):
super().__init__()
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
strides = list(strides or [2] * (len(c_mults) - 1))
if len(strides) < len(c_mults) - 1:
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
else:
strides = strides[: len(c_mults) - 1]
strides = _normalize_vae_strides(c_mults, strides)
channels_base = channels
layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)]
for idx in range(len(c_mults) - 1):
Expand Down Expand Up @@ -249,11 +258,7 @@ def __init__(
):
super().__init__()
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
strides = list(strides or [2] * (len(c_mults) - 1))
if len(strides) < len(c_mults) - 1:
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
else:
strides = strides[: len(c_mults) - 1]
strides = _normalize_vae_strides(c_mults, strides)
channels_base = channels

self.shortcut = (
Expand Down Expand Up @@ -317,6 +322,18 @@ def __init__(
scale: float = 0.71,
):
super().__init__()
c_mults = c_mults or [1, 2, 4, 8, 16]
normalized_strides = _normalize_vae_strides([1] + c_mults, strides)
actual_downsampling_ratio = math.prod(normalized_strides)
if actual_downsampling_ratio != downsampling_ratio:
raise ValueError(
f"`downsampling_ratio` must match the product of normalized `strides`. Got "
f"`downsampling_ratio={downsampling_ratio}` but `strides={normalized_strides}` have product "
f"{actual_downsampling_ratio}."
)
self.register_to_config(
c_mults=c_mults, strides=normalized_strides, downsampling_ratio=actual_downsampling_ratio
)
if act_fn is None:
if use_snake is None:
act_fn = "snake"
Expand All @@ -326,7 +343,7 @@ def __init__(
in_channels=in_channels,
channels=channels,
c_mults=c_mults,
strides=strides,
strides=normalized_strides,
latent_dim=latent_dim,
encoder_latent_dim=encoder_latent_dim,
act_fn=act_fn,
Expand All @@ -337,7 +354,7 @@ def __init__(
in_channels=in_channels,
channels=channels,
c_mults=c_mults,
strides=strides,
strides=normalized_strides,
latent_dim=latent_dim,
act_fn=act_fn,
in_shortcut=in_shortcut,
Expand Down
20 changes: 16 additions & 4 deletions src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ..attention import AttentionModuleMixin
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
Expand Down Expand Up @@ -228,6 +228,12 @@ def __call__(


class AudioDiTAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = AudioDiTSelfAttnProcessor
_available_processors = [
AudioDiTSelfAttnProcessor,
]
_supports_qkv_fusion = False

def __init__(
self,
q_dim: int,
Expand All @@ -238,12 +244,13 @@ def __init__(
bias: bool = True,
qk_norm: bool = False,
eps: float = 1e-6,
processor: AttentionModuleMixin | None = None,
processor: "AudioDiTSelfAttnProcessor | AudioDiTCrossAttnProcessor | None" = None,
):
super().__init__()
kv_dim = q_dim if kv_dim is None else kv_dim
self.heads = heads
self.inner_dim = dim_head * heads
self.use_bias = bias
self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias)
Expand All @@ -252,7 +259,9 @@ def __init__(
self.q_norm = RMSNorm(self.inner_dim, eps=eps)
self.k_norm = RMSNorm(self.inner_dim, eps=eps)
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)])
self.set_processor(processor or AudioDiTSelfAttnProcessor())
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)

def forward(
self,
Expand Down Expand Up @@ -331,6 +340,9 @@ def __call__(
return hidden_states


AudioDiTAttention._available_processors = [AudioDiTSelfAttnProcessor, AudioDiTCrossAttnProcessor]


class AudioDiTFeedForward(nn.Module):
def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True):
super().__init__()
Expand Down Expand Up @@ -452,7 +464,7 @@ def forward(
return hidden_states


class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin):
class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin, AttentionMixin):
_supports_gradient_checkpointing = False
_repeated_blocks = ["AudioDiTBlock"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# https://github.com/meituan-longcat/LongCat-AudioDiT

import re
from typing import Callable
from typing import Any, Callable

import torch
import torch.nn.functional as F
Expand All @@ -32,6 +32,8 @@

logger = logging.get_logger(__name__)

PipelineCallback = Callable[[Any, int, torch.Tensor, dict[str, torch.Tensor]], dict[str, torch.Tensor]]

EXAMPLE_DOC_STRING = """
Examples:
```py
Expand Down Expand Up @@ -148,8 +150,7 @@ def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[
)
input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
with torch.no_grad():
output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
prompt_embeds = output.last_hidden_state
if self.text_norm_feat:
prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6)
Expand Down Expand Up @@ -229,7 +230,7 @@ def __call__(
generator: torch.Generator | list[torch.Generator] | None = None,
output_type: str = "np",
return_dict: bool = True,
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end: PipelineCallback | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
):
r"""
Expand Down Expand Up @@ -296,9 +297,13 @@ def __call__(
negative_prompt_embeds_len, length=negative_prompt_embeds.shape[1]
)

latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=prompt_embeds.dtype)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(dtype=transformer_dtype)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=transformer_dtype)

latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=transformer_dtype)
latents = self.prepare_latents(
batch_size, duration, device, prompt_embeds.dtype, generator=generator, latents=latents
batch_size, duration, device, transformer_dtype, generator=generator, latents=latents
)
if num_inference_steps < 1:
raise ValueError("num_inference_steps must be a positive integer.")
Expand All @@ -311,9 +316,7 @@ def __call__(

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
curr_t = (
(t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=prompt_embeds.dtype)
)
curr_t = (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=transformer_dtype)
pred = self.transformer(
hidden_states=latents,
encoder_hidden_states=prompt_embeds,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2026 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from diffusers import LongCatAudioDiTVae


def test_longcat_audio_vae_default_strides_match_downsampling_ratio():
vae = LongCatAudioDiTVae(channels=1, latent_dim=2, encoder_latent_dim=4)

assert vae.config.strides == [2, 4, 4, 8, 8]
assert vae.config.downsampling_ratio == 2048


def test_longcat_audio_vae_raises_when_downsampling_ratio_mismatches_strides():
with pytest.raises(ValueError, match="downsampling_ratio"):
LongCatAudioDiTVae(
channels=1,
latent_dim=2,
encoder_latent_dim=4,
strides=[2, 2, 2, 2, 2],
downsampling_ratio=2048,
)
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterCo


def test_longcat_audio_attention_uses_standard_self_attn_kwargs():
from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention
from diffusers.models.transformers.transformer_longcat_audio_dit import (
AudioDiTAttention,
AudioDiTSelfAttnProcessor,
)

attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False)

assert attn._default_processor_cls is AudioDiTSelfAttnProcessor
assert AudioDiTSelfAttnProcessor in attn._available_processors
assert attn.use_bias is False

eye = torch.eye(4)
with torch.no_grad():
attn.to_q.weight.copy_(eye)
Expand All @@ -119,3 +126,30 @@ def test_longcat_audio_attention_uses_standard_self_attn_kwargs():
output = attn(hidden_states=hidden_states, attention_mask=attention_mask)

assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1]))


def test_longcat_audio_attention_direct_fuse_projections_noops():
from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention

attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4)

attn.fuse_projections()

assert not attn.fused_projections
assert not hasattr(attn, "to_qkv")


def test_longcat_audio_transformer_exposes_attention_processors():
model = LongCatAudioDiTTransformer(
dit_dim=64,
dit_depth=2,
dit_heads=4,
dit_text_dim=32,
latent_dim=8,
text_conv=False,
)

processors = model.attn_processors

assert len(processors) == 4
model.set_attn_processor(dict(processors))
59 changes: 57 additions & 2 deletions tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_dummy_components(self):
strides=[2],
latent_dim=8,
encoder_latent_dim=16,
downsampling_ratio=2,
downsampling_ratio=4,
sample_rate=24000,
)

Expand Down Expand Up @@ -158,6 +158,60 @@ def test_num_images_per_prompt(self):
def test_encode_prompt_works_in_isolation(self):
self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.")

def test_encode_prompt_returns_grad_bearing_embeds(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components())
pipe.to(device)

with torch.enable_grad():
prompt_embeds, _ = pipe.encode_prompt("soft ocean ambience", torch.device(device))
loss = prompt_embeds.float().sum()

self.assertTrue(prompt_embeds.requires_grad)
loss.backward()
self.assertTrue(any(param.grad is not None for param in pipe.text_encoder.parameters()))

def test_transformer_inputs_use_transformer_dtype(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components())
pipe.to(device)
pipe.transformer.to(dtype=torch.bfloat16)

observed_dtypes = []

def record_transformer_inputs(module, args, kwargs):
observed_dtypes.append(
{
"hidden_states": kwargs["hidden_states"].dtype,
"encoder_hidden_states": kwargs["encoder_hidden_states"].dtype,
"timestep": kwargs["timestep"].dtype,
"latent_cond": kwargs["latent_cond"].dtype,
}
)

hook = pipe.transformer.register_forward_pre_hook(record_transformer_inputs, with_kwargs=True)
inputs = self.get_dummy_inputs(device)
inputs.update(
{
"negative_prompt": "noise",
"guidance_scale": 4.0,
"output_type": "latent",
}
)

try:
output = pipe(**inputs).audios
finally:
hook.remove()

self.assertEqual(output.dtype, torch.bfloat16)
self.assertGreaterEqual(len(observed_dtypes), 2)
for dtypes in observed_dtypes:
self.assertEqual(dtypes["hidden_states"], torch.bfloat16)
self.assertEqual(dtypes["encoder_hidden_states"], torch.bfloat16)
self.assertEqual(dtypes["timestep"], torch.bfloat16)
self.assertEqual(dtypes["latent_cond"], torch.bfloat16)

def test_uniform_flow_match_scheduler_grid_matches_manual_updates(self):
num_inference_steps = 6
scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True)
Expand Down Expand Up @@ -203,9 +257,10 @@ def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self):
if not tokenizer_path.exists():
raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
pipe = LongCatAudioDiTPipeline.from_pretrained(
model_path,
tokenizer=tokenizer_path,
tokenizer=tokenizer,
torch_dtype=torch.float16,
local_files_only=True,
)
Expand Down
Loading