Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) | ✅ | ✅ | ✅ | ✅ | | ✅ |

## Resources

Expand Down
135 changes: 118 additions & 17 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
import re
import subprocess
import sys
import warnings
from pathlib import Path
Expand All @@ -35,13 +36,17 @@
import modelopt.torch.quantization as mtq

"""
This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4,
or using auto quantization for optimal per-layer quantization.
Quantize a timm vision model and export to ONNX for TensorRT deployment.

Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.

The script will:
1. Given the model name, create a timm torch model.
2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode.
3. Export the quantized torch model to ONNX format.
1. Load a pretrained timm model (e.g., ViT, Swin, ResNet).
2. Quantize the model using the specified mode. For models with Conv2d layers,
Conv2d quantization is automatically overridden for TensorRT compatibility
(FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ).
Comment thread
coderabbitai[bot] marked this conversation as resolved.
3. Export the quantized model to ONNX with FP16 weights.
4. Optionally evaluate accuracy on ImageNet-1k before and after quantization.
"""


Expand Down Expand Up @@ -109,7 +114,8 @@ def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
r"downsample|maxpool|global_pool).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -147,6 +153,36 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
)


def _calibrate_uncalibrated_quantizers(model, data_loader):
"""Calibrate FP8 quantizers that weren't calibrated by mtq.quantize().

When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not
be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard
calibration. This function explicitly calibrates those uncalibrated quantizers.
"""
uncalibrated = []
for _, module in model.named_modules():
for attr_name in ("input_quantizer", "weight_quantizer"):
if not hasattr(module, attr_name):
continue
quantizer = getattr(module, attr_name)
if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"):
quantizer.enable_calib()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
uncalibrated.append(quantizer)

if not uncalibrated:
return

model.eval()
with torch.no_grad():
for batch in data_loader:
model(batch)

for quantizer in uncalibrated:
quantizer.disable_calib()
quantizer.load_calib_amax()


def quantize_model(model, config, data_loader=None):
"""Quantize the model using the given config and calibration data."""
if data_loader is not None:
Expand All @@ -159,6 +195,10 @@ def forward_loop(model):
else:
quantized_model = mtq.quantize(model, config)

# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
if data_loader is not None:
_calibrate_uncalibrated_quantizers(quantized_model, data_loader)

mtq.disable_quantizer(quantized_model, filter_func)
return quantized_model

Expand All @@ -185,6 +225,38 @@ def _disable_inplace_relu(model):
module.inplace = False


def _override_conv2d_to_fp8(model, data_loader):
"""Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8.

TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them.
"""
overridden = []
for _, module in model.named_modules():
if not isinstance(module, torch.nn.Conv2d):
continue
for attr_name in ("input_quantizer", "weight_quantizer"):
if not hasattr(module, attr_name):
continue
quantizer = getattr(module, attr_name)
if quantizer.is_enabled and quantizer.block_sizes:
# Override to FP8 per-tensor
quantizer.block_sizes = None
quantizer._num_bits = (4, 3)
quantizer._axis = None
Comment thread
ajrasane marked this conversation as resolved.
Outdated
quantizer.enable_calib()
overridden.append(quantizer)

if overridden:
model.eval()
with torch.no_grad():
for batch in data_loader:
model(batch["image"])
for quantizer in overridden:
quantizer.disable_calib()
quantizer.load_calib_amax()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


def auto_quantize_model(
model,
data_loader,
Expand Down Expand Up @@ -233,6 +305,10 @@ def auto_quantize_model(
verbose=True,
)

# Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility.
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
_override_conv2d_to_fp8(quantized_model, data_loader)

# Disable quantization for specified layers
mtq.disable_quantizer(quantized_model, filter_func)

Expand Down Expand Up @@ -320,6 +396,11 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--trt_build",
action="store_true",
help="Build a TensorRT engine from the exported ONNX model using trtexec.",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
Expand Down Expand Up @@ -378,18 +459,18 @@ def main():
args.num_score_steps,
)
else:
# Standard quantization - only load calibration data if needed
# Standard quantization - load calibration data
# Note: MXFP8 is dynamic and does not need calibration itself, but when
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
# quantizers require calibration data.
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)

quantized_model = quantize_model(model, config, data_loader)

Expand Down Expand Up @@ -421,6 +502,26 @@ def main():

print(f"Quantized ONNX model is saved to {args.onnx_save_path}")

if args.trt_build:
build_trt_engine(args.onnx_save_path)


def build_trt_engine(onnx_path):
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
"--builderOptimizationLevel=4",
]
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
raise RuntimeError(
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
print("TensorRT engine build succeeded.")


if __name__ == "__main__":
main()
46 changes: 43 additions & 3 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Utility functions related to Onnx."""

import base64
import contextlib
import inspect
import json
import logging
Expand Down Expand Up @@ -402,6 +403,29 @@ def is_fp8_quantized(model: nn.Module) -> bool:
return False


@contextlib.contextmanager
def _disable_fp8_conv_weight_quantizers(model: nn.Module):
"""Temporarily disable FP8 weight quantizers on Conv layers during ONNX export.

The TorchScript ONNX exporter requires static kernel shapes for Conv operations,
but FP8 weight quantization (TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear)
produces dynamic-shape outputs that break this requirement. Disabling Conv weight
quantizers during export allows the Conv to export with static-shape FP16/FP32
weights. Conv activations still have FP8 QDQ nodes (input quantizers remain enabled).
"""
disabled = []
for _, module in model.named_modules():
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if hasattr(module, "weight_quantizer") and module.weight_quantizer.is_enabled:
module.weight_quantizer.disable()
disabled.append(module)
try:
yield
finally:
for module in disabled:
module.weight_quantizer.enable()


def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Real quantizes the weights in the onnx model.

Expand Down Expand Up @@ -522,7 +546,11 @@ def get_onnx_bytes_and_metadata(
input_none_names = list(set(tree_spec_input.names) - set(input_names))

use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
is_fp4_quantized(model)
or is_mxfp8_quantized(model)
or is_fp8_quantized(model)
or is_int8_quantized(model)
or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()

Expand Down Expand Up @@ -556,7 +584,14 @@ def get_onnx_bytes_and_metadata(
if is_fp4_quantized(model) or is_mxfp8_quantized(model)
else nullcontext()
)
with torch.inference_mode(), autocast, quantizer_context:
# Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static
# kernel shapes, but FP8 DequantizeLinear produces dynamic shapes.
conv_wq_context = (
_disable_fp8_conv_weight_quantizers(model)
if is_fp8_quantized(model)
else nullcontext()
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
with torch.inference_mode(), autocast, quantizer_context, conv_wq_context:
additional_kwargs = {}
if not dynamo_export:
additional_kwargs["dynamic_axes"] = dynamic_axes
Expand Down Expand Up @@ -598,7 +633,12 @@ def get_onnx_bytes_and_metadata(
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

if weights_dtype in ["fp16", "bf16"]:
if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model):
if (
is_int4_quantized(model)
or is_mxfp8_quantized(model)
or is_fp8_quantized(model)
or is_int8_quantized(model)
):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
Expand Down
15 changes: 13 additions & 2 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,20 @@ def export_fp4(

@contextlib.contextmanager
def configure_linear_module_onnx_quantizers(model):
"""Sets the onnx export attributes for the given model."""
"""Sets the onnx export attributes for the given model.

For modules with block quantization (NVFP4/MXFP8):
- Weight quantizers use "static" export (TRT_FP4QDQ for NVFP4, DQ-only for MXFP8)
- Input/activation quantizers use "dynamic" export (TRT_FP4DynamicQuantize, etc.)

This must be set for ALL modules with block quantization, not just nn.Linear,
because models like ResNet have non-Linear modules (e.g., MaxPool2d) with NVFP4/MXFP8
input quantizers that would otherwise default to the static path and produce
TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle).
"""
for _, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes:
module.input_quantizer._onnx_quantizer_type = "dynamic"
if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes:
module.weight_quantizer._onnx_quantizer_type = "static"
yield
1 change: 1 addition & 0 deletions tests/_test_utils/torch/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False):
# "dm_nfnet_f0",
"efficientnet_b0",
"swin_tiny_patch4_window7_224",
"resnet50",
],
_create_timm_fn,
),
Expand Down
36 changes: 2 additions & 34 deletions tests/examples/torch_onnx/test_torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# limitations under the License.


import os
import subprocess

import pytest
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command

Expand All @@ -28,42 +25,16 @@
"vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'),
"swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'),
"swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'),
"resnet50": ("resnet50", None),
}

# Builder optimization level: 4 for low-bit modes, 3 otherwise
_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"}


def _verify_trt_engine_build(onnx_save_path, quantize_mode):
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
example_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
)
onnx_path = os.path.join(example_dir, onnx_save_path)
assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}"

opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3"
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
f"--builderOptimizationLevel={opt_level}",
]

result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"TensorRT engine build failed for {onnx_save_path} "
f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}"
)


@pytest.mark.parametrize("quantize_mode", _QUANT_MODES)
@pytest.mark.parametrize("model_key", list(_MODELS))
def test_torch_onnx(model_key, quantize_mode):
timm_model_name, model_kwargs = _MODELS[model_key]
onnx_save_path = f"{model_key}.{quantize_mode}.onnx"

# Step 1: Quantize and export to ONNX
cmd_parts = extend_cmd_parts(
["python", "torch_quant_to_onnx.py"],
timm_model_name=timm_model_name,
Expand All @@ -73,8 +44,5 @@ def test_torch_onnx(model_key, quantize_mode):
calibration_data_size="1",
num_score_steps="1",
)
cmd_parts.append("--no_pretrained")
cmd_parts.extend(["--no_pretrained", "--trt_build"])
run_example_command(cmd_parts, "torch_onnx")

# Step 2: Verify the exported ONNX model builds a TensorRT engine
_verify_trt_engine_build(onnx_save_path, quantize_mode)
Loading