Skip to content

Fix Wan batch conditioning and config handling#13693

Open
hlky wants to merge 1 commit intohuggingface:mainfrom
hlky:codex/wan-review-fixes
Open

Fix Wan batch conditioning and config handling#13693
hlky wants to merge 1 commit intohuggingface:mainfrom
hlky:codex/wan-review-fixes

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 7, 2026

Fixes #13578.

What does this PR do?

This PR addresses the Wan systemic review findings:

  • Expands Wan i2v and Animate image embeddings, masks, latents, and conditioning tensors to batch_size * num_videos_per_prompt in per-prompt order.
  • Keeps image required where VAE conditioning still needs it, while allowing precomputed CLIP image_embeds.
  • Accepts documented Wan i2v image list/tuple inputs and validates precomputed image-embed batch sizes.
  • Uses Wan VAE config scale factors in VACE, video-to-video, and modular Wan paths.
  • Trims VACE reference latents for output_type="latent".
  • Preserves Wan Animate image processor config values.
  • Stops silently ignoring unsupported Wan video-to-video num_videos_per_prompt != 1.
  • Keeps modular Wan timesteps in scheduler precision while casting model inputs to the selected transformer dtype.
  • Adds focused fast regression tests and a .ai/pipelines.md checklist note for num_*_per_prompt expansion drift.

Tests

  • ruff check ... on touched Python files
  • ruff format --check ... on touched Python files
  • git diff --check
  • python -m compileall -q src/diffusers/pipelines/wan src/diffusers/modular_pipelines/wan tests/pipelines/wan tests/modular_pipelines/wan
  • python -m pytest tests/pipelines/wan/test_wan_image_to_video.py tests/pipelines/wan/test_wan_vace.py tests/pipelines/wan/test_wan_video_to_video.py tests/pipelines/wan/test_wan_animate.py tests/modular_pipelines/wan/test_modular_pipeline_wan.py -q -k "not test_save_load_float16"

Result: 180 passed, 21 skipped, 5 deselected, 61 warnings, 6 subtests passed.

Note: WanFLFToVideoPipelineFastTests::test_save_load_float16 fails the same way on clean upstream/main (max diff 0.04175 > 0.01), so it was excluded from the clean touched-file verification.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments

)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat_interleave(
block_state.num_videos_per_prompt, dim=0
).to(block_state.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already cast dtype later in denoise step no?

]

@property
def intermediate_outputs(self) -> list[OutputParam]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to declare it here > this does not go beyond the loop step
https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks

hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
hidden_states=block_state.latent_model_input.to(dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dtype update here makes sense,

Keeps modular Wan timesteps in scheduler precision while casting model inputs to the selected transformer dtype.

I also liked we use transformer.dtype instead of prompt_embeds.dtype -> should be the stanrdard preferred way of infer dtype moving forward

But most of the rest of the change on inputs/outputs does not seem valid to me, loop step works a bit differently https://huggingface.co/docs/diffusers/modular_diffusers/loop_sequential_pipeline_blocks#loop-blocks; let me know if I missed anything


return mask_lat_size

def _expand_tensor_to_effective_batch(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it a regular function?

" only forward one of the two."
)
if image is None and image_embeds is None:
if image is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think the list of image make sense here for I2V pipeline, no?
Maybe we should just fix the docstring instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

wan model/pipeline review

2 participants