Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111Separius wants to merge 1 commit intofeature/puzzletronfrom
Conversation
Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression. Changes: - Add modelopt/torch/puzzletron/bypass_distillation/ module with full training loop, stitched model factory, checkpoint management, and data classes - Integrate bypass as optional Step 3 in puzzletron.py and puzzletron_nas_plugin.py (pipeline progress counter updates to 9 steps when bypass is enabled) - Add HuggingFace auto-download and skip-if-exists logic to puzzletron_nas_plugin.py for all pipeline steps - Add normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss to sewing_kit/utils.py - Fix child_init.py: support list of pruning mixins; fix None override treated as "keep original value" instead of raising TypeCheckError - Fix dataset.py: graceful fallback when tokenizer has no chat_template (base models) - Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling Python files are copied alongside config.json (required for trust_remote_code checkpoints such as NemotronH) - Add create_train_dataloader to dataloaders.py - Add MoEChannelPruning to MlpInitMode enum - Add default pruning_mixins() to ModelDescriptor base class - Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks during subblock_attention bypass (based on block config) - Enable bypass in llama-3_1-8B_pruneffn_memory config; add example bypass/defaults.yaml - Update README with bypass documentation: when to use, time cost, sequential execution, W&B logging - Add unit tests for loss functions and distribution utilities - Add GPU integration tests for bypass (FFN pruning, KV compression, multi-config sweep, checkpoint validation) - Fix test_puzzletron.py assertion to handle variable GPU counts
📝 WalkthroughWalkthroughThis pull request adds a complete bypass distillation (blockwise local distillation) training pipeline as an optional stage to the Puzzletron model optimization framework. It introduces new configuration schemas, orchestration logic, checkpoint utilities, loss computations, stitched model factories for teacher-student learning, distributed training loops with gradient scaling, data loaders, and comprehensive test coverage including GPU integration tests. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Launcher as launch_bypass<br/>_distillation()
participant Orchestrator as run_bypassed<br/>_training()
participant Loader as DataLoader &<br/>Teacher Model
participant Factory as stitched_model<br/>_factory()
participant TrainingLoop as train()
participant Dist as Distributed<br/>Sync & Checkpoints
User->>Launcher: Hydra config
Launcher->>Launcher: For each bypass.config
Launcher->>Orchestrator: cfg with overrides
Orchestrator->>Loader: Load teacher, init student
Orchestrator->>Factory: Create stitched modules & descriptors
Factory->>Factory: Compute block ownership<br/>per rank
Factory-->>Orchestrator: Student + stitched modules
Orchestrator->>Loader: Build streaming dataloader
Orchestrator->>Orchestrator: Resume from checkpoint<br/>(if exists)
Orchestrator->>TrainingLoop: Call train()
loop Per iteration
TrainingLoop->>Loader: Fetch batch
TrainingLoop->>TrainingLoop: Teacher forward pass
loop Per stitched module
TrainingLoop->>TrainingLoop: Student forward + loss
TrainingLoop->>TrainingLoop: Backward + optimize<br/>(with grad scaling)
end
TrainingLoop->>Dist: Gather losses across ranks
TrainingLoop->>Dist: Log & merge history (master)
alt Time to save
TrainingLoop->>Dist: save_bypass_checkpoint()
Dist->>Dist: Sync barriers
end
alt Validate
TrainingLoop->>TrainingLoop: Validation forward
TrainingLoop->>Dist: Sync val_loss
end
end
TrainingLoop-->>Orchestrator: Training complete
Orchestrator->>Dist: realize_bypass_checkpoints()
Dist->>Dist: Create puzzle_dir symlink
Orchestrator-->>Launcher: Run finished
Launcher-->>User: All configs trained
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)
236-245:⚠️ Potential issue | 🟡 MinorThe fallback printer still emits only rank-local values.
This branch now advertises
num_layers={total_layers}, but it still prints only the contents ofrank_{rank}.pthand is executed on rank 0 only. On multi-GPU runs the suggestedEXPECTED_PRUNING_VALUESsnippet will be incomplete.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The printer currently outputs only rank-local pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0 aggregates pruning data from all ranks before printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth files) into a global pruning_scores for each layer_name, compute the global score and channels (e.g., combine/average or gather channel indices across ranks) respecting total_layers, and then have rank 0 iterate over layer_names using the aggregated values when printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES snippet.modelopt/torch/puzzletron/pruning/pruning_utils.py (1)
40-47:⚠️ Potential issue | 🟠 Major
MoEChannelPruningis exposed before the init path supports it.
modelopt/torch/puzzletron/tools/bypassed_training/child_init.pynow branches on this enum and forwards it into_init_mlp_module(), but_init_mlp_module()still falls through toUnsupported mlp_init_modefor this value when expert widths change. Any config that selectsMoEChannelPruningwill fail during child initialization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still treats that case as unsupported; update the _init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that performs the correct initialization when expert widths change (e.g., adapt the weight/activation shapes by slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
804-806: This change makes explicitnullresets impossible.Treating
Noneas “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back toNone. If callers need both behaviors, use a sentinel for “no override” and reserveNonefor explicit clearing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 804 - 806, The current override function (override) treats item_overrides == None as "keep original", which prevents callers from explicitly clearing a value to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no override" and reserve None in item_overrides to mean "set to None"/clear the field, updating the override function to check against the sentinel (NO_OVERRIDE) instead of None and adjust any callers that construct overrides to use the sentinel when they mean "leave original".modelopt/torch/puzzletron/utils/data/dataset.py (1)
123-130: Keep role markers in the no-template fallback.Joining only
contentcollapsessystem/user/assistantturns into plain text, which changes the supervision for chat datasets. A lightweight fallback like"{role}: {content}"preserves the conversation structure without relying on a tokenizer template.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The fallback that builds sample when getattr(self.tokenizer, "chat_template", None) is None should preserve role markers instead of joining only message["content"]; update the else branch in dataset.py (the block that currently sets sample = "\n".join(m["content"] for m in sample)) to join messages using a lightweight role-prefixed format like "{role}: {content}" so conversation turns (system/user/assistant) are retained; keep using the same sample variable and ensure this mirrors the structure expected by downstream code that consumes apply_chat_template outputs.modelopt/torch/puzzletron/utils/parsing.py (1)
337-345: Don’t silently treat every NaN as a no-op block.This formatter now drops any NaN entry and can report
No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The current filtering silently drops any NaN in losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides diverging trainable blocks; instead, update the logic around losses_dict, best_steps_dict and best_values_dict so you only drop entries whose keys match known skipped block types (e.g., the explicit list of no-op block names like "Mamba"), and for any other NaN values emit a warning/error (via the existing logger) that a trainable block produced NaN rather than removing it; ensure best_steps_dict and best_values_dict are only pruned to match the filtered losses_dict after this selective filtering and warning behavior.examples/puzzletron/main.py (1)
154-167: Progress messages inrun_mip_onlyare hardcoded and inconsistent with the dynamic approach.The
run_full_puzzletronfunction now uses dynamic step counting (N = _total_steps(hydra_cfg)), butrun_mip_onlystill uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).Consider applying the same dynamic step count logic here for consistency.
♻️ Suggested fix
def run_mip_only(hydra_config_path: str): ... # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + N = _total_steps(hydra_cfg) # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)" ) sweep.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to compute the total steps like run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when formatting the progress messages instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls around the conditional that currently show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step increments are correct for both the sweep branch (sweep.run_mip_sweep) and the mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress displays consistently with _total_steps.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
548-556: Unused variablenum_trainable_params.The variable
num_trainable_paramsis computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.♻️ Proposed removal
assert "learning_rate" in cfg.training - num_trainable_params = sum( - p.requires_grad and submodule_name in p_name - for p_name, p in student_stitched_module.named_parameters() - if "dummy_param" not in p_name # exclude placeholder params - ) - # Do NOT enable dummy params: blocks with no real trainable parameters - # (e.g. Mamba blocks during an attention-only bypass run) should produce - # NaN loss so they are excluded from statistics — identical to the - # optimizer=None path in the training loop. student_module_parameters = {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 548 - 556, Remove the unused computation of num_trainable_params: delete the sum(...) assignment that iterates over student_stitched_module.named_parameters() checking p.requires_grad and submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding explanatory comment about dummy params if still relevant, but eliminate the dead variable and its associated needless iteration to avoid wasted computation and clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.
In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.
In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.
In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.
---
Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.
---
Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".
In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.
In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40
📒 Files selected for processing (27)
examples/puzzletron/README.mdexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/main.pymodelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/puzzletron.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/data/dataset.pymodelopt/torch/puzzletron/utils/parsing.pytests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yamltests/gpu/torch/puzzletron/test_bypass.pytests/gpu/torch/puzzletron/test_puzzletron.pytests/unit/torch/puzzletron/__init__.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_bypass_utils.py
| # If "latest" doesn't exist, look explicitly into directories with `*iter-*` | ||
| candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()] | ||
|
|
||
| if not candidate_dirs: | ||
| return None | ||
|
|
||
| def get_iter_num(dir_name): | ||
| match = re.search(r"iter-(\d+)", dir_name.name) | ||
| return int(match.group(1)) if match else 0 | ||
|
|
||
| checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True) | ||
| for latest_dir in checkpoint_dirs: | ||
| if (latest_dir / "saving_completed").exists(): | ||
| return str(latest_dir) |
There was a problem hiding this comment.
Include step_num when picking the latest checkpoint.
This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.
💡 Suggested fix
- def get_iter_num(dir_name):
- match = re.search(r"iter-(\d+)", dir_name.name)
- return int(match.group(1)) if match else 0
-
- checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+ def checkpoint_order(path: Path) -> tuple[int, int, float]:
+ match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+ if not match:
+ return (0, 0, path.stat().st_mtime)
+ return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+ checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
| tokenizer = AutoTokenizer.from_pretrained( | ||
| cfg.teacher_dir, | ||
| trust_remote_code=True, | ||
| token=True, | ||
| ) |
There was a problem hiding this comment.
CRITICAL: Hardcoded trust_remote_code=True violates security coding guidelines.
The coding guidelines explicitly prohibit hardcoding trust_remote_code=True for transformers model or tokenizer loading. This flag enables execution of arbitrary Python shipped with a checkpoint, creating an RCE vector if the model source is untrusted.
The code already retrieves trust_remote_code from the descriptor at line 596 and uses it for model config loading at lines 597 and 631. Apply the same pattern here.
🔒 Proposed fix
tokenizer = AutoTokenizer.from_pretrained(
cfg.teacher_dir,
- trust_remote_code=True,
+ trust_remote_code=trust_remote_code,
token=True,
)As per coding guidelines: "Flag trust_remote_code=True hardcoded for transformers model or tokenizer loading as CRITICAL security issue. Code should expose it as a caller-configurable parameter defaulting to False, not hardcode True"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.teacher_dir, | |
| trust_remote_code=True, | |
| token=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.teacher_dir, | |
| trust_remote_code=trust_remote_code, | |
| token=True, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
673 - 677, Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.
| if (teacher_dir / "config.json").exists(): | ||
| mprint(f"Puzzletron Progress 2/{N}: teacher checkpoint already exists, skipping conversion") | ||
| else: | ||
| mprint(f"Puzzletron Progress 2/{N}: converting model to Puzzletron heterogeneous format (single-gpu)") |
There was a problem hiding this comment.
Use completion markers, not existence-only skips.
These guards treat config.json, any rank_*.pth, any file under pruned_ckpts_output_dir, or the two library outputs as proof that the whole stage is reusable. After an interrupted run—or after rerunning in the same puzzle_dir with different config—this can skip conversion/scoring/pruning/library build on partial or stale artifacts. For pruning-activation scoring, launch_score_activations() already has a stricter completeness check, so this outer pre-check is weaker than the existing stage logic.
Also applies to: 181-183, 191-193, 286-289
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py` around lines
146 - 149, The pre-checks that treat presence of files like (teacher_dir /
"config.json"), any rank_*.pth, files under pruned_ckpts_output_dir, or library
outputs as sufficient to skip stages are unsafe; change these guards to rely on
durable completion markers (e.g., a .done or .complete file) created at the
successful end of conversion/scoring/pruning/library build instead of
existence-only checks, so functions like the conversion branch around
teacher_dir/config.json, the rank_* checkpoint checks, and the
pruned_ckpts_output_dir/library checks only skip when their corresponding
completion marker exists; ensure launch_score_activations() remains the stricter
gate for pruning-activation scoring but remove or weaken the naive existence
checks noted at the conversion lines (the block using teacher_dir/config.json)
and the other mentioned blocks (191-193, 286-289) to check for the specific
"<stage>.complete" marker before skipping.
| loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( | ||
| target, torch.zeros_like(target) + epsilon, reduction=reduction | ||
| ) |
There was a problem hiding this comment.
Normalize with + epsilon/clamp_min, not target - epsilon.
Adding epsilon to the zero tensor changes the denominator from E[target^2] + ε to E[(target - ε)^2]. That biases the scale and makes near-zero targets produce an ε^2 denominator, which can blow up the bypass loss.
💡 Suggested fix
def normalized_mse_loss(
@@
- loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
- target, torch.zeros_like(target) + epsilon, reduction=reduction
- )
+ loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
+ target, torch.zeros_like(target), reduction=reduction
+ ).clamp_min(epsilon)
@@
- norm_of_target_vectors = F.mse_loss(
- target, torch.zeros_like(target) + epsilon, reduction="none"
- ).mean(norm_dims)
+ norm_of_target_vectors = F.mse_loss(
+ target, torch.zeros_like(target), reduction="none"
+ ).mean(norm_dims).clamp_min(epsilon)Also applies to: 479-482
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/sewing_kit/utils.py` around lines 452 - 454, The
normalization denominator is computed as F.mse_loss(target,
torch.zeros_like(target) + epsilon, ...) which shifts the target by epsilon and
biases the scale; instead compute the denominator as F.mse_loss(target,
torch.zeros_like(target), reduction=reduction) + epsilon (or clamp_min the
denominator to epsilon) so you add epsilon to the final scalar denominator
instead of to the zero tensor; update the occurrences around the loss assignment
(loss, input, target, epsilon, F.mse_loss) and the similar block at lines
479-482 accordingly.
| if not hasattr(model_config, "auto_map"): | ||
| return | ||
|
|
||
| # The config class's source file lives in the HF cache together with all other | ||
| # custom code files for this model. Walk the auto_map values to find every | ||
| # module file that needs to be present alongside config.json. | ||
| source_dir = Path(inspect.getfile(type(model_config))).parent | ||
|
|
||
| module_files = { | ||
| f"{class_ref.split('.')[0]}.py" for class_ref in model_config.auto_map.values() | ||
| } | ||
|
|
||
| for filename in module_files: | ||
| src = source_dir / filename | ||
| dst = Path(checkpoint_dir) / filename | ||
| if src.exists() and not dst.exists(): | ||
| shutil.copy(src, dst) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In Hugging Face Transformers, what forms can config.auto_mapvalues take for custom-code models, includingAutoTokenizerentries and repo-qualified references likerepo_id--module.ClassName?
💡 Result:
In custom-code (i.e., trust_remote_code=True) Transformers models, config.auto_map is a dict whose values can be:
-
A single string pointing to a class in the repo (Python dotted path):
"module_file.ClassName"(e.g.,"models.TransformerForABSA"). [3]- The same, but repo-qualified:
"repo_id--module_file.ClassName"(therepo_id--prefix is added when missing). [2]
-
A 2-item sequence (list/tuple) of strings (used especially for
AutoTokenizer):[slow_tokenizer_ref, fast_tokenizer_ref], where each element is either:"module_file.ClassName", or"repo_id--module_file.ClassName", and may beNone(commonly for “no fast tokenizer”). [1][2]
For AutoTokenizer entries specifically, the mapping can appear as:
config.auto_map["AutoTokenizer"] = [slow_ref, fast_ref](loaded and then used to pick index0or1depending onuse_fast). [1]- In
tokenizer_config.json,auto_mapmay be either:- the legacy top-level 2-item list/tuple, or
- a dict with an
"AutoTokenizer"key whose value is that 2-item list/tuple. [1]
Example shapes
"auto_map": { "AutoModelForCausalLM": "modeling_my.MyLM" }
"auto_map": { "AutoModelForCausalLM": "my-org/my-repo--modeling_my.MyLM" }
"auto_map": { "AutoTokenizer": ["tokenization_my.MyTokenizer", "tokenization_my.MyTokenizerFast"] }
"auto_map": { "AutoTokenizer": ["my-org/my-repo--tokenization_my.MyTokenizer", null] }Sources: [1] [2] [3]
Fix auto_map value parsing to handle lists and repo-qualified references.
The code assumes all auto_map values are simple dotted strings, but Transformers uses multiple forms:
- Lists/tuples (especially for
AutoTokenizer):["tokenization_my.MyTokenizer", "tokenization_my.MyTokenizerFast"]— calling.split('.')on a list crashes withAttributeError. - Repo-qualified strings:
"repo_id--tokenization_my.MyTokenizer"—class_ref.split('.')[0]extracts"repo_id--tokenization_my"instead of the filename"tokenization_my.py", producing missing files.
Normalize values by (1) extracting the first element if a list/tuple, (2) stripping repo_id-- prefix, (3) then extracting the module name. Without this, custom-code checkpoints remain incomplete for trust_remote_code=True loads.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` around lines 380 -
396, The auto_map parsing in checkpoint_utils_hf.py incorrectly assumes each
model_config.auto_map value is a dotted string; update the logic that builds
module_files (and any usage of class_ref) to first normalize each value by: if
it's a list/tuple take the first element, if it contains a repo qualifier split
off the "repo_id--" prefix, then take the module part before the first '.' and
append ".py" (so "tokenization_my.py"); apply this normalization where
module_files is created and when iterating filenames so lists/tuples and
repo-qualified references are handled and the correct source filenames are
copied.
| num_workers: int = 0, | ||
| ) -> DataLoader: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n -C3 'class ConstantLengthDataset|def __iter__|get_worker_info|worker_info' \
modelopt/torch/puzzletron/utils/data/dataset.py \
modelopt/torch/puzzletron/utils/data/dataloaders.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1247
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get the full __iter__ method implementation
echo "=== ConstantLengthDataset.__iter__() implementation ==="
sed -n '100,150p' modelopt/torch/puzzletron/utils/data/dataset.py
echo ""
echo "=== create_train_dataloader function signature and first 30 lines ==="
sed -n '80,120p' modelopt/torch/puzzletron/utils/data/dataloaders.py
echo ""
echo "=== Check for get_worker_info usage in dataset.py ==="
grep -n "get_worker_info" modelopt/torch/puzzletron/utils/data/dataset.py || echo "No get_worker_info found"
echo ""
echo "=== Check for any guards on num_workers in dataloaders.py ==="
grep -n "num_workers" modelopt/torch/puzzletron/utils/data/dataloaders.pyRepository: NVIDIA/Model-Optimizer
Length of output: 4020
num_workers is unsafe until ConstantLengthDataset is worker-sharded.
ConstantLengthDataset.__iter__() in modelopt/torch/puzzletron/utils/data/dataset.py walks self.dataset directly and never uses get_worker_info(). Passing num_workers > 0 to the DataLoader will cause each worker to read the same iterable stream, duplicating training samples silently.
Suggested guard
def create_train_dataloader(
bos_rate: float = 1.0,
num_workers: int = 0,
) -> DataLoader:
"""Create an infinite training DataLoader over ConstantLengthDataset."""
+ if num_workers != 0:
+ raise ValueError(
+ "create_train_dataloader currently requires num_workers=0 because "
+ "ConstantLengthDataset is not sharded across DataLoader workers."
+ )
+
if isinstance(dataset_path, str):
dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory)Also applies to: 114-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 89 - 90,
The DataLoader factory allows num_workers>0 while ConstantLengthDataset.__iter__
does not shard via get_worker_info(), causing duplicate samples; update the
dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
| if shuffle_seed is not None: | ||
| train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does Hugging Face datasets.IterableDataset.shuffle()accept akeep_in_memory keyword argument?
💡 Result:
No. datasets.IterableDataset.shuffle() only takes seed, generator, and buffer_size (no keep_in_memory kwarg), per both the API reference and the implementation. [1], [2]
Branch the shuffle call for streaming datasets.
IterableDataset.shuffle() does not accept keep_in_memory=True. When load_streaming_fn() is used (streaming datasets), line 99 will fail at runtime with an unexpected keyword argument error.
Suggested fix
train_data = dataset[dataset_name]
if shuffle_seed is not None:
- train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
+ if isinstance(train_data, datasets.IterableDataset):
+ train_data = train_data.shuffle(seed=shuffle_seed)
+ else:
+ train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 98 - 99,
The call to train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) fails for
streaming (Iterable) datasets because IterableDataset.shuffle() doesn't accept
keep_in_memory; update the code that checks shuffle_seed to detect streaming
datasets (e.g., via whatever marker load_streaming_fn sets or by checking
hasattr(train_data, "__iter__") vs __len__/isinstance of IterableDataset) and
branch: for non-streaming datasets call train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) as before, and for streaming/iterable datasets call
train_data.shuffle(seed=shuffle_seed) without keep_in_memory; ensure you modify
the block that references shuffle_seed and train_data.shuffle so runtime errors
are avoided when load_streaming_fn() returns a streaming dataset.
| 5. Run ``pruning_ckpts`` (rank 0 only) then barrier. | ||
| """ | ||
| set_seed(SEED) | ||
| dist.setup(timeout=timedelta(10)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
from datetime import timedelta
print("timedelta(10) =", timedelta(10))
print("timedelta(seconds=10) =", timedelta(seconds=10))
print("timedelta(minutes=10) =", timedelta(minutes=10))
PYRepository: NVIDIA/Model-Optimizer
Length of output: 162
🏁 Script executed:
# Check the file and context around line 213
cd tests/gpu/torch/puzzletron && wc -l test_bypass.pyRepository: NVIDIA/Model-Optimizer
Length of output: 84
🏁 Script executed:
# View the context around line 213
sed -n '205,220p' tests/gpu/torch/puzzletron/test_bypass.pyRepository: NVIDIA/Model-Optimizer
Length of output: 749
🏁 Script executed:
# Search for other dist.setup() calls to check for timeout patterns
rg "dist\.setup\(" tests/gpu/torch/puzzletron/test_bypass.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 209
🏁 Script executed:
# Search across the codebase for dist.setup() timeout patterns
rg "dist\.setup\(" --type py -A 1 | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 1633
Fix timedelta(10) to use explicit time units.
timedelta(10) specifies 10 days, not 10 seconds or 10 minutes. This pattern appears across multiple test and example files and would cause GPU tests to hang excessively (10 days) on deadlock. Use explicit units instead: timedelta(seconds=10) or timedelta(minutes=10).
Affected locations include:
- tests/gpu/torch/puzzletron/test_bypass.py
- tests/gpu/torch/puzzletron/test_puzzletron.py
- tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
- tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py (2 occurrences)
- examples/puzzletron/main.py (2 occurrences)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/puzzletron/test_bypass.py` at line 213, The timeout passed to
dist.setup uses timedelta(10) which means 10 days; change it to an explicit unit
like timedelta(seconds=10) (or timedelta(minutes=10) if intended) to avoid
10-day test hangs — locate the call to dist.setup (symbol: dist.setup) in
tests/gpu/torch/puzzletron/test_bypass.py and the other listed files and replace
timedelta(10) with timedelta(seconds=10) (or the correct unit) in each
occurrence.
Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.
Changes:
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests