|
| 1 | +""" |
| 2 | +Flax/JAX implementation of the LTX-2 Latent Upsampler. |
| 3 | +""" |
| 4 | + |
| 5 | +import os |
| 6 | +import json |
| 7 | +import math |
| 8 | +from typing import Optional, Tuple |
| 9 | + |
| 10 | +import jax |
| 11 | +import jax.numpy as jnp |
| 12 | +import flax.linen as nn |
| 13 | + |
| 14 | +from huggingface_hub import hf_hub_download |
| 15 | +from huggingface_hub.utils import EntryNotFoundError, HfHubHTTPError |
| 16 | + |
| 17 | +RATIONAL_RESAMPLER_SCALE_MAPPING = { |
| 18 | + 0.75: (3, 4), |
| 19 | + 1.5: (3, 2), |
| 20 | + 2.0: (2, 1), |
| 21 | + 4.0: (4, 1), |
| 22 | +} |
| 23 | + |
| 24 | + |
| 25 | +class ResBlock(nn.Module): |
| 26 | + channels: int |
| 27 | + mid_channels: Optional[int] = None |
| 28 | + dims: int = 3 |
| 29 | + |
| 30 | + @nn.compact |
| 31 | + def __call__(self, hidden_states: jax.Array) -> jax.Array: |
| 32 | + mid_channels = self.mid_channels if self.mid_channels is not None else self.channels |
| 33 | + |
| 34 | + kernel_size = (3,) * self.dims |
| 35 | + padding = ((1, 1),) * self.dims |
| 36 | + |
| 37 | + residual = hidden_states |
| 38 | + |
| 39 | + hidden_states = nn.Conv(mid_channels, kernel_size=kernel_size, padding=padding, name="conv1")(hidden_states) |
| 40 | + hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="norm1")(hidden_states) |
| 41 | + hidden_states = nn.silu(hidden_states) |
| 42 | + |
| 43 | + hidden_states = nn.Conv(self.channels, kernel_size=kernel_size, padding=padding, name="conv2")(hidden_states) |
| 44 | + hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="norm2")(hidden_states) |
| 45 | + |
| 46 | + hidden_states = nn.silu(hidden_states + residual) |
| 47 | + |
| 48 | + return hidden_states |
| 49 | + |
| 50 | + |
| 51 | +class PixelShuffleND(nn.Module): |
| 52 | + dims: int |
| 53 | + upscale_factors: Tuple[int, ...] = (2, 2, 2) |
| 54 | + |
| 55 | + @nn.compact |
| 56 | + def __call__(self, x: jax.Array) -> jax.Array: |
| 57 | + if self.dims == 3: |
| 58 | + p1, p2, p3 = self.upscale_factors[:3] |
| 59 | + b, d, h, w, c_p = x.shape |
| 60 | + c = c_p // (p1 * p2 * p3) |
| 61 | + x = jnp.reshape(x, (b, d, h, w, c, p1, p2, p3)) |
| 62 | + x = jnp.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4)) |
| 63 | + x = jnp.reshape(x, (b, d * p1, h * p2, w * p3, c)) |
| 64 | + return x |
| 65 | + elif self.dims == 2: |
| 66 | + p1, p2 = self.upscale_factors[:2] |
| 67 | + b, h, w, c_p = x.shape |
| 68 | + c = c_p // (p1 * p2) |
| 69 | + x = jnp.reshape(x, (b, h, w, c, p1, p2)) |
| 70 | + x = jnp.transpose(x, (0, 1, 4, 2, 5, 3)) |
| 71 | + x = jnp.reshape(x, (b, h * p1, w * p2, c)) |
| 72 | + return x |
| 73 | + elif self.dims == 1: |
| 74 | + p1 = self.upscale_factors[0] |
| 75 | + b, f, h, w, c_p = x.shape |
| 76 | + c = c_p // p1 |
| 77 | + x = jnp.reshape(x, (b, f, h, w, c, p1)) |
| 78 | + x = jnp.transpose(x, (0, 1, 5, 2, 3, 4)) |
| 79 | + x = jnp.reshape(x, (b, f * p1, h, w, c)) |
| 80 | + return x |
| 81 | + |
| 82 | + |
| 83 | +class BlurDownsample(nn.Module): |
| 84 | + dims: int |
| 85 | + stride: int |
| 86 | + kernel_size: int = 5 |
| 87 | + |
| 88 | + def setup(self): |
| 89 | + if self.dims not in (2, 3): |
| 90 | + raise ValueError(f"`dims` must be either 2 or 3 but is {self.dims}") |
| 91 | + if self.kernel_size < 3 or self.kernel_size % 2 != 1: |
| 92 | + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {self.kernel_size}") |
| 93 | + |
| 94 | + k = jnp.array([math.comb(self.kernel_size - 1, i) for i in range(self.kernel_size)], dtype=jnp.float32) |
| 95 | + k2d = jnp.outer(k, k) |
| 96 | + k2d = k2d / jnp.sum(k2d) |
| 97 | + self.kernel = jnp.reshape(k2d, (self.kernel_size, self.kernel_size, 1, 1)) |
| 98 | + |
| 99 | + def __call__(self, x: jax.Array) -> jax.Array: |
| 100 | + if self.stride == 1: |
| 101 | + return x |
| 102 | + |
| 103 | + pad = self.kernel_size // 2 |
| 104 | + |
| 105 | + c = x.shape[-1] |
| 106 | + # Tile the single-channel kernel to match the required output channels |
| 107 | + kernel_broadcast = jnp.tile(self.kernel, (1, 1, 1, c)) |
| 108 | + |
| 109 | + if self.dims == 2: |
| 110 | + x = jax.lax.conv_general_dilated( |
| 111 | + lhs=x, |
| 112 | + rhs=kernel_broadcast, |
| 113 | + window_strides=(self.stride, self.stride), |
| 114 | + padding=((pad, pad), (pad, pad)), |
| 115 | + feature_group_count=c, |
| 116 | + dimension_numbers=("NHWC", "HWIO", "NHWC"), |
| 117 | + ) |
| 118 | + else: |
| 119 | + b, f, h, w, _ = x.shape |
| 120 | + x = jnp.reshape(x, (b * f, h, w, c)) |
| 121 | + |
| 122 | + # For depthwise convolution: kernel remains [H, W, 1, 1] |
| 123 | + x = jax.lax.conv_general_dilated( |
| 124 | + lhs=x, |
| 125 | + rhs=kernel_broadcast, |
| 126 | + window_strides=(self.stride, self.stride), |
| 127 | + padding=((pad, pad), (pad, pad)), |
| 128 | + feature_group_count=c, |
| 129 | + dimension_numbers=("NHWC", "HWIO", "NHWC"), |
| 130 | + ) |
| 131 | + |
| 132 | + h2, w2 = x.shape[1], x.shape[2] |
| 133 | + x = jnp.reshape(x, (b, f, h2, w2, c)) |
| 134 | + |
| 135 | + return x |
| 136 | + |
| 137 | + |
| 138 | +class SpatialRationalResampler(nn.Module): |
| 139 | + mid_channels: int = 1024 |
| 140 | + scale: float = 2.0 |
| 141 | + |
| 142 | + @nn.compact |
| 143 | + def __call__(self, x: jax.Array) -> jax.Array: |
| 144 | + if self.scale not in RATIONAL_RESAMPLER_SCALE_MAPPING: |
| 145 | + raise ValueError(f"scale {self.scale} not supported.") |
| 146 | + num, den = RATIONAL_RESAMPLER_SCALE_MAPPING[self.scale] |
| 147 | + |
| 148 | + x = nn.Conv((num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="conv")(x) |
| 149 | + x = PixelShuffleND(dims=2, upscale_factors=(num, num))(x) |
| 150 | + x = BlurDownsample(dims=2, stride=den)(x) |
| 151 | + return x |
| 152 | + |
| 153 | + |
| 154 | +class LTX2LatentUpsamplerModel(nn.Module): |
| 155 | + in_channels: int = 128 |
| 156 | + mid_channels: int = 1024 |
| 157 | + num_blocks_per_stage: int = 4 |
| 158 | + dims: int = 3 |
| 159 | + spatial_upsample: bool = True |
| 160 | + temporal_upsample: bool = False |
| 161 | + rational_spatial_scale: Optional[float] = 2.0 |
| 162 | + |
| 163 | + @classmethod |
| 164 | + def load_config(cls, pretrained_model_name_or_path: str, subfolder: str = "", **kwargs): |
| 165 | + """Dynamically loads config.json from a local path or the Hugging Face Hub.""" |
| 166 | + try: |
| 167 | + if os.path.isdir(pretrained_model_name_or_path): |
| 168 | + config_file = os.path.join(pretrained_model_name_or_path, subfolder, "config.json") |
| 169 | + else: |
| 170 | + config_file = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json", subfolder=subfolder) |
| 171 | + |
| 172 | + with open(config_file, "r") as f: |
| 173 | + config_dict = json.load(f) |
| 174 | + |
| 175 | + # Apply any runtime overrides passed in via kwargs |
| 176 | + config_dict.update(kwargs) |
| 177 | + return config_dict |
| 178 | + |
| 179 | + except (OSError, json.JSONDecodeError, EntryNotFoundError, HfHubHTTPError) as e: |
| 180 | + print(f"Warning: Could not load upsampler config.json (using defaults). Reason: {e}") |
| 181 | + return kwargs |
| 182 | + |
| 183 | + @nn.compact |
| 184 | + def __call__(self, hidden_states: jax.Array) -> jax.Array: |
| 185 | + b, f, h, w, c = hidden_states.shape |
| 186 | + |
| 187 | + if self.dims == 2: |
| 188 | + hidden_states = jnp.reshape(hidden_states, (b * f, h, w, c)) |
| 189 | + |
| 190 | + hidden_states = nn.Conv(self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="initial_conv")( |
| 191 | + hidden_states |
| 192 | + ) |
| 193 | + hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="initial_norm")(hidden_states) |
| 194 | + hidden_states = nn.silu(hidden_states) |
| 195 | + |
| 196 | + for i in range(self.num_blocks_per_stage): |
| 197 | + hidden_states = ResBlock(channels=self.mid_channels, dims=2, name=f"res_blocks_{i}")(hidden_states) |
| 198 | + |
| 199 | + if self.spatial_upsample: |
| 200 | + if self.rational_spatial_scale is not None: |
| 201 | + hidden_states = SpatialRationalResampler(self.mid_channels, self.rational_spatial_scale, name="upsampler")( |
| 202 | + hidden_states |
| 203 | + ) |
| 204 | + else: |
| 205 | + hidden_states = nn.Conv(4 * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="upsampler_0")( |
| 206 | + hidden_states |
| 207 | + ) |
| 208 | + hidden_states = PixelShuffleND(dims=2)(hidden_states) |
| 209 | + |
| 210 | + for i in range(self.num_blocks_per_stage): |
| 211 | + hidden_states = ResBlock(channels=self.mid_channels, dims=2, name=f"post_upsample_res_blocks_{i}")(hidden_states) |
| 212 | + |
| 213 | + hidden_states = nn.Conv(self.in_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="final_conv")( |
| 214 | + hidden_states |
| 215 | + ) |
| 216 | + |
| 217 | + h2, w2 = hidden_states.shape[1], hidden_states.shape[2] |
| 218 | + hidden_states = jnp.reshape(hidden_states, (b, f, h2, w2, self.in_channels)) |
| 219 | + |
| 220 | + else: |
| 221 | + hidden_states = nn.Conv( |
| 222 | + self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="initial_conv" |
| 223 | + )(hidden_states) |
| 224 | + hidden_states = nn.GroupNorm(epsilon=1e-5, num_groups=32, name="initial_norm")(hidden_states) |
| 225 | + hidden_states = nn.silu(hidden_states) |
| 226 | + |
| 227 | + for i in range(self.num_blocks_per_stage): |
| 228 | + hidden_states = ResBlock(channels=self.mid_channels, dims=3, name=f"res_blocks_{i}")(hidden_states) |
| 229 | + |
| 230 | + # FIX: Added Missing Joint Spatiotemporal logic! |
| 231 | + if self.spatial_upsample and self.temporal_upsample: |
| 232 | + hidden_states = nn.Conv( |
| 233 | + 8 * self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="upsampler_0" |
| 234 | + )(hidden_states) |
| 235 | + hidden_states = PixelShuffleND(dims=3)(hidden_states) |
| 236 | + hidden_states = hidden_states[:, 1:, :, :, :] |
| 237 | + elif self.temporal_upsample: |
| 238 | + hidden_states = nn.Conv( |
| 239 | + 2 * self.mid_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="upsampler_0" |
| 240 | + )(hidden_states) |
| 241 | + hidden_states = PixelShuffleND(dims=1)(hidden_states) |
| 242 | + hidden_states = hidden_states[:, 1:, :, :, :] |
| 243 | + elif self.spatial_upsample: |
| 244 | + hidden_states = jnp.reshape(hidden_states, (b * f, h, w, self.mid_channels)) |
| 245 | + if self.rational_spatial_scale is not None: |
| 246 | + hidden_states = SpatialRationalResampler(self.mid_channels, self.rational_spatial_scale, name="upsampler")( |
| 247 | + hidden_states |
| 248 | + ) |
| 249 | + else: |
| 250 | + hidden_states = nn.Conv(4 * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name="upsampler_0")( |
| 251 | + hidden_states |
| 252 | + ) |
| 253 | + hidden_states = PixelShuffleND(dims=2)(hidden_states) |
| 254 | + h2, w2 = hidden_states.shape[1], hidden_states.shape[2] |
| 255 | + hidden_states = jnp.reshape(hidden_states, (b, f, h2, w2, self.mid_channels)) |
| 256 | + |
| 257 | + for i in range(self.num_blocks_per_stage): |
| 258 | + hidden_states = ResBlock(channels=self.mid_channels, dims=3, name=f"post_upsample_res_blocks_{i}")(hidden_states) |
| 259 | + |
| 260 | + hidden_states = nn.Conv(self.in_channels, kernel_size=(3, 3, 3), padding=((1, 1), (1, 1), (1, 1)), name="final_conv")( |
| 261 | + hidden_states |
| 262 | + ) |
| 263 | + |
| 264 | + return hidden_states |
0 commit comments