Skip to content

Commit 90619c5

Browse files
committed
Reapply "Attention bug fixes, tokamax splash defaulting logic (#282)" (#287)
This reverts commit 503e9d6.
1 parent 5cbf844 commit 90619c5

5 files changed

Lines changed: 152 additions & 97 deletions

File tree

docs/attention_blocks_flowchart.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Attention block sizes
2+
3+
## Description
4+
- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass
5+
- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv"
6+
- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass
7+
- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q
8+
- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv
9+
- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv"
10+
- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q"
11+
- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv"
12+
- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead.
13+
14+
## Flowchart
15+
16+
Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes.
17+
18+
![alt text](attention_blocks_flowchart.png)
19+
20+
> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used
21+
22+
## How block sizes matter for perfomance and accuracy
23+
24+
Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly recommended to tune them.
25+
26+
Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes.
27+
28+
> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values.
29+
30+
> KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed.
229 KB
Loading

src/maxdiffusion/max_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,26 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504-
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
504+
attention_is_tokamax = "tokamax" in config.attention
505+
user_block_sizes:Dict[str, int] = config.flash_block_sizes
506+
if attention_is_tokamax:
507+
max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
508+
"Hence following flash block properties specified will be ignored:"
509+
f"block_q: {user_block_sizes['block_q']},"
510+
f"block_q_dq: {user_block_sizes.get('block_q_dq')},"
511+
f"block_kv_dq: {user_block_sizes.get('block_kv_dq')},"
512+
f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}"
513+
)
505514
flash_block_sizes = splash_attention_kernel.BlockSizes(
506-
block_q=config.flash_block_sizes["block_q"],
507-
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
508-
block_kv=config.flash_block_sizes["block_kv"],
509-
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
510-
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
511-
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
512-
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
513-
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
514-
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
515+
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"],
516+
block_kv_compute=user_block_sizes["block_kv_compute"],
517+
block_kv=user_block_sizes["block_kv"],
518+
block_q_dkv=user_block_sizes["block_q_dkv"],
519+
block_kv_dkv=user_block_sizes["block_kv_dkv"],
520+
block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"],
521+
block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"),
522+
block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"),
523+
use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"),
515524
)
516525
return flash_block_sizes
517526

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,15 @@ def _tpu_flash_attention(
189189
if flash_block_sizes:
190190
block_sizes = flash_block_sizes
191191
else:
192+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
192193
block_sizes = splash_attention_kernel.BlockSizes(
193-
block_q=min(q_max_block_size, query.shape[2]),
194+
block_q=block_size_q,
194195
block_kv_compute=min(kv_max_block_size, key.shape[2]),
195196
block_kv=min(kv_max_block_size, key.shape[2]),
196-
block_q_dkv=min(q_max_block_size, query.shape[2]),
197+
block_q_dkv=block_size_q,
197198
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
198199
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
199-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
200+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
200201
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201202
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202203
)

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 99 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
26-
26+
from flax.linen import partitioning as nn_partitioning
2727
from .. import pyconfig
2828
from ..max_utils import (create_device_mesh, get_flash_block_sizes)
2929
from ..models.wan.transformers.transformer_wan import (
@@ -53,6 +53,18 @@ class WanTransformerTest(unittest.TestCase):
5353

5454
def setUp(self):
5555
WanTransformerTest.dummy_data = {}
56+
pyconfig.initialize(
57+
[
58+
None,
59+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
60+
],
61+
unittest=True,
62+
)
63+
config = pyconfig.config
64+
self.config = config
65+
devices_array = create_device_mesh(config)
66+
self.mesh = Mesh(devices_array, config.mesh_axes)
67+
5668

5769
def test_rotary_pos_embed(self):
5870
batch_size = 1
@@ -70,28 +82,31 @@ def test_nnx_pixart_alpha_text_projection(self):
7082
key = jax.random.key(0)
7183
rngs = nnx.Rngs(key)
7284
dummy_caption = jnp.ones((1, 512, 4096))
73-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
74-
dummy_output = layer(dummy_caption)
75-
dummy_output.shape == (1, 512, 5120)
85+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
86+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
87+
dummy_output = layer(dummy_caption)
88+
dummy_output.shape == (1, 512, 5120)
7689

7790
def test_nnx_timestep_embedding(self):
7891
key = jax.random.key(0)
7992
rngs = nnx.Rngs(key)
8093

8194
dummy_sample = jnp.ones((1, 256))
82-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
83-
dummy_output = layer(dummy_sample)
84-
assert dummy_output.shape == (1, 5120)
95+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
96+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
97+
dummy_output = layer(dummy_sample)
98+
assert dummy_output.shape == (1, 5120)
8599

86100
def test_fp32_layer_norm(self):
87101
key = jax.random.key(0)
88102
rngs = nnx.Rngs(key)
89103
batch_size = 1
90104
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
91105
# expected same output shape with same dtype
92-
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
93-
dummy_output = layer(dummy_hidden_states)
94-
assert dummy_output.shape == dummy_hidden_states.shape
106+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
107+
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
108+
dummy_output = layer(dummy_hidden_states)
109+
assert dummy_output.shape == dummy_hidden_states.shape
95110

96111
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
97112
def test_wan_time_text_embedding(self):
@@ -102,20 +117,21 @@ def test_wan_time_text_embedding(self):
102117
time_freq_dim = 256
103118
time_proj_dim = 30720
104119
text_embed_dim = 4096
105-
layer = WanTimeTextImageEmbedding(
106-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
107-
)
120+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
121+
layer = WanTimeTextImageEmbedding(
122+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
123+
)
108124

109-
dummy_timestep = jnp.ones(batch_size)
125+
dummy_timestep = jnp.ones(batch_size)
110126

111-
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
112-
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
113-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
114-
dummy_timestep, dummy_encoder_hidden_states
115-
)
116-
assert temb.shape == (batch_size, dim)
117-
assert timestep_proj.shape == (batch_size, time_proj_dim)
118-
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
127+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
128+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
129+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
130+
dummy_timestep, dummy_encoder_hidden_states
131+
)
132+
assert temb.shape == (batch_size, dim)
133+
assert timestep_proj.shape == (batch_size, time_proj_dim)
134+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
119135

120136
def test_wan_block(self):
121137
key = jax.random.key(0)
@@ -181,68 +197,66 @@ def test_wan_block(self):
181197
assert dummy_output.shape == dummy_hidden_states.shape
182198

183199
def test_wan_attention(self):
184-
pyconfig.initialize(
185-
[
186-
None,
187-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
188-
],
189-
unittest=True,
190-
)
191-
config = pyconfig.config
192-
193-
batch_size = 1
194-
channels = 16
195-
frames = 21
196-
height = 90
197-
width = 160
198-
hidden_states_shape = (batch_size, frames, height, width, channels)
199-
dummy_hidden_states = jnp.ones(hidden_states_shape)
200-
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
201-
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
202-
203-
key = jax.random.key(0)
204-
rngs = nnx.Rngs(key)
205-
devices_array = create_device_mesh(config)
206-
207-
flash_block_sizes = get_flash_block_sizes(config)
208-
209-
mesh = Mesh(devices_array, config.mesh_axes)
210-
batch_size = 1
211-
query_dim = 5120
212-
attention = FlaxWanAttention(
213-
rngs=rngs,
214-
query_dim=query_dim,
215-
heads=40,
216-
dim_head=128,
217-
attention_kernel="flash",
218-
mesh=mesh,
219-
flash_block_sizes=flash_block_sizes,
220-
)
221-
222-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
223-
224-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
225-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
226-
with mesh:
227-
dummy_output = attention(
228-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
229-
)
230-
assert dummy_output.shape == dummy_hidden_states_shape
231-
232-
# dot product
233-
try:
234-
attention = FlaxWanAttention(
235-
rngs=rngs,
236-
query_dim=query_dim,
237-
heads=40,
238-
dim_head=128,
239-
attention_kernel="dot_product",
240-
split_head_dim=True,
241-
mesh=mesh,
242-
flash_block_sizes=flash_block_sizes,
200+
for attention_kernel in ["flash", "tokamax_flash"]:
201+
pyconfig.initialize(
202+
[
203+
None,
204+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
205+
f"attention={attention_kernel}"
206+
],
207+
unittest=True
243208
)
244-
except NotImplementedError:
245-
pass
209+
config = pyconfig.config
210+
batch_size = 1
211+
channels = 16
212+
frames = 21
213+
height = 90
214+
width = 160
215+
hidden_states_shape = (batch_size, frames, height, width, channels)
216+
dummy_hidden_states = jnp.ones(hidden_states_shape)
217+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
218+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
219+
220+
key = jax.random.key(0)
221+
rngs = nnx.Rngs(key)
222+
devices_array = create_device_mesh(config)
223+
mesh = Mesh(devices_array, config.mesh_axes)
224+
batch_size = 1
225+
query_dim = 5120
226+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
227+
flash_block_sizes = get_flash_block_sizes(config)
228+
attention = FlaxWanAttention(
229+
rngs=rngs,
230+
query_dim=query_dim,
231+
heads=40,
232+
dim_head=128,
233+
attention_kernel=attention_kernel,
234+
mesh=mesh,
235+
flash_block_sizes=flash_block_sizes,
236+
)
237+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
238+
239+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
240+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
241+
dummy_output = attention(
242+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
243+
)
244+
assert dummy_output.shape == dummy_hidden_states_shape
245+
246+
# dot product
247+
try:
248+
attention = FlaxWanAttention(
249+
rngs=rngs,
250+
query_dim=query_dim,
251+
heads=40,
252+
dim_head=128,
253+
attention_kernel="dot_product",
254+
split_head_dim=True,
255+
mesh=mesh,
256+
flash_block_sizes=flash_block_sizes,
257+
)
258+
except NotImplementedError:
259+
pass
246260

247261
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
248262
def test_wan_model(self):
@@ -272,7 +286,8 @@ def test_wan_model(self):
272286
mesh = Mesh(devices_array, config.mesh_axes)
273287
batch_size = 1
274288
num_layers = 1
275-
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
289+
with nn_partitioning.axis_rules(config.logical_axis_rules):
290+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
276291

277292
dummy_timestep = jnp.ones((batch_size))
278293
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))

0 commit comments

Comments
 (0)