diff --git a/docs/source/en/api/models/motif_video_transformer_3d.md b/docs/source/en/api/models/motif_video_transformer_3d.md
new file mode 100644
index 000000000000..011058832ee2
--- /dev/null
+++ b/docs/source/en/api/models/motif_video_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# MotifVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in Motif-Video by the Motif Technologies Team.
+
+The model uses a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers and rotary positional embeddings (RoPE) for video generation.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import MotifVideoTransformer3DModel
+
+transformer = MotifVideoTransformer3DModel.from_pretrained("Motif-Technologies/Motif-Video-2B", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## MotifVideoTransformer3DModel
+
+[[autodoc]] MotifVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/motif_video.md b/docs/source/en/api/pipelines/motif_video.md
new file mode 100644
index 000000000000..9e0929599ea2
--- /dev/null
+++ b/docs/source/en/api/pipelines/motif_video.md
@@ -0,0 +1,123 @@
+
+
+# Motif-Video
+
+[Technical Report](https://arxiv.org/abs/2604.16503)
+
+Motif-Video is a 2B parameter diffusion transformer designed for text-to-video and image-to-video generation. It features a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers, Shared Cross-Attention for stable text-video alignment under long video sequences, T5Gemma2 text encoder, and rectified flow matching for velocity prediction.
+
+
+
+
+
+## Text-to-Video Generation
+
+Use `MotifVideoPipeline` for text-to-video generation:
+
+```python
+import torch
+from diffusers import MotifVideoPipeline
+from diffusers.utils import export_to_video
+
+
+pipe = MotifVideoPipeline.from_pretrained(
+ "Motif-Technologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+)
+pipe.to("cuda")
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair."
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=24)
+```
+
+## Image-to-Video Generation
+
+Use `MotifVideoImage2VideoPipeline` for image-to-video generation:
+
+```python
+import torch
+from diffusers import MotifVideoImage2VideoPipeline
+from diffusers.utils import export_to_video, load_image
+
+
+pipe = MotifVideoImage2VideoPipeline.from_pretrained(
+ "Motif-Technologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+)
+pipe.to("cuda")
+
+image = load_image("input_image.png")
+prompt = "A cinematic scene with vivid colors."
+negative_prompt = "worst quality, blurry, jittery, distorted"
+
+video = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "i2v_output.mp4", fps=24)
+```
+
+### Memory-efficient Inference
+
+For GPUs with less than 30GB VRAM (e.g., RTX 4090), use model CPU offloading:
+
+```bash
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+```
+
+```python
+import torch
+from diffusers import MotifVideoPipeline
+from diffusers.utils import export_to_video
+
+
+pipe = MotifVideoPipeline.from_pretrained(
+ "Motif-Technologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+)
+pipe.enable_model_cpu_offload()
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair."
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=24)
+```
+
+## MotifVideoPipeline
+
+[[autodoc]] MotifVideoPipeline
+ - all
+ - __call__
+
+## MotifVideoImage2VideoPipeline
+
+[[autodoc]] MotifVideoImage2VideoPipeline
+ - all
+ - __call__
+
+## MotifVideoPipelineOutput
+
+[[autodoc]] pipelines.motif_video.pipeline_output.MotifVideoPipelineOutput
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 1b1f6b3032b3..c5371837e76a 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -265,6 +265,7 @@
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
+ "MotifVideoTransformer3DModel",
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
@@ -637,6 +638,9 @@
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
+ "MotifVideoImage2VideoPipeline",
+ "MotifVideoPipeline",
+ "MotifVideoPipelineOutput",
"MusicLDMPipeline",
"NucleusMoEImagePipeline",
"OmniGenPipeline",
@@ -1087,6 +1091,7 @@
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
+ MotifVideoTransformer3DModel,
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
@@ -1434,6 +1439,9 @@
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
+ MotifVideoImage2VideoPipeline,
+ MotifVideoPipeline,
+ MotifVideoPipelineOutput,
MusicLDMPipeline,
NucleusMoEImagePipeline,
OmniGenPipeline,
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
index 0267dd481a89..372ce4f76e91 100644
--- a/src/diffusers/hooks/_helpers.py
+++ b/src/diffusers/hooks/_helpers.py
@@ -188,6 +188,10 @@ def _register_transformer_blocks_metadata():
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
+ from ..models.transformers.transformer_motif_video import (
+ MotifVideoSingleTransformerBlock,
+ MotifVideoTransformerBlock,
+ )
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
@@ -290,6 +294,22 @@ def _register_transformer_blocks_metadata():
),
)
+ # MotifVideo
+ TransformerBlockRegistry.register(
+ model_class=MotifVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=MotifVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index c7bb2de4437a..43fc8d897fe6 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -21,7 +21,11 @@
from typing_extensions import Self
from .. import __version__
-from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
+from ..models.model_loading_utils import (
+ _caching_allocator_warmup,
+ _determine_device_map,
+ _expand_device_map,
+)
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
@@ -194,6 +198,10 @@
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
+ "MotifVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
+ "default_subfolder": "transformer",
+ },
}
@@ -336,7 +344,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
disable_mmap = kwargs.pop("disable_mmap", False)
device_map = kwargs.pop("device_map", None)
- user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "single_file",
+ "framework": "pytorch",
+ }
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
@@ -393,7 +405,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
- original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
+ original_config=original_config,
+ checkpoint=checkpoint,
+ **config_mapping_kwargs,
)
else:
if config is not None:
@@ -465,7 +479,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
+ config=diffusers_model_config,
+ checkpoint=checkpoint,
+ **checkpoint_mapping_kwargs,
)
else:
diffusers_format_checkpoint = checkpoint
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index bb765c56d013..ff8e16aad447 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -121,6 +121,7 @@
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
+ _import_structure["transformers.transformer_motif_video"] = ["MotifVideoTransformer3DModel"]
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
@@ -247,6 +248,7 @@
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
+ MotifVideoTransformer3DModel,
NucleusMoEImageTransformer2DModel,
OmniGenTransformer2DModel,
OvisImageTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 5c64b5fc99fa..156b54e7f07d 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -44,6 +44,7 @@
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
+ from .transformer_motif_video import MotifVideoTransformer3DModel
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_ovis_image import OvisImageTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_motif_video.py b/src/diffusers/models/transformers/transformer_motif_video.py
new file mode 100644
index 000000000000..77f1abb0674f
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_motif_video.py
@@ -0,0 +1,1049 @@
+# Copyright 2026 Motif Technologies and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin, get_parameter_dtype
+from ..normalization import (
+ AdaLayerNormContinuous,
+ AdaLayerNormZero,
+ AdaLayerNormZeroSingle,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MotifVideoCrossAttnProcessor2_0:
+ """Attention processor for Motif-Video text cross-attention."""
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "MotifVideoCrossAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: "MotifVideoCrossAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ image_embed_seq_len: int = 0,
+ ) -> torch.Tensor:
+ txt_kv = encoder_hidden_states[:, image_embed_seq_len:, :]
+
+ text_mask = None
+ if attention_mask is not None:
+ text_mask = attention_mask[:, :, :, image_embed_seq_len - encoder_hidden_states.shape[1] :]
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(txt_kv)
+ value = attn.to_v(txt_kv)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=text_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class MotifVideoAttnProcessor2_0:
+ """Attention processor for Motif-Video self-attention."""
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: "MotifVideoAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Concatenate hidden states with encoder hidden states for joint attention if needed
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ # Project QKV
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # Normalize QK
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE
+ if image_rotary_emb is not None:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ split_idx = -encoder_hidden_states.shape[1]
+ query = torch.cat(
+ [
+ apply_rotary_emb(query[:, :split_idx, :, :], image_rotary_emb, sequence_dim=1),
+ query[:, split_idx:, :, :],
+ ],
+ dim=1,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, :split_idx, :, :], image_rotary_emb, sequence_dim=1),
+ key[:, split_idx:, :, :],
+ ],
+ dim=1,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # Add encoder conditioning QKV projections and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
+
+ # Compute attention with backend dispatch
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Apply output projections and split encoder states
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if attn.to_out is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.to_add_out is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+ if attn.to_out is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class MotifVideoCrossAttention(nn.Module, AttentionModuleMixin):
+ """Dedicated cross-attention module for Motif-Video text cross-attention."""
+
+ _default_processor_cls = MotifVideoCrossAttnProcessor2_0
+ _available_processors = [MotifVideoCrossAttnProcessor2_0]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ qk_norm: str = "rms_norm",
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if qk_norm == "rms_norm":
+ self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ else:
+ self.norm_q = None
+ self.norm_k = None
+
+ self.to_out = nn.ModuleList(
+ [
+ nn.Linear(self.inner_dim, query_dim, bias=out_bias),
+ nn.Dropout(dropout),
+ ]
+ )
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_embed_seq_len: int = 0,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ image_rotary_emb,
+ image_embed_seq_len,
+ )
+
+
+class MotifVideoAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = MotifVideoAttnProcessor2_0
+ _available_processors = [MotifVideoAttnProcessor2_0]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: int | None = None,
+ added_proj_bias: bool | None = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ pre_only: bool = False,
+ context_pre_only: bool = False,
+ qk_norm: str = "rms_norm",
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.pre_only = pre_only
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+ self.context_pre_only = context_pre_only
+
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ # QK Norm
+ if qk_norm == "rms_norm":
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "layer_norm":
+ self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ else:
+ self.norm_q = None
+ self.norm_k = None
+
+ if not pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+ else:
+ self.to_out = None
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ if not context_pre_only:
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+ else:
+ self.to_add_out = None
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
+ self.to_add_out = None
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ image_rotary_emb: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class MotifVideoPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ return hidden_states
+
+
+class MotifVideoAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class MotifVideoConditionEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ ) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ compute_dtype = get_parameter_dtype(self.timestep_embedder)
+ if compute_dtype != torch.int8:
+ timesteps_proj = timesteps_proj.to(compute_dtype)
+ conditioning = self.timestep_embedder(timesteps_proj) # (N, D)
+
+ return conditioning
+
+
+class MotifVideoRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int,
+ patch_size_t: int,
+ rope_dim: List[int],
+ theta: float = 256.0,
+ ):
+ """
+ Rotary Positional Embedding (RoPE) for video latents.
+
+ Args:
+ patch_size (`int`): Spatial patch size.
+ patch_size_t (`int`): Temporal patch size.
+ rope_dim (`List[int]`): Dimensions for RoPE across [Time, Height, Width] axes.
+ theta (`float`, *optional*, defaults to 256.0): Base frequency for rotary embeddings.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ rope_sizes = [
+ num_frames // self.patch_size_t,
+ height // self.patch_size,
+ width // self.patch_size,
+ ]
+
+ axes_grids = []
+ for i in range(3):
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij")
+ grid = torch.stack(grid, dim=0)
+
+ freqs = []
+ is_mps = hidden_states.device.type == "mps"
+ is_npu = hidden_states.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(3):
+ freq = get_1d_rotary_pos_embed(
+ dim=self.rope_dim[i],
+ pos=grid[i].reshape(-1),
+ theta=self.theta,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
+ return freqs_cos, freqs_sin
+
+
+class MotifVideoImageProjection(nn.Module):
+ def __init__(self, in_features: int, hidden_size: int):
+ super().__init__()
+ self.norm_in = nn.LayerNorm(in_features)
+ self.linear_1 = nn.Linear(in_features, in_features)
+ self.act_fn = nn.GELU()
+ self.linear_2 = nn.Linear(in_features, hidden_size)
+ self.norm_out = nn.LayerNorm(hidden_size)
+
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm_in(image_embeds)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.norm_out(hidden_states)
+ return hidden_states
+
+
+class MotifVideoSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ enable_text_cross_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = MotifVideoAttention(
+ query_dim=hidden_size,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=hidden_size,
+ bias=True,
+ pre_only=True,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ processor=MotifVideoAttnProcessor2_0(),
+ )
+
+ self.cross_attn = (
+ MotifVideoCrossAttention(
+ query_dim=hidden_size,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=True,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+ if enable_text_cross_attention
+ else None
+ )
+
+ self.enable_text_cross_attention = enable_text_cross_attention
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_embed_seq_len: int = 0,
+ ) -> torch.Tensor:
+ encoder_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-encoder_seq_length, :],
+ norm_hidden_states[:, -encoder_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 3. Text cross-attention
+ if self.cross_attn is not None:
+ cross_output = self.cross_attn(
+ hidden_states=attn_output,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ image_embed_seq_len=image_embed_seq_len,
+ )
+ attn_output = attn_output + cross_output
+
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 4. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-encoder_seq_length, :],
+ hidden_states[:, -encoder_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class MotifVideoTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ enable_text_cross_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
+
+ self.attn = MotifVideoAttention(
+ query_dim=hidden_size,
+ added_kv_proj_dim=hidden_size,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=hidden_size,
+ bias=True,
+ context_pre_only=False,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ processor=MotifVideoAttnProcessor2_0(),
+ )
+
+ self.cross_attn = (
+ MotifVideoCrossAttention(
+ query_dim=hidden_size,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=True,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+ if enable_text_cross_attention
+ else None
+ )
+
+ self.enable_text_cross_attention = enable_text_cross_attention
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_embed_seq_len: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+
+ # 4. Text cross-attention
+ if self.cross_attn is not None:
+ cross_output = self.cross_attn(
+ hidden_states=attn_output,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ image_embed_seq_len=image_embed_seq_len,
+ )
+ hidden_states = hidden_states + cross_output
+
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 5. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class MotifVideoTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Motif-Video model.
+
+ Args:
+ in_channels (`int`, defaults to `33`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_decoder_layers (`int`, defaults to `0`):
+ The number of decoder layers in single-stream blocks.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the temporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ image_embed_dim (`int`, *optional*):
+ Input dimension of image embeddings from a vision encoder. If provided, enables image conditioning.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _repeated_blocks = ["MotifVideoSingleTransformerBlock", "MotifVideoTransformerBlock"]
+ _no_split_modules = [
+ "MotifVideoTransformerBlock",
+ "MotifVideoSingleTransformerBlock",
+ "MotifVideoPatchEmbed",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 33,
+ out_channels: int = 16,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_decoder_layers: int = 0,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ text_embed_dim: int = 4096,
+ image_embed_dim: int | None = None,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
+ enable_text_cross_attention_dual: bool = False,
+ enable_text_cross_attention_single: bool = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
+
+ # First frame conditioning: Image conditioning embedders
+ self.image_embed_dim = image_embed_dim
+ if image_embed_dim is not None:
+ self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
+
+ self.time_text_embed = MotifVideoConditionEmbedding(inner_dim)
+
+ # 2. RoPE
+ self.rope = MotifVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
+
+ # Cross-attention config
+ self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
+ self.enable_text_cross_attention_single = enable_text_cross_attention_single
+
+ # 3. Dual stream transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ MotifVideoTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ norm_type=norm_type,
+ enable_text_cross_attention=enable_text_cross_attention_dual,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
+ num_encoder_single = num_single_layers - num_decoder_layers
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ MotifVideoSingleTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ norm_type=norm_type,
+ enable_text_cross_attention=enable_text_cross_attention_single
+ if i < num_encoder_single
+ else False,
+ )
+ for i in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(
+ inner_dim,
+ inner_dim,
+ elementwise_affine=False,
+ eps=1e-6,
+ norm_type=norm_type,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ # Verify cross-attention config matches actual block state.
+ # Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
+ for i, block in enumerate(self.transformer_blocks):
+ if block.enable_text_cross_attention != enable_text_cross_attention_dual:
+ raise ValueError(
+ f"transformer_blocks[{i}].enable_text_cross_attention="
+ f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
+ f"Check checkpoint config.json key names match __init__ parameters."
+ )
+ for i, block in enumerate(self.single_transformer_blocks):
+ expected = enable_text_cross_attention_single if i < num_encoder_single else False
+ if block.enable_text_cross_attention != expected:
+ raise ValueError(
+ f"single_transformer_blocks[{i}].enable_text_cross_attention="
+ f"{block.enable_text_cross_attention}, expected {expected}. "
+ f"Check checkpoint config.json key names match __init__ parameters."
+ )
+
+ self.gradient_checkpointing = False
+ self.num_decoder_layers = num_decoder_layers
+
+ def _maybe_gradient_checkpoint_block(self, block, *args):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ return self._gradient_checkpointing_func(block, *args)
+ return block(*args)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor | None = None,
+ image_embeds: torch.Tensor | None = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of the MotifVideoTransformer3DModel.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input latent tensor of shape `(batch_size, channels, num_frames, height, width)`.
+ timestep (`torch.LongTensor`):
+ Diffusion timesteps of shape `(batch_size,)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Text conditioning of shape `(batch_size, sequence_length, embed_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ Mask for text conditioning of shape `(batch_size, sequence_length)`.
+ image_embeds (`torch.Tensor`, *optional*):
+ Image embeddings from vision encoder of shape `(batch_size, num_tokens, embed_dim)`.
+ attention_kwargs (`dict`, *optional*):
+ Additional arguments for attention processors.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`].
+
+ Returns:
+ [`~models.modeling_outputs.Transformer2DModelOutput`] or `tuple`:
+ The predicted samples.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, _, num_frames, height, width = hidden_states.shape
+ p, p_t = self.config.patch_size, self.config.patch_size_t
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Conditional embeddings
+ temb = self.time_text_embed(timestep)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # First frame conditioning: Image embeddings from vision encoder
+ if image_embeds is not None:
+ image_embeds = self.image_embedder(image_embeds)
+ encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
+ if encoder_attention_mask is not None:
+ image_mask = torch.ones(
+ image_embeds.shape[0],
+ image_embeds.shape[1],
+ device=encoder_attention_mask.device,
+ dtype=encoder_attention_mask.dtype,
+ )
+ encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
+
+ # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
+ image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
+
+ decoder_hidden_states = hidden_states.clone()
+
+ if encoder_attention_mask is not None:
+ attention_mask = F.pad(
+ encoder_attention_mask.to(torch.bool),
+ (hidden_states.shape[1], 0),
+ value=True,
+ )
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+ else:
+ attention_mask = None
+
+ # 3. Dual stream transformer blocks
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ image_embed_seq_len,
+ )
+
+ # 4. Single stream transformer blocks (Encoder)
+ single_transformer_blocks = self.single_transformer_blocks
+
+ for block in single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]:
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ image_embed_seq_len,
+ )
+
+ # 5. Single stream transformer blocks (Decoder)
+ if self.num_decoder_layers > 0:
+ encoder_hidden_states = hidden_states
+ attention_mask = None
+
+ for block in single_transformer_blocks[-self.num_decoder_layers :]:
+ decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ decoder_hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ )
+
+ hidden_states = decoder_hidden_states
+
+ # 6. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ -1,
+ p_t,
+ p,
+ p,
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(
+ sample=hidden_states,
+ )
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index f0fc7585bf31..d26940e7d613 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -347,6 +347,11 @@
]
)
_import_structure["mochi"] = ["MochiPipeline"]
+ _import_structure["motif_video"] = [
+ "MotifVideoPipeline",
+ "MotifVideoImage2VideoPipeline",
+ "MotifVideoPipelineOutput",
+ ]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["ernie_image"] = ["ErnieImagePipeline"]
_import_structure["ovis_image"] = ["OvisImagePipeline"]
@@ -792,6 +797,11 @@
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
+ from .motif_video import (
+ MotifVideoImage2VideoPipeline,
+ MotifVideoPipeline,
+ MotifVideoPipelineOutput,
+ )
from .nucleusmoe_image import NucleusMoEImagePipeline
from .omnigen import OmniGenPipeline
from .ovis_image import OvisImagePipeline
diff --git a/src/diffusers/pipelines/motif_video/__init__.py b/src/diffusers/pipelines/motif_video/__init__.py
new file mode 100644
index 000000000000..ee1d7c72ee65
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_motif_video"] = ["MotifVideoPipeline"]
+ _import_structure["pipeline_motif_video_image2video"] = ["MotifVideoImage2VideoPipeline"]
+ _import_structure["pipeline_output"] = ["MotifVideoPipelineOutput"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_motif_video import MotifVideoPipeline
+ from .pipeline_motif_video_image2video import MotifVideoImage2VideoPipeline
+ from .pipeline_output import MotifVideoPipelineOutput
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py
new file mode 100644
index 000000000000..b3ce381549c3
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py
@@ -0,0 +1,797 @@
+# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+# NOTE: This pipeline requires transformers>=5.1.0 for T5Gemma2Encoder support.
+# The T5Gemma2Encoder class is only available in transformers 5.1.0 and later.
+from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...guiders import BaseGuidance
+from ...models import AutoencoderKLWan
+from ...models.transformers import MotifVideoTransformer3DModel
+from ...schedulers import SchedulerMixin
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import MotifVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import MotifVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the Motif-Video pipeline
+ >>> motif_video_model_id = "Motif-Technologies/Motif-Video-2B"
+ >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=1280,
+ ... height=736,
+ ... num_frames=121,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ timesteps: list[int] | None = None,
+ sigmas: list[float] | None = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`list[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`list[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MotifVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Motif-Video.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`MotifVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents. Should be an
+ instance of a class inheriting from `SchedulerMixin`, such as [`DPMSolverMultistepScheduler`]. If not
+ provided, uses the scheduler attached to the pretrained model.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5Gemma2Encoder`]):
+ Primary text encoder for encoding text prompts into embeddings.
+ tokenizer ([`PreTrainedTokenizerBase`]):
+ Tokenizer corresponding to the primary text encoder.
+ guider ([`BaseGuidance`]):
+ The guidance method to use. Should be an instance of a class inheriting from `BaseGuidance`, such as
+ [`ClassifierFreeGuidance`], [`AdaptiveProjectedGuidance`], or [`SkipLayerGuidance`]. If not provided,
+ defaults to `ClassifierFreeGuidance`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = ["feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: SchedulerMixin,
+ vae: AutoencoderKLWan,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: MotifVideoTransformer3DModel,
+ guider: BaseGuidance,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
+ )
+
+ def _get_prompt_embeds(
+ self,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Optional[Union[str, List[str]]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_inputs = BatchEncoding(
+ {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
+ )
+
+ prompt_embeds = text_encoder(**text_inputs)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, text_inputs.attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] | None = None,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be encoded.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ max_sequence_length (`int`, defaults to 512):
+ Maximum sequence length for the tokenizer.
+ device (`torch.device`, *optional*):
+ Device to place tensors on.
+ dtype (`torch.dtype`, *optional*):
+ Data type for tensors.
+
+ Returns:
+ `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]`:
+ A tuple containing:
+ - `prompt_embeds`: The text embeddings for the positive prompt
+ - `negative_prompt_embeds`: The text embeddings for the negative prompt (None if not using guidance)
+ - `prompt_attention_mask`: The attention mask for the positive prompt
+ - `negative_prompt_attention_mask`: The attention mask for the negative prompt (None if not using
+ guidance)
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ seq_len = prompt_embeds.shape[1]
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.bool()
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ # Compute negative embeddings if needed
+ if negative_prompt_embeds is None and negative_prompt is not None:
+ # Prepare negative_prompt to match batch_size
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ else:
+ negative_prompt = list(negative_prompt)
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=negative_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.bool()
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(batch_size, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat_interleave(
+ num_videos_per_prompt, dim=0
+ )
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ batch_size,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None:
+ if not isinstance(negative_prompt, (str, list)):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
+ raise ValueError(
+ f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) / latents_std
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is None:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+ return latents
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ vae_batch_size: int | None = None,
+ ):
+ r"""
+ The call function to the pipeline for text-to-video generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance.
+ height (`int`, defaults to `736`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `121`):
+ The number of video frames to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ PyTorch Generator object(s) for deterministic generation.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `"pil"`, `"np"`, or `"latent"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Arguments passed to the attention processor.
+ callback_on_step_end (`Callable`, *optional*):
+ A function or subclass of `PipelineCallback` or `MultiPipelineCallbacks` called at the end of each
+ denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ Maximum sequence length for the tokenizer.
+ vae_batch_size (`int`, *optional*):
+ Batch size for VAE decoding. If provided and latents batch size is larger, VAE decoding will be done in
+ chunks.
+
+ Examples:
+
+ Returns:
+ [`~MotifVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list of generated video frames.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 2. Check inputs
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ batch_size=batch_size,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ # Ensure negative prompt is provided for multi-condition guiders
+ if (
+ self.guider is not None
+ and self.guider.num_conditions > 1
+ and negative_prompt_embeds is None
+ and negative_prompt is None
+ ):
+ negative_prompt = ""
+
+ prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = (
+ self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ )
+
+ # 4. Prepare latents
+ num_channels_latents = self.vae.config.z_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ packed_latent_height = latent_height // self.transformer_spatial_patch_size
+ packed_latent_width = latent_width // self.transformer_spatial_patch_size
+ packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
+ video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ # Prepare conditioning tensors (T2V mode: no first-frame conditioning)
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape
+ latent_condition = torch.zeros(
+ batch_size,
+ latent_channels,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ latent_mask = torch.zeros(
+ batch_size,
+ 1,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate current latents with conditioning: [latents | latent_condition | latent_mask]
+ hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
+
+ timestep = t.expand(latents.shape[0])
+
+ # Guider: collect model inputs
+ if self.guider is not None and self.guider.num_conditions == 1:
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds,),
+ "encoder_attention_mask": (prompt_attention_mask,),
+ }
+ else:
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ),
+ }
+
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ noise_pred = self.transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0].clone()
+
+ guider_state_batch.noise_pred = noise_pred
+ self.guider.cleanup_models(self.transformer)
+
+ noise_pred = self.guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ if "negative_prompt_embeds" in callback_outputs:
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
+ if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
+ video_chunks = []
+ for i in range(0, latents.shape[0], vae_batch_size):
+ chunk = latents[i : i + vae_batch_size]
+ video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
+ video = torch.cat(video_chunks, dim=0)
+ del video_chunks
+ else:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MotifVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py
new file mode 100644
index 000000000000..57acbb2189aa
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py
@@ -0,0 +1,912 @@
+# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+# NOTE: This pipeline requires transformers>=5.1.0 for T5Gemma2Encoder support.
+# The T5Gemma2Encoder class is only available in transformers 5.1.0 and later.
+from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...guiders import BaseGuidance
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLWan
+from ...models.transformers import MotifVideoTransformer3DModel
+from ...schedulers import SchedulerMixin
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import MotifVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import MotifVideoImage2VideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> # Load the Motif-Video image-to-video pipeline
+ >>> motif_video_model_id = "Motif-Technologies/Motif-Video-2B"
+ >>> pipe = MotifVideoImage2VideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Load an image
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.png"
+ ... )
+
+ >>> prompt = "An astronaut is walking on the moon surface, kicking up dust with each step"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=1280,
+ ... height=736,
+ ... num_frames=121,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ timesteps: list[int] | None = None,
+ sigmas: list[float] | None = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`list[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`list[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MotifVideoImage2VideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-video generation using Motif-Video with first frame conditioning.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`MotifVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents. Should be an
+ instance of a class inheriting from `SchedulerMixin`, such as [`DPMSolverMultistepScheduler`]. If not
+ provided, uses the scheduler attached to the pretrained model.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5Gemma2Encoder`]):
+ Primary text encoder for encoding text prompts into embeddings.
+ tokenizer ([`PreTrainedTokenizerBase`]):
+ Tokenizer corresponding to the primary text encoder.
+ feature_extractor ([`SiglipImageProcessor`]):
+ Image processor for the SigLIP vision encoder.
+ guider ([`BaseGuidance`]):
+ The guidance method to use. Should be an instance of a class inheriting from `BaseGuidance`, such as
+ [`ClassifierFreeGuidance`], [`AdaptiveProjectedGuidance`], or [`SkipLayerGuidance`]. If not provided,
+ defaults to `ClassifierFreeGuidance`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: SchedulerMixin,
+ vae: AutoencoderKLWan,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: MotifVideoTransformer3DModel,
+ guider: BaseGuidance,
+ feature_extractor: SiglipImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ guider=guider,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
+ )
+
+ # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline._get_prompt_embeds
+ def _get_prompt_embeds(
+ self,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Optional[Union[str, List[str]]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_inputs = BatchEncoding(
+ {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
+ )
+
+ prompt_embeds = text_encoder(**text_inputs)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, text_inputs.attention_mask
+
+ # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] | None = None,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be encoded.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ max_sequence_length (`int`, defaults to 512):
+ Maximum sequence length for the tokenizer.
+ device (`torch.device`, *optional*):
+ Device to place tensors on.
+ dtype (`torch.dtype`, *optional*):
+ Data type for tensors.
+
+ Returns:
+ `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]`:
+ A tuple containing:
+ - `prompt_embeds`: The text embeddings for the positive prompt
+ - `negative_prompt_embeds`: The text embeddings for the negative prompt (None if not using guidance)
+ - `prompt_attention_mask`: The attention mask for the positive prompt
+ - `negative_prompt_attention_mask`: The attention mask for the negative prompt (None if not using
+ guidance)
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ seq_len = prompt_embeds.shape[1]
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.bool()
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ # Compute negative embeddings if needed
+ if negative_prompt_embeds is None and negative_prompt is not None:
+ # Prepare negative_prompt to match batch_size
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ else:
+ negative_prompt = list(negative_prompt)
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=negative_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.bool()
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(batch_size, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat_interleave(
+ num_videos_per_prompt, dim=0
+ )
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ @staticmethod
+ def _get_image_embeds(
+ image_encoder,
+ feature_extractor: SiglipImageProcessor,
+ image,
+ device: torch.device,
+ ) -> torch.Tensor:
+ """Helper to encode single image with SigLIP."""
+ image_encoder_dtype = next(image_encoder.parameters()).dtype
+
+ if isinstance(image, torch.Tensor):
+ image = image.float()
+ image = feature_extractor.preprocess(
+ images=image,
+ do_resize=True,
+ do_rescale=False,
+ do_normalize=True,
+ do_convert_rgb=True,
+ return_tensors="pt",
+ )
+
+ image = image.to(device=device, dtype=image_encoder_dtype)
+ return image_encoder(**image).last_hidden_state
+
+ def _prepare_first_frame_conditioning(
+ self,
+ video: torch.Tensor,
+ latents: torch.Tensor,
+ use_conditioning: bool,
+ generator: Optional[torch.Generator] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Prepare first frame conditioning tensors.
+
+ For I2V mode:
+ 1. Extract and VAE-encode first frame from video
+ 2. Create latent_condition with first frame latents at frame 0
+ 3. Create latent_mask with 1.0 at frame 0
+ 4. Get image_embeds from vision encoder
+
+ For T2V mode:
+ 1. Return zeros for latent_condition and latent_mask, None for image_embeds
+
+ Args:
+ video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
+ latents: Latents [batch_size, channels, num_frames, height, width]
+ use_conditioning: Whether to use first-frame conditioning (True for I2V)
+ generator: Optional random number generator
+
+ Returns:
+ Tuple of (latent_condition, latent_mask, image_embeds).
+ """
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape
+ device = latents.device
+ dtype = latents.dtype
+
+ use_conditioning = use_conditioning and (latent_num_frames > 1)
+
+ latent_condition = torch.zeros(
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
+ )
+ latent_mask = torch.zeros(
+ batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
+ )
+ image_embeds = None
+
+ if use_conditioning:
+ # video shape: [B, F, C, H, W] -> [B, C, F, H, W] for VAE
+ first_frame_latents = self.vae.encode(video[:, 0:1].permute(0, 2, 1, 3, 4)).latent_dist.sample(
+ generator=generator
+ )
+ first_frame_latents = self._normalize_latents(
+ latents=first_frame_latents,
+ latents_mean=self.vae.config.latents_mean,
+ latents_std=self.vae.config.latents_std,
+ )
+
+ latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
+ latent_condition[:, :, 1:, :, :] = 0
+
+ latent_mask[:, :, 0] = 1.0
+
+ first_frame_vision = video[:, 0] # [B, C, H, W]
+ first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
+
+ if self.text_encoder is not None:
+ image_embeds = self._get_image_embeds(
+ image_encoder=self.text_encoder.vision_tower,
+ feature_extractor=self.feature_extractor,
+ image=first_frame_vision,
+ device=device,
+ )
+
+ return latent_condition, latent_mask, image_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ batch_size,
+ image,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}."
+ )
+
+ if image is None:
+ raise ValueError("`image` is required for image-to-video generation.")
+
+ if image is not None:
+ if isinstance(image, list):
+ if len(image) != 1:
+ raise ValueError(
+ f"`image` must be a single image, got a list of {len(image)} images. "
+ "For image-to-video generation, only a single first frame is supported."
+ )
+ elif isinstance(image, torch.Tensor):
+ if image.dim() not in (3, 4):
+ raise ValueError(
+ f"`image` must be a 3D tensor [C, H, W] or 4D tensor [B, C, H, W], got {image.dim()}D"
+ )
+ if image.dim() == 4 and image.shape[0] != 1:
+ raise ValueError(f"`image` batch size must be 1 when passed as a 4D tensor, got {image.shape[0]}")
+ elif isinstance(image, np.ndarray):
+ if image.ndim not in (3, 4):
+ raise ValueError(
+ f"`image` must be a 3D array [H, W, C] or 4D array [B, H, W, C], got {image.ndim}D"
+ )
+ if image.ndim == 4 and image.shape[0] != 1:
+ raise ValueError(f"`image` batch size must be 1 when passed as a 4D array, got {image.shape[0]}")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None:
+ if not isinstance(negative_prompt, (str, list)):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
+ raise ValueError(
+ f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape, "
+ f"got {prompt_embeds.shape} and {negative_prompt_embeds.shape}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std + latents_mean
+ return latents
+
+ # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is None:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+ return latents
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for image-to-video generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to use as the first frame for video generation.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the video generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation.
+ height (`int`, defaults to `736`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `121`):
+ The number of video frames to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ PyTorch Generator object(s) for deterministic generation.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Arguments passed to the attention processor.
+ callback_on_step_end (`Callable`, *optional*):
+ A function or subclass of `PipelineCallback` called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ Maximum sequence length for the tokenizer.
+
+ Examples:
+
+ Returns:
+ [`~MotifVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list of generated video frames.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 2. Check inputs
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ batch_size=batch_size,
+ image=image,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ device = self._execution_device
+
+ # 3. Preprocess image
+ # preprocess_video expects a list of video frames
+ if not isinstance(image, list):
+ image = [image]
+
+ video = self.video_processor.preprocess_video(image, height=height, width=width)
+ # preprocess_video returns (B, C, T, H, W), permute to (B, T, C, H, W)
+ video = video.permute(0, 2, 1, 3, 4)
+ video = video.to(device=device, dtype=self.transformer.dtype)
+
+ # 4. Prepare latents
+ num_channels_latents = self.vae.config.z_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare text embeddings
+ # Ensure negative prompt is provided for multi-condition guiders
+ if (
+ self.guider is not None
+ and self.guider.num_conditions > 1
+ and negative_prompt_embeds is None
+ and negative_prompt is None
+ ):
+ negative_prompt = ""
+
+ prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = (
+ self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ )
+
+ # 6. First frame conditioning
+ latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
+ video,
+ latents,
+ use_conditioning=True,
+ generator=generator,
+ )
+
+ # Repeat conditioning tensors for each generation per prompt
+ if num_videos_per_prompt > 1:
+ latent_condition = latent_condition.repeat_interleave(num_videos_per_prompt, dim=0)
+ latent_mask = latent_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+ if image_embeds is not None:
+ image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ # 7. Prepare timesteps
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ packed_latent_height = latent_height // self.transformer_spatial_patch_size
+ packed_latent_width = latent_width // self.transformer_spatial_patch_size
+ packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
+ video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate: [latents | latent_condition | latent_mask]
+ hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
+
+ timestep = t.expand(latents.shape[0])
+
+ if self.guider is not None and self.guider.num_conditions == 1:
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds,),
+ "encoder_attention_mask": (prompt_attention_mask,),
+ }
+ else:
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_attention_mask, negative_prompt_attention_mask),
+ }
+
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ noise_pred = self.transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ image_embeds=image_embeds,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0].clone()
+
+ guider_state_batch.noise_pred = noise_pred
+ self.guider.cleanup_models(self.transformer)
+
+ noise_pred = self.guider(guider_state)[0]
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ if "negative_prompt_embeds" in callback_outputs:
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MotifVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_output.py b/src/diffusers/pipelines/motif_video/pipeline_output.py
new file mode 100644
index 000000000000..aa0b2b83b323
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class MotifVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Motif-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 9bfb73c1999e..0ce20a4f7d97 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1560,6 +1560,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MotifVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class MotionAdapter(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index cfa1318783f3..407a13b7496d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2807,6 +2807,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class MotifVideoImage2VideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class MotifVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class MotifVideoPipelineOutput(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class MusicLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/transformers/test_models_transformer_motif_video.py b/tests/models/transformers/test_models_transformer_motif_video.py
new file mode 100644
index 000000000000..d3ac3a874927
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_motif_video.py
@@ -0,0 +1,191 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from diffusers import MotifVideoTransformer3DModel
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
+from ..testing_utils import (
+ AttentionTesterMixin,
+ BaseModelTesterConfig,
+ LoraTesterMixin,
+ MemoryTesterMixin,
+ ModelTesterMixin,
+ TorchCompileTesterMixin,
+ TrainingTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class MotifVideoTransformerTesterConfig(BaseModelTesterConfig):
+ @property
+ def model_class(self):
+ return MotifVideoTransformer3DModel
+
+ @property
+ def pretrained_model_name_or_path(self):
+ return "" # TODO: Set Hub repository ID
+
+ @property
+ def pretrained_model_kwargs(self):
+ return {"subfolder": "transformer"}
+
+ @property
+ def generator(self):
+ return torch.Generator("cpu").manual_seed(0)
+
+ @property
+ def main_input_name(self) -> str:
+ return "hidden_states"
+
+ @property
+ def input_shape(self) -> tuple[int, ...]:
+ return (1, 33, 9, 16, 16)
+
+ @property
+ def output_shape(self) -> tuple[int, ...]:
+ return (1, 16, 9, 16, 16)
+
+ def get_init_dict(self) -> dict[str, int | list[int] | float | str | bool]:
+ return {
+ "in_channels": 33,
+ "out_channels": 16,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_decoder_layers": 0,
+ "mlp_ratio": 4.0,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "qk_norm": "rms_norm",
+ "norm_type": "layer_norm",
+ "text_embed_dim": 32,
+ "image_embed_dim": 4,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (4, 4, 4),
+ "enable_text_cross_attention_dual": False,
+ "enable_text_cross_attention_single": False,
+ }
+
+ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
+ batch_size = 1
+ num_channels = 33
+ num_frames = 9
+ height = 16
+ width = 16
+ text_embed_dim = 32
+ sequence_length = 12
+
+ return {
+ "hidden_states": randn_tensor(
+ (batch_size, num_channels, num_frames, height, width),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "encoder_hidden_states": randn_tensor(
+ (batch_size, sequence_length, text_embed_dim),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
+ }
+
+
+class TestMotifVideoTransformerModel(MotifVideoTransformerTesterConfig, ModelTesterMixin):
+ pass
+
+
+class TestMotifVideoTransformerMemory(MotifVideoTransformerTesterConfig, MemoryTesterMixin):
+ pass
+
+
+class TestMotifVideoTransformerTorchCompile(MotifVideoTransformerTesterConfig, TorchCompileTesterMixin):
+ @property
+ def different_shapes_for_compilation(self):
+ return [(4, 4), (4, 8), (8, 8)]
+
+ def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
+ batch_size = 1
+ num_channels = 33
+ num_frames = 9
+ text_embed_dim = 32
+ sequence_length = 12
+
+ return {
+ "hidden_states": randn_tensor(
+ (batch_size, num_channels, num_frames, height, width),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "encoder_hidden_states": randn_tensor(
+ (batch_size, sequence_length, text_embed_dim),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
+ }
+
+
+class TestMotifVideoTransformerLora(MotifVideoTransformerTesterConfig, LoraTesterMixin):
+ pass
+
+
+class TestMotifVideoTransformerTraining(MotifVideoTransformerTesterConfig, TrainingTesterMixin):
+ pass
+
+
+class TestMotifVideoTransformerAttention(MotifVideoTransformerTesterConfig, AttentionTesterMixin):
+ pass
+
+
+class TestMotifVideoTransformerLoraHotSwappingForModel(
+ MotifVideoTransformerTesterConfig, LoraHotSwappingForModelTesterMixin
+):
+ @property
+ def different_shapes_for_compilation(self):
+ return [(4, 4), (4, 8), (8, 8)]
+
+ def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
+ batch_size = 1
+ num_channels = 33
+ num_frames = 9
+ text_embed_dim = 32
+ sequence_length = 12
+
+ return {
+ "hidden_states": randn_tensor(
+ (batch_size, num_channels, num_frames, height, width),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "encoder_hidden_states": randn_tensor(
+ (batch_size, sequence_length, text_embed_dim),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
+ }
diff --git a/tests/pipelines/motif_video/__init__.py b/tests/pipelines/motif_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/motif_video/test_motif_video.py b/tests/pipelines/motif_video/test_motif_video.py
new file mode 100644
index 000000000000..7bd4332ee29f
--- /dev/null
+++ b/tests/pipelines/motif_video/test_motif_video.py
@@ -0,0 +1,144 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import (
+ AutoTokenizer,
+ T5Gemma2Encoder,
+ T5Gemma2EncoderConfig,
+ T5Gemma2TextConfig,
+)
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoPipeline
+from diffusers.guiders import AdaptiveProjectedGuidance
+from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class MotifVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = MotifVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "guidance_scale"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ # Build a tiny T5Gemma2Encoder to match the pipeline's expected text_encoder type
+ text_config = T5Gemma2TextConfig(
+ hidden_size=32,
+ num_hidden_layers=1,
+ num_attention_heads=2,
+ intermediate_size=64,
+ vocab_size=1104,
+ max_position_embeddings=128,
+ head_dim=16,
+ num_key_value_heads=2,
+ dropout_rate=0.0,
+ )
+ encoder_config = T5Gemma2EncoderConfig(text_config=text_config)
+ text_encoder = T5Gemma2Encoder(encoder_config)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = MotifVideoTransformer3DModel(
+ in_channels=33,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=1,
+ num_single_layers=1,
+ mlp_ratio=4.0,
+ patch_size=1,
+ patch_size_t=1,
+ qk_norm="rms_norm",
+ text_embed_dim=32,
+ rope_axes_dim=(4, 4, 4),
+ )
+
+ guider = AdaptiveProjectedGuidance()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ "guider": guider,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A test video",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 16, 16, 3))
diff --git a/tests/pipelines/motif_video/test_motif_video_image2video.py b/tests/pipelines/motif_video/test_motif_video_image2video.py
new file mode 100644
index 000000000000..91e5ca88988e
--- /dev/null
+++ b/tests/pipelines/motif_video/test_motif_video_image2video.py
@@ -0,0 +1,199 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionConfig,
+ T5Gemma2Encoder,
+ T5Gemma2EncoderConfig,
+ T5Gemma2TextConfig,
+)
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoImage2VideoPipeline
+from diffusers.guiders import AdaptiveProjectedGuidance
+from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class MotifVideoImage2VideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = MotifVideoImage2VideoPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"guidance_scale"}
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ # Build a tiny T5Gemma2Encoder to match the pipeline's expected text_encoder type
+ text_config = T5Gemma2TextConfig(
+ hidden_size=32,
+ num_hidden_layers=1,
+ num_attention_heads=2,
+ intermediate_size=64,
+ vocab_size=1104,
+ max_position_embeddings=128,
+ head_dim=16,
+ num_key_value_heads=2,
+ dropout_rate=0.0,
+ )
+
+ vision_config = SiglipVisionConfig(
+ hidden_size=4,
+ num_hidden_layers=1,
+ num_attention_heads=2,
+ intermediate_size=64,
+ image_size=16,
+ patch_size=4,
+ num_channels=3,
+ )
+
+ encoder_config = T5Gemma2EncoderConfig(
+ text_config=text_config,
+ vision_config=vision_config,
+ )
+ text_encoder = T5Gemma2Encoder(encoder_config)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ feature_extractor = SiglipImageProcessor(
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ size={"height": 16, "width": 16},
+ )
+
+ torch.manual_seed(0)
+ transformer = MotifVideoTransformer3DModel(
+ in_channels=33,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=1,
+ num_single_layers=1,
+ mlp_ratio=4.0,
+ patch_size=1,
+ patch_size_t=1,
+ qk_norm="rms_norm",
+ text_embed_dim=32,
+ image_embed_dim=4,
+ rope_axes_dim=(4, 4, 4),
+ )
+
+ guider = AdaptiveProjectedGuidance()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": feature_extractor,
+ "guider": guider,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (16, 16))
+
+ inputs = {
+ "image": image,
+ "prompt": "A test video",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 16, 16, 3))
+
+ @unittest.skip("MotifVideo I2V only supports a single conditioning image")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip("MotifVideo I2V only supports a single conditioning image")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip("MotifVideo I2V requires vision tower for image conditioning - cannot work without text_encoder")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("T5Gemma2Encoder's vision_tower doesn't support block-level or leaf-level offloading")
+ def test_pipeline_level_group_offloading_inference(self):
+ pass
+
+ @unittest.skip("T5Gemma2Encoder's vision_tower doesn't support block-level or leaf-level offloading")
+ def test_sequential_cpu_offload_forward_pass(self):
+ pass
+
+ @unittest.skip("T5Gemma2Encoder's vision_tower doesn't support block-level or leaf-level offloading")
+ def test_sequential_offload_forward_pass_twice(self):
+ pass