Skip to content

feat: Add linear CE loss fusion for DPO#2139

Open
pengdurice wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
pengdurice:peng-add-linear-ce-fusion-v2
Open

feat: Add linear CE loss fusion for DPO#2139
pengdurice wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
pengdurice:peng-add-linear-ce-fusion-v2

Conversation

@pengdurice
Copy link
Contributor

@pengdurice pengdurice commented Mar 22, 2026

What does this PR do ?

Suport Linear CE Loss Fusion for DPO
On top of #2036 where Linear CE loss fusion support is added for SFT. This PR adds the support to DPO loss.

Optimizations

Chunked Linear Cross-Entropy Fusion Loss

During standard DPO training the model materializes a full logit tensor of shape [batch_size, seq_length, vocab_size] (up to parallelism) for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The chunked linear cross-entropy fusion loss avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk.

Benefits:

  • Extends the maximum trainable sequence length significantly by eliminating the large logit tensor from GPU memory.
  • Applies to both the training forward-backward pass and the reference model logprob computation.
  • Produces numerically equivalent loss values to the standard path.

Issues

NA

Tests

  1. Unit tests passed
  2. Local tests (running the sh file) passed
  3. Extended context window from <30K to 40K. Even larger context window triggers OOM in another place which will be fixed in another PR.
  4. A side by side loss curve comparison between baseline and experiment(linear ce loss fusion enabled)
image

Usage

# add the following to your dpo.yaml file.
megatron_cfg:
    enabled: true
    use_linear_ce_fusion_loss: true
    linear_ce_fusion_chunk_size: 256 # or other numbers 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • Documentation

    • Added DPO training optimization guide section
    • Updated SFT training documentation to reference optimization
  • New Features

    • Added example DPO configuration for Qwen2.5-Math-7B model
    • Enabled chunked linear cross-entropy loss support for DPO training
  • Tests

    • Added nightly test suite for DPO with chunked linear cross-entropy loss
    • Added unit test validating chunked linear cross-entropy loss behavior

…the loss values being nearly identical between base and exp.

Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 22, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: pengdurice <pengduhit@gmail.com>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 23, 2026
Signed-off-by: pengdurice <pengduhit@gmail.com>
@pengdurice pengdurice changed the title Add linear CE loss fusion for DPO feat: Add linear CE loss fusion for DPO Mar 23, 2026
@pengdurice pengdurice marked this pull request as ready for review March 23, 2026 18:19
@pengdurice pengdurice requested review from a team as code owners March 23, 2026 18:19
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 23, 2026

📝 Walkthrough

Walkthrough

This PR extends DPO training support for chunked linear cross-entropy fusion loss optimization. It adds documentation describing the feature, introduces Megatron configuration flags to control activation, threads the use_linear_ce_fusion parameter through loss functions and post-processors, adds an example recipe configuration for Qwen2.5-Math-7B, and includes integration tests.

Changes

Cohort / File(s) Summary
Documentation
docs/guides/dpo.md, docs/guides/sft.md
Added "Optimizations" section to DPO guide detailing chunked linear CE fusion loss, including configuration flags (enabled, use_linear_ce_fusion_loss, linear_ce_fusion_chunk_size) and constraints. Updated SFT guide note to clarify optimization applicability to both SFT and DPO training.
Configuration Files
examples/configs/dpo.yaml, examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml
Added Megatron loss flags (use_linear_ce_fusion_loss: false, linear_ce_fusion_chunk_size: 256) to base DPO config. Created new recipe file with tensor/pipeline parallelism settings (tensor_model_parallel_size: 4, pipeline_model_parallel_size: 2), fusion loss enabled with chunk size 128, MOE router settings, and W&B/TensorBoard/MLflow logging configuration.
Loss Function & Post-Processor
nemo_rl/algorithms/loss/loss_functions.py, nemo_rl/models/megatron/train.py
Updated DPOLossFn.__init__ to accept optional use_linear_ce_fusion parameter and pass it to underlying NLLLossFn. Modified LogprobsPostProcessor.__init__ to accept use_linear_ce_fusion flag and branched post-processing logic: when enabled, directly computes token logprobs via torch.float32 cast and slicing; when disabled, preserves original tensor-parallel and packed/unpacked sequence paths.
Feature Wiring
nemo_rl/algorithms/dpo.py, nemo_rl/models/policy/workers/megatron_policy_worker.py
Extended DPOAlgorithm.setup() to extract use_linear_ce_fusion flag from policy config and pass to DPOLossFn. Updated MegatronPolicyWorkerImpl.get_logprobs() to read use_linear_ce_fusion_loss from megatron_cfg and thread it into both LogprobsPostProcessor and megatron_forward_backward().
Tests
tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh, tests/test_suites/nightly.txt, tests/unit/models/policy/test_megatron_worker.py
Added new nightly integration test script for DPO training with linear CE fusion enabled on Qwen2.5-Math-7B. Created unit test test_megatron_dpo_linear_ce_fusion_agreement comparing standard DPO loss vs. fusion-enabled loss with numerical tolerance (rtol=1e-2, atol=1e-2) to validate correctness of the optimized path.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • PR #2036: Provides foundational implementations of chunked hidden-states→logprobs autograd path, GPT forward patch, and NLLLinearCEFusionLoss that this PR depends on to enable the full fusion optimization flow.

Suggested reviewers

  • ananthsub
  • yaoyu-33
  • yuki-97
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Add linear CE loss fusion for DPO' directly and clearly summarizes the main change—adding Linear Cross-Entropy loss fusion support to DPO training, matching the PR's primary objective.
Test Results For Major Changes ✅ Passed PR includes documented test results: unit tests passed, functional tests added to nightly suite, and loss curve comparison shows no regression in convergence.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/unit/models/policy/test_megatron_worker.py (1)

1990-2056: Exercise the reference-policy logprob path in this agreement test.

This test injects reference_policy_logprobs and keeps init_reference_model=False, so it only compares the actor-side fused loss. A regression in the newly wired reference-model get_logprobs() path would still pass here. Please derive the reference logprobs from the policy, or add a small companion assertion that does.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/models/policy/test_megatron_worker.py` around lines 1990 - 2056,
The test injects reference_policy_logprobs while creating Policy with
init_reference_model=False, so the reference-model logprob path
(Policy.get_logprobs / reference model wiring) is not exercised; fix by deriving
reference_policy_logprobs from the policy under test (call Policy.get_logprobs
or the actor/reference logprob method on policy_std/policy_fuse) instead of
using torch.randn, or add a small assertion that compares the injected
reference_policy_logprobs to policy.get_logprobs(...) output (using the same
input_ids/attention_mask/token_mask) to ensure the reference-model logprob path
is exercised; update the test around the reference_policy_logprobs creation and
where Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs
come from or are validated against the policy's get_logprobs method.
examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml (1)

22-24: MoE-specific settings may be unnecessary for this model.

Qwen/Qwen2.5-Math-7B is a dense (non-MoE) model. The MoE-related settings (freeze_moe_router, moe_router_bias_update_rate, moe_permute_fusion) will likely be ignored but add noise to the config. Consider removing them if they were copied from an MoE recipe.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml`
around lines 22 - 24, Remove the unnecessary MoE-specific config keys from this
dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and
moe_permute_fusion in the
llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since
Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence,
replace with explicit comments or defaults in the recipe loader rather than
leaving these MoE flags in the dense model config.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/guides/dpo.md`:
- Around line 212-220: Update the YAML snippet so enabling Megatron is explicit
and disables the other backend: add policy.dtensor_cfg.enabled: false alongside
the policy.megatron_cfg block (or clearly call out to set
policy.dtensor_cfg.enabled to false) so the final snippet contains
policy.megatron_cfg.enabled: true and policy.dtensor_cfg.enabled: false,
ensuring the switch activates Megatron cleanly without leaving dtensor enabled.

---

Nitpick comments:
In
`@examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml`:
- Around line 22-24: Remove the unnecessary MoE-specific config keys from this
dense model recipe: delete freeze_moe_router, moe_router_bias_update_rate, and
moe_permute_fusion in the
llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml since
Qwen2.5-Math-7B is not MoE; if any higher-level code relies on their presence,
replace with explicit comments or defaults in the recipe loader rather than
leaving these MoE flags in the dense model config.

In `@tests/unit/models/policy/test_megatron_worker.py`:
- Around line 1990-2056: The test injects reference_policy_logprobs while
creating Policy with init_reference_model=False, so the reference-model logprob
path (Policy.get_logprobs / reference model wiring) is not exercised; fix by
deriving reference_policy_logprobs from the policy under test (call
Policy.get_logprobs or the actor/reference logprob method on
policy_std/policy_fuse) instead of using torch.randn, or add a small assertion
that compares the injected reference_policy_logprobs to policy.get_logprobs(...)
output (using the same input_ids/attention_mask/token_mask) to ensure the
reference-model logprob path is exercised; update the test around the
reference_policy_logprobs creation and where
Policy(policy_std)/Policy(policy_fuse) are used so the generated logprobs come
from or are validated against the policy's get_logprobs method.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bcabae4a-06fe-4bd1-b2d7-e0e85ed7b178

📥 Commits

Reviewing files that changed from the base of the PR and between 9feb4b0 and eaeb51b.

📒 Files selected for processing (11)
  • docs/guides/dpo.md
  • docs/guides/sft.md
  • examples/configs/dpo.yaml
  • examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/loss/loss_functions.py
  • nemo_rl/models/megatron/train.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh
  • tests/test_suites/nightly.txt
  • tests/unit/models/policy/test_megatron_worker.py

Comment on lines +212 to +220
Add the following to your Megatron config in your YAML file:

```yaml
policy:
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Show the backend switch in the YAML snippet.

examples/configs/dpo.yaml keeps policy.dtensor_cfg.enabled: true by default, so copying only this block can leave both backends enabled. Please include policy.dtensor_cfg.enabled: false here, or call it out explicitly, so the enablement instructions switch to Megatron cleanly.

✏️ Suggested doc fix
 policy:
+  dtensor_cfg:
+    enabled: false
   megatron_cfg:
     enabled: true
     use_linear_ce_fusion_loss: true
     linear_ce_fusion_chunk_size: 256  # tokens per chunk; smaller = less memory, larger = more throughput
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/guides/dpo.md` around lines 212 - 220, Update the YAML snippet so
enabling Megatron is explicit and disables the other backend: add
policy.dtensor_cfg.enabled: false alongside the policy.megatron_cfg block (or
clearly call out to set policy.dtensor_cfg.enabled to false) so the final
snippet contains policy.megatron_cfg.enabled: true and
policy.dtensor_cfg.enabled: false, ensuring the switch activates Megatron
cleanly without leaving dtensor enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants