Skip to content

Fix Flax LMSDiscreteScheduler sigma indexing and align LMS prediction_type with PyTorch#13556

Open
HaoyuLi-Nova wants to merge 1 commit intohuggingface:mainfrom
HaoyuLi-Nova:main
Open

Fix Flax LMSDiscreteScheduler sigma indexing and align LMS prediction_type with PyTorch#13556
HaoyuLi-Nova wants to merge 1 commit intohuggingface:mainfrom
HaoyuLi-Nova:main

Conversation

@HaoyuLi-Nova
Copy link
Copy Markdown

@HaoyuLi-Nova HaoyuLi-Nova commented Apr 24, 2026

What does this PR do?

FlaxLMSDiscreteScheduler pipelines pass timestep values from state.timesteps (training-time indices), while state.sigmas after set_timesteps is indexed by inference step (length num_inference_steps + 1). The previous step implementation indexed sigmas with that timestep value directly, which is inconsistent with FlaxEulerDiscreteScheduler and can yield wrong sigma values or out-of-bounds behavior. This PR resolves the step index via jnp.where(state.timesteps == timestep, size=1) before reading sigma and before LMS coefficient integration.

Additional alignment with PyTorch LMSDiscreteScheduler:

  • Flax: support prediction_type="sample", validate prediction_type in __init__, clarify docs (beta_schedule, get_lms_coefficient, step timestep semantics), fix comment typo (“settable”).
  • PyTorch: fix the ValueError message in step so it lists epsilon, sample, and v_prediction (behavior already supported sample).

Tests: extend test_prediction_type with "sample", add test_full_loop_with_sample_prediction, and add tests/schedulers/test_scheduler_lms_flax.py (@require_flax, lazy JAX import).

Fixes # (N/A — no linked issue)

Before submitting

Who can review?

cc @yiyixuxu (schedulers) @pcuenca (JAX/Flax)

@github-actions github-actions Bot added schedulers size/S PR with diff < 50 LOC labels Apr 24, 2026
@HaoyuLi-Nova HaoyuLi-Nova changed the title Replace the TODO placeholders in the docstring of Document FlaxLMSDiscreteScheduler.get_lms_coefficient arguments Apr 24, 2026
@HaoyuLi-Nova HaoyuLi-Nova reopened this Apr 24, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 24, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 24, 2026
@HaoyuLi-Nova HaoyuLi-Nova changed the title Document FlaxLMSDiscreteScheduler.get_lms_coefficient arguments Fix FlaxLMS step timestep→index; complete get_lms_coefficient docs Apr 24, 2026
…improve documentation

- Updated the `FlaxLMSDiscreteScheduler` and `LMSDiscreteScheduler` classes to include 'sample' as a valid option for `prediction_type`.
- Improved docstrings for clarity, specifying the expected types and values for various parameters.
- Added a new test case to validate the full loop functionality with the 'sample' prediction type.

This change ensures better flexibility in prediction methods and enhances code documentation for future reference.
@github-actions github-actions Bot added tests size/M PR with diff < 200 LOC and removed size/S PR with diff < 50 LOC labels May 9, 2026
@HaoyuLi-Nova HaoyuLi-Nova changed the title Fix FlaxLMS step timestep→index; complete get_lms_coefficient docs Fix Flax LMSDiscreteScheduler sigma indexing and align LMS prediction_type with PyTorch May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

schedulers size/M PR with diff < 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant