Skip to content

Commit 0a695cf

Browse files
committed
up
1 parent 2f9971d commit 0a695cf

2 files changed

Lines changed: 137 additions & 2 deletions

File tree

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1718

1819
import numpy as np
@@ -27,7 +28,6 @@
2728
from ...schedulers import FlowMatchEulerDiscreteScheduler
2829
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2930
from ...utils.torch_utils import randn_tensor
30-
from ..pipeline_utils import retrieve_timesteps
3131
from .pipeline_output import CogView4PipelineOutput
3232

3333

@@ -67,6 +67,73 @@ def calculate_shift(
6767
return mu
6868

6969

70+
def retrieve_timesteps(
71+
scheduler,
72+
num_inference_steps: Optional[int] = None,
73+
device: Optional[Union[str, torch.device]] = None,
74+
timesteps: Optional[List[int]] = None,
75+
sigmas: Optional[List[float]] = None,
76+
**kwargs,
77+
):
78+
r"""
79+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
80+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
81+
82+
Args:
83+
scheduler (`SchedulerMixin`):
84+
The scheduler to get timesteps from.
85+
num_inference_steps (`int`):
86+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
87+
must be `None`.
88+
device (`str` or `torch.device`, *optional*):
89+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
90+
timesteps (`List[int]`, *optional*):
91+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
92+
`num_inference_steps` and `sigmas` must be `None`.
93+
sigmas (`List[float]`, *optional*):
94+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
95+
`num_inference_steps` and `timesteps` must be `None`.
96+
97+
Returns:
98+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
99+
second element is the number of inference steps.
100+
"""
101+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
102+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103+
104+
if timesteps is not None and sigmas is not None:
105+
if not accepts_timesteps and not accepts_sigmas:
106+
raise ValueError(
107+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
108+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
109+
)
110+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
111+
timesteps = scheduler.timesteps
112+
num_inference_steps = len(timesteps)
113+
elif timesteps is not None and sigmas is None:
114+
if not accepts_timesteps:
115+
raise ValueError(
116+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117+
f" timestep schedules. Please check whether you are using the correct scheduler."
118+
)
119+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
120+
timesteps = scheduler.timesteps
121+
num_inference_steps = len(timesteps)
122+
elif timesteps is None and sigmas is not None:
123+
if not accepts_sigmas:
124+
raise ValueError(
125+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126+
f" sigmas schedules. Please check whether you are using the correct scheduler."
127+
)
128+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
129+
timesteps = scheduler.timesteps
130+
num_inference_steps = len(timesteps)
131+
else:
132+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
133+
timesteps = scheduler.timesteps
134+
return timesteps, num_inference_steps
135+
136+
70137
class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
71138
r"""
72139
Pipeline for text-to-image generation using CogView4.

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1718

1819
import numpy as np
@@ -26,7 +27,6 @@
2627
from ...schedulers import FlowMatchEulerDiscreteScheduler
2728
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2829
from ...utils.torch_utils import randn_tensor
29-
from ..pipeline_utils import retrieve_timesteps
3030
from .pipeline_output import CogView4PipelineOutput
3131

3232

@@ -68,6 +68,74 @@ def calculate_shift(
6868
return mu
6969

7070

71+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
72+
def retrieve_timesteps(
73+
scheduler,
74+
num_inference_steps: Optional[int] = None,
75+
device: Optional[Union[str, torch.device]] = None,
76+
timesteps: Optional[List[int]] = None,
77+
sigmas: Optional[List[float]] = None,
78+
**kwargs,
79+
):
80+
r"""
81+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
82+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
83+
84+
Args:
85+
scheduler (`SchedulerMixin`):
86+
The scheduler to get timesteps from.
87+
num_inference_steps (`int`):
88+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
89+
must be `None`.
90+
device (`str` or `torch.device`, *optional*):
91+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
92+
timesteps (`List[int]`, *optional*):
93+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
94+
`num_inference_steps` and `sigmas` must be `None`.
95+
sigmas (`List[float]`, *optional*):
96+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
97+
`num_inference_steps` and `timesteps` must be `None`.
98+
99+
Returns:
100+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101+
second element is the number of inference steps.
102+
"""
103+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104+
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105+
106+
if timesteps is not None and sigmas is not None:
107+
if not accepts_timesteps and not accepts_sigmas:
108+
raise ValueError(
109+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110+
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111+
)
112+
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
113+
timesteps = scheduler.timesteps
114+
num_inference_steps = len(timesteps)
115+
elif timesteps is not None and sigmas is None:
116+
if not accepts_timesteps:
117+
raise ValueError(
118+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119+
f" timestep schedules. Please check whether you are using the correct scheduler."
120+
)
121+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122+
timesteps = scheduler.timesteps
123+
num_inference_steps = len(timesteps)
124+
elif timesteps is None and sigmas is not None:
125+
if not accepts_sigmas:
126+
raise ValueError(
127+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128+
f" sigmas schedules. Please check whether you are using the correct scheduler."
129+
)
130+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131+
timesteps = scheduler.timesteps
132+
num_inference_steps = len(timesteps)
133+
else:
134+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135+
timesteps = scheduler.timesteps
136+
return timesteps, num_inference_steps
137+
138+
71139
class CogView4ControlPipeline(DiffusionPipeline):
72140
r"""
73141
Pipeline for text-to-image generation using CogView4.

0 commit comments

Comments
 (0)