|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import inspect |
16 | 17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
17 | 18 |
|
18 | 19 | import numpy as np |
|
26 | 27 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
27 | 28 | from ...utils import is_torch_xla_available, logging, replace_example_docstring |
28 | 29 | from ...utils.torch_utils import randn_tensor |
29 | | -from ..pipeline_utils import retrieve_timesteps |
30 | 30 | from .pipeline_output import CogView4PipelineOutput |
31 | 31 |
|
32 | 32 |
|
@@ -68,6 +68,74 @@ def calculate_shift( |
68 | 68 | return mu |
69 | 69 |
|
70 | 70 |
|
| 71 | +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps |
| 72 | +def retrieve_timesteps( |
| 73 | + scheduler, |
| 74 | + num_inference_steps: Optional[int] = None, |
| 75 | + device: Optional[Union[str, torch.device]] = None, |
| 76 | + timesteps: Optional[List[int]] = None, |
| 77 | + sigmas: Optional[List[float]] = None, |
| 78 | + **kwargs, |
| 79 | +): |
| 80 | + r""" |
| 81 | + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| 82 | + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| 83 | +
|
| 84 | + Args: |
| 85 | + scheduler (`SchedulerMixin`): |
| 86 | + The scheduler to get timesteps from. |
| 87 | + num_inference_steps (`int`): |
| 88 | + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| 89 | + must be `None`. |
| 90 | + device (`str` or `torch.device`, *optional*): |
| 91 | + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 92 | + timesteps (`List[int]`, *optional*): |
| 93 | + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| 94 | + `num_inference_steps` and `sigmas` must be `None`. |
| 95 | + sigmas (`List[float]`, *optional*): |
| 96 | + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| 97 | + `num_inference_steps` and `timesteps` must be `None`. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| 101 | + second element is the number of inference steps. |
| 102 | + """ |
| 103 | + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 104 | + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 105 | + |
| 106 | + if timesteps is not None and sigmas is not None: |
| 107 | + if not accepts_timesteps and not accepts_sigmas: |
| 108 | + raise ValueError( |
| 109 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 110 | + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." |
| 111 | + ) |
| 112 | + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) |
| 113 | + timesteps = scheduler.timesteps |
| 114 | + num_inference_steps = len(timesteps) |
| 115 | + elif timesteps is not None and sigmas is None: |
| 116 | + if not accepts_timesteps: |
| 117 | + raise ValueError( |
| 118 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 119 | + f" timestep schedules. Please check whether you are using the correct scheduler." |
| 120 | + ) |
| 121 | + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| 122 | + timesteps = scheduler.timesteps |
| 123 | + num_inference_steps = len(timesteps) |
| 124 | + elif timesteps is None and sigmas is not None: |
| 125 | + if not accepts_sigmas: |
| 126 | + raise ValueError( |
| 127 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 128 | + f" sigmas schedules. Please check whether you are using the correct scheduler." |
| 129 | + ) |
| 130 | + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| 131 | + timesteps = scheduler.timesteps |
| 132 | + num_inference_steps = len(timesteps) |
| 133 | + else: |
| 134 | + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| 135 | + timesteps = scheduler.timesteps |
| 136 | + return timesteps, num_inference_steps |
| 137 | + |
| 138 | + |
71 | 139 | class CogView4ControlPipeline(DiffusionPipeline): |
72 | 140 | r""" |
73 | 141 | Pipeline for text-to-image generation using CogView4. |
|
0 commit comments