Skip to content

Commit f869c66

Browse files
authored
Merge branch 'main' into export-tests
2 parents 4a58c14 + 901da9d commit f869c66

10 files changed

Lines changed: 1187 additions & 25 deletions

File tree

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
120120
- all
121121
- __call__
122122

123+
## QwenImageEditInpaintPipeline
124+
125+
[[autodoc]] QwenImageEditInpaintPipeline
126+
- all
127+
- __call__
128+
123129
## QwenImaggeControlNetPipeline
124130
- all
125131
- __call__

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@
494494
"PixArtSigmaPAGPipeline",
495495
"PixArtSigmaPipeline",
496496
"QwenImageControlNetPipeline",
497+
"QwenImageEditInpaintPipeline",
497498
"QwenImageEditPipeline",
498499
"QwenImageImg2ImgPipeline",
499500
"QwenImageInpaintPipeline",
@@ -1134,6 +1135,7 @@
11341135
PixArtSigmaPAGPipeline,
11351136
PixArtSigmaPipeline,
11361137
QwenImageControlNetPipeline,
1138+
QwenImageEditInpaintPipeline,
11371139
QwenImageEditPipeline,
11381140
QwenImageImg2ImgPipeline,
11391141
QwenImageInpaintPipeline,

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
21292129

21302130

21312131
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2132+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
2133+
if has_diffusion_model:
2134+
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
2135+
21322136
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
21332137
if has_lora_unet:
21342138
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
@@ -2201,29 +2205,44 @@ def convert_key(key: str) -> str:
22012205
all_keys = list(state_dict.keys())
22022206
down_key = ".lora_down.weight"
22032207
up_key = ".lora_up.weight"
2208+
a_key = ".lora_A.weight"
2209+
b_key = ".lora_B.weight"
22042210

2205-
def get_alpha_scales(down_weight, alpha_key):
2206-
rank = down_weight.shape[0]
2207-
alpha = state_dict.pop(alpha_key).item()
2208-
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209-
scale_down = scale
2210-
scale_up = 1.0
2211-
while scale_down * 2 < scale_up:
2212-
scale_down *= 2
2213-
scale_up /= 2
2214-
return scale_down, scale_up
2211+
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
2212+
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
22152213

2216-
for k in all_keys:
2217-
if k.endswith(down_key):
2218-
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2219-
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2220-
alpha_key = k.replace(down_key, ".alpha")
2221-
2222-
down_weight = state_dict.pop(k)
2223-
up_weight = state_dict.pop(k.replace(down_key, up_key))
2224-
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2225-
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2226-
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2214+
if has_non_diffusers_lora_id:
2215+
2216+
def get_alpha_scales(down_weight, alpha_key):
2217+
rank = down_weight.shape[0]
2218+
alpha = state_dict.pop(alpha_key).item()
2219+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2220+
scale_down = scale
2221+
scale_up = 1.0
2222+
while scale_down * 2 < scale_up:
2223+
scale_down *= 2
2224+
scale_up /= 2
2225+
return scale_down, scale_up
2226+
2227+
for k in all_keys:
2228+
if k.endswith(down_key):
2229+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2230+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2231+
alpha_key = k.replace(down_key, ".alpha")
2232+
2233+
down_weight = state_dict.pop(k)
2234+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2235+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2236+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2237+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2238+
2239+
# Already in diffusers format (lora_A/lora_B), just pop
2240+
elif has_diffusers_lora_id:
2241+
for k in all_keys:
2242+
if a_key in k or b_key in k:
2243+
converted_state_dict[k] = state_dict.pop(k)
2244+
elif ".alpha" in k:
2245+
state_dict.pop(k)
22272246

22282247
if len(state_dict) > 0:
22292248
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6684,7 +6684,8 @@ def lora_state_dict(
66846684

66856685
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
66866686
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6687-
if has_alphas_in_sd or has_lora_unet:
6687+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
6688+
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
66886689
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66896690

66906691
out = (state_dict, metadata) if return_lora_metadata else state_dict

src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,13 @@ def _native_npu_attention(
955955
dropout_p: float = 0.0,
956956
scale: Optional[float] = None,
957957
) -> torch.Tensor:
958-
return npu_fusion_attention(
958+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
959+
out = npu_fusion_attention(
959960
query,
960961
key,
961962
value,
962-
query.size(2), # num_heads
963-
input_layout="BSND",
963+
query.size(1), # num_heads
964+
input_layout="BNSD",
964965
pse=None,
965966
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
966967
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
969970
sync=False,
970971
inner_precise=0,
971972
)[0]
973+
out = out.transpose(1, 2).contiguous()
974+
return out
972975

973976

974977
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@
393393
"QwenImageImg2ImgPipeline",
394394
"QwenImageInpaintPipeline",
395395
"QwenImageEditPipeline",
396+
"QwenImageEditInpaintPipeline",
396397
"QwenImageControlNetPipeline",
397398
]
398399
try:
@@ -714,6 +715,7 @@
714715
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
715716
from .qwenimage import (
716717
QwenImageControlNetPipeline,
718+
QwenImageEditInpaintPipeline,
717719
QwenImageEditPipeline,
718720
QwenImageImg2ImgPipeline,
719721
QwenImageInpaintPipeline,

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
2828
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
29+
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
2930
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
3031
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
3132

@@ -39,6 +40,7 @@
3940
from .pipeline_qwenimage import QwenImagePipeline
4041
from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
4142
from .pipeline_qwenimage_edit import QwenImageEditPipeline
43+
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
4244
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
4345
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4446
else:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,12 @@ def __call__(
551551
Function invoked when calling the pipeline for generation.
552552
553553
Args:
554+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
555+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
556+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
557+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
558+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
559+
latents as `image`, but if passing latents directly it is not encoded again.
554560
prompt (`str` or `List[str]`, *optional*):
555561
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
556562
instead.

0 commit comments

Comments
 (0)