feat: Add linear CE loss fusion for DPO#2139
feat: Add linear CE loss fusion for DPO#2139pengdurice wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
Conversation
…the loss values being nearly identical between base and exp. Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
📝 WalkthroughWalkthroughThis PR extends DPO training support for chunked linear cross-entropy fusion loss optimization. It adds documentation describing the feature, introduces Megatron configuration flags to control activation, threads the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unit/models/policy/test_megatron_worker.py (1)
1990-2056: Exercise the reference-policy logprob path in this agreement test.This test injects
reference_policy_logprobsand keepsinit_reference_model=False, so it only compares the actor-side fused loss. A regression in the newly wired reference-modelget_logprobs()path would still pass here. Please derive the reference logprobs from the policy, or add a small companion assertion that does.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/models/policy/test_megatron_worker.py` around lines 1990 - 2056, The test injects reference_policy_logprobs while creating Policy with init_reference_model=False, so the reference-model logprob path (Policy.get_logprobs / reference model wiring) is not exercised; fix by deriving reference_policy_logprobs from the policy under test (call Policy.get_logprobs or the actor/reference logprob method on policy_std/policy_fuse) instead of using torch.randn, or add a small assertion that compares the injected reference_policy_logprobs to policy.get_logprobs(...) output (using the same input_ids/attention_mask/token_mask) to ensure the reference-model logprob path is exercised; update the test around the reference_policy_logprobs creation and where Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs come from or are validated against the policy's get_logprobs method.examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml (1)
22-24: MoE-specific settings may be unnecessary for this model.
Qwen/Qwen2.5-Math-7Bis a dense (non-MoE) model. The MoE-related settings (freeze_moe_router,moe_router_bias_update_rate,moe_permute_fusion) will likely be ignored but add noise to the config. Consider removing them if they were copied from an MoE recipe.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml` around lines 22 - 24, Remove the unnecessary MoE-specific config keys from this dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and moe_permute_fusion in the llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence, replace with explicit comments or defaults in the recipe loader rather than leaving these MoE flags in the dense model config.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/guides/dpo.md`:
- Around line 212-220: Update the YAML snippet so enabling Megatron is explicit
and disables the other backend: add policy.dtensor_cfg.enabled: false alongside
the policy.megatron_cfg block (or clearly call out to set
policy.dtensor_cfg.enabled to false) so the final snippet contains
policy.megatron_cfg.enabled: true and policy.dtensor_cfg.enabled: false,
ensuring the switch activates Megatron cleanly without leaving dtensor enabled.
---
Nitpick comments:
In
`@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml`:
- Around line 22-24: Remove the unnecessary MoE-specific config keys from this
dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and
moe_permute_fusion in the
llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since
Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence,
replace with explicit comments or defaults in the recipe loader rather than
leaving these MoE flags in the dense model config.
In `@tests/unit/models/policy/test_megatron_worker.py`:
- Around line 1990-2056: The test injects reference_policy_logprobs while
creating Policy with init_reference_model=False, so the reference-model logprob
path (Policy.get_logprobs / reference model wiring) is not exercised; fix by
deriving reference_policy_logprobs from the policy under test (call
Policy.get_logprobs or the actor/reference logprob method on
policy_std/policy_fuse) instead of using torch.randn, or add a small assertion
that compares the injected reference_policy_logprobs to policy.get_logprobs(...)
output (using the same input_ids/attention_mask/token_mask) to ensure the
reference-model logprob path is exercised; update the test around the
reference_policy_logprobs creation and where
Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs come
from or are validated against the policy's get_logprobs method.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bcabae4a-06fe-4bd1-b2d7-e0e85ed7b178
📒 Files selected for processing (11)
docs/guides/dpo.mddocs/guides/sft.mdexamples/configs/dpo.yamlexamples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yamlnemo_rl/algorithms/dpo.pynemo_rl/algorithms/loss/loss_functions.pynemo_rl/models/megatron/train.pynemo_rl/models/policy/workers/megatron_policy_worker.pytests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.shtests/test_suites/nightly.txttests/unit/models/policy/test_megatron_worker.py
| Add the following to your Megatron config in your YAML file: | ||
|
|
||
| ```yaml | ||
| policy: | ||
| megatron_cfg: | ||
| enabled: true | ||
| use_linear_ce_fusion_loss: true | ||
| linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput | ||
| ``` |
There was a problem hiding this comment.
Show the backend switch in the YAML snippet.
examples/configs/dpo.yaml keeps policy.dtensor_cfg.enabled: true by default, so copying only this block can leave both backends enabled. Please include policy.dtensor_cfg.enabled: false here, or call it out explicitly, so the enablement instructions switch to Megatron cleanly.
✏️ Suggested doc fix
policy:
+ dtensor_cfg:
+ enabled: false
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/guides/dpo.md` around lines 212 - 220, Update the YAML snippet so
enabling Megatron is explicit and disables the other backend: add
policy.dtensor_cfg.enabled: false alongside the policy.megatron_cfg block (or
clearly call out to set policy.dtensor_cfg.enabled to false) so the final
snippet contains policy.megatron_cfg.enabled: true and
policy.dtensor_cfg.enabled: false, ensuring the switch activates Megatron
cleanly without leaving dtensor enabled.
What does this PR do ?
Suport Linear CE Loss Fusion for DPO
On top of #2036 where Linear CE loss fusion support is added for SFT. This PR adds the support to DPO loss.
Optimizations
Chunked Linear Cross-Entropy Fusion Loss
During standard DPO training the model materializes a full logit tensor of shape
[batch_size, seq_length, vocab_size](up to parallelism) for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The chunked linear cross-entropy fusion loss avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk.Benefits:
Issues
NA
Tests
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
Documentation
New Features
Tests