Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 80 additions & 34 deletions tests/models/transformers/test_models_transformer_sana.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,57 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import SanaTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]
class SanaTransformer2DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaTransformer2DModel

@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
def output_shape(self) -> tuple[int, ...]:
return (4, 32, 32)

@property
def input_shape(self) -> tuple[int, ...]:
return (4, 32, 32)

hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def uses_custom_attn_processor(self) -> bool:
return True

@property
def input_shape(self):
return (4, 32, 32)
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.9]

@property
def output_shape(self):
return (4, 32, 32)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"patch_size": 1,
"in_channels": 4,
"out_channels": 4,
Expand All @@ -75,9 +77,53 @@ def prepare_init_args_and_inputs_for_common(self):
"caption_channels": 8,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8

return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}


class TestSanaTransformer2D(SanaTransformer2DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Transformer 2D."""


class TestSanaTransformer2DMemory(SanaTransformer2DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Transformer 2D."""


class TestSanaTransformer2DTraining(SanaTransformer2DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Transformer 2D."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestSanaTransformer2DAttention(SanaTransformer2DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Transformer 2D."""


class TestSanaTransformer2DCompile(SanaTransformer2DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Transformer 2D."""


class TestSanaTransformer2DBitsAndBytes(SanaTransformer2DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Transformer 2D."""


class TestSanaTransformer2DTorchAo(SanaTransformer2DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Transformer 2D."""
116 changes: 77 additions & 39 deletions tests/models/transformers/test_models_transformer_sana_video.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,57 +13,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import SanaVideoTransformer3DModel

from ...testing_utils import (
enable_full_determinism,
torch_device,
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


enable_full_determinism()


class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class SanaVideoTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaVideoTransformer3DModel

@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)

hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def main_input_name(self) -> str:
return "hidden_states"

@property
def input_shape(self):
return (16, 2, 16, 16)
def uses_custom_attn_processor(self) -> bool:
return True

@property
def output_shape(self):
return (16, 2, 16, 16)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict[str, int | float | list[int] | tuple | str | bool]:
return {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 2,
Expand All @@ -82,16 +80,56 @@ def prepare_init_args_and_inputs_for_common(self):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12

return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}


class TestSanaVideoTransformer3D(SanaVideoTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Video Transformer 3D."""


class TestSanaVideoTransformer3DMemory(SanaVideoTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Video Transformer 3D."""


class TestSanaVideoTransformer3DTraining(SanaVideoTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Video Transformer 3D."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
class TestSanaVideoTransformer3DAttention(SanaVideoTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Video Transformer 3D."""


class TestSanaVideoTransformer3DCompile(SanaVideoTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Video Transformer 3D."""


class TestSanaVideoTransformer3DBitsAndBytes(SanaVideoTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Video Transformer 3D."""


def prepare_init_args_and_inputs_for_common(self):
return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestSanaVideoTransformer3DTorchAo(SanaVideoTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Video Transformer 3D."""
Loading