From 0d264255f87a3c051abb8fa9cb6aa992cc59cc62 Mon Sep 17 00:00:00 2001 From: lhy <3445729633@qq.com> Date: Sat, 9 May 2026 08:52:50 +0800 Subject: [PATCH] Enhance LMSDiscreteScheduler to support 'sample' prediction type and improve documentation - Updated the `FlaxLMSDiscreteScheduler` and `LMSDiscreteScheduler` classes to include 'sample' as a valid option for `prediction_type`. - Improved docstrings for clarity, specifying the expected types and values for various parameters. - Added a new test case to validate the full loop functionality with the 'sample' prediction type. This change ensures better flexibility in prediction methods and enhances code documentation for future reference. --- .../schedulers/scheduling_lms_discrete.py | 2 +- .../scheduling_lms_discrete_flax.py | 37 +++++++++++++------ tests/schedulers/test_scheduler_lms.py | 26 ++++++++++++- tests/schedulers/test_scheduler_lms_flax.py | 36 ++++++++++++++++++ 4 files changed, 87 insertions(+), 14 deletions(-) create mode 100644 tests/schedulers/test_scheduler_lms_flax.py diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index e917adda9516..b53151c2c287 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -593,7 +593,7 @@ def step( pred_original_sample = model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) # 2. Convert to an ODE derivative diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 65902678e1d9..8c34be59e47c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -36,7 +36,7 @@ class LMSDiscreteSchedulerState: common: CommonSchedulerState - # setable values + # settable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray sigmas: jnp.ndarray @@ -82,8 +82,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2` (same options as [`CommonSchedulerState`]). trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): @@ -117,6 +117,10 @@ def __init__( "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) + if prediction_type not in ("epsilon", "sample", "v_prediction"): + raise ValueError( + f"`prediction_type` must be one of `epsilon`, `sample`, `v_prediction`, got {prediction_type!r}" + ) self.dtype = dtype def create_state(self, common: CommonSchedulerState | None = None) -> LMSDiscreteSchedulerState: @@ -158,14 +162,16 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): + def get_lms_coefficient( + self, state: LMSDiscreteSchedulerState, order: int, t: int, current_order: int + ) -> float: """ Compute a linear multistep coefficient. Args: - order (TODO): - t (TODO): - current_order (TODO): + order (`int`): Multistep order (number of derivative history terms used). + t (`int`): Current step index along the inference sigma schedule (not the training timestep id). + current_order (`int`): Index of the Lagrange basis term; must satisfy `0 <= current_order < order`. """ def lms_derivative(tau): @@ -240,7 +246,8 @@ def step( Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. + timestep (`int` or scalar array): value taken from `state.timesteps` at the current inference step (not the + inference step index; same convention as [`FlaxEulerDiscreteScheduler.step`]). sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. @@ -256,7 +263,10 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - sigma = state.sigmas[timestep] + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + sigma = state.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": @@ -264,9 +274,11 @@ def step( elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) # 2. Convert to an ODE derivative @@ -276,8 +288,9 @@ def step( state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + idx = int(step_index) + order = min(idx + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, order, idx, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index c4abca3ac973..dd94fbd92550 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -34,7 +34,7 @@ def test_schedules(self): self.check_over_configs(beta_schedule=schedule) def test_prediction_type(self): - for prediction_type in ["epsilon", "v_prediction"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) def test_time_indices(self): @@ -89,6 +89,30 @@ def test_full_loop_with_v_prediction(self): assert abs(result_sum.item() - 0.0017) < 1e-2 assert abs(result_mean.item() - 2.2676e-06) < 1e-3 + def test_full_loop_with_sample_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="sample") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 6.358e-06) < 1e-8 + assert abs(result_mean.item() - 8.28e-09) < 1e-11 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_scheduler_lms_flax.py b/tests/schedulers/test_scheduler_lms_flax.py new file mode 100644 index 000000000000..83bbf28d5312 --- /dev/null +++ b/tests/schedulers/test_scheduler_lms_flax.py @@ -0,0 +1,36 @@ +import unittest + +from ..testing_utils import require_flax + + +@require_flax +class FlaxLMSDiscreteSchedulerTest(unittest.TestCase): + def test_step_uses_timestep_identity_like_euler_flax(self): + import jax.numpy as jnp + + from diffusers import FlaxLMSDiscreteScheduler + + # `state.sigmas` is indexed by inference step (len = num_inference_steps + 1), while pipeline code passes + # values from `state.timesteps` (training timestep ids). Step must resolve the step index like + # `FlaxEulerDiscreteScheduler`. + scheduler = FlaxLMSDiscreteScheduler( + num_train_timesteps=1100, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + ) + state = scheduler.create_state() + shape = (2, 4, 8, 8) + state = scheduler.set_timesteps(state, num_inference_steps=10, shape=shape) + t = state.timesteps[3] + sample = jnp.ones(shape, dtype=jnp.float32) + model_output = jnp.zeros_like(sample) + scaled = scheduler.scale_model_input(state, sample, t) + out = scheduler.step(state, model_output, t, scaled) + self.assertEqual(tuple(out.prev_sample.shape), shape) + + def test_invalid_prediction_type_in_init(self): + from diffusers import FlaxLMSDiscreteScheduler + + with self.assertRaises(ValueError): + FlaxLMSDiscreteScheduler(prediction_type="not_a_valid_prediction_type")