Skip to content

Commit aca1b79

Browse files
authored
Merge branch 'develop' into abort_requests
2 parents 7e0108d + 5c60e2f commit aca1b79

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,17 @@ def moe_topk_select(
245245
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
246246
) # [seq_len, n_group]
247247
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
248-
group_mask = paddle.zeros_like(group_scores).put_along_axis(
249-
group_idx, paddle.to_tensor(1.0, dtype=group_scores.dtype), axis=-1
248+
group_mask = paddle.sum(
249+
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
250+
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
250251
)
251252
score_mask = (
252253
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
253254
) # [seq_len, n_experts]
254255
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
255256

256257
_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
257-
topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1)
258+
topk_weights = paddle.index_sample(gate_probs, topk_ids)
258259

259260
# normalize combine weights
260261
if renormalize:

0 commit comments

Comments
 (0)