|
16 | 16 | TextualInversionLoaderMixin, |
17 | 17 | ) |
18 | 18 | from diffusers.models import AutoencoderKL, UNet2DConditionModel |
19 | | -from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor |
20 | 19 | from diffusers.models.lora import adjust_lora_scale_text_encoder |
21 | 20 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
22 | 21 | from diffusers.schedulers import KarrasDiffusionSchedulers |
23 | 22 | from diffusers.utils import ( |
| 23 | + deprecate, |
24 | 24 | is_accelerate_available, |
25 | 25 | is_accelerate_version, |
26 | 26 | is_invisible_watermark_available, |
@@ -614,18 +614,7 @@ def tiled_decode(self, latents, current_height, current_width): |
614 | 614 |
|
615 | 615 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae |
616 | 616 | def upcast_vae(self): |
617 | | - dtype = self.vae.dtype |
618 | | - self.vae.to(dtype=torch.float32) |
619 | | - use_torch_2_0_or_xformers = isinstance( |
620 | | - self.vae.decoder.mid_block.attentions[0].processor, |
621 | | - (AttnProcessor2_0, XFormersAttnProcessor), |
622 | | - ) |
623 | | - # if xformers or torch_2_0 is used attention block does not need |
624 | | - # to be in float32 which can save lots of memory |
625 | | - if use_torch_2_0_or_xformers: |
626 | | - self.vae.post_quant_conv.to(dtype) |
627 | | - self.vae.decoder.conv_in.to(dtype) |
628 | | - self.vae.decoder.mid_block.to(dtype) |
| 617 | + deprecate("`upcast_vae` is deprecated") |
629 | 618 |
|
630 | 619 | @torch.no_grad() |
631 | 620 | @replace_example_docstring(EXAMPLE_DOC_STRING) |
@@ -997,7 +986,7 @@ def __call__( |
997 | 986 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
998 | 987 |
|
999 | 988 | if needs_upcasting: |
1000 | | - self.upcast_vae() |
| 989 | + self.vae.to(torch.float32) |
1001 | 990 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
1002 | 991 | print("### Phase 1 Decoding ###") |
1003 | 992 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
@@ -1257,7 +1246,7 @@ def __call__( |
1257 | 1246 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
1258 | 1247 |
|
1259 | 1248 | if needs_upcasting: |
1260 | | - self.upcast_vae() |
| 1249 | + self.vae.to(torch.float32) |
1261 | 1250 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
1262 | 1251 |
|
1263 | 1252 | print("### Phase {} Decoding ###".format(current_scale_num)) |
|
0 commit comments