Skip to content

Commit ed881a1

Browse files
committed
up up up
1 parent fe2a6a3 commit ed881a1

6 files changed

Lines changed: 187 additions & 159 deletions

File tree

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 131 additions & 126 deletions
Large diffs are not rendered by default.

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
105105
if not block_state.output_type == "latent":
106106
latents = block_state.latents
107107
# make sure the VAE is in float32 mode, as it overflows in float16
108-
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
108+
needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
109109

110-
if block_state.needs_upcasting:
110+
if needs_upcasting:
111111
self.upcast_vae(components)
112112
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
113113
elif latents.dtype != components.vae.dtype:
@@ -117,21 +117,21 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
117117

118118
# unscale/denormalize the latents
119119
# denormalize with the mean and std if available and not None
120-
block_state.has_latents_mean = (
120+
has_latents_mean = (
121121
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
122122
)
123-
block_state.has_latents_std = (
123+
has_latents_std = (
124124
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
125125
)
126-
if block_state.has_latents_mean and block_state.has_latents_std:
127-
block_state.latents_mean = (
126+
if has_latents_mean and has_latents_std:
127+
latents_mean = (
128128
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
129129
)
130-
block_state.latents_std = (
130+
latents_std = (
131131
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
132132
)
133133
latents = (
134-
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
134+
latents * latents_std / components.vae.config.scaling_factor + latents_mean
135135
)
136136
else:
137137
latents = latents / components.vae.config.scaling_factor

src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def intermediate_inputs(self) -> List[str]:
6767

6868
@torch.no_grad()
6969
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
70-
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
70+
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
7171

7272
return components, block_state
7373

@@ -134,10 +134,10 @@ def check_inputs(components, block_state):
134134
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
135135
self.check_inputs(components, block_state)
136136

137-
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
137+
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
138138
if components.num_channels_unet == 9:
139-
block_state.scaled_latents = torch.cat(
140-
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
139+
block_state.latent_model_input = torch.cat(
140+
[block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1
141141
)
142142

143143
return components, block_state
@@ -232,7 +232,7 @@ def __call__(
232232
# Predict the noise residual
233233
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
234234
guider_state_batch.noise_pred = components.unet(
235-
block_state.scaled_latents,
235+
block_state.latent_model_input,
236236
t,
237237
encoder_hidden_states=prompt_embeds,
238238
timestep_cond=block_state.timestep_cond,
@@ -410,7 +410,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
410410
mid_block_res_sample = block_state.mid_block_res_sample_zeros
411411
else:
412412
down_block_res_samples, mid_block_res_sample = components.controlnet(
413-
block_state.scaled_latents,
413+
block_state.latent_model_input,
414414
t,
415415
encoder_hidden_states=guider_state_batch.prompt_embeds,
416416
controlnet_cond=block_state.controlnet_cond,
@@ -430,7 +430,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl
430430
# Predict the noise
431431
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
432432
guider_state_batch.noise_pred = components.unet(
433-
block_state.scaled_latents,
433+
block_state.latent_model_input,
434434
t,
435435
encoder_hidden_states=guider_state_batch.prompt_embeds,
436436
timestep_cond=block_state.timestep_cond,

src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ def encode_prompt(
390390
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
391391
the output of the pre-final layer will be used for computing the prompt embeddings.
392392
"""
393-
device = device or components._execution_device
394393
dtype = components.text_encoder_2.dtype
395394

396395

@@ -526,7 +525,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
526525
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
527526

528527
device = components._execution_device
529-
dtype = components.text_encoder_2.dtype
530528

531529
# Encode input prompt
532530
lora_scale = (
@@ -542,8 +540,8 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
542540
) = self.encode_prompt(
543541
components,
544542
prompt=block_state.prompt,
545-
prompt2=block_state.prompt_2,
546-
device = device,
543+
prompt_2=block_state.prompt_2,
544+
device=device,
547545
requires_unconditional_embeds=components.requires_unconditional_embeds,
548546
negative_prompt=block_state.negative_prompt,
549547
negative_prompt_2=block_state.negative_prompt_2,
@@ -604,11 +602,11 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
604602
device = components._execution_device
605603
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
606604

607-
block_state.processed_image = components.image_processor.preprocess(block_state.image)
605+
image = components.image_processor.preprocess(block_state.image)
608606

609607
# Encode image into latents
610608
block_state.image_latents = encode_vae_image(
611-
image=block_state.processed_image,
609+
image=image,
612610
vae=components.vae,
613611
generator=block_state.generator,
614612
dtype=dtype,
@@ -681,7 +679,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
681679
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
682680
),
683681
OutputParam(
684-
"mask_latents",
682+
"mask",
685683
type_hint=torch.Tensor,
686684
description="The mask to apply on the latents for the inpainting generation.",
687685
),
@@ -715,37 +713,37 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
715713
width = components.default_width
716714

717715
if block_state.padding_mask_crop is not None:
718-
crops_coords = components.mask_processor.get_crop_region(
716+
block_state.crops_coords = components.mask_processor.get_crop_region(
719717
mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop
720718
)
721719
resize_mode = "fill"
722720
else:
723-
crops_coords = None
721+
block_state.crops_coords = None
724722
resize_mode = "default"
725723

726-
processed_image = components.image_processor.preprocess(
724+
image = components.image_processor.preprocess(
727725
block_state.image,
728726
height=height,
729727
width=width,
730728
crops_coords=crops_coords,
731729
resize_mode=resize_mode,
732730
)
733731

734-
processed_image = processed_image.to(dtype=torch.float32)
732+
image = image.to(dtype=torch.float32)
735733

736-
processed_mask_image = components.mask_processor.preprocess(
734+
mask = components.mask_processor.preprocess(
737735
block_state.mask_image,
738736
height=height,
739737
width=width,
740738
resize_mode=resize_mode,
741739
crops_coords=crops_coords,
742740
)
743741

744-
masked_image = processed_image * (block_state.mask_latents < 0.5)
742+
masked_image = image * (block_state.mask_latents < 0.5)
745743

746744
# Prepare image latent variables
747745
block_state.image_latents = encode_vae_image(
748-
image=processed_image,
746+
image=image,
749747
vae=components.vae,
750748
generator=block_state.generator,
751749
dtype=dtype,
@@ -763,11 +761,11 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
763761

764762
# resize mask to match the image latents
765763
_, _, height_latents, width_latents = block_state.image_latents.shape
766-
block_state.mask_latents = torch.nn.functional.interpolate(
767-
processed_mask_image,
764+
block_state.mask = torch.nn.functional.interpolate(
765+
mask,
768766
size=(height_latents, width_latents),
769767
)
770-
block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device)
768+
block_state.mask = block_state.mask.to(dtype=dtype, device=device)
771769

772770
self.set_block_state(state, block_state)
773771

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
StableDiffusionXLPrepareAdditionalConditioningStep,
2727
StableDiffusionXLPrepareLatentsStep,
2828
StableDiffusionXLSetTimestepsStep,
29+
StableDiffusionXLLCMStep,
2930
)
3031
from .decoders import (
3132
StableDiffusionXLDecodeStep,
@@ -79,6 +80,16 @@ def description(self):
7980
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
8081

8182

83+
class StableDiffusionXLAutoLCMStep(AutoPipelineBlocks):
84+
block_classes = [StableDiffusionXLLCMStep]
85+
block_names = ["lcm"]
86+
block_trigger_inputs = ["embedded_guidance_scale"]
87+
88+
@property
89+
def description(self):
90+
return "Run LCM step if `latents` is provided. This step should be placed before the 'input' step.\n"
91+
92+
8293
# before_denoise: text2img
8394
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
8495
block_classes = [
@@ -262,6 +273,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
262273
StableDiffusionXLAutoIPAdapterStep,
263274
StableDiffusionXLAutoVaeEncoderStep,
264275
StableDiffusionXLAutoBeforeDenoiseStep,
276+
StableDiffusionXLAutoLCMStep,
265277
StableDiffusionXLAutoControlNetInputStep,
266278
StableDiffusionXLAutoDenoiseStep,
267279
StableDiffusionXLAutoDecodeStep,
@@ -271,6 +283,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
271283
"ip_adapter",
272284
"image_encoder",
273285
"before_denoise",
286+
"lcm",
274287
"controlnet_input",
275288
"denoise",
276289
"decoder",
@@ -286,6 +299,7 @@ def description(self):
286299
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
287300
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
288301
+ "- for text-to-image generation, all you need to provide is `prompt`"
302+
+ "- to run the latent consistency models workflow, you need to provide `embedded_guidance_scale`"
289303
)
290304

291305

@@ -357,6 +371,13 @@ def description(self):
357371
]
358372
)
359373

374+
LCM_BLOCKS = InsertableDict(
375+
376+
[
377+
("lcm", StableDiffusionXLAutoLCMStep),
378+
]
379+
)
380+
360381
AUTO_BLOCKS = InsertableDict(
361382
[
362383
("text_encoder", StableDiffusionXLTextEncoderStep),
@@ -376,5 +397,6 @@ def description(self):
376397
"inpaint": INPAINT_BLOCKS,
377398
"controlnet": CONTROLNET_BLOCKS,
378399
"ip_adapter": IP_ADAPTER_BLOCKS,
400+
"lcm": LCM_BLOCKS,
379401
"auto": AUTO_BLOCKS,
380402
}

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ def requires_unconditional_embeds(self):
9595
# by default, always prepare unconditional embeddings
9696
requires_unconditional_embeds = True
9797

98-
if hasattr(self, "guider") and self.guider is not None:
98+
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None:
99+
requires_unconditional_embeds = False
100+
101+
elif hasattr(self, "guider") and self.guider is not None:
99102
requires_unconditional_embeds = self.guider.num_conditions > 1
100103

101104
return requires_unconditional_embeds

0 commit comments

Comments
 (0)