@@ -216,6 +216,7 @@ def __init__(
216216 rescale_betas_zero_snr : bool = False ,
217217 use_dynamic_shifting : bool = False ,
218218 time_shift_type : str = "exponential" ,
219+ shift_terminal : Optional [float ] = None ,
219220 ):
220221 if self .config .use_beta_sigmas and not is_scipy_available ():
221222 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -235,6 +236,8 @@ def __init__(
235236 self .betas = betas_for_alpha_bar (num_train_timesteps )
236237 else :
237238 raise NotImplementedError (f"{ beta_schedule } is not implemented for { self .__class__ } " )
239+ if shift_terminal is not None and not use_flow_sigmas :
240+ raise ValueError ("`shift_terminal` is only supported when `use_flow_sigmas=True`." )
238241
239242 if rescale_betas_zero_snr :
240243 self .betas = rescale_zero_terminal_snr (self .betas )
@@ -303,7 +306,12 @@ def set_begin_index(self, begin_index: int = 0):
303306 self ._begin_index = begin_index
304307
305308 def set_timesteps (
306- self , num_inference_steps : int , device : Union [str , torch .device ] = None , mu : Optional [float ] = None
309+ self ,
310+ num_inference_steps : Optional [int ] = None ,
311+ device : Union [str , torch .device ] = None ,
312+ mu : Optional [float ] = None ,
313+ sigmas : Optional [List [float ]] = None ,
314+ timesteps : Optional [List [float ]] = None ,
307315 ):
308316 """
309317 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -314,10 +322,23 @@ def set_timesteps(
314322 device (`str` or `torch.device`, *optional*):
315323 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
316324 """
325+ if self .config .use_dynamic_shifting and mu is None :
326+ raise ValueError ("`mu` must be passed when `use_dynamic_shifting` is set to be `True`" )
327+
328+ if sigmas is not None or timesteps is not None :
329+ if not self .config .use_flow_sigmas :
330+ raise ValueError (
331+ "Passing `sigmas` or `timesteps` is only supported when `use_flow_sigmas=True`. "
332+ "Please set `use_flow_sigmas=True` during scheduler initialization."
333+ )
334+ num_inference_steps = len (sigmas ) if sigmas is not None else len (timesteps )
335+ if sigmas is not None and timesteps is not None :
336+ if len (sigmas ) != len (timesteps ):
337+ raise ValueError ("`sigmas` and `timesteps` should have the same length" )
338+
339+ is_timesteps_provided = timesteps is not None
340+
317341 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
318- if mu is not None :
319- assert self .config .use_dynamic_shifting and self .config .time_shift_type == "exponential"
320- self .config .flow_shift = np .exp (mu )
321342 if self .config .timestep_spacing == "linspace" :
322343 timesteps = (
323344 np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps + 1 )
@@ -342,7 +363,8 @@ def set_timesteps(
342363 f"{ self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
343364 )
344365
345- sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
366+ if sigmas is None :
367+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
346368 if self .config .use_karras_sigmas :
347369 log_sigmas = np .log (sigmas )
348370 sigmas = np .flip (sigmas ).copy ()
@@ -386,10 +408,21 @@ def set_timesteps(
386408 )
387409 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
388410 elif self .config .use_flow_sigmas :
389- alphas = np .linspace (1 , 1 / self .config .num_train_timesteps , num_inference_steps + 1 )
390- sigmas = 1.0 - alphas
391- sigmas = np .flip (self .config .flow_shift * sigmas / (1 + (self .config .flow_shift - 1 ) * sigmas ))[:- 1 ].copy ()
392- timesteps = (sigmas * self .config .num_train_timesteps ).copy ()
411+ if sigmas is None :
412+ sigmas = np .linspace (1 , 1 / self .config .num_train_timesteps , num_inference_steps + 1 )[:- 1 ]
413+ if self .config .use_dynamic_shifting :
414+ sigmas = self .time_shift (mu , 1.0 , sigmas )
415+ else :
416+ sigmas = self .config .flow_shift * sigmas / (1 + (self .config .flow_shift - 1 ) * sigmas )
417+ if self .config .shift_terminal :
418+ sigmas = self .stretch_shift_to_terminal (sigmas )
419+ eps = 1e-6
420+ if np .fabs (sigmas [0 ] - 1 ) < eps :
421+ sigmas [0 ] -= (
422+ eps # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
423+ )
424+ if not is_timesteps_provided :
425+ timesteps = (sigmas * self .config .num_train_timesteps ).copy ()
393426 if self .config .final_sigmas_type == "sigma_min" :
394427 sigma_last = sigmas [- 1 ]
395428 elif self .config .final_sigmas_type == "zero" :
@@ -429,6 +462,43 @@ def set_timesteps(
429462 self ._begin_index = None
430463 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
431464
465+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
466+ def time_shift (self , mu : float , sigma : float , t : torch .Tensor ):
467+ if self .config .time_shift_type == "exponential" :
468+ return self ._time_shift_exponential (mu , sigma , t )
469+ elif self .config .time_shift_type == "linear" :
470+ return self ._time_shift_linear (mu , sigma , t )
471+
472+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
473+ def stretch_shift_to_terminal (self , t : torch .Tensor ) -> torch .Tensor :
474+ r"""
475+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
476+ value.
477+
478+ Reference:
479+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
480+
481+ Args:
482+ t (`torch.Tensor`):
483+ A tensor of timesteps to be stretched and shifted.
484+
485+ Returns:
486+ `torch.Tensor`:
487+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
488+ """
489+ one_minus_z = 1 - t
490+ scale_factor = one_minus_z [- 1 ] / (1 - self .config .shift_terminal )
491+ stretched_t = 1 - (one_minus_z / scale_factor )
492+ return stretched_t
493+
494+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
495+ def _time_shift_exponential (self , mu , sigma , t ):
496+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
497+
498+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
499+ def _time_shift_linear (self , mu , sigma , t ):
500+ return mu / (mu + (1 / t - 1 ) ** sigma )
501+
432502 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
433503 def _threshold_sample (self , sample : torch .Tensor ) -> torch .Tensor :
434504 """
0 commit comments