Skip to content

Add LoRA co-training support for HF EAGLE speculative decoding#1060

Open
yeyu-nvidia wants to merge 47 commits intomainfrom
yeyu/speculative-lora-cotrain
Open

Add LoRA co-training support for HF EAGLE speculative decoding#1060
yeyu-nvidia wants to merge 47 commits intomainfrom
yeyu/speculative-lora-cotrain

Conversation

@yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Mar 17, 2026

● ### What does this PR do?

Type of change: New feature

Adds LoRA co-training support for HF EAGLE speculative decoding. When
eagle_base_lora=True, HF PEFT LoRA adapters are injected into the base
model and co-trained alongside the EAGLE draft module in a single online
training pass. A preservation loss (KL divergence between the original
frozen base model output and the LoRA-adapted output) is added to prevent
the base model from drifting during training. LoRA adapter weights are
exported separately alongside the EAGLE draft model artifacts.

Usage

import modelopt.torch.speculative as mtsp

# Convert model to EAGLE with LoRA co-training enabled
mtsp.convert(model, mode=[("eagle", {
    "eagle_architecture_config": eagle_arch_cfg,
    "eagle_base_lora": True,
    "eagle_base_lora_rank": 64,
    "eagle_base_lora_alpha": 16.0,
    "eagle_base_lora_target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "eagle_base_lora_preservation_loss_weight": 0.1,
})])

# Train as usual — LoRA params and eagle_module params are trainable,
# base model weights are frozen. Total loss = eagle_loss + preservation_loss.
output = model(input_ids=input_ids, labels=labels)
output.loss.backward()

# Export: eagle draft weights + LoRA adapter weights saved separately
model.get_exporter().export("./export_dir")
# export_dir/
#   model.safetensors            <- EAGLE draft module
#   config.json                  <- EAGLE config
#   lora_adapter_model.safetensors  <- LoRA adapter weights
#   lora_adapter_config.json        <- LoRA config (rank, alpha, target_modules)

Testing

Added tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py with 5 unit tests:
- test_lora_layers_injectedverifies LoRA layers are present in the base model after conversion
- test_trainable_paramsverifies only lora_* and eagle_module params have requires_grad=True
- test_forward_returns_lossverifies the forward pass returns a non-zero scalar loss
- test_eagle_offline_incompatibleverifies eagle_base_lora=True + eagle_offline=True raises an error
- test_export_lora_artifactsverifies export() produces the expected LoRA files

Before your PR is "Ready for review"

- Is this change backward compatible?: ✅ — new config fields all have defaults; existing EAGLE workflows are unaffected.
- If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ — uses peft>=0.17.0 which is already listed in the [hf] optional extra.
- Did you write any new necessary tests?: ✅
- Did you update Changelog?: ✅

### Bug fix (latest commit)

Fixed a `case` pattern ordering bug in `launch_train.sh`: the glob
`--eagle_base_lora*` was listed before the specific patterns
(`--eagle_base_lora_rank*`, `--eagle_base_lora_alpha*`, etc.). Since bash
`case` uses first-match-wins, passing any specific LoRA arg (e.g.
`--eagle_base_lora_rank 64`) would silently overwrite `EAGLE_BASE_LORA`
instead of the intended variable, causing the `"True"` check to fail and
disabling LoRA co-training entirely. Fixed by placing specific patterns
before the general one.

Additional Information

This feature is intended for online HF training only (eagle_offline=True is explicitly blocked). The LoRA adapters are applied to the base model via peft.inject_adapter_in_model (in-place, no wrapper), keeping the existing HFEagleModel structure intact.

@yeyu-nvidia yeyu-nvidia requested review from a team as code owners March 17, 2026 18:56
@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu March 17, 2026 18:56
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds configurable PEFT LoRA support for the EAGLE base model: new config fields, runtime LoRA adapter injection/toggling, a KL-based preservation loss, export of LoRA artifacts, dependency addition, and unit tests covering injection, training constraints, loss, incompatibility, and export.

Changes

Cohort / File(s) Summary
Dependency
examples/speculative_decoding/requirements.txt
Adds peft>=0.17.0.
Config
modelopt/torch/speculative/config.py
Adds five EagleConfig fields for base-model LoRA: eagle_base_lora, eagle_base_lora_rank, eagle_base_lora_alpha, eagle_base_lora_target_modules, eagle_base_lora_preservation_loss_weight.
Runtime LoRA integration
modelopt/torch/speculative/plugins/transformers.py
Injects/manages PEFT LoRA adapters: adds _inject_base_lora, _set_base_lora_enabled, _preservation_loss; updates modify() and _base_model_forward() to run LoRA-disabled reference forward and LoRA-enabled forward and compute preservation loss.
Eagle model wiring
modelopt/torch/speculative/eagle/eagle_model.py
Assigns new LoRA-related config fields to EagleModel attributes during modify().
Export
modelopt/torch/export/plugins/hf_spec_export.py
Adds _export_lora() to export LoRA adapter weights (lora_adapter_model.safetensors) and lora_adapter_config.json; adjusts state-key validation and integrates LoRA export into the export flow when eagle_base_lora is present.
Tests
tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py
New test suite and fixture lora_eagle_model: verifies LoRA injection, trainable-parameter freezing, forward loss, eagle_offline incompatibility, and export artifacts (model.safetensors, lora_adapter_model.safetensors, lora_adapter_config.json).
Other runtime change
modelopt/torch/speculative/utils.py
enable_cp_ttt_patch now calls sdpa_kernel with backends list [CUDNN_ATTENTION, MATH] instead of a single backend.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant HFEagleModel
    participant BaseModel
    participant LoRAAdapter
    participant Exporter

    Trainer->>HFEagleModel: init(config with eagle_base_lora)
    HFEagleModel->>LoRAAdapter: _inject_base_lora() (create LoraConfig, inject adapters)
    LoRAAdapter-->>HFEagleModel: adapters installed, LoRA params unfrozen

    Trainer->>HFEagleModel: training step
    HFEagleModel->>BaseModel: _set_base_lora_enabled(False)
    HFEagleModel->>BaseModel: forward -> ref_logits
    BaseModel-->>HFEagleModel: ref_logits

    HFEagleModel->>BaseModel: _set_base_lora_enabled(True)
    HFEagleModel->>BaseModel: forward -> lora_logits, hidden_states
    BaseModel-->>HFEagleModel: lora_logits, hidden_states

    HFEagleModel->>HFEagleModel: _preservation_loss(ref_logits, lora_logits)
    HFEagleModel-->>Trainer: combined loss

    Trainer->>Exporter: export request
    Exporter->>LoRAAdapter: export weights -> `lora_adapter_model.safetensors`
    Exporter->>Exporter: write `lora_adapter_config.json`
    Exporter-->>Trainer: exported artifacts
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add LoRA co-training support for HF EAGLE speculative decoding' directly and accurately summarizes the main change—introducing LoRA co-training capability for Hugging Face EAGLE speculative decoding models.
Docstring Coverage ✅ Passed Docstring coverage is 84.21% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR introduces no unsafe torch.load, numpy.load, or hardcoded trust_remote_code settings. No eval/exec calls or nosec comments added. PEFT dependency is Apache-2.0 licensed.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/speculative-lora-cotrain

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
modelopt/torch/speculative/config.py (1)

122-145: Strengthen LoRA config typing and value bounds.

Consider constraining invalid user input at config parse time (e.g., rank <= 0, negative preservation weight) and avoid mutable list defaults.

Proposed patch
-    eagle_base_lora_rank: int = ModeloptField(
+    eagle_base_lora_rank: int = ModeloptField(
         default=64,
+        ge=1,
         description="LoRA rank for the base model adapters.",
     )

     eagle_base_lora_alpha: float = ModeloptField(
         default=16.0,
+        gt=0.0,
         description="LoRA alpha (scaling) for the base model adapters.",
     )

-    eagle_base_lora_target_modules: list = ModeloptField(
-        default=[],
+    eagle_base_lora_target_modules: tuple[str, ...] = ModeloptField(
+        default=(),
         description=(
             "List of module name patterns to apply LoRA to in the base model "
             "(e.g. ['q_proj', 'v_proj']). Empty list uses peft defaults."
         ),
     )

     eagle_base_lora_preservation_loss_weight: float = ModeloptField(
         default=0.1,
+        ge=0.0,
         description=(
             "Weight for the preservation loss that minimizes the KL divergence between "
             "the LoRA-adapted base model output and the original base model output."
         ),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 122 - 145, The config
fields eagle_base_lora_rank, eagle_base_lora_alpha,
eagle_base_lora_target_modules, and eagle_base_lora_preservation_loss_weight use
permissive types and a mutable list default; update their ModeloptField
definitions to enforce proper typing and validate bounds at parse/validation
time: require eagle_base_lora_rank to be an int > 0, eagle_base_lora_alpha to be
a float >= 0, eagle_base_lora_preservation_loss_weight to be a float >= 0
(reject negatives), and replace the mutable default for
eagle_base_lora_target_modules with an immutable default (e.g., None or tuple)
and coerce/validate it into a list of strings; implement these checks using the
config validation hook or the ModeloptField's validator callbacks so invalid
inputs raise a clear parsing error.
tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py (1)

93-100: Strengthen export test by validating LoRA config contents, not just file existence.

Existence checks can pass with malformed config. Assert expected r, lora_alpha, and target_modules in lora_adapter_config.json.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py` around
lines 93 - 100, In test_export_lora_artifacts, after exporting via
lora_eagle_model.get_exporter().export(export_dir), open and parse export_dir /
"lora_adapter_config.json" as JSON and assert the config contains the keys "r",
"lora_alpha", and "target_modules"; further validate that "r" and "lora_alpha"
are positive integers and that "target_modules" is a non-empty list of strings
(or matches the expected module names for this model), so the test checks
semantic correctness not just file existence.
modelopt/torch/speculative/plugins/transformers.py (1)

581-582: Prefer F.kl_div for preservation loss clarity/stability.

Current expression is a manual cross-entropy form; F.kl_div makes intent explicit and is less error-prone to maintain.

Proposed patch
+import torch.nn.functional as F
...
-        loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits)
-        return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight
+        ref_prob = F.softmax(ref_logits.detach(), dim=-1)
+        lora_log_prob = F.log_softmax(lora_logits, dim=-1)
+        kl = F.kl_div(lora_log_prob, ref_prob, reduction="batchmean")
+        return kl * self.eagle_base_lora_preservation_loss_weight
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 581 - 582,
Replace the manual cross-entropy-like expression with torch.nn.functional.kl_div
to make intent explicit and numerically stable: compute log-probs from
lora_logits with F.log_softmax, compute target probs from ref_logits.detach()
with F.softmax, call F.kl_div(log_probs, target_probs, reduction='none'), sum
over the last dim, take the mean, and multiply by
self.eagle_base_lora_preservation_loss_weight (no leading negative). Update the
expression that currently uses nn.Softmax/nn.LogSoftmax and returns
-loss.sum(...).mean()*self.eagle_base_lora_preservation_loss_weight to the
F.kl_div-based sequence using the same dimensions and detachment of ref_logits.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 191-193: The export currently builds lora_sd = {k: v for k, v in
full_sd.items() if "lora_A" in k or "lora_B" in k} and calls save_file(...) even
if lora_sd is empty; add a guard after constructing lora_sd in the
hf_spec_export export routine to fail fast: check if lora_sd is empty and if so,
raise a clear exception (or call processLogger.error and raise RuntimeError)
indicating no LoRA tensors found instead of writing an empty file; reference the
lora_sd variable, full_sd source, save_file call and
export_dir/"lora_adapter_model.safetensors" target so the change is applied in
the right spot.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 812-818: The block that disables LoRA adapters using
self._set_base_lora_enabled(False) before calling _run_forward can leave
adapters disabled if _run_forward raises; wrap the reference forward in a
try/finally so that self._set_base_lora_enabled(True) always executes, still
clearing self._aux_hidden_states when present and returning/using ref_logits
from _run_forward; specifically, call _set_base_lora_enabled(False), run
ref_logits = _run_forward(no_grad=True).logits inside try, then in finally
re-enable via _set_base_lora_enabled(True) and clear self._aux_hidden_states if
present.
- Around line 648-650: The code currently uses an assert to enforce that
eagle_base_lora and eagle_offline are not both set (in the block that calls
self._inject_base_lora()); replace the assert with an explicit runtime exception
(e.g., raise ValueError or RuntimeError) so the check always runs in production.
Locate the conditional that checks self.eagle_base_lora and the incompatible
flag self.eagle_offline, and throw a clear exception with a descriptive message
instead of using assert before calling self._inject_base_lora().

---

Nitpick comments:
In `@modelopt/torch/speculative/config.py`:
- Around line 122-145: The config fields eagle_base_lora_rank,
eagle_base_lora_alpha, eagle_base_lora_target_modules, and
eagle_base_lora_preservation_loss_weight use permissive types and a mutable list
default; update their ModeloptField definitions to enforce proper typing and
validate bounds at parse/validation time: require eagle_base_lora_rank to be an
int > 0, eagle_base_lora_alpha to be a float >= 0,
eagle_base_lora_preservation_loss_weight to be a float >= 0 (reject negatives),
and replace the mutable default for eagle_base_lora_target_modules with an
immutable default (e.g., None or tuple) and coerce/validate it into a list of
strings; implement these checks using the config validation hook or the
ModeloptField's validator callbacks so invalid inputs raise a clear parsing
error.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 581-582: Replace the manual cross-entropy-like expression with
torch.nn.functional.kl_div to make intent explicit and numerically stable:
compute log-probs from lora_logits with F.log_softmax, compute target probs from
ref_logits.detach() with F.softmax, call F.kl_div(log_probs, target_probs,
reduction='none'), sum over the last dim, take the mean, and multiply by
self.eagle_base_lora_preservation_loss_weight (no leading negative). Update the
expression that currently uses nn.Softmax/nn.LogSoftmax and returns
-loss.sum(...).mean()*self.eagle_base_lora_preservation_loss_weight to the
F.kl_div-based sequence using the same dimensions and detachment of ref_logits.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py`:
- Around line 93-100: In test_export_lora_artifacts, after exporting via
lora_eagle_model.get_exporter().export(export_dir), open and parse export_dir /
"lora_adapter_config.json" as JSON and assert the config contains the keys "r",
"lora_alpha", and "target_modules"; further validate that "r" and "lora_alpha"
are positive integers and that "target_modules" is a non-empty list of strings
(or matches the expected module names for this model), so the test checks
semantic correctness not just file existence.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 66fcf495-7fbb-405d-9f5d-1206155ab766

📥 Commits

Reviewing files that changed from the base of the PR and between 00fa5bd and ebdbf651b7b5a00b143e9628b22c4be81338e5ea.

📒 Files selected for processing (5)
  • examples/speculative_decoding/requirements.txt
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py

@h-guo18
Copy link
Contributor

h-guo18 commented Mar 17, 2026

How would the base model quality and AL looks like with this lora cotraining?

@yeyu-nvidia
Copy link
Contributor Author

How would the base model quality and AL looks like with this lora cotraining?

Haven't tested. will report later

@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/speculative-lora-cotrain branch from bcee9bd to d933f42 Compare March 18, 2026 16:21
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/speculative/plugins/transformers.py (2)

647-650: ⚠️ Potential issue | 🟠 Major

Replace assert with explicit exception for runtime config validation.

assert statements can be optimized out with -O flag. Use an explicit exception to ensure the incompatibility check always executes.

Proposed patch
         # Inject HF PEFT LoRA adapters into the base model for co-training
         if self.eagle_base_lora:
-            assert not self.eagle_offline, "eagle_base_lora is incompatible with eagle_offline=True"
+            if self.eagle_offline:
+                raise ValueError("eagle_base_lora is incompatible with eagle_offline=True")
             self._inject_base_lora()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 647 - 650,
Replace the runtime config assertion in the block that injects HF PEFT LoRA
adapters: instead of using assert not self.eagle_offline, raise an explicit
exception (e.g., raise RuntimeError or ValueError) when self.eagle_base_lora is
true and self.eagle_offline is true so the check always runs; keep the existing
call to self._inject_base_lora() when the check passes and reference the same
symbols (self.eagle_base_lora, self.eagle_offline, self._inject_base_lora) when
making the change.

811-817: ⚠️ Potential issue | 🟠 Major

Wrap reference forward in try/finally to ensure LoRA re-enablement.

If _run_forward raises an exception, LoRA adapters remain disabled, causing incorrect behavior in subsequent training steps.

Proposed patch
         ref_logits = None
         if self.eagle_base_lora:
             self._set_base_lora_enabled(False)
-            ref_logits = _run_forward(no_grad=True).logits
-            if hasattr(self, "_aux_hidden_states"):
-                self._aux_hidden_states.clear()
-            self._set_base_lora_enabled(True)
+            try:
+                ref_logits = _run_forward(no_grad=True).logits
+                if hasattr(self, "_aux_hidden_states"):
+                    self._aux_hidden_states.clear()
+            finally:
+                self._set_base_lora_enabled(True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 811 - 817,
The block that disables LoRA adapters (guarded by self.eagle_base_lora) should
wrap the call to _run_forward in a try/finally so _set_base_lora_enabled(True)
always runs even if _run_forward raises; keep the current behavior of clearing
self._aux_hidden_states and assigning ref_logits from _run_forward().logits, but
move those operations into the try (or after a successful call) and perform
re-enablement in the finally block to guarantee LoRA is restored.
modelopt/torch/export/plugins/hf_spec_export.py (1)

195-212: ⚠️ Potential issue | 🟠 Major

Missing guard for empty LoRA state dict.

If LoRA injection regresses or filtering fails, lora_sd will be empty, resulting in an empty adapter file being written without error. Add a guard to fail fast.

Proposed patch
     def _export_lora(self, export_dir: Path, full_sd: dict):
         """Export base model LoRA adapter weights alongside the eagle module artifacts."""
         lora_sd = {k: v for k, v in full_sd.items() if "lora_A" in k or "lora_B" in k}
+        if not lora_sd:
+            raise RuntimeError(
+                "No LoRA adapter tensors found in state_dict; refusing to export empty adapter."
+            )
         save_file(lora_sd, export_dir / "lora_adapter_model.safetensors")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 195 - 212, The
_export_lora method currently builds lora_sd and unconditionally writes it and a
config; add a fail-fast guard that checks if lora_sd is empty and raises a clear
exception (e.g., RuntimeError/ValueError) before calling save_file or
constructing/writing the LoraConfig, so you don't create an empty adapter file
if LoRA filtering/injection failed; reference lora_sd, _export_lora, save_file,
LoraConfig and export_dir when adding the check and error.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

574-582: Prefer F.softmax / F.log_softmax over module instantiation.

Creating nn.Softmax and nn.LogSoftmax module instances on each call adds unnecessary overhead. Use the functional API instead.

Proposed patch
+import torch.nn.functional as F
+
     def _preservation_loss(
         self, ref_logits: torch.Tensor, lora_logits: torch.Tensor
     ) -> torch.Tensor:
         """KL divergence encouraging LoRA output to stay close to the original base model.
 
         KL(softmax(ref) || log_softmax(lora)) weighted by eagle_base_lora_preservation_loss_weight.
         """
-        loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits)
+        loss = F.softmax(ref_logits.detach(), dim=-1) * F.log_softmax(lora_logits, dim=-1)
         return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 574 - 582,
In _preservation_loss, avoid instantiating nn.Softmax and nn.LogSoftmax on each
call; use the functional API (torch.nn.functional.softmax and
torch.nn.functional.log_softmax) to compute softmax(ref_logits.detach(), dim=-1)
and log_softmax(lora_logits, dim=-1) respectively, then compute the elementwise
product, sum over vocab dim and return the weighted mean exactly as before using
self.eagle_base_lora_preservation_loss_weight; ensure you keep the
ref_logits.detach() semantics and the dim=-1 argument.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 195-212: The _export_lora method currently builds lora_sd and
unconditionally writes it and a config; add a fail-fast guard that checks if
lora_sd is empty and raises a clear exception (e.g., RuntimeError/ValueError)
before calling save_file or constructing/writing the LoraConfig, so you don't
create an empty adapter file if LoRA filtering/injection failed; reference
lora_sd, _export_lora, save_file, LoraConfig and export_dir when adding the
check and error.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 647-650: Replace the runtime config assertion in the block that
injects HF PEFT LoRA adapters: instead of using assert not self.eagle_offline,
raise an explicit exception (e.g., raise RuntimeError or ValueError) when
self.eagle_base_lora is true and self.eagle_offline is true so the check always
runs; keep the existing call to self._inject_base_lora() when the check passes
and reference the same symbols (self.eagle_base_lora, self.eagle_offline,
self._inject_base_lora) when making the change.
- Around line 811-817: The block that disables LoRA adapters (guarded by
self.eagle_base_lora) should wrap the call to _run_forward in a try/finally so
_set_base_lora_enabled(True) always runs even if _run_forward raises; keep the
current behavior of clearing self._aux_hidden_states and assigning ref_logits
from _run_forward().logits, but move those operations into the try (or after a
successful call) and perform re-enablement in the finally block to guarantee
LoRA is restored.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 574-582: In _preservation_loss, avoid instantiating nn.Softmax and
nn.LogSoftmax on each call; use the functional API (torch.nn.functional.softmax
and torch.nn.functional.log_softmax) to compute softmax(ref_logits.detach(),
dim=-1) and log_softmax(lora_logits, dim=-1) respectively, then compute the
elementwise product, sum over vocab dim and return the weighted mean exactly as
before using self.eagle_base_lora_preservation_loss_weight; ensure you keep the
ref_logits.detach() semantics and the dim=-1 argument.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b4fbd165-c015-4d1e-a7a2-edfbff923aec

📥 Commits

Reviewing files that changed from the base of the PR and between 7fae8b1909338f1976f83480d78c8c9b1315f500 and bcee9bdf4b3d545c4205817abdddd95723591180.

📒 Files selected for processing (2)
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/plugins/transformers.py

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.36%. Comparing base (839fa3d) to head (b5bf776).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1060      +/-   ##
==========================================
+ Coverage   70.30%   70.36%   +0.06%     
==========================================
  Files         227      227              
  Lines       25857    25878      +21     
==========================================
+ Hits        18179    18210      +31     
+ Misses       7678     7668      -10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds LoRA co-training support for HF EAGLE speculative decoding — reasonable feature design with good test coverage. However, peft is now a hard import-time dependency for the entire transformers.py plugin due to top-level imports. This will break anyone importing modelopt.torch.speculative.plugins.transformers without peft installed, even if they don't use LoRA. This needs to be fixed before merging. Several other issues below.

@@ -37,6 +37,9 @@
import torch
import transformers
from packaging.version import Version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: These top-level from peft import ... lines make peft a hard dependency for the entire transformers plugin. Anyone importing this module without peft installed will get an ImportError, even if they don't use LoRA.

These must be lazy imports inside _inject_base_lora, _set_base_lora_enabled, and the export method. For example:

def _inject_base_lora(self):
    from peft import LoraConfig
    from peft.mapping import inject_adapter_in_model
    ...

def _set_base_lora_enabled(self, enabled: bool):
    from peft.tuners.lora import LoraLayer
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5632abab — peft imports moved inside _inject_base_lora() (LoraConfig, inject_adapter_in_model) and _set_base_lora_enabled() (LoraLayer), so peft is no longer a hard top-level dependency.

eagle_base_lora_alpha: float = ModeloptField(
default=16.0,
description="LoRA alpha (scaling) for the base model adapters.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable default: default=[] means all config instances share the same list object. Use default_factory=list or default=None with a note that None uses peft defaults.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5632abab — changed to list | None = ModeloptField(default=None). All existing usages already used or None so nothing breaks.


def _set_base_lora_enabled(self, enabled: bool) -> None:
"""Enable or disable LoRA adapters in the base model."""
for module in self._base_model.modules():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "KL divergence" but this computes cross-entropy: -softmax(ref) * log_softmax(lora). The missing entropy term is constant w.r.t. LoRA params so gradients are correct, but the naming is misleading. Either rename or add a comment clarifying this is KL up to a constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is logit KL divergence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand on this: the implementation computes cross-entropy H(ref, lora) = -softmax(ref) · log_softmax(lora), which equals KL(ref ∥ lora) + H(ref). Since H(ref) is constant w.r.t. LoRA parameters, the gradients are identical to true KL divergence — so the optimization objective is equivalent. The docstring has been updated to clarify this: it now reads "KL(softmax(ref) || log_softmax(lora))" and notes that the entropy term of the reference is constant and dropped.

@@ -723,7 +762,9 @@ def _compute_ttt_attention_mask(
) -> BlockMask | torch.Tensor:
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
msk_func = get_ttt_msk_func(seq_length, ttt_step)
dtypemin = torch.finfo(self._base_llm_config.dtype).min
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference forward with LoRA disabled has no try/finally. If the forward throws, LoRA stays disabled for all subsequent calls:

self._set_base_lora_enabled(False)
try:
    ref_logits = _run_forward(no_grad=True).logits
    if hasattr(self, "_aux_hidden_states"):
        self._aux_hidden_states.clear()
finally:
    self._set_base_lora_enabled(True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6e709b55 — wrapped the reference forward in try/finally so _set_base_lora_enabled(True) is guaranteed to run even if the forward throws.

@@ -610,6 +644,11 @@ def modify(
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert not self.eagle_offline can be optimized out with python -O. Use if self.eagle_offline: raise ValueError(...) for runtime config validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6e709b55 — replaced with if self.eagle_offline: raise ValueError(...). Also updated the test to expect ValueError instead of AssertionError.

aux_only_keys = {"fc", "layers.0.hidden_norm"}
required_keys = set(expected_keys_single_layer["required"])
if not use_aux:
required_keys -= aux_only_keys
# Check that export sd has required keys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRA key filtering ("lora_A" in k or "lora_B" in k) is fragile. PEFT may use other key patterns (lora_embedding_A, lora_magnitude_vector, etc.). Consider using peft's own utilities to identify adapter parameters, or add a warning if zero LoRA tensors are found.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5632abab — tightened the filter to .lora_A. / .lora_B. (dot-bounded) and added a RuntimeError if no LoRA tensors are found.

@@ -467,7 +467,7 @@ def enable_cp_ttt_patch():
import modelopt.torch.speculative.plugins.transformers

modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding SDPBackend.MATH fallback and the dtype getattr changes in transformers.py look unrelated to LoRA co-training. Consider splitting into a separate PR or calling them out in the description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand: the SDPBackend.MATH fallback in enable_cp_ttt_patch() is required for CPU unit tests. CUDNN_ATTENTION is only available on GPU, so without MATH as a fallback the test environment raises an error when no supported SDPA backend is found. This is directly exercised by the LoRA co-training forward pass test (test_forward_returns_loss), which runs the TTT attention path on CPU. The change is scoped to the enable_cp_ttt_patch() context manager and doesn't affect production GPU paths.

@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu March 19, 2026 16:30
yeyu-nvidia and others added 14 commits March 19, 2026 09:38
Introduces eagle_base_lora training mode where HF PEFT LoRA adapters are
injected into the base model and co-trained with the EAGLE draft module.
A preservation loss (KL divergence between original and LoRA-adapted base
model outputs) is added to prevent the base model from drifting during
training. LoRA adapter weights are exported separately alongside the EAGLE
draft model artifacts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The LoRA co-training config fields (eagle_base_lora, eagle_base_lora_rank,
eagle_base_lora_alpha, eagle_base_lora_target_modules,
eagle_base_lora_preservation_loss_weight) were defined in the config but
never assigned in EagleModel.modify(), causing DynamicModule.__getattr__
to raise AttributeError when HFEagleModel.modify() accessed them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Set num_key_value_heads=16 (matching num_attention_heads) to avoid GQA,
which triggers enable_gqa=True in SDPA — unsupported on CPU backends.
Set use_last_layernorm=True so the norm layer is created and norm.weight
is present in the export state dict as required by the export validator.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- enable_cp_ttt_patch: add SDPBackend.MATH alongside CUDNN_ATTENTION so
  the math kernel is available as fallback on CPU (fixes test_forward_returns_loss)
- _check_valid_sd: skip fc/hidden_norm from required keys when
  use_aux_hidden_state=False, as these layers only exist in EAGLE-3
  (fixes test_export_lora_artifacts)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Handle None dtype in _compute_ttt_attention_mask by falling back to
  torch.float32 when HF config.dtype is unset
- Fix JSON serialization of LoRA config by converting set target_modules
  to sorted list in _export_lora

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
peft is not available in all test environments. Move the LoraConfig
import inside _export_lora where it is actually used.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Use getattr with bfloat16 fallback instead of direct attribute access,
which raises AttributeError in transformers 4.53.3.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Use getattr with None default combined with or-fallback to bfloat16,
handling both: attribute missing (tf_min/4.53.3) and attribute present
but None (tf_latest).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Use getattr(..., None) or torch.bfloat16 to handle both absent attribute
(transformers tf_min) and attribute-exists-but-None (tf_latest) cases.

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Use torch.get_default_dtype() as fallback instead of torch.bfloat16 when
config.dtype is None (transformers >= 4.53 sets LlamaConfig.dtype=None).
This prevents RuntimeError from scaled_dot_product_attention when the model
computes in float32 but the attention mask was created as bfloat16.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Move peft imports (LoraConfig, inject_adapter_in_model, LoraLayer) inside
  the methods that use them (_inject_base_lora, _set_base_lora_enabled) so
  peft is not a hard top-level dependency for all speculative decoding users
- Change eagle_base_lora_target_modules default from [] to None to avoid
  mutable default shared across config instances
- Tighten LoRA key filtering from "lora_A" in k to ".lora_A." in k to avoid
  false positives, and add fail-fast RuntimeError when no LoRA tensors found

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Wrap reference forward pass in try/finally so LoRA adapters are always
  re-enabled even if the forward throws (prevents silent training with
  permanently disabled LoRA on subsequent calls)
- Replace assert with raise ValueError for eagle_offline compatibility check
  so it cannot be silently optimized away with python -O; update test to
  expect ValueError instead of AssertionError

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/speculative-lora-cotrain branch from 6e709b5 to 696d251 Compare March 19, 2026 16:38
yeyu-nvidia and others added 3 commits March 19, 2026 11:09
Add LoRA co-training flags to EagleArguments in main.py and wire them
into the eagle config passed to mtsp.convert(). In launch_train.sh,
group all LoRA flags into a single LORA_ARGS variable (built only when
--eagle_base_lora True) following the existing OFFLINE_TRAINING_ARGS /
VLM_ARGS pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…args

The glob pattern --eagle_base_lora* was listed before the more specific
--eagle_base_lora_rank*, --eagle_base_lora_alpha*, etc. patterns. Since
bash case uses first-match-wins, passing any of the specific LoRA args
(e.g. --eagle_base_lora_rank 64) would incorrectly overwrite EAGLE_BASE_LORA
instead of the intended variable, causing the "True" check at line 218 to
fail and silently disabling LoRA co-training.

Fix by reordering so specific patterns are matched before the general one.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
When eagle_base_lora=True, base_model_logits (LoRA output) was passed
un-detached as soft labels to the EAGLE loss. This created a circular
gradient: the EAGLE loss pushed LoRA to match eagle's predictions while
EAGLE was simultaneously trained to match LoRA's output, causing both
to collapse to a degenerate low-entropy solution (loss 19 -> 0.69 in
100 steps, fake AR 3.9).

Detach base_model_logits in the EAGLE soft-label path when LoRA is
active. LoRA still receives meaningful EAGLE gradients through the
hidden-state path (out_hiddens -> eagle_input_hiddens -> EAGLE module
-> EAGLE loss), which is the intended co-training signal for improving
base model representations. Preservation loss continues to prevent
base model drift.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
yeyu-nvidia and others added 30 commits March 19, 2026 13:24
Base model logits are the teacher signal for EAGLE distillation and
should always be treated as fixed targets. Remove the eagle_base_lora
conditional and unconditionally detach, keeping the change minimal.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…raining

The hidden-state-only gradient path (from the previous detach approach)
produced negligible LoRA updates: the gradient attenuates over many base
model layers and the EAGLE module adapts to compensate, leaving LoRA
effectively frozen.

Restore the direct logits gradient path (remove detach) so LoRA receives
a strong learning signal. Raise the default preservation loss weight from
0.1 to 1.0 to balance the two forces — EAGLE loss drives LoRA toward a
more predictable distribution while the preservation loss prevents
degenerate collapse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Detach base_outputs.logits when used as soft labels in the EAGLE loss so
gradients do not flow back to LoRA through the label path (which causes
circular collapse). LoRA still receives EAGLE gradients via the hidden-
state path (out_hiddens -> eagle_input_hiddens).

Add eagle_base_lora_lr_multiplier (default 10x) to compensate for the
weaker hidden-state gradient signal: LoRA parameters are split into a
separate optimizer param group with lr = base_lr * multiplier.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
With the detach fix the collapse risk is gone, so the weight reverts to
its original value. The 10x LoRA LR multiplier already amplifies all
LoRA gradients; keeping weight=1.0 would over-constrain LoRA training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The hidden-state gradient path alone (EAGLE loss -> out_hiddens -> LoRA)
is too weak to produce meaningful LoRA updates in practice. Add a language
modeling loss on the LoRA-adapted base model output as the primary LoRA
training signal (eagle_base_lora_lm_loss_weight, default 1.0).

This avoids the circular collapse seen with the logits-path gradient:
LoRA directly optimizes base model LM quality on the training data, while
EAGLE co-adapts to predict the improved base model's tokens (via detached
logits). The preservation loss remains as KL regularization.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
With LoRA co-training the model carries extra parameters and optimizer
states (LoRA A/B + Adam moments), reducing the headroom available for
the validation forward passes. Call torch.cuda.empty_cache() before
validate_ar() to release unused cached allocations without affecting
any live tensors (parameters, optimizer states, gradients).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
B=0 initialization creates a saddle point where the preservation gradient
is exactly zero at init, allowing the EAGLE logits gradient to dominate
unopposed before preservation can react. Initialize lora_B with N(0, 0.01)
so the preservation loss is active from step 0 and constrains LoRA from
the start.

With preservation active at init, restore the direct logits gradient path
(remove detach on base_outputs.logits in EAGLE loss) to give LoRA a strong
training signal while relying on preservation loss to prevent collapse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…ils)

Non-zero lora_B initialization did not prevent collapse: the EAGLE logits
gradient (~16 loss scale) overwhelms preservation loss (~0.1 * small KL)
before it can react, regardless of B init or preservation weight tried
(0.1, 1.0, 5.0 — all collapse; large weights freeze LoRA at zero instead).

The coupled LoRA-EAGLE logits gradient has no stable non-degenerate fixed
point. Restore the detach so EAGLE loss only trains the EAGLE module.
LoRA trains exclusively via LM loss + preservation loss with the non-zero
B init kept to break the B=0 saddle point for those loss terms.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Remove .clone().detach() from aux hidden states hook when LoRA
co-training is active, so EAGLE loss can backpropagate through
intermediate hidden states to LoRA parameters. This is the primary
gradient path: EAGLE loss → fc → aux_hiddens → base transformer
layers → LoRA weights. The logit path remains detached to prevent
mode collapse.

Also remove LM loss on LoRA (redundant with preservation loss on a
well-trained base model) and reset LoRA LR multiplier default to 1x.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
EAGLE gradient through aux_hiddens breaks the B=0 saddle point,
so non-zero B init is no longer needed. B=0 means LoRA starts as
identity — EAGLE's initial accuracy is unperturbed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Replace full gradient flow with a configurable scale factor
(default 0.01) to prevent the "moving target" problem where
LoRA changes hidden states faster than EAGLE can adapt.

Formula: detached + scale * (live - detached), so LoRA receives
scale * full_gradient — a weak directional signal without
destabilizing EAGLE training.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Replace gradient scaling with GAN-style alternating training:
- Phase A (default 100 steps): freeze LoRA, train EAGLE normally
  with detached aux_hiddens and logits.
- Phase B (default 10 steps): freeze EAGLE, train LoRA with full
  gradient from EAGLE loss (through both logits and hidden states)
  plus preservation loss. Safe because EAGLE is frozen.

This avoids both the "moving target" problem (full gradient +
simultaneous training) and the "invisible LoRA" problem (detached
gradient gives LoRA no useful signal).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Replace 2-phase alternating with 3-phase to avoid competing gradients:
- Phase A (eagle): train EAGLE, LoRA frozen — standard training
- Phase B (lora_eagle): train LoRA via EAGLE loss only, EAGLE frozen
  — LoRA moves hidden states + logits toward EAGLE-friendly
- Phase C (lora_preserve): train LoRA via preservation loss only,
  EAGLE frozen — logits pulled back toward reference

This separates the competing EAGLE and preservation gradients so each
phase has a single clear optimization objective.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
After exporting an EAGLE checkpoint with export_hf_checkpoint.py, this
script loads the exported LoRA adapter weights and merges them into the
original base model for TRT-LLM deployment.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
In Phase C, EAGLE params are frozen and logits are detached, so the
EAGLE loss produces no useful gradients. Early-return after the base
model forward with just the preservation loss.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Each of the 3 alternating training phases now reports its loss
separately, making it easier to diagnose training dynamics.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Phase C (preservation-only) returns empty train_acc lists since EAGLE
forward is skipped. Skip appending these to avoid numpy shape mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
_remove_adapter was removed from peft.utils. Use the public
delete_adapter method on each LoraLayer after merging instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The previous approach using inject_adapter_in_model + manual merge left
LoRA wrapper modules (base_layer) in the saved model, producing invalid
weight keys. Use PeftModel.from_pretrained + merge_and_unload() which
properly unwraps modules back to plain Linear.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The exported LoRA keys use "lora_A.weight" but PeftModel expects
"lora_A.default.weight". Also keep the "model." prefix which PeftModel
needs. The previous version stripped the prefix and didn't add .default,
causing all LoRA weights to be silently missing (merged as zeros).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Move the lora_A.weight -> lora_A.default.weight renaming from
merge_lora.py into the export step so exported files are directly
loadable by PeftModel without any key transformation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…ontinuing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
When target_modules is None (PEFT defaults), PeftModel.from_pretrained
cannot determine which modules have LoRA adapters, causing all keys to
be reported as missing. Infer the actual module names (e.g., q_proj,
v_proj) from the exported LoRA state dict keys.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Keys are model.layers.0.self_attn.q_proj.lora_A.default.weight,
so the module name is at index [-4], not [-3].

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…bility

PeftModel expects adapter safetensors keys to include the full
base_model.model. prefix. Our exported keys start with model. so
we prepend base_model.model. before loading. Also added diagnostic
prints to confirm keys loaded successfully.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Peft 0.18+ no longer expects base_model.model. prefix in adapter
safetensors. Also strip .default. segment from keys since older peft
uses lora_A.weight while newer uses lora_A.default.weight.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Peft 0.18 emits a false "missing adapter keys" warning even when weights
load correctly. Suppress it and instead verify that loaded LoRA weights
are non-zero.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Stripping .default from keys caused peft to not match the weights,
leaving lora_B at zero (init value). Keep exported keys as-is since
peft 0.18 uses .default internally. Verify lora_B norms instead of
lora_A since A gets kaiming-initialized regardless of loading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
PeftModel.from_pretrained silently fails to match exported LoRA keys
due to prefix mismatch. Switch to get_peft_model + explicit key mapping:
auto-detect the prefix between exported keys and PeftModel state dict,
load weights directly, and verify lora_B norms before merging.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants