From ffa46523b8d67f4d5749e3d4f164c05a7d5b0ff0 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 7 May 2026 15:13:26 +0000 Subject: [PATCH] Fix Stable Cascade review issues --- .../models/unets/unet_stable_cascade.py | 41 +++-- .../stable_cascade/pipeline_stable_cascade.py | 135 ++++++++++------- .../pipeline_stable_cascade_combined.py | 23 ++- .../pipeline_stable_cascade_prior.py | 101 ++++++++----- .../unets/test_models_unet_stable_cascade.py | 120 +++++++++++++++ .../test_stable_cascade_combined.py | 73 ++++++++- .../test_stable_cascade_decoder.py | 83 +++++++++- .../test_stable_cascade_prior.py | 143 ++++++++++++++++++ 8 files changed, 596 insertions(+), 123 deletions(-) create mode 100644 tests/models/unets/test_models_unet_stable_cascade.py diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index af98b7a1c602..7e8441ac17ac 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -135,6 +135,8 @@ class StableCascadeUNetOutput(BaseOutput): class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True + _supports_group_offloading = False + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( @@ -148,24 +150,24 @@ def __init__( num_attention_heads: tuple[int, ...] = (32, 32), down_num_layers_per_block: tuple[int, ...] = (8, 24), up_num_layers_per_block: tuple[int, ...] = (24, 8), - down_blocks_repeat_mappers: tuple[int] | None = ( + down_blocks_repeat_mappers: tuple[int, ...] | None = ( 1, 1, ), - up_blocks_repeat_mappers: tuple[int] | None = (1, 1), - block_types_per_layer: tuple[tuple[str]] = ( + up_blocks_repeat_mappers: tuple[int, ...] | None = (1, 1), + block_types_per_layer: tuple[tuple[str, ...], ...] = ( ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ), clip_text_in_channels: int | None = None, - clip_text_pooled_in_channels=1280, + clip_text_pooled_in_channels: int = 1280, clip_image_in_channels: int | None = None, - clip_seq=4, + clip_seq: int = 4, effnet_in_channels: int | None = None, pixel_mapper_in_channels: int | None = None, - kernel_size=3, - dropout: float | tuple[float] = (0.1, 0.1), - self_attn: bool | tuple[bool] = True, + kernel_size: int = 3, + dropout: float | tuple[float, ...] = (0.1, 0.1), + self_attn: bool | tuple[bool, ...] = True, timestep_conditioning_type: tuple[str, ...] = ("sca", "crp"), switch_level: tuple[bool] | None = None, ): @@ -431,20 +433,27 @@ def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): if len(clip_txt_pooled.shape) == 2: - clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1 ) - if clip_txt is not None and clip_img is not None: + + clip = [] + if clip_txt is not None: clip_txt = self.clip_txt_mapper(clip_txt) + clip.append(clip_txt) + + clip.append(clip_txt_pool) + + if clip_img is not None: if len(clip_img.shape) == 2: clip_img = clip_img.unsqueeze(1) clip_img = self.clip_img_mapper(clip_img).view( clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1 ) - clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) - else: - clip = clip_txt_pool + clip.append(clip_img) + + clip = torch.cat(clip, dim=1) return self.clip_norm(clip) def _down_encode(self, x, r_embed, clip): @@ -548,8 +557,8 @@ def forward( crp=None, return_dict=True, ): - if pixels is None: - pixels = sample.new_zeros(sample.size(0), 3, 8, 8) + if pixels is None and hasattr(self, "pixels_mapper"): + pixels = sample.new_zeros(sample.size(0), self.config.pixel_mapper_in_channels, 8, 8) # Process the conditioning embeddings timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) @@ -560,7 +569,7 @@ def forward( cond = crp else: cond = None - t_cond = cond or torch.zeros_like(timestep_ratio) + t_cond = cond if cond is not None else torch.zeros_like(timestep_ratio) timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1) clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 6a4066eb6e17..3ad32bccf31d 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -17,6 +17,7 @@ import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring @@ -44,12 +45,12 @@ >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 ... ).to("cuda") - >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain( + >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrained( ... "stabilityai/stable-cascade", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" - >>> prior_output = pipe(prompt) + >>> prior_output = prior_pipe(prompt) >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) ``` """ @@ -109,6 +110,7 @@ def __init__( vqgan=vqgan, ) self.register_to_config(latent_dim_scale=latent_dim_scale) + self.image_processor = VaeImageProcessor(do_normalize=False) def prepare_latents( self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler @@ -126,7 +128,7 @@ def prepare_latents( else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) latents = latents * scheduler.init_noise_sigma return latents @@ -178,22 +180,29 @@ def encode_prompt( if prompt_embeds_pooled is None: prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_dtype = self.text_encoder.dtype if self.text_encoder is not None else prompt_embeds.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=prompt_embeds_dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: + if self.tokenizer is None or self.text_encoder is None: + raise ValueError( + "`negative_prompt_embeds` must be provided when classifier-free guidance is enabled and the " + "pipeline does not have a tokenizer and text encoder." + ) + uncond_tokens: list[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt] if prompt is not None else [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -220,21 +229,12 @@ def encode_prompt( negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - seq_len = negative_prompt_embeds_pooled.shape[1] - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( - dtype=self.text_encoder.dtype, device=device - ) - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( - batch_size * num_images_per_prompt, seq_len, -1 + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat_interleave( + num_images_per_prompt, dim=0 ) - # done duplicates return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled @@ -243,7 +243,9 @@ def check_inputs( prompt, negative_prompt=None, prompt_embeds=None, + prompt_embeds_pooled=None, negative_prompt_embeds=None, + negative_prompt_embeds_pooled=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( @@ -279,6 +281,24 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if prompt_embeds is not None and prompt_embeds_pooled is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: + if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: + raise ValueError( + "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" + f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" + f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." + ) + @property def guidance_scale(self): return self._guidance_scale @@ -306,8 +326,9 @@ def get_timestep_ratio_conditioning(self, t, alphas_cumprod): def __call__( self, image_embeddings: torch.Tensor | list[torch.Tensor], - prompt: str | list[str] = None, + prompt: str | list[str] | None = None, num_inference_steps: int = 10, + timesteps: list[float] | None = None, guidance_scale: float = 0.0, negative_prompt: str | list[str] | None = None, prompt_embeds: torch.Tensor | None = None, @@ -317,7 +338,7 @@ def __call__( num_images_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, - output_type: str | None = "pil", + output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Callable[[int, int], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], @@ -326,22 +347,24 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image_embedding (`torch.Tensor` or `list[torch.Tensor]`): + image_embeddings (`torch.Tensor` or `list[torch.Tensor]`): Image Embeddings either extracted from an image or generated by a Prior Model. prompt (`str` or `list[str]`): The prompt or prompts to guide the image generation. - num_inference_steps (`int`, *optional*, defaults to 12): + num_inference_steps (`int`, *optional*, defaults to 10): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process. If provided, `num_inference_steps` is ignored. guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of - equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by - setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are - closely linked to the text `prompt`, usually at the expense of lower image quality. + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages generated images that are closely linked to the + text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `decoder_guidance_scale` is less than `1`). + if `guidance_scale` is less than or equal to `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -367,7 +390,7 @@ def __call__( tensor will be generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` - (`np.array`) or `"pt"` (`torch.Tensor`). + (`np.array`), `"pt"` (`torch.Tensor`) or `"latent"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): @@ -384,8 +407,7 @@ def __call__( Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image - embeddings. + otherwise a `tuple`. When returning a tuple, the first element is the generated images. """ # 0. Define commonly used variables @@ -394,17 +416,24 @@ def __call__( self._guidance_scale = guidance_scale if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) + image_embeddings = image_embeddings.to(device=device, dtype=dtype) if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -420,19 +449,18 @@ def __call__( num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size) # 2. Encode caption - if prompt_embeds is None and negative_prompt_embeds is None: - _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( - prompt=prompt, - device=device, - batch_size=batch_size, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - prompt_embeds_pooled=prompt_embeds_pooled, - negative_prompt_embeds=negative_prompt_embeds, - negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, - ) + _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) # The pooled embeds from the prior are pooled again before being passed to the decoder prompt_embeds_pooled = ( @@ -446,7 +474,7 @@ def __call__( else image_embeddings ) - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents @@ -516,20 +544,11 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if output_type not in ["pt", "np", "pil", "latent"]: - raise ValueError( - f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" - ) - if not output_type == "latent": # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) - if output_type == "np": - images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work - elif output_type == "pil": - images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work - images = self.numpy_to_pil(images) + images = self.image_processor.postprocess(images, output_type=output_type) else: images = latents @@ -537,5 +556,5 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return images + return (images,) return ImagePipelineOutput(images) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index 0afecad097da..98c1330ed0f6 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -125,6 +125,7 @@ def __init__( ) def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.prior_pipe.enable_xformers_memory_efficient_attention(attention_op) self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): @@ -160,12 +161,14 @@ def set_progress_bar_config(self, **kwargs): def __call__( self, prompt: str | list[str] | None = None, - images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, + images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] | None = None, height: int = 512, width: int = 512, prior_num_inference_steps: int = 60, + prior_timesteps: list[float] | None = None, prior_guidance_scale: float = 4.0, num_inference_steps: int = 12, + timesteps: list[float] | None = None, decoder_guidance_scale: float = 0.0, negative_prompt: str | list[str] | None = None, prompt_embeds: torch.Tensor | None = None, @@ -175,7 +178,7 @@ def __call__( num_images_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, - output_type: str | None = "pil", + output_type: str = "pil", return_dict: bool = True, prior_callback_on_step_end: Callable[[int, int], None] | None = None, prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], @@ -221,12 +224,16 @@ def __call__( closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int | dict[float, int]`, *optional*, defaults to 60): The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. For more specific timestep spacing, you can pass customized - `prior_timesteps` + expense of slower inference. + prior_timesteps (`list[float]`, *optional*): + Custom timesteps to use for the prior denoising process. If provided, `prior_num_inference_steps` is + ignored. num_inference_steps (`int`, *optional*, defaults to 12): The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. For more specific timestep spacing, you can pass customized - `timesteps` + the expense of slower inference. + timesteps (`list[float]`, *optional*): + Custom timesteps to use for the decoder denoising process. If provided, `num_inference_steps` is + ignored. decoder_guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. @@ -242,7 +249,7 @@ def __call__( tensor will be generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` - (`np.array`) or `"pt"` (`torch.Tensor`). + (`np.array`), `"pt"` (`torch.Tensor`) or `"latent"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback_on_step_end (`Callable`, *optional*): @@ -281,6 +288,7 @@ def __call__( height=height, width=width, num_inference_steps=prior_num_inference_steps, + timesteps=prior_timesteps, guidance_scale=prior_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, @@ -305,6 +313,7 @@ def __call__( image_embeddings=image_embeddings, prompt=prompt if prompt_embeds is None else None, num_inference_steps=num_inference_steps, + timesteps=timesteps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 0c5ea9ed61b4..3265dccb6743 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -51,7 +51,7 @@ ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" - >>> prior_output = pipe(prompt) + >>> prior_output = prior_pipe(prompt) ``` """ @@ -73,8 +73,8 @@ class StableCascadePriorPipelineOutput(BaseOutput): image_embeddings: torch.Tensor | np.ndarray prompt_embeds: torch.Tensor | np.ndarray prompt_embeds_pooled: torch.Tensor | np.ndarray - negative_prompt_embeds: torch.Tensor | np.ndarray - negative_prompt_embeds_pooled: torch.Tensor | np.ndarray + negative_prompt_embeds: torch.Tensor | np.ndarray | None + negative_prompt_embeds_pooled: torch.Tensor | np.ndarray | None class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline): @@ -147,7 +147,7 @@ def prepare_latents( else: if latents.shape != latent_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) latents = latents * scheduler.init_noise_sigma return latents @@ -199,22 +199,29 @@ def encode_prompt( if prompt_embeds_pooled is None: prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_dtype = self.text_encoder.dtype if self.text_encoder is not None else prompt_embeds.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=prompt_embeds_dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: + if self.tokenizer is None or self.text_encoder is None: + raise ValueError( + "`negative_prompt_embeds` must be provided when classifier-free guidance is enabled and the " + "pipeline does not have a tokenizer and text encoder." + ) + uncond_tokens: list[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt] if prompt is not None else [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -241,21 +248,12 @@ def encode_prompt( negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - seq_len = negative_prompt_embeds_pooled.shape[1] - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( - dtype=self.text_encoder.dtype, device=device + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat_interleave( + num_images_per_prompt, dim=0 ) - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - # done duplicates return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled @@ -342,7 +340,12 @@ def check_inputs( " only forward one of the two." ) - if images: + if images is not None: + if not isinstance(images, list): + images = [images] + if len(images) == 0: + raise ValueError("`images` cannot be an empty list.") + for i, image in enumerate(images): if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise TypeError( @@ -377,11 +380,11 @@ def get_timestep_ratio_conditioning(self, t, alphas_cumprod): def __call__( self, prompt: str | list[str] | None = None, - images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, + images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] | None = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 20, - timesteps: list[float] = None, + timesteps: list[float] | None = None, guidance_scale: float = 4.0, negative_prompt: str | list[str] | None = None, prompt_embeds: torch.Tensor | None = None, @@ -392,7 +395,7 @@ def __call__( num_images_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, - output_type: str | None = "pt", + output_type: str = "pt", return_dict: bool = True, callback_on_step_end: Callable[[int, int], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], @@ -403,22 +406,26 @@ def __call__( Args: prompt (`str` or `list[str]`): The prompt or prompts to guide the image generation. + images (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, *optional*): + The image conditioning inputs to guide prior generation. height (`int`, *optional*, defaults to 1024): The height in pixels of the generated image. width (`int`, *optional*, defaults to 1024): The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 60): + num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 8.0): + timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process. If provided, `num_inference_steps` is ignored. + guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of - equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by - setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are - closely linked to the text `prompt`, usually at the expense of lower image quality. + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages image embeddings that are closely linked to the + text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `decoder_guidance_scale` is less than `1`). + if `guidance_scale` is less than or equal to `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -435,7 +442,7 @@ def __call__( input argument. image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If - not provided, image embeddings will be generated from `image` input argument if existing. + not provided, image embeddings will be generated from `images` input argument if existing. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): @@ -445,9 +452,9 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` - (`np.array`) or `"pt"` (`torch.Tensor`). + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the image embeddings. Choose between: `"pt"` (`torch.Tensor`), `"latent"` + (`torch.Tensor`) or `"np"` (`np.array`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): @@ -472,6 +479,11 @@ def __call__( device = self._execution_device dtype = next(self.prior.parameters()).dtype self._guidance_scale = guidance_scale + if output_type not in ["pt", "np", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np` and `latent` are supported not output_type={output_type}" + ) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -479,6 +491,9 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + if images is not None and not isinstance(images, list): + images = [images] + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -520,7 +535,9 @@ def __call__( num_images_per_prompt=num_images_per_prompt, ) elif image_embeds is not None: - image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + image_embeds_pooled = image_embeds.to(device=device, dtype=dtype).repeat( + batch_size * num_images_per_prompt, 1, 1 + ) uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) else: image_embeds_pooled = torch.zeros( @@ -556,7 +573,7 @@ def __call__( ) # 4. Prepare and set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents @@ -630,9 +647,15 @@ def __call__( if output_type == "np": latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work + prompt_embeds_pooled = prompt_embeds_pooled.cpu().float().numpy() negative_prompt_embeds = ( negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None ) # float() as bfloat16-> numpy doesn't work + negative_prompt_embeds_pooled = ( + negative_prompt_embeds_pooled.cpu().float().numpy() + if negative_prompt_embeds_pooled is not None + else None + ) if not return_dict: return ( diff --git a/tests/models/unets/test_models_unet_stable_cascade.py b/tests/models/unets/test_models_unet_stable_cascade.py new file mode 100644 index 000000000000..e8b4237892e4 --- /dev/null +++ b/tests/models/unets/test_models_unet_stable_cascade.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import StableCascadeUNet + +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class StableCascadeUNetTests(ModelTesterMixin, unittest.TestCase): + model_class = StableCascadeUNet + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 2 + sample = floats_tensor((batch_size, 4, 8, 8)).to(torch_device) + timestep_ratio = torch.ones(batch_size, device=torch_device) + clip_text_pooled = floats_tensor((batch_size, 1, 8)).to(torch_device) + + return { + "sample": sample, + "timestep_ratio": timestep_ratio, + "clip_text_pooled": clip_text_pooled, + } + + @property + def input_shape(self): + return (4, 8, 8) + + @property + def output_shape(self): + return (4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "conditioning_dim": 8, + "block_out_channels": (8,), + "num_attention_heads": (1,), + "down_num_layers_per_block": (1,), + "up_num_layers_per_block": (1,), + "down_blocks_repeat_mappers": (1,), + "up_blocks_repeat_mappers": (1,), + "block_types_per_layer": (("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),), + "clip_text_in_channels": 8, + "clip_text_pooled_in_channels": 8, + "clip_image_in_channels": 8, + "dropout": 0.0, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_get_clip_embeddings_accepts_2d_pooled_text(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + clip_text_pooled = floats_tensor((2, 8)).to(torch_device) + clip = model.get_clip_embeddings(clip_text_pooled) + + self.assertEqual(clip.shape, (2, model.config.clip_seq, model.config.conditioning_dim)) + + def test_get_clip_embeddings_accepts_optional_clip_modalities(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + clip_text_pooled = floats_tensor((2, 1, 8)).to(torch_device) + clip_text = floats_tensor((2, 3, 8)).to(torch_device) + clip_img = floats_tensor((2, 8)).to(torch_device) + + text_only_clip = model.get_clip_embeddings(clip_text_pooled, clip_txt=clip_text) + image_only_clip = model.get_clip_embeddings(clip_text_pooled, clip_img=clip_img) + + self.assertEqual(text_only_clip.shape, (2, 3 + model.config.clip_seq, model.config.conditioning_dim)) + self.assertEqual(image_only_clip.shape, (2, 2 * model.config.clip_seq, model.config.conditioning_dim)) + + def test_forward_accepts_batched_sca_crp(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict["sca"] = torch.tensor([0.1, 0.2], device=torch_device) + inputs_dict["crp"] = torch.tensor([0.3, 0.4], device=torch_device) + + with torch.no_grad(): + output = model(**inputs_dict).sample + + self.assertEqual(output.shape, (2,) + self.output_shape) + + def test_forward_uses_configured_pixel_mapper_channels(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["pixel_mapper_in_channels"] = 5 + model = self.model_class(**init_dict).to(torch_device) + + with torch.no_grad(): + output = model(**inputs_dict).sample + + self.assertEqual(output.shape, (2,) + self.output_shape) + + def test_gradient_checkpointing_is_applied(self): + super().test_gradient_checkpointing_is_applied(expected_set={"StableCascadeUNet"}) diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py index d9a511ab199c..8e87dd43b8b8 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -14,6 +14,7 @@ # limitations under the License. import unittest +from unittest.mock import Mock import numpy as np import torch @@ -40,7 +41,9 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC "width", "latents", "prior_guidance_scale", + "prior_timesteps", "decoder_guidance_scale", + "timesteps", "negative_prompt", "num_inference_steps", "return_dict", @@ -136,6 +139,7 @@ def get_dummy_components(self): prior = self.dummy_prior scheduler = DDPMWuerstchenScheduler() + prior_scheduler = DDPMWuerstchenScheduler() tokenizer = self.dummy_tokenizer text_encoder = self.dummy_text_encoder decoder = self.dummy_decoder @@ -152,7 +156,7 @@ def get_dummy_components(self): "prior_text_encoder": prior_text_encoder, "prior_tokenizer": prior_tokenizer, "prior_prior": prior, - "prior_scheduler": scheduler, + "prior_scheduler": prior_scheduler, "prior_feature_extractor": None, "prior_image_encoder": None, } @@ -193,7 +197,7 @@ def test_stable_cascade(self): image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[-3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 128, 128, 3) @@ -205,6 +209,71 @@ def test_stable_cascade(self): f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" ) + def test_enable_xformers_memory_efficient_attention_delegates_to_both_pipelines(self): + pipe = StableCascadeCombinedPipeline.__new__(StableCascadeCombinedPipeline) + pipe.prior_pipe = Mock() + pipe.decoder_pipe = Mock() + attention_op = Mock() + + pipe.enable_xformers_memory_efficient_attention(attention_op) + + pipe.prior_pipe.enable_xformers_memory_efficient_attention.assert_called_once_with(attention_op) + pipe.decoder_pipe.enable_xformers_memory_efficient_attention.assert_called_once_with(attention_op) + + def test_decoder_guidance_with_prior_without_guidance(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["prior_guidance_scale"] = 1.0 + inputs["decoder_guidance_scale"] = 2.0 + inputs["prior_num_inference_steps"] = 1 + inputs["num_inference_steps"] = 1 + inputs["output_type"] = "latent" + + output = pipe(**inputs) + + self.assertEqual(output.images.shape[0], 1) + + def test_custom_timesteps_are_forwarded(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prior_timesteps = [1.0, 0.5, 0.0] + timesteps = [1.0, 0.25, 0.0] + seen_prior_timesteps = [] + seen_timesteps = [] + + def prior_callback_on_step_end(pipe, i, t, callback_kwargs): + seen_prior_timesteps.append(t.item()) + return callback_kwargs + + def callback_on_step_end(pipe, i, t, callback_kwargs): + seen_timesteps.append(t.item()) + return callback_kwargs + + inputs = self.get_dummy_inputs(device) + inputs["prior_num_inference_steps"] = 1 + inputs["prior_timesteps"] = prior_timesteps + inputs["num_inference_steps"] = 1 + inputs["timesteps"] = timesteps + inputs["output_type"] = "latent" + inputs["prior_callback_on_step_end"] = prior_callback_on_step_end + inputs["callback_on_step_end"] = callback_on_step_end + + pipe(**inputs) + + self.assertEqual(seen_prior_timesteps, prior_timesteps[:-1]) + self.assertEqual(seen_timesteps, timesteps[:-1]) + @require_torch_accelerator def test_offloads(self): pipes = [] diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py index b92df4c5d268..a90409893480 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py @@ -180,7 +180,7 @@ def test_wuerstchen_decoder(self): output = pipe(**self.get_dummy_inputs(device)) image = output.images - image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False) + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] @@ -191,6 +191,87 @@ def test_wuerstchen_decoder(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_prepare_latents_casts_supplied_latents(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + image_embeddings = torch.ones((1, 4, 1, 1)) + latents = torch.randn((1, 4, 4, 4), dtype=torch.float32) + + latents = pipe.prepare_latents( + batch_size=1, + image_embeddings=image_embeddings, + num_images_per_prompt=1, + dtype=torch.bfloat16, + device=torch.device("cpu"), + generator=None, + latents=latents, + scheduler=pipe.scheduler, + ) + + self.assertEqual(latents.dtype, torch.bfloat16) + + def test_precomputed_prompt_embeds_with_guidance(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, pipe.tokenizer.model_max_length, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + + output = pipe( + image_embeddings=torch.ones((1, 4, 4, 4)), + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + guidance_scale=2.0, + num_inference_steps=1, + output_type="latent", + ) + + self.assertEqual(output.images.shape, (1, 4, 16, 16)) + + def test_check_inputs_requires_pooled_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + with self.assertRaisesRegex(ValueError, "`prompt_embeds_pooled` must also be provided"): + pipe.check_inputs(prompt=None, prompt_embeds=torch.zeros(1, 4, self.text_embedder_hidden_size)) + + with self.assertRaisesRegex(ValueError, "`negative_prompt_embeds_pooled` must also be provided"): + pipe.check_inputs( + prompt="horse", + negative_prompt_embeds=torch.zeros(1, 4, self.text_embedder_hidden_size), + ) + + def test_custom_timesteps_are_forwarded(self): + components = self.get_dummy_components() + components["text_encoder"] = None + components["tokenizer"] = None + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, 4, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + timesteps = [1.0, 0.25, 0.0] + seen_timesteps = [] + + def callback_on_step_end(pipe, i, t, callback_kwargs): + seen_timesteps.append(t.item()) + return callback_kwargs + + pipe( + image_embeddings=torch.ones((1, 4, 4, 4)), + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + guidance_scale=1.0, + num_inference_steps=1, + timesteps=timesteps, + output_type="latent", + callback_on_step_end=callback_on_step_end, + ) + + self.assertEqual(seen_timesteps, timesteps[:-1]) + self.assertTrue(torch.allclose(pipe.scheduler.timesteps.cpu(), torch.tensor(timesteps))) + @skip_mps def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-2) diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py index 0bc821b7e64f..e1b12b0e4d9a 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py @@ -15,10 +15,12 @@ import gc import unittest +from types import SimpleNamespace import numpy as np import pytest import torch +from PIL import Image from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline @@ -49,6 +51,18 @@ enable_full_determinism() +class DummyStableCascadeImageEncoder(torch.nn.Module): + def forward(self, pixel_values): + return SimpleNamespace( + image_embeds=torch.zeros(pixel_values.shape[0], 768, device=pixel_values.device, dtype=pixel_values.dtype) + ) + + +class DummyStableCascadeFeatureExtractor: + def __call__(self, image, return_tensors=None): + return SimpleNamespace(pixel_values=torch.zeros(1, 3, 8, 8)) + + class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableCascadePriorPipeline params = ["prompt"] @@ -188,6 +202,135 @@ def test_wuerstchen_prior(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2 + def test_check_inputs_accepts_single_images(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + pipe.check_inputs(prompt="horse", images=Image.new("RGB", (8, 8))) + pipe.check_inputs(prompt="horse", images=torch.zeros(3, 8, 8)) + with self.assertRaisesRegex(ValueError, "`images` cannot be an empty list"): + pipe.check_inputs(prompt="horse", images=[]) + + def test_single_images_are_encoded(self): + components = self.get_dummy_components() + components["text_encoder"] = None + components["tokenizer"] = None + components["image_encoder"] = DummyStableCascadeImageEncoder() + components["feature_extractor"] = DummyStableCascadeFeatureExtractor() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, 4, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + + output = pipe( + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + images=Image.new("RGB", (8, 8)), + height=42, + width=42, + guidance_scale=1.0, + num_inference_steps=1, + output_type="pt", + ) + + self.assertEqual(output.image_embeddings.shape, (1, 16, 1, 1)) + + def test_prepare_latents_casts_supplied_latents(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + latents = torch.randn((1, pipe.prior.config.in_channels, 1, 1), dtype=torch.float32) + + latents = pipe.prepare_latents( + batch_size=1, + height=42, + width=42, + num_images_per_prompt=1, + dtype=torch.bfloat16, + device=torch.device("cpu"), + generator=None, + latents=latents, + scheduler=pipe.scheduler, + ) + + self.assertEqual(latents.dtype, torch.bfloat16) + + def test_custom_timesteps_are_forwarded(self): + components = self.get_dummy_components() + components["text_encoder"] = None + components["tokenizer"] = None + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, 4, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + timesteps = [1.0, 0.25, 0.0] + seen_timesteps = [] + + def callback_on_step_end(pipe, i, t, callback_kwargs): + seen_timesteps.append(t.item()) + return callback_kwargs + + pipe( + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + height=42, + width=42, + guidance_scale=1.0, + num_inference_steps=1, + timesteps=timesteps, + output_type="pt", + callback_on_step_end=callback_on_step_end, + ) + + self.assertEqual(seen_timesteps, timesteps[:-1]) + self.assertTrue(torch.allclose(pipe.scheduler.timesteps.cpu(), torch.tensor(timesteps))) + + def test_precomputed_prompt_embeds_accept_negative_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, pipe.tokenizer.model_max_length, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + + output = pipe( + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt="low quality", + height=42, + width=42, + guidance_scale=2.0, + num_inference_steps=1, + output_type="pt", + ) + + self.assertEqual(output.image_embeddings.shape, (1, 16, 1, 1)) + + def test_np_output_converts_all_prompt_embeds(self): + components = self.get_dummy_components() + components["text_encoder"] = None + components["tokenizer"] = None + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt_embeds = torch.zeros(1, 4, self.text_embedder_hidden_size) + prompt_embeds_pooled = torch.zeros(1, 1, self.text_embedder_hidden_size) + + output = pipe( + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + height=42, + width=42, + guidance_scale=1.0, + num_inference_steps=1, + output_type="np", + ) + + self.assertIsInstance(output.image_embeddings, np.ndarray) + self.assertIsInstance(output.prompt_embeds, np.ndarray) + self.assertIsInstance(output.prompt_embeds_pooled, np.ndarray) + @skip_mps def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-1)