Skip to content

Commit 1a1d048

Browse files
authored
[Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100 (#6963)
1 parent 61a9079 commit 1a1d048

9 files changed

Lines changed: 1247 additions & 73 deletions

File tree

custom_ops/gpu_ops/moe/depermute_prefill_combine.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,19 @@ std::vector<paddle::Tensor> DepermutePrefillCombine(
172172
case paddle::DataType::FLOAT8_E4M3FN: {
173173
switch (topk) {
174174
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
175+
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 6)
175176
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
176177
default:
177-
PD_THROW("Unsupported topk value, must be 4 or 8");
178+
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
178179
}
179180
}
180181
case paddle::DataType::BFLOAT16: {
181182
switch (topk) {
182183
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
184+
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 6)
183185
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
184186
default:
185-
PD_THROW("Unsupported topk value, must be 4 or 8");
187+
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
186188
}
187189
}
188190
default:

custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
217217
switch (topk) {
218218
DISPATCH_TOPK(
219219
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
220+
DISPATCH_TOPK(
221+
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 6)
220222
DISPATCH_TOPK(
221223
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
222224
default:
223-
PD_THROW("Unsupported topk value, must be 4 or 8");
225+
PD_THROW("Unsupported topk value, must be 4 or 6 or 8");
224226
}
225227
}
226228
case paddle::DataType::INT32: {

fastdeploy/envs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ def _validate_split_kv_size(value: int) -> int:
6262
"FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
6363
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
6464
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
65-
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass" and "flashinfer-trtllm" can be set currently.
65+
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently.
6666
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
67+
# Set nvfp4 load interleaved weight scale.
68+
"FD_NVFP4_LOAD_BLOCKSCALE_LEAVE": lambda: os.getenv("FD_NVFP4_LOAD_BLOCKSCALE_LEAVE", "0"),
6769
# Set mxfp4 backend."flashinfer" can be set currently.
6870
"FD_MOE_MXFP4_BACKEND": lambda: os.getenv("FD_MOE_MXFP4_BACKEND", "flashinfer"),
6971
# Whether to use Machete for wint4 dense gemm.

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,15 @@ def create_buffer(self):
163163
if self.deepep_buffer is not None:
164164
self.clear_buffer()
165165

166+
num_qps_per_rank = max(24, self.num_experts // self.ep_size)
166167
if self.splitwise_role == "mixed":
167168
logger.info("Initializing mixed mode buffer (low latency).")
168169
self.deepep_buffer = deep_ep.Buffer(
169170
self.group,
170171
self.num_nvl_bytes,
171172
self.num_rdma_bytes,
172173
low_latency_mode=True,
173-
num_qps_per_rank=24,
174+
num_qps_per_rank=num_qps_per_rank,
174175
)
175176
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
176177
else:
@@ -183,7 +184,7 @@ def create_buffer(self):
183184
self.num_nvl_bytes,
184185
self.num_rdma_bytes,
185186
low_latency_mode=True,
186-
num_qps_per_rank=24,
187+
num_qps_per_rank=num_qps_per_rank,
187188
)
188189
else:
189190
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
@@ -199,7 +200,7 @@ def _create_low_latency_buffer(self):
199200
if self.ep_size // 8 > 1:
200201
num_qps_per_rank_now = self.ep_size // 8
201202
else:
202-
num_qps_per_rank_now = 1
203+
num_qps_per_rank_now = self.num_experts // self.ep_size
203204
self.deepep_buffer = deep_ep.Buffer(
204205
self.group,
205206
self.num_nvl_bytes,
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Any, Optional
18+
19+
import paddle
20+
21+
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
22+
23+
24+
def _dtype_str(dtype) -> str:
25+
"""Normalize dtype to string, handling both paddle and torch proxy dtypes."""
26+
return str(dtype).split(".")[-1]
27+
28+
29+
def _is_dtype(tensor, *dtype_names: str) -> bool:
30+
"""Check tensor dtype by name, compatible with both paddle and torch proxy tensors."""
31+
return _dtype_str(tensor.dtype) in dtype_names
32+
33+
34+
def _perm(tensor, *dims):
35+
"""Permute tensor dims, compatible with both paddle (transpose) and torch proxy (permute)."""
36+
try:
37+
return tensor.transpose(list(dims))
38+
except TypeError:
39+
return tensor.permute(*dims)
40+
41+
42+
def get_cute_dtype(input) -> str:
43+
s = _dtype_str(input.dtype)
44+
if s == "bfloat16":
45+
return "bfloat16"
46+
elif s == "float16":
47+
return "float16"
48+
elif s == "float32":
49+
return "float32"
50+
else:
51+
raise ValueError(f"Unsupported cute dtype {input.dtype}")
52+
53+
54+
def flashinfer_cutedsl_moe_masked(
55+
hidden_states: tuple,
56+
input_global_scale: paddle.Tensor,
57+
w1: paddle.Tensor,
58+
w1_blockscale: paddle.Tensor,
59+
w1_alpha: paddle.Tensor,
60+
w2: paddle.Tensor,
61+
a2_global_scale: paddle.Tensor,
62+
w2_blockscale: paddle.Tensor,
63+
w2_alpha: paddle.Tensor,
64+
masked_m: paddle.Tensor,
65+
down_sm_count: Optional[int] = None,
66+
down_signals: Optional[paddle.Tensor] = None,
67+
down_start_event: Optional[Any] = None,
68+
):
69+
"""
70+
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL kernels.
71+
72+
Args:
73+
hidden_states: Either of the following:
74+
* (paddle.Tensor, None): [num_experts, m, k] bf16 — not pre-quantized
75+
* (paddle.Tensor, paddle.Tensor): [m, k//2, num_experts] uint8,
76+
[m, k//16, num_experts] float8_e4m3fn — pre-quantized FP4 from dispatch
77+
input_global_scale: (l,) float32, value is 1/input_scale per expert
78+
w1: [l, 2*n, k//2] uint8, FP4-packed gate+up projection weights
79+
w1_blockscale: float8_e4m3fn blockscale for w1
80+
w1_alpha: (l,) float32, = input_scale * w1_weight_scale_2
81+
w2: [l, k, n//2] uint8, FP4-packed down projection weights
82+
a2_global_scale: (l,) float32, 1/input_scale for GEMM2
83+
w2_blockscale: float8_e4m3fn blockscale for w2
84+
w2_alpha: (l,) float32, = input_scale * w2_weight_scale_2
85+
masked_m: (l,) int32, valid token count per expert; max(masked_m) == m
86+
87+
Returns:
88+
paddle.Tensor: [num_experts, m, k] bf16
89+
"""
90+
from flashinfer import (
91+
scaled_fp4_grouped_quantize,
92+
silu_and_mul_scaled_nvfp4_experts_quantize,
93+
)
94+
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
95+
96+
# === Dtype assertions ===
97+
# Use string-based dtype check to be compatible with both paddle and torch proxy tensors
98+
assert _is_dtype(w1, "uint8"), f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
99+
assert _is_dtype(w1_blockscale, "float8_e4m3fn"), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
100+
assert _is_dtype(w1_alpha, "float32"), f"w1_alpha must be float32, got {w1_alpha.dtype}"
101+
assert _is_dtype(w2, "uint8"), f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
102+
assert _is_dtype(a2_global_scale, "float32"), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
103+
assert _is_dtype(w2_blockscale, "float8_e4m3fn"), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
104+
assert _is_dtype(w2_alpha, "float32"), f"w2_alpha must be float32, got {w2_alpha.dtype}"
105+
assert len(hidden_states) == 2, f"hidden_states must be a tuple of length 2, got {len(hidden_states)}"
106+
107+
# intermediate_size derived from w2 last dimension
108+
n = w2.shape[-1] * 2
109+
110+
if hidden_states[1] is not None:
111+
# Pre-quantized path: tokens already FP4-packed by dispatch
112+
# a_q: [m, k//2, num_experts] uint8
113+
# a_q_sf:[m, k//16, num_experts] float8_e4m3fn
114+
a_q = hidden_states[0].view(paddle.uint8)
115+
a_q_sf = hidden_states[1].view(paddle.float8_e4m3fn)
116+
m, k_by_2, num_experts = a_q.shape
117+
k = k_by_2 * 2
118+
else:
119+
# Standard path: bf16 [num_experts, m, k], quantize to FP4 here
120+
num_experts, m, k = hidden_states[0].shape
121+
122+
assert _is_dtype(
123+
input_global_scale, "float32"
124+
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
125+
assert list(input_global_scale.shape) == [
126+
num_experts
127+
], f"input_global_scale must be (l,), got {input_global_scale.shape}"
128+
129+
a_q, a_q_sf = scaled_fp4_grouped_quantize(
130+
hidden_states[0],
131+
masked_m,
132+
input_global_scale,
133+
)
134+
135+
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n={2*n}, got {w1.shape[-2]}"
136+
assert w1.shape[-1] * 2 == k, f"w1 last dim * 2 must equal k={k}, got {w1.shape[-1] * 2}"
137+
assert (
138+
w2.shape[-2] == k and w2.shape[-1] == n // 2
139+
), f"w2 shape mismatch, got {list(w2.shape[-2:])}, expected [{k}, {n // 2}]"
140+
assert list(w1_alpha.shape) == [num_experts], f"w1_alpha must be (l,), got {w1_alpha.shape}"
141+
assert list(a2_global_scale.shape) == [num_experts], f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
142+
assert list(w2_alpha.shape) == [num_experts], f"w2_alpha must be (l,), got {w2_alpha.shape}"
143+
144+
assert _is_dtype(a_q, "uint8")
145+
assert _is_dtype(a_q_sf, "float8_e4m3fn")
146+
147+
ab_dtype = "float4_e2m1fn"
148+
sf_dtype = "float8_e4m3fn"
149+
c_dtype = "bfloat16"
150+
sf_vec_size = 16
151+
152+
# === GEMM1: gate+up projection ===
153+
# grouped_gemm_nt_masked requires output in [m, 2*n, l] layout
154+
gateup_output = paddle.empty([num_experts, m, n * 2], dtype=paddle.bfloat16)
155+
gateup_output = gateup_output.transpose([1, 2, 0]) # [m, 2*n, num_experts]
156+
157+
# w1: [E, 2*n, k//2] → _perm(., 1, 2, 0) → [2*n, k//2, E]
158+
# w1_blockscale:[E, 2*n, k//G] → _perm(., 1, 2, 0) → [2*n, k//G, E]
159+
# Both must share the same expert-last layout for grouped_gemm_nt_masked.
160+
grouped_gemm_nt_masked(
161+
(a_q, a_q_sf),
162+
(_perm(w1, 1, 2, 0), _perm(w1_blockscale, 1, 2, 0)),
163+
gateup_output,
164+
masked_m,
165+
ab_dtype=ab_dtype,
166+
sf_dtype=sf_dtype,
167+
c_dtype=c_dtype,
168+
sf_vec_size=sf_vec_size,
169+
alpha=w1_alpha.reshape([1, 1, num_experts]),
170+
alpha_dtype=get_cute_dtype(w1_alpha),
171+
) # fills gateup_output in logical [m, 2*n, l]
172+
173+
# === SiLU + mul + quantize intermediate activations to FP4 ===
174+
# Input expected as [num_experts, m, 2*n]
175+
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
176+
gateup_output.transpose([2, 0, 1]), # [num_experts, m, 2*n]
177+
masked_m,
178+
a2_global_scale,
179+
)
180+
181+
if down_start_event is not None:
182+
down_start_event.record()
183+
184+
# === GEMM2: down projection ===
185+
# grouped_gemm_nt_masked requires output in [m, k, l] layout
186+
out = paddle.empty([num_experts, m, k], dtype=paddle.bfloat16)
187+
out = out.transpose([1, 2, 0]) # [m, k, num_experts]
188+
189+
# w2: [E, k, n//2] → _perm(., 1, 2, 0) → [k, n//2, E]
190+
# w2_blockscale:[E, k, n//G] → _perm(., 1, 2, 0) → [k, n//G, E]
191+
# Both must share the same expert-last layout for grouped_gemm_nt_masked.
192+
grouped_gemm_nt_masked(
193+
(diq, diq_sf),
194+
(_perm(w2, 1, 2, 0), _perm(w2_blockscale, 1, 2, 0)),
195+
out,
196+
masked_m,
197+
ab_dtype=ab_dtype,
198+
sf_dtype=sf_dtype,
199+
c_dtype=c_dtype,
200+
sf_vec_size=sf_vec_size,
201+
alpha=w2_alpha.reshape([1, 1, num_experts]),
202+
alpha_dtype=get_cute_dtype(w2_alpha),
203+
**(
204+
dict(
205+
sm_count=down_sm_count,
206+
dst_signals=down_signals,
207+
)
208+
if down_sm_count is not None or down_signals is not None
209+
else {}
210+
),
211+
) # fills out in logical [m, k, l]
212+
213+
# Return [num_experts, m, k]
214+
return out.transpose([2, 0, 1])

0 commit comments

Comments
 (0)