Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions modelopt/onnx/graph_surgery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,74 @@
from .gqa_replacement import replace_attention_with_gqa
from .utils.dtype_conversion import convert_fp16_to_bf16

# Registry of available graph surgeries.
# Maps surgery name -> (function, input_path_param_name)
# input_path_param_name is the keyword argument name for the input model path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see input_param_name on Line#131. Should be same name here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, its removed now?

# (different surgeries use different names: model_path, encoder_path, input_path)
_SURGERY_REGISTRY = {
Comment thread
vishalpandya1990 marked this conversation as resolved.
"replace-gqa": (replace_attention_with_gqa, "model_path"),
"add-cross-kv": (add_cross_kv_to_encoder, "encoder_path"),
"convert-bf16": (convert_fp16_to_bf16, "input_path"),
"transpose-dq": (transpose_dequantize_linear_weights, "model_path"),
}


def get_available_surgeries() -> list[str]:
"""Return a list of all registered graph surgery names."""
return list(_SURGERY_REGISTRY.keys())


def run_graph_surgery(
surgery_name: str,
input_path: str,
output_path: str,
**kwargs,
):
"""Run a graph surgery by name.

This is the unified entry point for all graph surgeries. It dispatches
to the appropriate surgery function based on the surgery name.

When new surgeries are added to the registry, they are automatically
available through this function without any changes to calling code.

Args:
surgery_name: Name of the surgery to run (e.g. 'replace-gqa', 'transpose-dq').
Use get_available_surgeries() to see all available options.
input_path: Path to the input ONNX model.
output_path: Path to save the output ONNX model.
**kwargs: Surgery-specific parameters. Passed directly to the surgery function.

Returns:
The return value of the surgery function (typically ModelProto or dict).

Raises:
ValueError: If surgery_name is not registered.

Example:
>>> from modelopt.onnx.graph_surgery import run_graph_surgery, get_available_surgeries
>>> print(get_available_surgeries())
['replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq']
>>> run_graph_surgery(
... "replace-gqa",
... input_path="model.onnx",
... output_path="model_gqa.onnx",
... hf_model_id="meta-llama/Llama-2-7b-hf",
... )
"""
if surgery_name not in _SURGERY_REGISTRY:
available = ", ".join(f"'{s}'" for s in _SURGERY_REGISTRY)
raise ValueError(f"Unknown surgery: '{surgery_name}'. Available surgeries: {available}")

func, input_param_name = _SURGERY_REGISTRY[surgery_name]
return func(**{input_param_name: input_path, "output_path": output_path}, **kwargs)


__all__ = [
"add_cross_kv_to_encoder",
"convert_fp16_to_bf16",
"get_available_surgeries",
"replace_attention_with_gqa",
"run_graph_surgery",
"transpose_dequantize_linear_weights",
]
Loading