3333BlockSizes = splash_attention_kernel .BlockSizes
3434
3535AxisNames = tuple [str , ...]
36-
36+ # Physical axis names for device meshes.
37+ DATA = "data"
38+ FSDP = "fsdp"
39+ TENSOR = "tensor"
40+ # Logical axis names for model parameters and activations.
3741BATCH = "activation_batch"
3842LENGTH = "activation_length"
3943KV_LENGTH = "activation_kv_length"
4852WAN2_2 = "wan2.2"
4953
5054WAN_MODEL = WAN2_1
55+
56+ # For setting self/cross attention independently in splash kernel
57+ SELF_ATTN_HEAD = "activation_self_attn_heads"
58+ SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
59+ SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
60+ CROSS_ATTN_HEAD = "activation_cross_attn_heads"
61+ CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
62+ CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
63+
64+
65+ WAN_MODEL = "Wan2.1"
66+
67+ ### Common axis rules for ring attention ###
68+ RING_ATTENTION_AXIS_RULES = [
69+ [SELF_ATTN_HEAD , None ],
70+ [SELF_ATTN_Q_LENGTH , FSDP ],
71+ [SELF_ATTN_KV_LENGTH , FSDP ],
72+ [CROSS_ATTN_HEAD , None ],
73+ [CROSS_ATTN_Q_LENGTH , FSDP ],
74+ [CROSS_ATTN_KV_LENGTH , FSDP ],
75+ ]
76+
77+ SEQUENCE_PARALLEL_AXIS_RULES = [
78+ [SELF_ATTN_HEAD , None ],
79+ [SELF_ATTN_Q_LENGTH , FSDP ],
80+ [SELF_ATTN_KV_LENGTH , None ],
81+ [CROSS_ATTN_HEAD , None ],
82+ [CROSS_ATTN_Q_LENGTH , FSDP ],
83+ [CROSS_ATTN_KV_LENGTH , None ],
84+ ]
0 commit comments