|
| 1 | +""" |
| 2 | +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import Any, Optional |
| 18 | + |
| 19 | +import paddle |
| 20 | + |
| 21 | +paddle.compat.enable_torch_proxy(scope={"flashinfer"}) |
| 22 | + |
| 23 | + |
| 24 | +def _dtype_str(dtype) -> str: |
| 25 | + """Normalize dtype to string, handling both paddle and torch proxy dtypes.""" |
| 26 | + return str(dtype).split(".")[-1] |
| 27 | + |
| 28 | + |
| 29 | +def _is_dtype(tensor, *dtype_names: str) -> bool: |
| 30 | + """Check tensor dtype by name, compatible with both paddle and torch proxy tensors.""" |
| 31 | + return _dtype_str(tensor.dtype) in dtype_names |
| 32 | + |
| 33 | + |
| 34 | +def _perm(tensor, *dims): |
| 35 | + """Permute tensor dims, compatible with both paddle (transpose) and torch proxy (permute).""" |
| 36 | + try: |
| 37 | + return tensor.transpose(list(dims)) |
| 38 | + except TypeError: |
| 39 | + return tensor.permute(*dims) |
| 40 | + |
| 41 | + |
| 42 | +def get_cute_dtype(input) -> str: |
| 43 | + s = _dtype_str(input.dtype) |
| 44 | + if s == "bfloat16": |
| 45 | + return "bfloat16" |
| 46 | + elif s == "float16": |
| 47 | + return "float16" |
| 48 | + elif s == "float32": |
| 49 | + return "float32" |
| 50 | + else: |
| 51 | + raise ValueError(f"Unsupported cute dtype {input.dtype}") |
| 52 | + |
| 53 | + |
| 54 | +def flashinfer_cutedsl_moe_masked( |
| 55 | + hidden_states: tuple, |
| 56 | + input_global_scale: paddle.Tensor, |
| 57 | + w1: paddle.Tensor, |
| 58 | + w1_blockscale: paddle.Tensor, |
| 59 | + w1_alpha: paddle.Tensor, |
| 60 | + w2: paddle.Tensor, |
| 61 | + a2_global_scale: paddle.Tensor, |
| 62 | + w2_blockscale: paddle.Tensor, |
| 63 | + w2_alpha: paddle.Tensor, |
| 64 | + masked_m: paddle.Tensor, |
| 65 | + down_sm_count: Optional[int] = None, |
| 66 | + down_signals: Optional[paddle.Tensor] = None, |
| 67 | + down_start_event: Optional[Any] = None, |
| 68 | +): |
| 69 | + """ |
| 70 | + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL kernels. |
| 71 | +
|
| 72 | + Args: |
| 73 | + hidden_states: Either of the following: |
| 74 | + * (paddle.Tensor, None): [num_experts, m, k] bf16 — not pre-quantized |
| 75 | + * (paddle.Tensor, paddle.Tensor): [m, k//2, num_experts] uint8, |
| 76 | + [m, k//16, num_experts] float8_e4m3fn — pre-quantized FP4 from dispatch |
| 77 | + input_global_scale: (l,) float32, value is 1/input_scale per expert |
| 78 | + w1: [l, 2*n, k//2] uint8, FP4-packed gate+up projection weights |
| 79 | + w1_blockscale: float8_e4m3fn blockscale for w1 |
| 80 | + w1_alpha: (l,) float32, = input_scale * w1_weight_scale_2 |
| 81 | + w2: [l, k, n//2] uint8, FP4-packed down projection weights |
| 82 | + a2_global_scale: (l,) float32, 1/input_scale for GEMM2 |
| 83 | + w2_blockscale: float8_e4m3fn blockscale for w2 |
| 84 | + w2_alpha: (l,) float32, = input_scale * w2_weight_scale_2 |
| 85 | + masked_m: (l,) int32, valid token count per expert; max(masked_m) == m |
| 86 | +
|
| 87 | + Returns: |
| 88 | + paddle.Tensor: [num_experts, m, k] bf16 |
| 89 | + """ |
| 90 | + from flashinfer import ( |
| 91 | + scaled_fp4_grouped_quantize, |
| 92 | + silu_and_mul_scaled_nvfp4_experts_quantize, |
| 93 | + ) |
| 94 | + from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked |
| 95 | + |
| 96 | + # === Dtype assertions === |
| 97 | + # Use string-based dtype check to be compatible with both paddle and torch proxy tensors |
| 98 | + assert _is_dtype(w1, "uint8"), f"w1 must be uint8 (fp4 packed), got {w1.dtype}" |
| 99 | + assert _is_dtype(w1_blockscale, "float8_e4m3fn"), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" |
| 100 | + assert _is_dtype(w1_alpha, "float32"), f"w1_alpha must be float32, got {w1_alpha.dtype}" |
| 101 | + assert _is_dtype(w2, "uint8"), f"w2 must be uint8 (fp4 packed), got {w2.dtype}" |
| 102 | + assert _is_dtype(a2_global_scale, "float32"), f"a2_global_scale must be float32, got {a2_global_scale.dtype}" |
| 103 | + assert _is_dtype(w2_blockscale, "float8_e4m3fn"), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" |
| 104 | + assert _is_dtype(w2_alpha, "float32"), f"w2_alpha must be float32, got {w2_alpha.dtype}" |
| 105 | + assert len(hidden_states) == 2, f"hidden_states must be a tuple of length 2, got {len(hidden_states)}" |
| 106 | + |
| 107 | + # intermediate_size derived from w2 last dimension |
| 108 | + n = w2.shape[-1] * 2 |
| 109 | + |
| 110 | + if hidden_states[1] is not None: |
| 111 | + # Pre-quantized path: tokens already FP4-packed by dispatch |
| 112 | + # a_q: [m, k//2, num_experts] uint8 |
| 113 | + # a_q_sf:[m, k//16, num_experts] float8_e4m3fn |
| 114 | + a_q = hidden_states[0].view(paddle.uint8) |
| 115 | + a_q_sf = hidden_states[1].view(paddle.float8_e4m3fn) |
| 116 | + m, k_by_2, num_experts = a_q.shape |
| 117 | + k = k_by_2 * 2 |
| 118 | + else: |
| 119 | + # Standard path: bf16 [num_experts, m, k], quantize to FP4 here |
| 120 | + num_experts, m, k = hidden_states[0].shape |
| 121 | + |
| 122 | + assert _is_dtype( |
| 123 | + input_global_scale, "float32" |
| 124 | + ), f"input_global_scale must be float32, got {input_global_scale.dtype}" |
| 125 | + assert list(input_global_scale.shape) == [ |
| 126 | + num_experts |
| 127 | + ], f"input_global_scale must be (l,), got {input_global_scale.shape}" |
| 128 | + |
| 129 | + a_q, a_q_sf = scaled_fp4_grouped_quantize( |
| 130 | + hidden_states[0], |
| 131 | + masked_m, |
| 132 | + input_global_scale, |
| 133 | + ) |
| 134 | + |
| 135 | + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n={2*n}, got {w1.shape[-2]}" |
| 136 | + assert w1.shape[-1] * 2 == k, f"w1 last dim * 2 must equal k={k}, got {w1.shape[-1] * 2}" |
| 137 | + assert ( |
| 138 | + w2.shape[-2] == k and w2.shape[-1] == n // 2 |
| 139 | + ), f"w2 shape mismatch, got {list(w2.shape[-2:])}, expected [{k}, {n // 2}]" |
| 140 | + assert list(w1_alpha.shape) == [num_experts], f"w1_alpha must be (l,), got {w1_alpha.shape}" |
| 141 | + assert list(a2_global_scale.shape) == [num_experts], f"a2_global_scale must be (l,), got {a2_global_scale.shape}" |
| 142 | + assert list(w2_alpha.shape) == [num_experts], f"w2_alpha must be (l,), got {w2_alpha.shape}" |
| 143 | + |
| 144 | + assert _is_dtype(a_q, "uint8") |
| 145 | + assert _is_dtype(a_q_sf, "float8_e4m3fn") |
| 146 | + |
| 147 | + ab_dtype = "float4_e2m1fn" |
| 148 | + sf_dtype = "float8_e4m3fn" |
| 149 | + c_dtype = "bfloat16" |
| 150 | + sf_vec_size = 16 |
| 151 | + |
| 152 | + # === GEMM1: gate+up projection === |
| 153 | + # grouped_gemm_nt_masked requires output in [m, 2*n, l] layout |
| 154 | + gateup_output = paddle.empty([num_experts, m, n * 2], dtype=paddle.bfloat16) |
| 155 | + gateup_output = gateup_output.transpose([1, 2, 0]) # [m, 2*n, num_experts] |
| 156 | + |
| 157 | + # w1: [E, 2*n, k//2] → _perm(., 1, 2, 0) → [2*n, k//2, E] |
| 158 | + # w1_blockscale:[E, 2*n, k//G] → _perm(., 1, 2, 0) → [2*n, k//G, E] |
| 159 | + # Both must share the same expert-last layout for grouped_gemm_nt_masked. |
| 160 | + grouped_gemm_nt_masked( |
| 161 | + (a_q, a_q_sf), |
| 162 | + (_perm(w1, 1, 2, 0), _perm(w1_blockscale, 1, 2, 0)), |
| 163 | + gateup_output, |
| 164 | + masked_m, |
| 165 | + ab_dtype=ab_dtype, |
| 166 | + sf_dtype=sf_dtype, |
| 167 | + c_dtype=c_dtype, |
| 168 | + sf_vec_size=sf_vec_size, |
| 169 | + alpha=w1_alpha.reshape([1, 1, num_experts]), |
| 170 | + alpha_dtype=get_cute_dtype(w1_alpha), |
| 171 | + ) # fills gateup_output in logical [m, 2*n, l] |
| 172 | + |
| 173 | + # === SiLU + mul + quantize intermediate activations to FP4 === |
| 174 | + # Input expected as [num_experts, m, 2*n] |
| 175 | + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( |
| 176 | + gateup_output.transpose([2, 0, 1]), # [num_experts, m, 2*n] |
| 177 | + masked_m, |
| 178 | + a2_global_scale, |
| 179 | + ) |
| 180 | + |
| 181 | + if down_start_event is not None: |
| 182 | + down_start_event.record() |
| 183 | + |
| 184 | + # === GEMM2: down projection === |
| 185 | + # grouped_gemm_nt_masked requires output in [m, k, l] layout |
| 186 | + out = paddle.empty([num_experts, m, k], dtype=paddle.bfloat16) |
| 187 | + out = out.transpose([1, 2, 0]) # [m, k, num_experts] |
| 188 | + |
| 189 | + # w2: [E, k, n//2] → _perm(., 1, 2, 0) → [k, n//2, E] |
| 190 | + # w2_blockscale:[E, k, n//G] → _perm(., 1, 2, 0) → [k, n//G, E] |
| 191 | + # Both must share the same expert-last layout for grouped_gemm_nt_masked. |
| 192 | + grouped_gemm_nt_masked( |
| 193 | + (diq, diq_sf), |
| 194 | + (_perm(w2, 1, 2, 0), _perm(w2_blockscale, 1, 2, 0)), |
| 195 | + out, |
| 196 | + masked_m, |
| 197 | + ab_dtype=ab_dtype, |
| 198 | + sf_dtype=sf_dtype, |
| 199 | + c_dtype=c_dtype, |
| 200 | + sf_vec_size=sf_vec_size, |
| 201 | + alpha=w2_alpha.reshape([1, 1, num_experts]), |
| 202 | + alpha_dtype=get_cute_dtype(w2_alpha), |
| 203 | + **( |
| 204 | + dict( |
| 205 | + sm_count=down_sm_count, |
| 206 | + dst_signals=down_signals, |
| 207 | + ) |
| 208 | + if down_sm_count is not None or down_signals is not None |
| 209 | + else {} |
| 210 | + ), |
| 211 | + ) # fills out in logical [m, k, l] |
| 212 | + |
| 213 | + # Return [num_experts, m, k] |
| 214 | + return out.transpose([2, 0, 1]) |
0 commit comments