|
22 | 22 | import jax.numpy as jnp |
23 | 23 |
|
24 | 24 | from ..configuration_utils import ConfigMixin, register_to_config |
| 25 | +from ..utils import logging |
25 | 26 | from .scheduling_utils_flax import ( |
26 | 27 | CommonSchedulerState, |
27 | 28 | FlaxKarrasDiffusionSchedulers, |
|
31 | 32 | ) |
32 | 33 |
|
33 | 34 |
|
| 35 | +logger = logging.get_logger(__name__) |
| 36 | + |
| 37 | + |
34 | 38 | @flax.struct.dataclass |
35 | 39 | class DPMSolverMultistepSchedulerState: |
36 | 40 | common: CommonSchedulerState |
@@ -171,6 +175,10 @@ def __init__( |
171 | 175 | timestep_spacing: str = "linspace", |
172 | 176 | dtype: jnp.dtype = jnp.float32, |
173 | 177 | ): |
| 178 | + logger.warning( |
| 179 | + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " |
| 180 | + "recommend migrating to PyTorch classes or pinning your version of Diffusers." |
| 181 | + ) |
174 | 182 | self.dtype = dtype |
175 | 183 |
|
176 | 184 | def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: |
@@ -203,7 +211,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolv |
203 | 211 | ) |
204 | 212 |
|
205 | 213 | def set_timesteps( |
206 | | - self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple |
| 214 | + self, |
| 215 | + state: DPMSolverMultistepSchedulerState, |
| 216 | + num_inference_steps: int, |
| 217 | + shape: Tuple, |
207 | 218 | ) -> DPMSolverMultistepSchedulerState: |
208 | 219 | """ |
209 | 220 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
@@ -301,10 +312,13 @@ def convert_model_output( |
301 | 312 | if self.config.thresholding: |
302 | 313 | # Dynamic thresholding in https://huggingface.co/papers/2205.11487 |
303 | 314 | dynamic_max_val = jnp.percentile( |
304 | | - jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) |
| 315 | + jnp.abs(x0_pred), |
| 316 | + self.config.dynamic_thresholding_ratio, |
| 317 | + axis=tuple(range(1, x0_pred.ndim)), |
305 | 318 | ) |
306 | 319 | dynamic_max_val = jnp.maximum( |
307 | | - dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) |
| 320 | + dynamic_max_val, |
| 321 | + self.config.sample_max_value * jnp.ones_like(dynamic_max_val), |
308 | 322 | ) |
309 | 323 | x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val |
310 | 324 | return x0_pred |
@@ -385,7 +399,11 @@ def multistep_dpm_solver_second_order_update( |
385 | 399 | """ |
386 | 400 | t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] |
387 | 401 | m0, m1 = model_output_list[-1], model_output_list[-2] |
388 | | - lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] |
| 402 | + lambda_t, lambda_s0, lambda_s1 = ( |
| 403 | + state.lambda_t[t], |
| 404 | + state.lambda_t[s0], |
| 405 | + state.lambda_t[s1], |
| 406 | + ) |
389 | 407 | alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] |
390 | 408 | sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] |
391 | 409 | h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 |
@@ -443,7 +461,12 @@ def multistep_dpm_solver_third_order_update( |
443 | 461 | Returns: |
444 | 462 | `jnp.ndarray`: the sample tensor at the previous timestep. |
445 | 463 | """ |
446 | | - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] |
| 464 | + t, s0, s1, s2 = ( |
| 465 | + prev_timestep, |
| 466 | + timestep_list[-1], |
| 467 | + timestep_list[-2], |
| 468 | + timestep_list[-3], |
| 469 | + ) |
447 | 470 | m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] |
448 | 471 | lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( |
449 | 472 | state.lambda_t[t], |
@@ -615,7 +638,10 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: |
615 | 638 | return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) |
616 | 639 |
|
617 | 640 | def scale_model_input( |
618 | | - self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None |
| 641 | + self, |
| 642 | + state: DPMSolverMultistepSchedulerState, |
| 643 | + sample: jnp.ndarray, |
| 644 | + timestep: Optional[int] = None, |
619 | 645 | ) -> jnp.ndarray: |
620 | 646 | """ |
621 | 647 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the |
|
0 commit comments