Skip to content

Commit 5b3eb3d

Browse files
[TRITON] Add Attention support to the bench_models benchmarking script (#2274)
* Add Attention support to bench_models.py * Add MHA layout CLI arg * Add support for batched_gemm_a16wfp4 * Refactor TP logic and _get_handler * Remove unified attention from this branch
1 parent 5f9fdc9 commit 5b3eb3d

5 files changed

Lines changed: 364 additions & 102 deletions

File tree

op_tests/op_benchmarks/triton/bench_batched_gemm_a16wfp4.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
import torch
33
import triton
44
import math
5+
import aiter.ops.triton.utils._triton.arch_info as arch_info
6+
from aiter.ops.triton.gemm.batched.batched_gemm_afp4wfp4_pre_quant import (
7+
batched_gemm_afp4wfp4_pre_quant,
8+
)
9+
from op_tests.triton_tests.gemm.batched.test_batched_gemm_a16wfp4 import (
10+
generate_batched_gemm_a16wfp4_inputs,
11+
)
512
from op_tests.op_benchmarks.triton.utils.argparse import (
613
get_parser,
714
add_argparse_ff,
@@ -14,10 +21,6 @@
1421
print_vgpr,
1522
get_caller_name_no_ext,
1623
)
17-
from aiter.ops.triton.gemm.batched.batched_gemm_afp4wfp4_pre_quant import (
18-
batched_gemm_afp4wfp4_pre_quant,
19-
)
20-
import aiter.ops.triton.utils._triton.arch_info as arch_info
2124

2225

2326
def bench_gemm_fn(
@@ -29,7 +32,7 @@ def bench_gemm_fn(
2932
layout: str,
3033
):
3134
c_dtype = torch.bfloat16
32-
x, w, x_scale, w_scale, y = generate_batched_gemm_afp4wfp4_pre_quant_inputs(
35+
x, w, x_scale, w_scale, y = generate_batched_gemm_a16wfp4_inputs(
3336
batch, M, N, K, c_dtype, layout=layout, output=True
3437
)
3538
# flops
@@ -145,7 +148,7 @@ def run_benchmark(args, defaults):
145148
run_shape_benchmark(args)
146149

147150

148-
def parse_args():
151+
def parse_args(args: list[str] | None = None):
149152
parser = get_parser("Batched MXFP4 x MXFP4 GEMM, Pre Quant")
150153
parser = add_argparse_ff(parser)
151154
parser.add_argument(
@@ -154,22 +157,22 @@ def parse_args():
154157
required=False,
155158
help="Batch size to be used when using --model flag.",
156159
)
157-
return get_ff_args(parser)
160+
return get_ff_args(parser, args=args)
158161

159162

160-
def main():
163+
def main(args: list[str] | None = None):
161164
if not (arch_info.is_fp4_avail()):
162165
print("MXFP4 is not available on this architecture")
163166
sys.exit()
164167

165-
args, defaults = parse_args()
168+
args, defaults = parse_args(args=args)
166169
if args.print_vgpr:
167170
print("Retrieving VGPR usage for Triton kernels...")
168171
fun = lambda: run_benchmark(args, defaults) # noqa: E731
169172
print_vgpr(fun, get_caller_name_no_ext())
170-
return 0
173+
return
171174
run_benchmark(args, defaults)
172175

173176

174177
if __name__ == "__main__":
175-
sys.exit(main())
178+
main()

op_tests/op_benchmarks/triton/bench_mla_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def model_benchmark_configs(args):
6262

6363
def benchmark(args):
6464
dtype = str_to_torch_dtype[args.dtype]
65-
torch.set_default_dtype(dtype)
6665

6766
configs = []
6867

0 commit comments

Comments
 (0)