Skip to content

Commit db18643

Browse files
committed
add Flax/JAX implementation of LTX-2 Latent Upsampler
1 parent 5e8334f commit db18643

File tree

10 files changed

+896
-22
lines changed

10 files changed

+896
-22
lines changed

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
7979
return restored_checkpoint, step
8080

8181
def load_checkpoint(
82-
self, step=None, vae_only=False, load_transformer=True
82+
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
8383
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
8484
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
8585
opt_state = None
8686

8787
if restored_checkpoint:
8888
max_logging.log("Loading LTX2 pipeline from checkpoint")
89-
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler)
9090
if "opt_state" in restored_checkpoint.ltx2_state.keys():
9191
opt_state = restored_checkpoint.ltx2_state["opt_state"]
9292
else:
9393
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94-
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer, load_upsampler)
9595

9696
return pipeline, opt_state, step
9797

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,13 @@ jit_initializers: True
9292
enable_single_replica_ckpt_restoring: False
9393
seed: 0
9494
audio_format: "s16"
95+
96+
# LTX-2 Latent Upsampler
97+
run_latent_upsampler: False
98+
upsampler_model_path: "Lightricks/LTX-2"
99+
upsampler_spatial_patch_size: 1
100+
upsampler_temporal_patch_size: 1
101+
upsampler_adain_factor: 0.0
102+
upsampler_tone_map_compression_ratio: 0.0
103+
upsampler_rational_spatial_scale: 2.0
104+
upsampler_output_type: "pil"

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def get_git_commit_hash():
8181

8282

8383
def call_pipeline(config, pipeline, prompt, negative_prompt):
84-
# Set default generation arguments
8584
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8685
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
8786

@@ -99,6 +98,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9998
decode_noise_scale=getattr(config, "decode_noise_scale", None),
10099
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101100
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
101+
output_type=getattr(config, "upsampler_output_type", "pil"),
102102
)
103103
return out
104104

@@ -114,9 +114,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
114114
else:
115115
max_logging.log("Could not retrieve Git commit hash.")
116116

117+
checkpoint_loader = LTX2Checkpointer(config=config)
117118
if pipeline is None:
118-
checkpoint_loader = LTX2Checkpointer(config=config)
119-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
119+
# Use the config flag to determine if the upsampler should be loaded
120+
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
121+
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
120122

121123
pipeline.enable_vae_slicing()
122124
pipeline.enable_vae_tiling()
@@ -135,6 +137,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
135137
)
136138

137139
out = call_pipeline(config, pipeline, prompt, negative_prompt)
140+
138141
# out should have .frames and .audio
139142
videos = out.frames if hasattr(out, "frames") else out[0]
140143
audios = out.audio if hasattr(out, "audio") else None
@@ -143,6 +146,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
143146
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
144147
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
145148
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
149+
if getattr(config, "run_latent_upsampler", False):
150+
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
146151
max_logging.log(f"hardware: {jax.devices()[0].platform}")
147152
max_logging.log(f"number of devices: {jax.device_count()}")
148153
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _unflatten_heads(tensor, heads):
128128
return tensor
129129

130130

131-
def _reshape_data_for_flash(tensor, heads, num_context_shards = 1):
131+
def _reshape_data_for_flash(tensor, heads, num_context_shards=1):
132132
"""
133133
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
134134
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""
2+
Flax/JAX implementation of the LTX-2 Latent Upsampler.
3+
"""
4+
5+
import os
6+
import json
7+
import math
8+
from typing import Optional, Tuple
9+
10+
import jax
11+
import jax.numpy as jnp
12+
import flax.linen as nn
13+
14+
from huggingface_hub import hf_hub_download
15+
from huggingface_hub.utils import EntryNotFoundError, HfHubHTTPError
16+
17+
RATIONAL_RESAMPLER_SCALE_MAPPING = {
18+
0.75: (3, 4),
19+
1.5: (3, 2),
20+
2.0: (2, 1),
21+
4.0: (4, 1),
22+
}
23+
24+
25+
class ResBlock(nn.Module):
26+
channels: int
27+
mid_channels: Optional[int] = None
28+
dims: int = 3
29+
30+
@nn.compact
31+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
32+
mid_channels = self.mid_channels if self.mid_channels is not None else self.channels
33+
34+
kernel_size = (3,) * self.dims
35+
padding = ((1, 1),) * self.dims
36+
37+
residual = hidden_states
38+
39+
hidden_states = nn.Conv(mid_channels, kernel_size=kernel_size, padding=padding, name="conv1")(hidden_states)
40+
hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="norm1")(hidden_states)
41+
hidden_states = nn.silu(hidden_states)
42+
43+
hidden_states = nn.Conv(self.channels, kernel_size=kernel_size, padding=padding, name="conv2")(hidden_states)
44+
hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="norm2")(hidden_states)
45+
46+
hidden_states = nn.silu(hidden_states + residual)
47+
48+
return hidden_states
49+
50+
51+
class PixelShuffleND(nn.Module):
52+
dims: int
53+
upscale_factors: Tuple[int, ...] = (2, 2, 2)
54+
55+
@nn.compact
56+
def __call__(self, x: jax.Array) -> jax.Array:
57+
if self.dims == 3:
58+
p1, p2, p3 = self.upscale_factors[:3]
59+
b, d, h, w, c_p = x.shape
60+
c = c_p // (p1 * p2 * p3)
61+
x = jnp.reshape(x, (b, d, h, w, c, p1, p2, p3))
62+
x = jnp.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
63+
x = jnp.reshape(x, (b, d * p1, h * p2, w * p3, c))
64+
return x
65+
elif self.dims == 2:
66+
p1, p2 = self.upscale_factors[:2]
67+
b, h, w, c_p = x.shape
68+
c = c_p // (p1 * p2)
69+
x = jnp.reshape(x, (b, h, w, c, p1, p2))
70+
x = jnp.transpose(x, (0, 1, 4, 2, 5, 3))
71+
x = jnp.reshape(x, (b, h * p1, w * p2, c))
72+
return x
73+
elif self.dims == 1:
74+
p1 = self.upscale_factors[0]
75+
b, f, h, w, c_p = x.shape
76+
c = c_p // p1
77+
x = jnp.reshape(x, (b, f, h, w, c, p1))
78+
x = jnp.transpose(x, (0, 1, 5, 2, 3, 4))
79+
x = jnp.reshape(x, (b, f * p1, h, w, c))
80+
return x
81+
82+
83+
class BlurDownsample(nn.Module):
84+
dims: int
85+
stride: int
86+
kernel_size: int = 5
87+
88+
def setup(self):
89+
if self.dims not in (2, 3):
90+
raise ValueError(f"`dims` must be either 2 or 3 but is {self.dims}")
91+
if self.kernel_size < 3 or self.kernel_size % 2 != 1:
92+
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {self.kernel_size}")
93+
94+
k = jnp.array([math.comb(self.kernel_size - 1, i) for i in range(self.kernel_size)], dtype=jnp.float32)
95+
k2d = jnp.outer(k, k)
96+
k2d = k2d / jnp.sum(k2d)
97+
self.kernel = jnp.reshape(k2d, (self.kernel_size, self.kernel_size, 1, 1))
98+
99+
def __call__(self, x: jax.Array) -> jax.Array:
100+
if self.stride == 1:
101+
return x
102+
103+
pad = self.kernel_size // 2
104+
105+
c = x.shape[-1]
106+
# Tile the single-channel kernel to match the required output channels
107+
kernel_broadcast = jnp.tile(self.kernel, (1, 1, 1, c))
108+
109+
if self.dims == 2:
110+
x = jax.lax.conv_general_dilated(
111+
lhs=x,
112+
rhs=kernel_broadcast,
113+
window_strides=(self.stride, self.stride),
114+
padding=((pad, pad), (pad, pad)),
115+
feature_group_count=c,
116+
dimension_numbers=("NHWC", "HWIO", "NHWC"),
117+
)
118+
else:
119+
b, f, h, w, _ = x.shape
120+
x = jnp.reshape(x, (b * f, h, w, c))
121+
122+
# For depthwise convolution: kernel remains [H, W, 1, 1]
123+
x = jax.lax.conv_general_dilated(
124+
lhs=x,
125+
rhs=kernel_broadcast,
126+
window_strides=(self.stride, self.stride),
127+
padding=((pad, pad), (pad, pad)),
128+
feature_group_count=c,
129+
dimension_numbers=("NHWC", "HWIO", "NHWC"),
130+
)
131+
132+
h2, w2 = x.shape[1], x.shape[2]
133+
x = jnp.reshape(x, (b, f, h2, w2, c))
134+
135+
return x
136+
137+
138+
class SpatialRationalResampler(nn.Module):
139+
mid_channels: int = 1024
140+
scale: float = 2.0
141+
142+
@nn.compact
143+
def __call__(self, x: jax.Array) -> jax.Array:
144+
if self.scale not in RATIONAL_RESAMPLER_SCALE_MAPPING:
145+
raise ValueError(f"scale {self.scale} not supported.")
146+
num, den = RATIONAL_RESAMPLER_SCALE_MAPPING[self.scale]
147+
148+
x = nn.Conv((num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="conv")(x)
149+
x = PixelShuffleND(dims=2, upscale_factors=(num, num))(x)
150+
x = BlurDownsample(dims=2, stride=den)(x)
151+
return x
152+
153+
154+
class LTX2LatentUpsamplerModel(nn.Module):
155+
in_channels: int = 128
156+
mid_channels: int = 1024
157+
num_blocks_per_stage: int = 4
158+
dims: int = 3
159+
spatial_upsample: bool = True
160+
temporal_upsample: bool = False
161+
rational_spatial_scale: Optional[float] = 2.0
162+
163+
@classmethod
164+
def load_config(cls, pretrained_model_name_or_path: str, subfolder: str = "", **kwargs):
165+
"""Dynamically loads config.json from a local path or the Hugging Face Hub."""
166+
try:
167+
if os.path.isdir(pretrained_model_name_or_path):
168+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, "config.json")
169+
else:
170+
config_file = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json", subfolder=subfolder)
171+
172+
with open(config_file, "r") as f:
173+
config_dict = json.load(f)
174+
175+
# Apply any runtime overrides passed in via kwargs
176+
config_dict.update(kwargs)
177+
return config_dict
178+
179+
except (OSError, json.JSONDecodeError, EntryNotFoundError, HfHubHTTPError) as e:
180+
print(f"Warning: Could not load upsampler config.json (using defaults). Reason: {e}")
181+
return kwargs
182+
183+
@nn.compact
184+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
185+
b, f, h, w, c = hidden_states.shape
186+
187+
if self.dims == 2:
188+
hidden_states = jnp.reshape(hidden_states, (b * f, h, w, c))
189+
190+
hidden_states = nn.Conv(self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="initial_conv")(
191+
hidden_states
192+
)
193+
hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="initial_norm")(hidden_states)
194+
hidden_states = nn.silu(hidden_states)
195+
196+
for i in range(self.num_blocks_per_stage):
197+
hidden_states = ResBlock(channels=self.mid_channels, dims=2, name=f"res_blocks_{i}")(hidden_states)
198+
199+
if self.spatial_upsample:
200+
if self.rational_spatial_scale is not None:
201+
hidden_states = SpatialRationalResampler(self.mid_channels, self.rational_spatial_scale, name="upsampler")(
202+
hidden_states
203+
)
204+
else:
205+
hidden_states = nn.Conv(4 * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="upsampler_0")(
206+
hidden_states
207+
)
208+
hidden_states = PixelShuffleND(dims=2)(hidden_states)
209+
210+
for i in range(self.num_blocks_per_stage):
211+
hidden_states = ResBlock(channels=self.mid_channels, dims=2, name=f"post_upsample_res_blocks_{i}")(hidden_states)
212+
213+
hidden_states = nn.Conv(self.in_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="final_conv")(
214+
hidden_states
215+
)
216+
217+
h2, w2 = hidden_states.shape[1], hidden_states.shape[2]
218+
hidden_states = jnp.reshape(hidden_states, (b, f, h2, w2, self.in_channels))
219+
220+
else:
221+
hidden_states = nn.Conv(
222+
self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="initial_conv"
223+
)(hidden_states)
224+
hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="initial_norm")(hidden_states)
225+
hidden_states = nn.silu(hidden_states)
226+
227+
for i in range(self.num_blocks_per_stage):
228+
hidden_states = ResBlock(channels=self.mid_channels, dims=3, name=f"res_blocks_{i}")(hidden_states)
229+
230+
# FIX: Added Missing Joint Spatiotemporal logic!
231+
if self.spatial_upsample and self.temporal_upsample:
232+
hidden_states = nn.Conv(
233+
8 * self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="upsampler_0"
234+
)(hidden_states)
235+
hidden_states = PixelShuffleND(dims=3)(hidden_states)
236+
hidden_states = hidden_states[:, 1:, :, :, :]
237+
elif self.temporal_upsample:
238+
hidden_states = nn.Conv(
239+
2 * self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="upsampler_0"
240+
)(hidden_states)
241+
hidden_states = PixelShuffleND(dims=1)(hidden_states)
242+
hidden_states = hidden_states[:, 1:, :, :, :]
243+
elif self.spatial_upsample:
244+
hidden_states = jnp.reshape(hidden_states, (b * f, h, w, self.mid_channels))
245+
if self.rational_spatial_scale is not None:
246+
hidden_states = SpatialRationalResampler(self.mid_channels, self.rational_spatial_scale, name="upsampler")(
247+
hidden_states
248+
)
249+
else:
250+
hidden_states = nn.Conv(4 * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="upsampler_0")(
251+
hidden_states
252+
)
253+
hidden_states = PixelShuffleND(dims=2)(hidden_states)
254+
h2, w2 = hidden_states.shape[1], hidden_states.shape[2]
255+
hidden_states = jnp.reshape(hidden_states, (b, f, h2, w2, self.mid_channels))
256+
257+
for i in range(self.num_blocks_per_stage):
258+
hidden_states = ResBlock(channels=self.mid_channels, dims=3, name=f"post_upsample_res_blocks_{i}")(hidden_states)
259+
260+
hidden_states = nn.Conv(self.in_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="final_conv")(
261+
hidden_states
262+
)
263+
264+
return hidden_states

0 commit comments

Comments
 (0)