Skip to content

Commit 696d251

Browse files
yeyu-nvidiaclaude
andcommitted
Address remaining PR review feedback
- Wrap reference forward pass in try/finally so LoRA adapters are always re-enabled even if the forward throws (prevents silent training with permanently disabled LoRA on subsequent calls) - Replace assert with raise ValueError for eagle_offline compatibility check so it cannot be silently optimized away with python -O; update test to expect ValueError instead of AssertionError Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 56f459f commit 696d251

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,8 @@ def modify(
648648

649649
# Inject HF PEFT LoRA adapters into the base model for co-training
650650
if self.eagle_base_lora:
651-
assert not self.eagle_offline, "eagle_base_lora is incompatible with eagle_offline=True"
651+
if self.eagle_offline:
652+
raise ValueError("eagle_base_lora is incompatible with eagle_offline=True")
652653
self._inject_base_lora()
653654

654655
# delete base model layers for offline training
@@ -818,10 +819,12 @@ def _run_forward(no_grad):
818819
ref_logits = None
819820
if self.eagle_base_lora:
820821
self._set_base_lora_enabled(False)
821-
ref_logits = _run_forward(no_grad=True).logits
822-
if hasattr(self, "_aux_hidden_states"):
823-
self._aux_hidden_states.clear()
824-
self._set_base_lora_enabled(True)
822+
try:
823+
ref_logits = _run_forward(no_grad=True).logits
824+
finally:
825+
if hasattr(self, "_aux_hidden_states"):
826+
self._aux_hidden_states.clear()
827+
self._set_base_lora_enabled(True)
825828

826829
# Main forward — LoRA params receive gradients when eagle_base_lora is True.
827830
outputs = _run_forward(no_grad=freeze_base_model and not self.eagle_base_lora)

tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_eagle_offline_incompatible():
8686
model = get_tiny_llama(num_hidden_layers=4)
8787
config = deepcopy(EAGLE_LORA_CONFIG)
8888
config["eagle_offline"] = True
89-
with pytest.raises(AssertionError, match="eagle_base_lora is incompatible with eagle_offline"):
89+
with pytest.raises(ValueError, match="eagle_base_lora is incompatible with eagle_offline"):
9090
mtsp.convert(model, mode=[("eagle", config)])
9191

9292

0 commit comments

Comments
 (0)