Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def step(
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)

# 2. Convert to an ODE derivative
Expand Down
37 changes: 25 additions & 12 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class LMSDiscreteSchedulerState:
common: CommonSchedulerState

# setable values
# settable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
sigmas: jnp.ndarray
Expand Down Expand Up @@ -82,8 +82,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2` (same options as [`CommonSchedulerState`]).
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
Expand Down Expand Up @@ -117,6 +117,10 @@ def __init__(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
if prediction_type not in ("epsilon", "sample", "v_prediction"):
raise ValueError(
f"`prediction_type` must be one of `epsilon`, `sample`, `v_prediction`, got {prediction_type!r}"
)
self.dtype = dtype

def create_state(self, common: CommonSchedulerState | None = None) -> LMSDiscreteSchedulerState:
Expand Down Expand Up @@ -158,14 +162,16 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample

def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
def get_lms_coefficient(
self, state: LMSDiscreteSchedulerState, order: int, t: int, current_order: int
) -> float:
"""
Compute a linear multistep coefficient.

Args:
order (TODO):
t (TODO):
current_order (TODO):
order (`int`): Multistep order (number of derivative history terms used).
t (`int`): Current step index along the inference sigma schedule (not the training timestep id).
current_order (`int`): Index of the Lagrange basis term; must satisfy `0 <= current_order < order`.
"""

def lms_derivative(tau):
Expand Down Expand Up @@ -240,7 +246,8 @@ def step(
Args:
state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
timestep (`int` or scalar array): value taken from `state.timesteps` at the current inference step (not the
inference step index; same convention as [`FlaxEulerDiscreteScheduler.step`]).
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
Expand All @@ -256,17 +263,22 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

sigma = state.sigmas[timestep]
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]

sigma = state.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)

# 2. Convert to an ODE derivative
Expand All @@ -276,8 +288,9 @@ def step(
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
idx = int(step_index)
order = min(idx + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, idx, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
Expand Down
26 changes: 25 additions & 1 deletion tests/schedulers/test_scheduler_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_schedules(self):
self.check_over_configs(beta_schedule=schedule)

def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)

def test_time_indices(self):
Expand Down Expand Up @@ -89,6 +89,30 @@ def test_full_loop_with_v_prediction(self):
assert abs(result_sum.item() - 0.0017) < 1e-2
assert abs(result_mean.item() - 2.2676e-06) < 1e-3

def test_full_loop_with_sample_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="sample")
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma

for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)

model_output = model(sample, t)

output = scheduler.step(model_output, t, sample)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 6.358e-06) < 1e-8
assert abs(result_mean.item() - 8.28e-09) < 1e-11

def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down
36 changes: 36 additions & 0 deletions tests/schedulers/test_scheduler_lms_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from ..testing_utils import require_flax


@require_flax
class FlaxLMSDiscreteSchedulerTest(unittest.TestCase):
def test_step_uses_timestep_identity_like_euler_flax(self):
import jax.numpy as jnp

from diffusers import FlaxLMSDiscreteScheduler

# `state.sigmas` is indexed by inference step (len = num_inference_steps + 1), while pipeline code passes
# values from `state.timesteps` (training timestep ids). Step must resolve the step index like
# `FlaxEulerDiscreteScheduler`.
scheduler = FlaxLMSDiscreteScheduler(
num_train_timesteps=1100,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
)
state = scheduler.create_state()
shape = (2, 4, 8, 8)
state = scheduler.set_timesteps(state, num_inference_steps=10, shape=shape)
t = state.timesteps[3]
sample = jnp.ones(shape, dtype=jnp.float32)
model_output = jnp.zeros_like(sample)
scaled = scheduler.scale_model_input(state, sample, t)
out = scheduler.step(state, model_output, t, scaled)
self.assertEqual(tuple(out.prev_sample.shape), shape)

def test_invalid_prediction_type_in_init(self):
from diffusers import FlaxLMSDiscreteScheduler

with self.assertRaises(ValueError):
FlaxLMSDiscreteScheduler(prediction_type="not_a_valid_prediction_type")
Loading