Skip to content

Add MoE load balancing loss to distillation#3679

Open
JamesDeng42 wants to merge 1 commit intomainfrom
yujiedeng/load-balance-loss
Open

Add MoE load balancing loss to distillation#3679
JamesDeng42 wants to merge 1 commit intomainfrom
yujiedeng/load-balance-loss

Conversation

@JamesDeng42
Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 commented Apr 16, 2026

Description

This PR introduces support for Mixture of Experts (MoE) load balancing loss during the distillation workflow.

Key Changes

  1. NNX Intermediate Extraction (maxtext_utils.py & qwen3.py):
    • Replaced legacy Linen self.sow(...) calls with native nnx.Intermediate(load_balance_loss) inside
  2. Distillation Strategy Updates (distillation_utils.py & train_distill.py):
    • Upgraded DistillationForwardOutput to carry the collected moe_lb_loss.
    • Updated CombinedDistillationStrategy to actively add the moe_lb_loss to the total_loss so the optimizer
      minimizes it.
    • Surfaced "distill/moe_lb_loss" to the metrics dictionary for TensorBoard logging and visibility.
  3. Model Mutability (models.py):
    • Automatically appended "intermediates" to the mutable_collections list during the Transformer's forward pass
      whenever load_balance_loss_weight > 0.0 to ensure NNX variables can successfully write to the state.

Tests

Added "distill/moe_lb_loss" to the expected metrics keys in the test suite to prevent regressions in train_distill_test.py.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 16, 2026

@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch from 7eaad4e to 69341c6 Compare April 16, 2026 00:23
@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch from 69341c6 to 414e2f0 Compare April 16, 2026 00:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant