Skip to content

Commit 0cc19c7

Browse files
committed
push all the changes
1 parent f7b4145 commit 0cc19c7

30 files changed

Lines changed: 492 additions & 441 deletions

src/maxdiffusion/__init__.py

Lines changed: 196 additions & 182 deletions
Large diffs are not rendered by default.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jit_initializers: True
6161
# Set true to load weights from pytorch
6262
from_pt: True
6363
split_head_dim: True
64-
attention: 'tokamax_flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring
64+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring
6565
flash_min_seq_length: 0
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def load_config(
394394
proxies=proxies,
395395
resume_download=resume_download,
396396
local_files_only=local_files_only,
397-
use_auth_token=use_auth_token,
397+
token=use_auth_token,
398398
user_agent=user_agent,
399399
subfolder=subfolder,
400400
revision=revision,

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,9 +895,7 @@ def _splash_attention_forward_ring_raw(
895895
num_kv_heads = k.shape[0]
896896

897897
if len(k.shape) != expected_kv_rank:
898-
raise ValueError(
899-
f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a {len(k.shape)}-dim one."
900-
)
898+
raise ValueError(f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a {len(k.shape)}-dim one.")
901899

902900
if k.shape[-1] != head_dim_qk:
903901
raise ValueError(f"Expected 'key' head dimension to be: {head_dim_qk}. Instead got: {k.shape[-1]}.")
@@ -1054,10 +1052,13 @@ def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
10541052
pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map),
10551053
]
10561054

1057-
kernel_name = f"{get_kernel_name(is_mqa=is_mqa, save_residuals=True, is_segmented=segment_ids is not None, phase='fwd')}_ring_raw"
1055+
kernel_name = (
1056+
f"{get_kernel_name(is_mqa=is_mqa, save_residuals=True, is_segmented=segment_ids is not None, phase='fwd')}_ring_raw"
1057+
)
10581058
metadata = {"xprof_metadata": json.dumps(dataclasses.asdict(config))}
10591059

10601060
vmem_inputs = [q, k, v, q_segment_ids, kv_segment_ids, mask_info.partial_mask_blocks]
1061+
10611062
def _fwd_cost_estimate(
10621063
q: jax.Array,
10631064
k: jax.Array,

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,14 @@ def _generate_inputs(
290290
is_mqa: bool,
291291
is_segmented: bool,
292292
use_sinks: bool = False,
293-
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None, splash.SegmentIds | None, jax.Array,]:
293+
) -> tuple[
294+
jax.Array,
295+
jax.Array,
296+
jax.Array,
297+
jax.Array | None,
298+
splash.SegmentIds | None,
299+
jax.Array,
300+
]:
294301
seed = data.draw(seed_strategy())
295302
key = random.key(seed)
296303
k1, k2, k3, k_sinks, k_do = random.split(key, 5)

src/maxdiffusion/kernels/splash_attention/splash_attention_mask.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,14 @@ def __eq__(self, other: object):
278278
return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence)
279279

280280
def __hash__(self):
281-
return hash((
282-
type(self),
283-
self.shape,
284-
self.offset,
285-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
286-
))
281+
return hash(
282+
(
283+
type(self),
284+
self.shape,
285+
self.offset,
286+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
287+
)
288+
)
287289

288290

289291
class ChunkedCausalMask(_ComputableMask):
@@ -338,12 +340,14 @@ def __eq__(self, other: object):
338340
)
339341

340342
def __hash__(self):
341-
return hash((
342-
type(self),
343-
self.shape,
344-
self.chunk_size,
345-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
346-
))
343+
return hash(
344+
(
345+
type(self),
346+
self.shape,
347+
self.chunk_size,
348+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
349+
)
350+
)
347351

348352

349353
class LocalMask(_ComputableMask):
@@ -415,13 +419,15 @@ def __eq__(self, other: object):
415419
)
416420

417421
def __hash__(self):
418-
return hash((
419-
type(self),
420-
self.shape,
421-
self.window_size,
422-
self.offset,
423-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
424-
))
422+
return hash(
423+
(
424+
type(self),
425+
self.shape,
426+
self.window_size,
427+
self.offset,
428+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
429+
)
430+
)
425431

426432

427433
@dataclasses.dataclass(slots=True)

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,12 @@ def _process_mask(
446446
# Partial blocks are deduplicated and stored in unique_chunks to save memory.
447447
for coords in np.ndindex((q_blocks_count, kv_blocks_count)):
448448
(q_idx, kv_idx) = coords
449-
chunk = mask[(
450-
slice(q_idx * q_block_size, (q_idx + 1) * q_block_size),
451-
slice(kv_idx * kv_block_size, (kv_idx + 1) * kv_block_size),
452-
)]
449+
chunk = mask[
450+
(
451+
slice(q_idx * q_block_size, (q_idx + 1) * q_block_size),
452+
slice(kv_idx * kv_block_size, (kv_idx + 1) * kv_block_size),
453+
)
454+
]
453455
if chunk.any():
454456
if chunk.all():
455457
state_grid[q_idx, kv_idx] = 2

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -374,37 +374,39 @@ def test_lazy_causal_mask_chunking(self, block_size: tuple[int, int], shape: tup
374374
block_size,
375375
)
376376

377-
@parameterized.parameters([
378-
((256, 256), (1024, 1024), (128, None), 0),
379-
((256, 128), (1024, 1024), (128, None), 16),
380-
((128, 256), (1024, 1024), (128, None), 16),
381-
((256, 256), (1024, 1024), (128, 256), 0),
382-
((256, 128), (1024, 1024), (128, 256), 0),
383-
((128, 256), (1024, 1024), (128, 256), 16),
384-
((256, 256), (1024, 1024), (None, 256), 0),
385-
((256, 128), (1024, 1024), (None, 256), 32),
386-
((128, 256), (1024, 1024), (None, 256), 32),
387-
#
388-
((256, 256), (1024, 2048), (128, None), 0),
389-
((256, 128), (1024, 2048), (128, None), 16),
390-
((128, 256), (1024, 2048), (128, None), 16),
391-
((256, 256), (1024, 2048), (128, 256), 0),
392-
((256, 128), (1024, 2048), (128, 256), 0),
393-
((128, 256), (1024, 2048), (128, 256), 16),
394-
((256, 256), (1024, 2048), (None, 256), 0),
395-
((256, 128), (1024, 2048), (None, 256), 32),
396-
((128, 256), (1024, 2048), (None, 256), 32),
397-
#
398-
((256, 256), (2048, 1024), (128, None), 0),
399-
((256, 128), (2048, 1024), (128, None), 16),
400-
((128, 256), (2048, 1024), (128, None), 16),
401-
((256, 256), (2048, 1024), (128, 256), 0),
402-
((256, 128), (2048, 1024), (128, 256), 0),
403-
((128, 256), (2048, 1024), (128, 256), 16),
404-
((256, 256), (2048, 1024), (None, 256), 0),
405-
((256, 128), (2048, 1024), (None, 256), 32),
406-
((128, 256), (2048, 1024), (None, 256), 32),
407-
])
377+
@parameterized.parameters(
378+
[
379+
((256, 256), (1024, 1024), (128, None), 0),
380+
((256, 128), (1024, 1024), (128, None), 16),
381+
((128, 256), (1024, 1024), (128, None), 16),
382+
((256, 256), (1024, 1024), (128, 256), 0),
383+
((256, 128), (1024, 1024), (128, 256), 0),
384+
((128, 256), (1024, 1024), (128, 256), 16),
385+
((256, 256), (1024, 1024), (None, 256), 0),
386+
((256, 128), (1024, 1024), (None, 256), 32),
387+
((128, 256), (1024, 1024), (None, 256), 32),
388+
#
389+
((256, 256), (1024, 2048), (128, None), 0),
390+
((256, 128), (1024, 2048), (128, None), 16),
391+
((128, 256), (1024, 2048), (128, None), 16),
392+
((256, 256), (1024, 2048), (128, 256), 0),
393+
((256, 128), (1024, 2048), (128, 256), 0),
394+
((128, 256), (1024, 2048), (128, 256), 16),
395+
((256, 256), (1024, 2048), (None, 256), 0),
396+
((256, 128), (1024, 2048), (None, 256), 32),
397+
((128, 256), (1024, 2048), (None, 256), 32),
398+
#
399+
((256, 256), (2048, 1024), (128, None), 0),
400+
((256, 128), (2048, 1024), (128, None), 16),
401+
((128, 256), (2048, 1024), (128, None), 16),
402+
((256, 256), (2048, 1024), (128, 256), 0),
403+
((256, 128), (2048, 1024), (128, 256), 0),
404+
((128, 256), (2048, 1024), (128, 256), 16),
405+
((256, 256), (2048, 1024), (None, 256), 0),
406+
((256, 128), (2048, 1024), (None, 256), 32),
407+
((128, 256), (2048, 1024), (None, 256), 32),
408+
]
409+
)
408410
def test_lazy_local_mask_chunking(
409411
self,
410412
block_size: tuple[int, int],
@@ -1162,15 +1164,17 @@ def test_two_qseq_shards_causal_local_stacked(self):
11621164

11631165
expected_num_active_blocks = np.array([10, 10], dtype=np.int32)
11641166

1165-
expected_partial_mask_blocks = np.stack([
1166-
np.tri(*block_shape, dtype=np.int8),
1167-
np.triu(
1168-
np.tri(*block_shape, window_size, dtype=np.int8),
1169-
-window_size,
1170-
),
1171-
np.tri(*block_shape, -window_size, dtype=np.int8),
1172-
np.triu(np.ones(block_shape, dtype=np.int8), window_size),
1173-
])
1167+
expected_partial_mask_blocks = np.stack(
1168+
[
1169+
np.tri(*block_shape, dtype=np.int8),
1170+
np.triu(
1171+
np.tri(*block_shape, window_size, dtype=np.int8),
1172+
-window_size,
1173+
),
1174+
np.tri(*block_shape, -window_size, dtype=np.int8),
1175+
np.triu(np.ones(block_shape, dtype=np.int8), window_size),
1176+
]
1177+
)
11741178

11751179
expected_mask_info = mask_info_lib.MaskInfo(
11761180
expected_mask_next,
@@ -1341,18 +1345,20 @@ def test_two_shards_local_wide_local_narrow_stacked(self, q_seq_shards, kv_seq_s
13411345

13421346
expected_active_rows_dkv = np.concatenate(
13431347
[
1344-
np.array([
1345-
0,
1346-
0,
1347-
1,
1348-
1,
1349-
1,
1350-
2,
1351-
2,
1352-
2,
1353-
3,
1354-
3,
1355-
]),
1348+
np.array(
1349+
[
1350+
0,
1351+
0,
1352+
1,
1353+
1,
1354+
1,
1355+
2,
1356+
2,
1357+
2,
1358+
3,
1359+
3,
1360+
]
1361+
),
13561362
np.array([0, 0, 1, 1, 2, 2, 3, -1, -1, -1]),
13571363
],
13581364
axis=0,

src/maxdiffusion/max_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@
4646
from flax.linen import partitioning as nn_partitioning
4747
from flax.training import train_state
4848
from jax.experimental import mesh_utils
49-
from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel)
49+
50+
try:
51+
from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel)
52+
except ImportError:
53+
# For transformers>=5.0, these need different import paths
54+
try:
55+
from transformers.models.clip.modeling_flax_clip import FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel
56+
except ImportError:
57+
FlaxCLIPTextModel = None
58+
FlaxCLIPTextPreTrainedModel = None
5059
from flax import struct
5160
from typing import (
5261
Callable,
@@ -336,7 +345,10 @@ def init_train_state(model, tx, weights_init_fn, params=None, training=True, eva
336345
Args: model_params, model, tx, training
337346
"""
338347
if not params:
339-
if isinstance(model, FlaxCLIPTextModel) or isinstance(model, FlaxCLIPTextPreTrainedModel):
348+
is_clip_model = False
349+
if FlaxCLIPTextModel is not None and FlaxCLIPTextPreTrainedModel is not None:
350+
is_clip_model = isinstance(model, FlaxCLIPTextModel) or isinstance(model, FlaxCLIPTextPreTrainedModel)
351+
if is_clip_model:
340352
params = weights_init_fn()
341353
else:
342354
params = weights_init_fn(eval_only=eval_only)

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -962,12 +962,6 @@ def __init__(
962962
mask_padding_tokens: bool = True,
963963
residual_checkpoint_name: str | None = None,
964964
enable_jax_named_scopes: bool = False,
965-
added_kv_proj_dim: Optional[int] = None,
966-
image_seq_len: Optional[int] = None,
967-
):
968-
if attention_kernel == "cudnn_flash_te":
969-
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
970-
971965
added_kv_proj_dim: Optional[int] = None, # New for I2V
972966
image_seq_len: Optional[int] = None, # New for I2V
973967
):

0 commit comments

Comments
 (0)