Skip to content

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111

Open
Separius wants to merge 1 commit intofeature/puzzletronfrom
ssameni/puzzletron-bypass
Open

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Separius wants to merge 1 commit intofeature/puzzletronfrom
ssameni/puzzletron-bypass

Conversation

@Separius
Copy link

@Separius Separius commented Mar 24, 2026

Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.

Changes:

  • Add modelopt/torch/puzzletron/bypass_distillation/ module with full training loop, stitched model factory, checkpoint management, and data classes
  • Integrate bypass as optional Step 3 in puzzletron.py and puzzletron_nas_plugin.py (pipeline progress counter updates to 9 steps when bypass is enabled)
  • Add HuggingFace auto-download and skip-if-exists logic to puzzletron_nas_plugin.py for all pipeline steps
  • Add normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss to sewing_kit/utils.py
  • Fix child_init.py: support list of pruning mixins; fix None override treated as "keep original value" instead of raising TypeCheckError
  • Fix dataset.py: graceful fallback when tokenizer has no chat_template (base models)
  • Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling Python files are copied alongside config.json (required for trust_remote_code checkpoints such as NemotronH)
  • Add create_train_dataloader to dataloaders.py
  • Add MoEChannelPruning to MlpInitMode enum
  • Add default pruning_mixins() to ModelDescriptor base class
  • Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks during subblock_attention bypass (based on block config)
  • Enable bypass in llama-3_1-8B_pruneffn_memory config; add example bypass/defaults.yaml
  • Update README with bypass documentation: when to use, time cost, sequential execution, W&B logging
  • Add unit tests for loss functions and distribution utilities
  • Add GPU integration tests for bypass (FFN pruning, KV compression, multi-config sweep, checkpoint validation)
  • Fix test_puzzletron.py assertion to handle variable GPU counts

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Bypass Distillation (blockwise local distillation) as an optional compression stage for aggressive model compression
    • Support for sequential multi-configuration training with per-configuration overrides
    • Weights & Biases logging integration for per-block distillation metrics and losses
  • Documentation

    • Added comprehensive guide for Bypass Distillation setup, configuration, and expected time costs
  • Tests

    • Added integration and unit test coverage for bypass distillation and distributed training

Bypass distillation trains alternative transformer block configurations
using per-block knowledge distillation from the teacher model, producing
a library of better "puzzle pieces" for the MIP solver. It is most
beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or
significant KV head compression.

Changes:
- Add modelopt/torch/puzzletron/bypass_distillation/ module with full
  training loop, stitched model factory, checkpoint management, and
  data classes
- Integrate bypass as optional Step 3 in puzzletron.py and
  puzzletron_nas_plugin.py (pipeline progress counter updates to 9
  steps when bypass is enabled)
- Add HuggingFace auto-download and skip-if-exists logic to
  puzzletron_nas_plugin.py for all pipeline steps
- Add normalized_mse_loss, vectorwise_normalized_mse_loss, and
  batched_normalized_mse_loss to sewing_kit/utils.py
- Fix child_init.py: support list of pruning mixins; fix None override
  treated as "keep original value" instead of raising TypeCheckError
- Fix dataset.py: graceful fallback when tokenizer has no chat_template
  (base models)
- Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling
  Python files are copied alongside config.json (required for
  trust_remote_code checkpoints such as NemotronH)
- Add create_train_dataloader to dataloaders.py
- Add MoEChannelPruning to MlpInitMode enum
- Add default pruning_mixins() to ModelDescriptor base class
- Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to
  NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks
  during subblock_attention bypass (based on block config)
- Enable bypass in llama-3_1-8B_pruneffn_memory config; add example
  bypass/defaults.yaml
- Update README with bypass documentation: when to use, time cost,
  sequential execution, W&B logging
- Add unit tests for loss functions and distribution utilities
- Add GPU integration tests for bypass (FFN pruning, KV compression,
  multi-config sweep, checkpoint validation)
- Fix test_puzzletron.py assertion to handle variable GPU counts
@Separius Separius requested review from a team as code owners March 24, 2026 16:21
@Separius Separius requested a review from cjluo-nv March 24, 2026 16:21
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

This pull request adds a complete bypass distillation (blockwise local distillation) training pipeline as an optional stage to the Puzzletron model optimization framework. It introduces new configuration schemas, orchestration logic, checkpoint utilities, loss computations, stitched model factories for teacher-student learning, distributed training loops with gradient scaling, data loaders, and comprehensive test coverage including GPU integration tests.

Changes

Cohort / File(s) Summary
Documentation & Configuration
examples/puzzletron/README.md, examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml, examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
Added bypass distillation documentation explaining usage, time costs, and configuration parameters. Updated Hydra defaults to reference the new bypass: defaults preset. Created comprehensive bypass config defining training hyperparameters, dataset/model setup, distillation factory settings, and multi-config sweep support.
Main Entry Points
examples/puzzletron/main.py, modelopt/torch/puzzletron/puzzletron.py, modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
Updated progress tracking to dynamically compute total steps accounting for bypass stage. Integrated bypass distillation as an optional pipeline stage after pruning. Enhanced NAS plugin with bypass support, checkpoint resumption, model auto-download, and cache-skipping logic for idempotent reruns.
Bypass Distillation Core Package
modelopt/torch/puzzletron/bypass_distillation/__init__.py, modelopt/torch/puzzletron/bypass_distillation/data_classes.py, modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py, modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py, modelopt/torch/puzzletron/bypass_distillation/training_loop.py, modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
Implemented complete blockwise local distillation system: data structures for training telemetry, distributed module ownership computation, teacher-student stitched module factory with per-block descriptors and gradient scaling, orchestration/resumption logic for single/multi-config runs, distributed checkpoint persistence with symlink management, and cosine-warmup training loop with validation and W&B logging.
Model Extensions & Descriptors
modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py, modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py, modelopt/torch/puzzletron/pruning/pruning_utils.py
Added pruning_mixins() extension point to model descriptors for bypass-distillation mixin support. Implemented KV-head pruning for Nemotron-H with dedicated layer descriptor. Added MoEChannelPruning enum mode for future MoE channel pruning support.
Loss Functions & Data Utilities
modelopt/torch/puzzletron/sewing_kit/utils.py, modelopt/torch/puzzletron/utils/data/dataloaders.py, modelopt/torch/puzzletron/utils/data/dataset.py, modelopt/torch/puzzletron/utils/parsing.py
Introduced normalized MSE loss variants (scalar, vectorwise, batched) for bypass distillation. Added create_train_dataloader for infinite streaming training. Enhanced chat template handling with fallback for tokenizers lacking chat templates. Updated loss formatting to skip NaN entries and report missing trainable blocks.
Checkpoint & Training Infrastructure
modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py, modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
Extended checkpoint saving to preserve Hugging Face auto_map code files. Enhanced pruning mixin handling to support lists of mixins per layer. Fixed config override logic to preserve original values instead of nullifying.
GPU Integration Tests
tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml, tests/gpu/torch/puzzletron/test_bypass.py, tests/gpu/torch/puzzletron/test_puzzletron.py
Added minimal bypass test configuration and four multi-GPU test cases covering FFN pruning, KV-head compression, multi-config sequential runs, and checkpoint content validation. Updated existing pruning activation tests to correctly compute per-rank layer ownership.
Unit Tests
tests/unit/torch/puzzletron/test_bypass_losses.py, tests/unit/torch/puzzletron/test_bypass_utils.py
Added unit test coverage for normalized loss functions (shape/reduction behavior, numerical correctness) and distributed module ownership computation (even/uneven distribution, edge cases).

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Launcher as launch_bypass<br/>_distillation()
    participant Orchestrator as run_bypassed<br/>_training()
    participant Loader as DataLoader &<br/>Teacher Model
    participant Factory as stitched_model<br/>_factory()
    participant TrainingLoop as train()
    participant Dist as Distributed<br/>Sync & Checkpoints
    
    User->>Launcher: Hydra config
    Launcher->>Launcher: For each bypass.config
    Launcher->>Orchestrator: cfg with overrides
    
    Orchestrator->>Loader: Load teacher, init student
    Orchestrator->>Factory: Create stitched modules & descriptors
    Factory->>Factory: Compute block ownership<br/>per rank
    Factory-->>Orchestrator: Student + stitched modules
    
    Orchestrator->>Loader: Build streaming dataloader
    Orchestrator->>Orchestrator: Resume from checkpoint<br/>(if exists)
    
    Orchestrator->>TrainingLoop: Call train()
    
    loop Per iteration
        TrainingLoop->>Loader: Fetch batch
        TrainingLoop->>TrainingLoop: Teacher forward pass
        loop Per stitched module
            TrainingLoop->>TrainingLoop: Student forward + loss
            TrainingLoop->>TrainingLoop: Backward + optimize<br/>(with grad scaling)
        end
        TrainingLoop->>Dist: Gather losses across ranks
        TrainingLoop->>Dist: Log & merge history (master)
        alt Time to save
            TrainingLoop->>Dist: save_bypass_checkpoint()
            Dist->>Dist: Sync barriers
        end
        alt Validate
            TrainingLoop->>TrainingLoop: Validation forward
            TrainingLoop->>Dist: Sync val_loss
        end
    end
    
    TrainingLoop-->>Orchestrator: Training complete
    Orchestrator->>Dist: realize_bypass_checkpoints()
    Dist->>Dist: Create puzzle_dir symlink
    Orchestrator-->>Launcher: Run finished
    
    Launcher-->>User: All configs trained
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error PR contains critical security violation: hardcoded trust_remote_code=True at line 675 in training_loop.py contradicts security guidelines and available variable. Replace trust_remote_code=True with trust_remote_code=trust_remote_code at line 675; add weights_only parameter and justifying comments to torch.load() calls in bypass_checkpoint_utils.py lines 85 and 99.
Docstring Coverage ⚠️ Warning Docstring coverage is 73.47% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately summarizes the main change: adding bypass distillation (blockwise local KD) to the puzzletron pipeline. It is concise, clear, and directly reflects the primary objective.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ssameni/puzzletron-bypass

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)

236-245: ⚠️ Potential issue | 🟡 Minor

The fallback printer still emits only rank-local values.

This branch now advertises num_layers={total_layers}, but it still prints only the contents of rank_{rank}.pth and is executed on rank 0 only. On multi-GPU runs the suggested EXPECTED_PRUNING_VALUES snippet will be incomplete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The
printer currently outputs only rank-local pruning_scores causing incomplete
EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0
aggregates pruning data from all ranks before printing: collect and merge
per-rank pruning_scores (or load all rank_{rank}.pth files) into a global
pruning_scores for each layer_name, compute the global score and channels (e.g.,
combine/average or gather channel indices across ranks) respecting total_layers,
and then have rank 0 iterate over layer_names using the aggregated values when
printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES
snippet.
modelopt/torch/puzzletron/pruning/pruning_utils.py (1)

40-47: ⚠️ Potential issue | 🟠 Major

MoEChannelPruning is exposed before the init path supports it.

modelopt/torch/puzzletron/tools/bypassed_training/child_init.py now branches on this enum and forwards it into _init_mlp_module(), but _init_mlp_module() still falls through to Unsupported mlp_init_mode for this value when expert widths change. Any config that selects MoEChannelPruning will fail during child initialization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The
enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still
treats that case as unsupported; update the _init_mlp_module implementation to
handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py
forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that
performs the correct initialization when expert widths change (e.g., adapt the
weight/activation shapes by slicing/reshaping or reuse the
ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer
falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

804-806: This change makes explicit null resets impossible.

Treating None as “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back to None. If callers need both behaviors, use a sentinel for “no override” and reserve None for explicit clearing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
804 - 806, The current override function (override) treats item_overrides ==
None as "keep original", which prevents callers from explicitly clearing a value
to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new
unique object like NO_OVERRIDE) to represent "no override" and reserve None in
item_overrides to mean "set to None"/clear the field, updating the override
function to check against the sentinel (NO_OVERRIDE) instead of None and adjust
any callers that construct overrides to use the sentinel when they mean "leave
original".
modelopt/torch/puzzletron/utils/data/dataset.py (1)

123-130: Keep role markers in the no-template fallback.

Joining only content collapses system/user/assistant turns into plain text, which changes the supervision for chat datasets. A lightweight fallback like "{role}: {content}" preserves the conversation structure without relying on a tokenizer template.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The
fallback that builds sample when getattr(self.tokenizer, "chat_template", None)
is None should preserve role markers instead of joining only message["content"];
update the else branch in dataset.py (the block that currently sets sample =
"\n".join(m["content"] for m in sample)) to join messages using a lightweight
role-prefixed format like "{role}: {content}" so conversation turns
(system/user/assistant) are retained; keep using the same sample variable and
ensure this mirrors the structure expected by downstream code that consumes
apply_chat_template outputs.
modelopt/torch/puzzletron/utils/parsing.py (1)

337-345: Don’t silently treat every NaN as a no-op block.

This formatter now drops any NaN entry and can report No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The
current filtering silently drops any NaN in losses_dict (and prunes
best_steps_dict/best_values_dict to match), which hides diverging trainable
blocks; instead, update the logic around losses_dict, best_steps_dict and
best_values_dict so you only drop entries whose keys match known skipped block
types (e.g., the explicit list of no-op block names like "Mamba"), and for any
other NaN values emit a warning/error (via the existing logger) that a trainable
block produced NaN rather than removing it; ensure best_steps_dict and
best_values_dict are only pruned to match the filtered losses_dict after this
selective filtering and warning behavior.
examples/puzzletron/main.py (1)

154-167: Progress messages in run_mip_only are hardcoded and inconsistent with the dynamic approach.

The run_full_puzzletron function now uses dynamic step counting (N = _total_steps(hydra_cfg)), but run_mip_only still uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).

Consider applying the same dynamic step count logic here for consistency.

♻️ Suggested fix
 def run_mip_only(hydra_config_path: str):
     ...
     # Load hydra config
     hydra_cfg = initialize_hydra_config_for_dir(
         config_dir=hydra_config_dir,
         config_name=hydra_config_name,
         overrides=[],
     )
+    N = _total_steps(hydra_cfg)

     # Check if sweep mode is enabled
     if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
         mprint(
-            "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
+            f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)"
         )
         sweep.run_mip_sweep(hydra_cfg)
     else:
         # mip_and_realize_models (distributed processing)
         # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
-        mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
+        mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)")
         mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)

     dist.cleanup()
-    mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
+    mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to
compute the total steps like run_full_puzzletron by calling
_total_steps(hydra_cfg) and use that N when formatting the progress messages
instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls
around the conditional that currently show "Puzzletron Progress 7/8" and "8/8"
with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}:
...") and ensure current_step increments are correct for both the sweep branch
(sweep.run_mip_sweep) and the mip branch
(mip_and_realize_models.launch_mip_and_realize_model) so progress displays
consistently with _total_steps.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

548-556: Unused variable num_trainable_params.

The variable num_trainable_params is computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.

♻️ Proposed removal
             assert "learning_rate" in cfg.training
-            num_trainable_params = sum(
-                p.requires_grad and submodule_name in p_name
-                for p_name, p in student_stitched_module.named_parameters()
-                if "dummy_param" not in p_name  # exclude placeholder params
-            )
-            # Do NOT enable dummy params: blocks with no real trainable parameters
-            # (e.g. Mamba blocks during an attention-only bypass run) should produce
-            # NaN loss so they are excluded from statistics — identical to the
-            # optimizer=None path in the training loop.

             student_module_parameters = {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 548 - 556, Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.

In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.

In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.

In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.

---

Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.

---

Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".

In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.

In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40

📥 Commits

Reviewing files that changed from the base of the PR and between e508b76 and e018ca0.

📒 Files selected for processing (27)
  • examples/puzzletron/README.md
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py
  • modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/puzzletron.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/utils/data/dataloaders.py
  • modelopt/torch/puzzletron/utils/data/dataset.py
  • modelopt/torch/puzzletron/utils/parsing.py
  • tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml
  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/unit/torch/puzzletron/__init__.py
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py

Comment on lines +45 to +58
# If "latest" doesn't exist, look explicitly into directories with `*iter-*`
candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()]

if not candidate_dirs:
return None

def get_iter_num(dir_name):
match = re.search(r"iter-(\d+)", dir_name.name)
return int(match.group(1)) if match else 0

checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
for latest_dir in checkpoint_dirs:
if (latest_dir / "saving_completed").exists():
return str(latest_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include step_num when picking the latest checkpoint.

This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.

💡 Suggested fix
-    def get_iter_num(dir_name):
-        match = re.search(r"iter-(\d+)", dir_name.name)
-        return int(match.group(1)) if match else 0
-
-    checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+    def checkpoint_order(path: Path) -> tuple[int, int, float]:
+        match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+        if not match:
+            return (0, 0, path.stat().st_mtime)
+        return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+    checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

Comment on lines +673 to +677
tokenizer = AutoTokenizer.from_pretrained(
cfg.teacher_dir,
trust_remote_code=True,
token=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

CRITICAL: Hardcoded trust_remote_code=True violates security coding guidelines.

The coding guidelines explicitly prohibit hardcoding trust_remote_code=True for transformers model or tokenizer loading. This flag enables execution of arbitrary Python shipped with a checkpoint, creating an RCE vector if the model source is untrusted.

The code already retrieves trust_remote_code from the descriptor at line 596 and uses it for model config loading at lines 597 and 631. Apply the same pattern here.

🔒 Proposed fix
 tokenizer = AutoTokenizer.from_pretrained(
     cfg.teacher_dir,
-    trust_remote_code=True,
+    trust_remote_code=trust_remote_code,
     token=True,
 )

As per coding guidelines: "Flag trust_remote_code=True hardcoded for transformers model or tokenizer loading as CRITICAL security issue. Code should expose it as a caller-configurable parameter defaulting to False, not hardcode True"

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tokenizer = AutoTokenizer.from_pretrained(
cfg.teacher_dir,
trust_remote_code=True,
token=True,
)
tokenizer = AutoTokenizer.from_pretrained(
cfg.teacher_dir,
trust_remote_code=trust_remote_code,
token=True,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
673 - 677, Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.

Comment on lines +146 to +149
if (teacher_dir / "config.json").exists():
mprint(f"Puzzletron Progress 2/{N}: teacher checkpoint already exists, skipping conversion")
else:
mprint(f"Puzzletron Progress 2/{N}: converting model to Puzzletron heterogeneous format (single-gpu)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use completion markers, not existence-only skips.

These guards treat config.json, any rank_*.pth, any file under pruned_ckpts_output_dir, or the two library outputs as proof that the whole stage is reusable. After an interrupted run—or after rerunning in the same puzzle_dir with different config—this can skip conversion/scoring/pruning/library build on partial or stale artifacts. For pruning-activation scoring, launch_score_activations() already has a stricter completeness check, so this outer pre-check is weaker than the existing stage logic.

Also applies to: 181-183, 191-193, 286-289

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py` around lines
146 - 149, The pre-checks that treat presence of files like (teacher_dir /
"config.json"), any rank_*.pth, files under pruned_ckpts_output_dir, or library
outputs as sufficient to skip stages are unsafe; change these guards to rely on
durable completion markers (e.g., a .done or .complete file) created at the
successful end of conversion/scoring/pruning/library build instead of
existence-only checks, so functions like the conversion branch around
teacher_dir/config.json, the rank_* checkpoint checks, and the
pruned_ckpts_output_dir/library checks only skip when their corresponding
completion marker exists; ensure launch_score_activations() remains the stricter
gate for pruning-activation scoring but remove or weaken the naive existence
checks noted at the conversion lines (the block using teacher_dir/config.json)
and the other mentioned blocks (191-193, 286-289) to check for the specific
"<stage>.complete" marker before skipping.

Comment on lines +452 to +454
loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
target, torch.zeros_like(target) + epsilon, reduction=reduction
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize with + epsilon/clamp_min, not target - epsilon.

Adding epsilon to the zero tensor changes the denominator from E[target^2] + ε to E[(target - ε)^2]. That biases the scale and makes near-zero targets produce an ε^2 denominator, which can blow up the bypass loss.

💡 Suggested fix
 def normalized_mse_loss(
@@
-    loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
-        target, torch.zeros_like(target) + epsilon, reduction=reduction
-    )
+    loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss(
+        target, torch.zeros_like(target), reduction=reduction
+    ).clamp_min(epsilon)
@@
-    norm_of_target_vectors = F.mse_loss(
-        target, torch.zeros_like(target) + epsilon, reduction="none"
-    ).mean(norm_dims)
+    norm_of_target_vectors = F.mse_loss(
+        target, torch.zeros_like(target), reduction="none"
+    ).mean(norm_dims).clamp_min(epsilon)

Also applies to: 479-482

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/sewing_kit/utils.py` around lines 452 - 454, The
normalization denominator is computed as F.mse_loss(target,
torch.zeros_like(target) + epsilon, ...) which shifts the target by epsilon and
biases the scale; instead compute the denominator as F.mse_loss(target,
torch.zeros_like(target), reduction=reduction) + epsilon (or clamp_min the
denominator to epsilon) so you add epsilon to the final scalar denominator
instead of to the zero tensor; update the occurrences around the loss assignment
(loss, input, target, epsilon, F.mse_loss) and the similar block at lines
479-482 accordingly.

Comment on lines +380 to +396
if not hasattr(model_config, "auto_map"):
return

# The config class's source file lives in the HF cache together with all other
# custom code files for this model. Walk the auto_map values to find every
# module file that needs to be present alongside config.json.
source_dir = Path(inspect.getfile(type(model_config))).parent

module_files = {
f"{class_ref.split('.')[0]}.py" for class_ref in model_config.auto_map.values()
}

for filename in module_files:
src = source_dir / filename
dst = Path(checkpoint_dir) / filename
if src.exists() and not dst.exists():
shutil.copy(src, dst)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In Hugging Face Transformers, what forms can config.auto_mapvalues take for custom-code models, includingAutoTokenizerentries and repo-qualified references likerepo_id--module.ClassName?

💡 Result:

In custom-code (i.e., trust_remote_code=True) Transformers models, config.auto_map is a dict whose values can be:

  • A single string pointing to a class in the repo (Python dotted path):

    • "module_file.ClassName" (e.g., "models.TransformerForABSA"). [3]
    • The same, but repo-qualified:
      • "repo_id--module_file.ClassName" (the repo_id-- prefix is added when missing). [2]
  • A 2-item sequence (list/tuple) of strings (used especially for AutoTokenizer):

    • [slow_tokenizer_ref, fast_tokenizer_ref], where each element is either:
      • "module_file.ClassName", or
      • "repo_id--module_file.ClassName", and may be None (commonly for “no fast tokenizer”). [1][2]

For AutoTokenizer entries specifically, the mapping can appear as:

  • config.auto_map["AutoTokenizer"] = [slow_ref, fast_ref] (loaded and then used to pick index 0 or 1 depending on use_fast). [1]
  • In tokenizer_config.json, auto_map may be either:
    • the legacy top-level 2-item list/tuple, or
    • a dict with an "AutoTokenizer" key whose value is that 2-item list/tuple. [1]

Example shapes

"auto_map": { "AutoModelForCausalLM": "modeling_my.MyLM" }
"auto_map": { "AutoModelForCausalLM": "my-org/my-repo--modeling_my.MyLM" }

"auto_map": { "AutoTokenizer": ["tokenization_my.MyTokenizer", "tokenization_my.MyTokenizerFast"] }
"auto_map": { "AutoTokenizer": ["my-org/my-repo--tokenization_my.MyTokenizer", null] }

Sources: [1] [2] [3]


Fix auto_map value parsing to handle lists and repo-qualified references.

The code assumes all auto_map values are simple dotted strings, but Transformers uses multiple forms:

  • Lists/tuples (especially for AutoTokenizer): ["tokenization_my.MyTokenizer", "tokenization_my.MyTokenizerFast"] — calling .split('.') on a list crashes with AttributeError.
  • Repo-qualified strings: "repo_id--tokenization_my.MyTokenizer"class_ref.split('.')[0] extracts "repo_id--tokenization_my" instead of the filename "tokenization_my.py", producing missing files.

Normalize values by (1) extracting the first element if a list/tuple, (2) stripping repo_id-- prefix, (3) then extracting the module name. Without this, custom-code checkpoints remain incomplete for trust_remote_code=True loads.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` around lines 380 -
396, The auto_map parsing in checkpoint_utils_hf.py incorrectly assumes each
model_config.auto_map value is a dotted string; update the logic that builds
module_files (and any usage of class_ref) to first normalize each value by: if
it's a list/tuple take the first element, if it contains a repo qualifier split
off the "repo_id--" prefix, then take the module part before the first '.' and
append ".py" (so "tokenization_my.py"); apply this normalization where
module_files is created and when iterating filenames so lists/tuples and
repo-qualified references are handled and the correct source filenames are
copied.

Comment on lines +89 to +90
num_workers: int = 0,
) -> DataLoader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n -C3 'class ConstantLengthDataset|def __iter__|get_worker_info|worker_info' \
  modelopt/torch/puzzletron/utils/data/dataset.py \
  modelopt/torch/puzzletron/utils/data/dataloaders.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1247


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the full __iter__ method implementation
echo "=== ConstantLengthDataset.__iter__() implementation ==="
sed -n '100,150p' modelopt/torch/puzzletron/utils/data/dataset.py

echo ""
echo "=== create_train_dataloader function signature and first 30 lines ==="
sed -n '80,120p' modelopt/torch/puzzletron/utils/data/dataloaders.py

echo ""
echo "=== Check for get_worker_info usage in dataset.py ==="
grep -n "get_worker_info" modelopt/torch/puzzletron/utils/data/dataset.py || echo "No get_worker_info found"

echo ""
echo "=== Check for any guards on num_workers in dataloaders.py ==="
grep -n "num_workers" modelopt/torch/puzzletron/utils/data/dataloaders.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4020


num_workers is unsafe until ConstantLengthDataset is worker-sharded.

ConstantLengthDataset.__iter__() in modelopt/torch/puzzletron/utils/data/dataset.py walks self.dataset directly and never uses get_worker_info(). Passing num_workers > 0 to the DataLoader will cause each worker to read the same iterable stream, duplicating training samples silently.

Suggested guard
 def create_train_dataloader(
     bos_rate: float = 1.0,
     num_workers: int = 0,
 ) -> DataLoader:
     """Create an infinite training DataLoader over ConstantLengthDataset."""
+    if num_workers != 0:
+        raise ValueError(
+            "create_train_dataloader currently requires num_workers=0 because "
+            "ConstantLengthDataset is not sharded across DataLoader workers."
+        )
+
     if isinstance(dataset_path, str):
         dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory)

Also applies to: 114-118

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 89 - 90,
The DataLoader factory allows num_workers>0 while ConstantLengthDataset.__iter__
does not shard via get_worker_info(), causing duplicate samples; update the
dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.

Comment on lines +98 to +99
if shuffle_seed is not None:
train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does Hugging Face datasets.IterableDataset.shuffle()accept akeep_in_memory keyword argument?

💡 Result:

No. datasets.IterableDataset.shuffle() only takes seed, generator, and buffer_size (no keep_in_memory kwarg), per both the API reference and the implementation. [1], [2]


Branch the shuffle call for streaming datasets.

IterableDataset.shuffle() does not accept keep_in_memory=True. When load_streaming_fn() is used (streaming datasets), line 99 will fail at runtime with an unexpected keyword argument error.

Suggested fix
     train_data = dataset[dataset_name]
     if shuffle_seed is not None:
-        train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
+        if isinstance(train_data, datasets.IterableDataset):
+            train_data = train_data.shuffle(seed=shuffle_seed)
+        else:
+            train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py` around lines 98 - 99,
The call to train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) fails for
streaming (Iterable) datasets because IterableDataset.shuffle() doesn't accept
keep_in_memory; update the code that checks shuffle_seed to detect streaming
datasets (e.g., via whatever marker load_streaming_fn sets or by checking
hasattr(train_data, "__iter__") vs __len__/isinstance of IterableDataset) and
branch: for non-streaming datasets call train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) as before, and for streaming/iterable datasets call
train_data.shuffle(seed=shuffle_seed) without keep_in_memory; ensure you modify
the block that references shuffle_seed and train_data.shuffle so runtime errors
are avoided when load_streaming_fn() returns a streaming dataset.

5. Run ``pruning_ckpts`` (rank 0 only) then barrier.
"""
set_seed(SEED)
dist.setup(timeout=timedelta(10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
from datetime import timedelta

print("timedelta(10) =", timedelta(10))
print("timedelta(seconds=10) =", timedelta(seconds=10))
print("timedelta(minutes=10) =", timedelta(minutes=10))
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 162


🏁 Script executed:

# Check the file and context around line 213
cd tests/gpu/torch/puzzletron && wc -l test_bypass.py

Repository: NVIDIA/Model-Optimizer

Length of output: 84


🏁 Script executed:

# View the context around line 213
sed -n '205,220p' tests/gpu/torch/puzzletron/test_bypass.py

Repository: NVIDIA/Model-Optimizer

Length of output: 749


🏁 Script executed:

# Search for other dist.setup() calls to check for timeout patterns
rg "dist\.setup\(" tests/gpu/torch/puzzletron/test_bypass.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 209


🏁 Script executed:

# Search across the codebase for dist.setup() timeout patterns
rg "dist\.setup\(" --type py -A 1 | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1633


Fix timedelta(10) to use explicit time units.

timedelta(10) specifies 10 days, not 10 seconds or 10 minutes. This pattern appears across multiple test and example files and would cause GPU tests to hang excessively (10 days) on deadlock. Use explicit units instead: timedelta(seconds=10) or timedelta(minutes=10).

Affected locations include:

  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py (2 occurrences)
  • examples/puzzletron/main.py (2 occurrences)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_bypass.py` at line 213, The timeout passed to
dist.setup uses timedelta(10) which means 10 days; change it to an explicit unit
like timedelta(seconds=10) (or timedelta(minutes=10) if intended) to avoid
10-day test hangs — locate the call to dist.setup (symbol: dist.setup) in
tests/gpu/torch/puzzletron/test_bypass.py and the other listed files and replace
timedelta(10) with timedelta(seconds=10) (or the correct unit) in each
occurrence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants