Skip to content

Commit 414e2f0

Browse files
committed
Add MoE load balancing loss to distillation
1 parent 51c7f2b commit 414e2f0

6 files changed

Lines changed: 43 additions & 3 deletions

File tree

src/maxtext/models/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ def __call__(
505505
mutable_collections.append("intermediates")
506506
if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections:
507507
mutable_collections.append("intermediates")
508+
if self.config.load_balance_loss_weight > 0.0 and "intermediates" not in mutable_collections:
509+
mutable_collections.append("intermediates")
508510

509511
if self.config.pure_nnx_decoder:
510512
logits, hidden_state, kv_caches = self.decoder(

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,7 @@ def __call__(
10551055
# We sow the load balancing loss so it can be collected and added to the total loss
10561056
# during training.
10571057
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
1058-
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
1058+
self.moe_lb_loss = nnx.Intermediate(load_balance_loss)
10591059

10601060
# Final residual connection (after the MoE block)
10611061
layer_output = residual + mlp_output
@@ -1299,7 +1299,7 @@ def __call__(
12991299
mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states)
13001300
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)
13011301
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
1302-
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
1302+
self.moe_lb_loss = nnx.Intermediate(load_balance_loss)
13031303

13041304
layer_output = intermediate_inputs + mlp_lnx
13051305
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class DistillationForwardOutput:
5252
logits: jax.Array
5353
#: out_projection_activations
5454
out_projection_activations: jax.Array | None = None
55+
#: moe load balance loss
56+
moe_lb_loss: jax.Array | None = None
5557

5658

5759
@flax.struct.dataclass(frozen=True)
@@ -373,13 +375,19 @@ def compute_loss(
373375

374376
total_loss = base_logit_loss + feature_loss
375377

378+
moe_lb_loss = jnp.array(0.0)
379+
if student_output.moe_lb_loss is not None:
380+
moe_lb_loss = student_output.moe_lb_loss
381+
total_loss += moe_lb_loss
382+
376383
# 4. Return Loss AND Metrics
377384
metrics = {
378385
"distill/soft_loss": soft_loss,
379386
"distill/hard_loss": hard_loss,
380387
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
381388
"distill/teacher_loss": teacher_hard_loss,
382389
"distill/out_proj_feature_loss": feature_loss,
390+
"distill/moe_lb_loss": moe_lb_loss,
383391
"distill/total_loss": total_loss,
384392
"distill/temperature": self.temperature,
385393
"distill/alpha": self.alpha,

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,14 @@ def model_forward_fn(
150150
if config.distill_beta > 0.0:
151151
out_projection_activations = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True)
152152

153+
moe_lb_loss = None
154+
if config.num_experts > 1 and config.load_balance_loss_weight > 0.0:
155+
total_moe_lb_loss = maxtext_utils.get_intermediate_value(model, "moe_lb_loss", clear=False)
156+
if total_moe_lb_loss is not None:
157+
moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss))
158+
153159
retval = distillation_utils.DistillationForwardOutput(
154-
logits=logits, out_projection_activations=out_projection_activations
160+
logits=logits, out_projection_activations=out_projection_activations, moe_lb_loss=moe_lb_loss
155161
)
156162
return retval
157163

src/maxtext/utils/maxtext_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,29 @@ def get_intermediate_value(model, nested_key, default=None, clear=False):
10811081
intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1]
10821082
if clear:
10831083
del model.decoder.layers["self_attention"][nested_key]
1084+
case "moe_lb_loss":
1085+
# pylint: disable=import-outside-toplevel
1086+
from flax import nnx
1087+
1088+
losses = []
1089+
for path, val in nnx.state(model).flat_state():
1090+
v = val.value if hasattr(val, "value") else val
1091+
if isinstance(v, dict) and "moe_lb_loss" in v:
1092+
losses.append(v["moe_lb_loss"][-1])
1093+
if clear:
1094+
v.pop("moe_lb_loss", None)
1095+
elif "moe_lb_loss" in path:
1096+
losses.append(v[-1] if isinstance(v, (list, tuple)) else v)
1097+
if clear:
1098+
curr = model
1099+
try:
1100+
for p in path[:-1]:
1101+
curr = getattr(curr, p)
1102+
delattr(curr, path[-1])
1103+
except AttributeError:
1104+
pass
1105+
if losses:
1106+
intermediate_value = losses
10841107
case _:
10851108
# Default case to handle any unknown nested keys
10861109
raise ValueError(f"Incorrect nested_key: {nested_key}")

tests/post_training/unit/train_distill_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def _test_monitored_strategy(self, *, sft_mode: bool, feature_loss_type: Literal
437437
"distill/kl_div",
438438
"distill/teacher_loss",
439439
"distill/out_proj_feature_loss",
440+
"distill/moe_lb_loss",
440441
"distill/total_loss",
441442
"distill/temperature",
442443
"distill/alpha",

0 commit comments

Comments
 (0)