Skip to content

Commit 19921e9

Browse files
committed
fold Unions into |
1 parent 5aa4f1d commit 19921e9

482 files changed

Lines changed: 3214 additions & 3310 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/diffusers/configuration_utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import re
2525
from collections import OrderedDict
2626
from pathlib import Path
27-
from typing import Any, Optional, Union
27+
from typing import Any, Optional
2828

2929
import numpy as np
3030
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
@@ -143,7 +143,7 @@ def __getattr__(self, name: str) -> Any:
143143

144144
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
145145

146-
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
146+
def save_config(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
147147
"""
148148
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
149149
[`~ConfigMixin.from_config`] class method.
@@ -189,8 +189,8 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
189189

190190
@classmethod
191191
def from_config(
192-
cls, config: Union[FrozenDict, dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
193-
) -> Union[Self, tuple[Self, dict[str, Any]]]:
192+
cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs
193+
) -> Self | tuple[Self, dict[str, Any]]:
194194
r"""
195195
Instantiate a Python class from a config dictionary.
196196
@@ -292,7 +292,7 @@ def get_config_dict(cls, *args, **kwargs):
292292
@validate_hf_hub_args
293293
def load_config(
294294
cls,
295-
pretrained_model_name_or_path: Union[str, os.PathLike],
295+
pretrained_model_name_or_path: str | os.PathLike,
296296
return_unused_kwargs=False,
297297
return_commit_hash=False,
298298
**kwargs,
@@ -563,9 +563,7 @@ def extract_init_dict(cls, config_dict, **kwargs):
563563
return init_dict, unused_kwargs, hidden_config_dict
564564

565565
@classmethod
566-
def _dict_from_json_file(
567-
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[dict[str, DDUFEntry]] = None
568-
):
566+
def _dict_from_json_file(cls, json_file: str | os.PathLike, dduf_entries: Optional[dict[str, DDUFEntry]] = None):
569567
if dduf_entries:
570568
text = dduf_entries[json_file].read_text()
571569
else:
@@ -625,7 +623,7 @@ def to_json_saveable(value):
625623

626624
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
627625

628-
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
626+
def to_json_file(self, json_file_path: str | os.PathLike):
629627
"""
630628
Save the configuration instance's parameters to a JSON file.
631629
@@ -756,7 +754,7 @@ class LegacyConfigMixin(ConfigMixin):
756754
"""
757755

758756
@classmethod
759-
def from_config(cls, config: Union[FrozenDict, dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
757+
def from_config(cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs):
760758
# To prevent dependency import problem.
761759
from .models.model_loading_utils import _fetch_remapped_cls_from_config
762760

src/diffusers/guiders/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
from .smoothed_energy_guidance import SmoothedEnergyGuidance
2929
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
3030

31-
GuiderType = Union[
32-
AdaptiveProjectedGuidance,
33-
AutoGuidance,
34-
ClassifierFreeGuidance,
35-
ClassifierFreeZeroStarGuidance,
36-
FrequencyDecoupledGuidance,
37-
PerturbedAttentionGuidance,
38-
SkipLayerGuidance,
39-
SmoothedEnergyGuidance,
40-
TangentialClassifierFreeGuidance,
41-
]
31+
GuiderType = (
32+
AdaptiveProjectedGuidance
33+
| AutoGuidance
34+
| ClassifierFreeGuidance
35+
| ClassifierFreeZeroStarGuidance
36+
| FrequencyDecoupledGuidance
37+
| PerturbedAttentionGuidance
38+
| SkipLayerGuidance
39+
| SmoothedEnergyGuidance
40+
| TangentialClassifierFreeGuidance
41+
)

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Optional, Union
16+
from typing import TYPE_CHECKING, Optional
1717

1818
import torch
1919

@@ -77,7 +77,7 @@ def __init__(
7777
self.momentum_buffer = None
7878

7979
def prepare_inputs(
80-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
80+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
8181
) -> list["BlockState"]:
8282
if input_fields is None:
8383
input_fields = self._input_fields

src/diffusers/guiders/auto_guidance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Any, Optional, Union
16+
from typing import TYPE_CHECKING, Any, Optional
1717

1818
import torch
1919

@@ -65,8 +65,8 @@ class AutoGuidance(BaseGuidance):
6565
def __init__(
6666
self,
6767
guidance_scale: float = 7.5,
68-
auto_guidance_layers: Optional[Union[int, list[int]]] = None,
69-
auto_guidance_config: Union[LayerSkipConfig, list[LayerSkipConfig], dict[str, Any]] = None,
68+
auto_guidance_layers: Optional[int | list[int]] = None,
69+
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
7070
dropout: Optional[float] = None,
7171
guidance_rescale: float = 0.0,
7272
use_original_formulation: bool = False,
@@ -133,7 +133,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
133133
registry.remove_hook(name, recurse=True)
134134

135135
def prepare_inputs(
136-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
136+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
137137
) -> list["BlockState"]:
138138
if input_fields is None:
139139
input_fields = self._input_fields

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Optional, Union
16+
from typing import TYPE_CHECKING, Optional
1717

1818
import torch
1919

@@ -84,7 +84,7 @@ def __init__(
8484
self.use_original_formulation = use_original_formulation
8585

8686
def prepare_inputs(
87-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
87+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
8888
) -> list["BlockState"]:
8989
if input_fields is None:
9090
input_fields = self._input_fields

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Optional, Union
16+
from typing import TYPE_CHECKING, Optional
1717

1818
import torch
1919

@@ -77,7 +77,7 @@ def __init__(
7777
self.use_original_formulation = use_original_formulation
7878

7979
def prepare_inputs(
80-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
80+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
8181
) -> list["BlockState"]:
8282
if input_fields is None:
8383
input_fields = self._input_fields

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Optional, Union
16+
from typing import TYPE_CHECKING, Optional
1717

1818
import torch
1919

@@ -141,12 +141,12 @@ class FrequencyDecoupledGuidance(BaseGuidance):
141141
@register_to_config
142142
def __init__(
143143
self,
144-
guidance_scales: Union[list[float], tuple[float]] = [10.0, 5.0],
145-
guidance_rescale: Union[float, list[float], tuple[float]] = 0.0,
146-
parallel_weights: Optional[Union[float, list[float], tuple[float]]] = None,
144+
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
145+
guidance_rescale: float | list[float] | tuple[float] = 0.0,
146+
parallel_weights: Optional[float | list[float] | tuple[float]] = None,
147147
use_original_formulation: bool = False,
148-
start: Union[float, list[float], tuple[float]] = 0.0,
149-
stop: Union[float, list[float], tuple[float]] = 1.0,
148+
start: float | list[float] | tuple[float] = 0.0,
149+
stop: float | list[float] | tuple[float] = 1.0,
150150
guidance_rescale_space: str = "data",
151151
upcast_to_double: bool = True,
152152
):
@@ -218,7 +218,7 @@ def __init__(
218218
)
219219

220220
def prepare_inputs(
221-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
221+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
222222
) -> list["BlockState"]:
223223
if input_fields is None:
224224
input_fields = self._input_fields

src/diffusers/guiders/guider_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16-
from typing import TYPE_CHECKING, Any, Optional, Union
16+
from typing import TYPE_CHECKING, Any, Optional
1717

1818
import torch
1919
from huggingface_hub.utils import validate_hf_hub_args
@@ -47,7 +47,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
4747
self._num_inference_steps: int = None
4848
self._timestep: torch.LongTensor = None
4949
self._count_prepared = 0
50-
self._input_fields: dict[str, Union[str, tuple[str, str]]] = None
50+
self._input_fields: dict[str, str | tuple[str, str]] = None
5151
self._enabled = True
5252

5353
if not (0.0 <= start < 1.0):
@@ -72,7 +72,7 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen
7272
self._timestep = timestep
7373
self._count_prepared = 0
7474

75-
def set_input_fields(self, **kwargs: dict[str, Union[str, tuple[str, str]]]) -> None:
75+
def set_input_fields(self, **kwargs: dict[str, str | tuple[str, str]]) -> None:
7676
"""
7777
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
7878
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
@@ -155,7 +155,7 @@ def num_conditions(self) -> int:
155155
@classmethod
156156
def _prepare_batch(
157157
cls,
158-
input_fields: dict[str, Union[str, tuple[str, str]]],
158+
input_fields: dict[str, str | tuple[str, str]],
159159
data: "BlockState",
160160
tuple_index: int,
161161
identifier: str,
@@ -205,7 +205,7 @@ def _prepare_batch(
205205
@validate_hf_hub_args
206206
def from_pretrained(
207207
cls,
208-
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
208+
pretrained_model_name_or_path: Optional[str | os.PathLike] = None,
209209
subfolder: Optional[str] = None,
210210
return_unused_kwargs=False,
211211
**kwargs,
@@ -262,7 +262,7 @@ def from_pretrained(
262262
)
263263
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
264264

265-
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
265+
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
266266
"""
267267
Save a guider configuration object to a directory so that it can be reloaded using the
268268
[`~BaseGuidance.from_pretrained`] class method.

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Any, Optional, Union
16+
from typing import TYPE_CHECKING, Any, Optional
1717

1818
import torch
1919

@@ -92,8 +92,8 @@ def __init__(
9292
perturbed_guidance_scale: float = 2.8,
9393
perturbed_guidance_start: float = 0.01,
9494
perturbed_guidance_stop: float = 0.2,
95-
perturbed_guidance_layers: Optional[Union[int, list[int]]] = None,
96-
perturbed_guidance_config: Union[LayerSkipConfig, list[LayerSkipConfig], dict[str, Any]] = None,
95+
perturbed_guidance_layers: Optional[int | list[int]] = None,
96+
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
9797
guidance_rescale: float = 0.0,
9898
use_original_formulation: bool = False,
9999
start: float = 0.0,
@@ -169,7 +169,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
169169

170170
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
171171
def prepare_inputs(
172-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
172+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
173173
) -> list["BlockState"]:
174174
if input_fields is None:
175175
input_fields = self._input_fields

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Any, Optional, Union
16+
from typing import TYPE_CHECKING, Any, Optional
1717

1818
import torch
1919

@@ -94,8 +94,8 @@ def __init__(
9494
skip_layer_guidance_scale: float = 2.8,
9595
skip_layer_guidance_start: float = 0.01,
9696
skip_layer_guidance_stop: float = 0.2,
97-
skip_layer_guidance_layers: Optional[Union[int, list[int]]] = None,
98-
skip_layer_config: Union[LayerSkipConfig, list[LayerSkipConfig], dict[str, Any]] = None,
97+
skip_layer_guidance_layers: Optional[int | list[int]] = None,
98+
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
9999
guidance_rescale: float = 0.0,
100100
use_original_formulation: bool = False,
101101
start: float = 0.0,
@@ -165,7 +165,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
165165
registry.remove_hook(hook_name, recurse=True)
166166

167167
def prepare_inputs(
168-
self, data: "BlockState", input_fields: Optional[dict[str, Union[str, tuple[str, str]]]] = None
168+
self, data: "BlockState", input_fields: Optional[dict[str, str | tuple[str, str]]] = None
169169
) -> list["BlockState"]:
170170
if input_fields is None:
171171
input_fields = self._input_fields

0 commit comments

Comments
 (0)