Skip to content

Commit a2ea45a

Browse files
rootonchairsayakpauldg845
authored
LTX2 distilled checkpoint support (#12934)
* add constants for distill sigmas values and allow ltx pipeline to pass in sigmas * add time conditioning conversion and token packing for latents * make style & quality * remove prenorm * add sigma param to ltx2 i2v * fix copies and add pack latents to i2v * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Infer latent dims if latents/audio_latents is supplied * add note for predefined sigmas * run make style and quality * revert distill timesteps & set original_state_dict_repo_idd to default None * add latent normalize * add create noised state, delete last sigmas * remove normalize step in latent upsample pipeline and move it to ltx2 pipeline * add create noise latent to i2v pipeline * fix copies * parse none value in weight conversion script * explicit shape handling * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * make style * add two stage inference tests * add ltx2 documentation * update i2v expected_audio_slice * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Apply suggestion from @dg845 Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update ltx2.md to remove one-stage example Removed one-stage generation example code and added comments for noise scale in two-stage generation. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Daniel Gu <dgu8957@gmail.com>
1 parent a58d0b9 commit a2ea45a

9 files changed

Lines changed: 508 additions & 74 deletions

File tree

docs/source/en/api/pipelines/ltx2.md

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https
2424

2525
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
2626

27+
## Two-stages Generation
28+
Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages:
29+
30+
- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning.
31+
- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness.
32+
33+
Sample usage of text-to-video two stages pipeline
34+
35+
```py
36+
import torch
37+
from diffusers import FlowMatchEulerDiscreteScheduler
38+
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
39+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
40+
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
41+
from diffusers.pipelines.ltx2.export_utils import encode_video
42+
43+
device = "cuda:0"
44+
width = 768
45+
height = 512
46+
47+
pipe = LTX2Pipeline.from_pretrained(
48+
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
49+
)
50+
pipe.enable_sequential_cpu_offload(device=device)
51+
52+
prompt = "A beautiful sunset over the ocean"
53+
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
54+
55+
# Stage 1 default (non-distilled) inference
56+
frame_rate = 24.0
57+
video_latent, audio_latent = pipe(
58+
prompt=prompt,
59+
negative_prompt=negative_prompt,
60+
width=width,
61+
height=height,
62+
num_frames=121,
63+
frame_rate=frame_rate,
64+
num_inference_steps=40,
65+
sigmas=None,
66+
guidance_scale=4.0,
67+
output_type="latent",
68+
return_dict=False,
69+
)
70+
71+
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
72+
"Lightricks/LTX-2",
73+
subfolder="latent_upsampler",
74+
torch_dtype=torch.bfloat16,
75+
)
76+
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
77+
upsample_pipe.enable_model_cpu_offload(device=device)
78+
upscaled_video_latent = upsample_pipe(
79+
latents=video_latent,
80+
output_type="latent",
81+
return_dict=False,
82+
)[0]
83+
84+
# Load Stage 2 distilled LoRA
85+
pipe.load_lora_weights(
86+
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
87+
)
88+
pipe.set_adapters("stage_2_distilled", 1.0)
89+
# VAE tiling is usually necessary to avoid OOM error when VAE decoding
90+
pipe.vae.enable_tiling()
91+
# Change scheduler to use Stage 2 distilled sigmas as is
92+
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
93+
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
94+
)
95+
pipe.scheduler = new_scheduler
96+
# Stage 2 inference with distilled LoRA and sigmas
97+
video, audio = pipe(
98+
latents=upscaled_video_latent,
99+
audio_latents=audio_latent,
100+
prompt=prompt,
101+
negative_prompt=negative_prompt,
102+
num_inference_steps=3,
103+
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218
104+
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
105+
guidance_scale=1.0,
106+
output_type="np",
107+
return_dict=False,
108+
)
109+
video = (video * 255).round().astype("uint8")
110+
video = torch.from_numpy(video)
111+
112+
encode_video(
113+
video[0],
114+
fps=frame_rate,
115+
audio=audio[0].float().cpu(),
116+
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
117+
output_path="ltx2_lora_distilled_sample.mp4",
118+
)
119+
```
120+
121+
## Distilled checkpoint generation
122+
Fastest two-stages generation pipeline using a distilled checkpoint.
123+
124+
```py
125+
import torch
126+
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
127+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
128+
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
129+
from diffusers.pipelines.ltx2.export_utils import encode_video
130+
131+
device = "cuda"
132+
width = 768
133+
height = 512
134+
random_seed = 42
135+
generator = torch.Generator(device).manual_seed(random_seed)
136+
model_path = "rootonchair/LTX-2-19b-distilled"
137+
138+
pipe = LTX2Pipeline.from_pretrained(
139+
model_path, torch_dtype=torch.bfloat16
140+
)
141+
pipe.enable_sequential_cpu_offload(device=device)
142+
143+
prompt = "A beautiful sunset over the ocean"
144+
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
145+
146+
frame_rate = 24.0
147+
video_latent, audio_latent = pipe(
148+
prompt=prompt,
149+
negative_prompt=negative_prompt,
150+
width=width,
151+
height=height,
152+
num_frames=121,
153+
frame_rate=frame_rate,
154+
num_inference_steps=8,
155+
sigmas=DISTILLED_SIGMA_VALUES,
156+
guidance_scale=1.0,
157+
generator=generator,
158+
output_type="latent",
159+
return_dict=False,
160+
)
161+
162+
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
163+
model_path,
164+
subfolder="latent_upsampler",
165+
torch_dtype=torch.bfloat16,
166+
)
167+
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
168+
upsample_pipe.enable_model_cpu_offload(device=device)
169+
upscaled_video_latent = upsample_pipe(
170+
latents=video_latent,
171+
output_type="latent",
172+
return_dict=False,
173+
)[0]
174+
175+
video, audio = pipe(
176+
latents=upscaled_video_latent,
177+
audio_latents=audio_latent,
178+
prompt=prompt,
179+
negative_prompt=negative_prompt,
180+
num_inference_steps=3,
181+
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178
182+
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
183+
generator=generator,
184+
guidance_scale=1.0,
185+
output_type="np",
186+
return_dict=False,
187+
)
188+
video = (video * 255).round().astype("uint8")
189+
video = torch.from_numpy(video)
190+
191+
encode_video(
192+
video[0],
193+
fps=frame_rate,
194+
audio=audio[0].float().cpu(),
195+
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
196+
output_path="ltx2_distilled_sample.mp4",
197+
)
198+
```
199+
27200
## LTX2Pipeline
28201

29202
[[autodoc]] LTX2Pipeline

scripts/convert_ltx2_to_diffusers.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
"up_blocks.4": "up_blocks.1",
6464
"up_blocks.5": "up_blocks.2.upsamplers.0",
6565
"up_blocks.6": "up_blocks.2",
66+
"last_time_embedder": "time_embedder",
67+
"last_scale_shift_table": "scale_shift_table",
6668
# Common
6769
# For all 3D ResNets
6870
"res_blocks": "resnets",
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
372374
return connectors
373375

374376

375-
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
377+
def get_ltx2_video_vae_config(
378+
version: str, timestep_conditioning: bool = False
379+
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
376380
if version == "test":
377381
config = {
378382
"model_id": "diffusers-internal-dev/dummy-ltx2",
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
396400
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
397401
"upsample_residual": (True, True, True),
398402
"upsample_factor": (2, 2, 2),
399-
"timestep_conditioning": False,
403+
"timestep_conditioning": timestep_conditioning,
400404
"patch_size": 4,
401405
"patch_size_t": 1,
402406
"resnet_norm_eps": 1e-6,
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
433437
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
434438
"upsample_residual": (True, True, True),
435439
"upsample_factor": (2, 2, 2),
436-
"timestep_conditioning": False,
440+
"timestep_conditioning": timestep_conditioning,
437441
"patch_size": 4,
438442
"patch_size_t": 1,
439443
"resnet_norm_eps": 1e-6,
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
450454
return config, rename_dict, special_keys_remap
451455

452456

453-
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
454-
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
457+
def convert_ltx2_video_vae(
458+
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
459+
) -> Dict[str, Any]:
460+
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
455461
diffusers_config = config["diffusers_config"]
456462

457463
with init_empty_weights():
@@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
659665
def get_args():
660666
parser = argparse.ArgumentParser()
661667

668+
def none_or_str(value: str):
669+
if isinstance(value, str) and value.lower() == "none":
670+
return None
671+
return value
672+
662673
parser.add_argument(
663674
"--original_state_dict_repo_id",
664675
default="Lightricks/LTX-2",
665-
type=str,
676+
type=none_or_str,
666677
help="HF Hub repo id with LTX 2.0 checkpoint",
667678
)
668679
parser.add_argument(
@@ -682,7 +693,7 @@ def get_args():
682693
parser.add_argument(
683694
"--combined_filename",
684695
default="ltx-2-19b-dev.safetensors",
685-
type=str,
696+
type=none_or_str,
686697
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
687698
)
688699
parser.add_argument("--vae_prefix", default="vae.", type=str)
@@ -701,22 +712,25 @@ def get_args():
701712
parser.add_argument(
702713
"--text_encoder_model_id",
703714
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
704-
type=str,
715+
type=none_or_str,
705716
help="HF Hub id for the LTX 2.0 base text encoder model",
706717
)
707718
parser.add_argument(
708719
"--tokenizer_id",
709720
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
710-
type=str,
721+
type=none_or_str,
711722
help="HF Hub id for the LTX 2.0 text tokenizer",
712723
)
713724
parser.add_argument(
714725
"--latent_upsampler_filename",
715726
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
716-
type=str,
727+
type=none_or_str,
717728
help="Latent upsampler filename",
718729
)
719730

731+
parser.add_argument(
732+
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
733+
)
720734
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
721735
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
722736
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
@@ -786,7 +800,9 @@ def main(args):
786800
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
787801
elif combined_ckpt is not None:
788802
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
789-
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
803+
vae = convert_ltx2_video_vae(
804+
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
805+
)
790806
if not args.full_pipeline and not args.upsample_pipeline:
791807
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
792808

src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,8 +743,8 @@ def __init__(
743743

744744
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
745745
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
746-
latents_std = torch.zeros((base_channels,))
747-
latents_mean = torch.ones((base_channels,))
746+
latents_std = torch.ones((base_channels,))
747+
latents_mean = torch.zeros((base_channels,))
748748
self.register_buffer("latents_mean", latents_mean, persistent=True)
749749
self.register_buffer("latents_std", latents_std, persistent=True)
750750

0 commit comments

Comments
 (0)