@@ -374,37 +374,39 @@ def test_lazy_causal_mask_chunking(self, block_size: tuple[int, int], shape: tup
374374 block_size ,
375375 )
376376
377- @parameterized .parameters ([
378- ((256 , 256 ), (1024 , 1024 ), (128 , None ), 0 ),
379- ((256 , 128 ), (1024 , 1024 ), (128 , None ), 16 ),
380- ((128 , 256 ), (1024 , 1024 ), (128 , None ), 16 ),
381- ((256 , 256 ), (1024 , 1024 ), (128 , 256 ), 0 ),
382- ((256 , 128 ), (1024 , 1024 ), (128 , 256 ), 0 ),
383- ((128 , 256 ), (1024 , 1024 ), (128 , 256 ), 16 ),
384- ((256 , 256 ), (1024 , 1024 ), (None , 256 ), 0 ),
385- ((256 , 128 ), (1024 , 1024 ), (None , 256 ), 32 ),
386- ((128 , 256 ), (1024 , 1024 ), (None , 256 ), 32 ),
387- #
388- ((256 , 256 ), (1024 , 2048 ), (128 , None ), 0 ),
389- ((256 , 128 ), (1024 , 2048 ), (128 , None ), 16 ),
390- ((128 , 256 ), (1024 , 2048 ), (128 , None ), 16 ),
391- ((256 , 256 ), (1024 , 2048 ), (128 , 256 ), 0 ),
392- ((256 , 128 ), (1024 , 2048 ), (128 , 256 ), 0 ),
393- ((128 , 256 ), (1024 , 2048 ), (128 , 256 ), 16 ),
394- ((256 , 256 ), (1024 , 2048 ), (None , 256 ), 0 ),
395- ((256 , 128 ), (1024 , 2048 ), (None , 256 ), 32 ),
396- ((128 , 256 ), (1024 , 2048 ), (None , 256 ), 32 ),
397- #
398- ((256 , 256 ), (2048 , 1024 ), (128 , None ), 0 ),
399- ((256 , 128 ), (2048 , 1024 ), (128 , None ), 16 ),
400- ((128 , 256 ), (2048 , 1024 ), (128 , None ), 16 ),
401- ((256 , 256 ), (2048 , 1024 ), (128 , 256 ), 0 ),
402- ((256 , 128 ), (2048 , 1024 ), (128 , 256 ), 0 ),
403- ((128 , 256 ), (2048 , 1024 ), (128 , 256 ), 16 ),
404- ((256 , 256 ), (2048 , 1024 ), (None , 256 ), 0 ),
405- ((256 , 128 ), (2048 , 1024 ), (None , 256 ), 32 ),
406- ((128 , 256 ), (2048 , 1024 ), (None , 256 ), 32 ),
407- ])
377+ @parameterized .parameters (
378+ [
379+ ((256 , 256 ), (1024 , 1024 ), (128 , None ), 0 ),
380+ ((256 , 128 ), (1024 , 1024 ), (128 , None ), 16 ),
381+ ((128 , 256 ), (1024 , 1024 ), (128 , None ), 16 ),
382+ ((256 , 256 ), (1024 , 1024 ), (128 , 256 ), 0 ),
383+ ((256 , 128 ), (1024 , 1024 ), (128 , 256 ), 0 ),
384+ ((128 , 256 ), (1024 , 1024 ), (128 , 256 ), 16 ),
385+ ((256 , 256 ), (1024 , 1024 ), (None , 256 ), 0 ),
386+ ((256 , 128 ), (1024 , 1024 ), (None , 256 ), 32 ),
387+ ((128 , 256 ), (1024 , 1024 ), (None , 256 ), 32 ),
388+ #
389+ ((256 , 256 ), (1024 , 2048 ), (128 , None ), 0 ),
390+ ((256 , 128 ), (1024 , 2048 ), (128 , None ), 16 ),
391+ ((128 , 256 ), (1024 , 2048 ), (128 , None ), 16 ),
392+ ((256 , 256 ), (1024 , 2048 ), (128 , 256 ), 0 ),
393+ ((256 , 128 ), (1024 , 2048 ), (128 , 256 ), 0 ),
394+ ((128 , 256 ), (1024 , 2048 ), (128 , 256 ), 16 ),
395+ ((256 , 256 ), (1024 , 2048 ), (None , 256 ), 0 ),
396+ ((256 , 128 ), (1024 , 2048 ), (None , 256 ), 32 ),
397+ ((128 , 256 ), (1024 , 2048 ), (None , 256 ), 32 ),
398+ #
399+ ((256 , 256 ), (2048 , 1024 ), (128 , None ), 0 ),
400+ ((256 , 128 ), (2048 , 1024 ), (128 , None ), 16 ),
401+ ((128 , 256 ), (2048 , 1024 ), (128 , None ), 16 ),
402+ ((256 , 256 ), (2048 , 1024 ), (128 , 256 ), 0 ),
403+ ((256 , 128 ), (2048 , 1024 ), (128 , 256 ), 0 ),
404+ ((128 , 256 ), (2048 , 1024 ), (128 , 256 ), 16 ),
405+ ((256 , 256 ), (2048 , 1024 ), (None , 256 ), 0 ),
406+ ((256 , 128 ), (2048 , 1024 ), (None , 256 ), 32 ),
407+ ((128 , 256 ), (2048 , 1024 ), (None , 256 ), 32 ),
408+ ]
409+ )
408410 def test_lazy_local_mask_chunking (
409411 self ,
410412 block_size : tuple [int , int ],
@@ -1162,15 +1164,17 @@ def test_two_qseq_shards_causal_local_stacked(self):
11621164
11631165 expected_num_active_blocks = np .array ([10 , 10 ], dtype = np .int32 )
11641166
1165- expected_partial_mask_blocks = np .stack ([
1166- np .tri (* block_shape , dtype = np .int8 ),
1167- np .triu (
1168- np .tri (* block_shape , window_size , dtype = np .int8 ),
1169- - window_size ,
1170- ),
1171- np .tri (* block_shape , - window_size , dtype = np .int8 ),
1172- np .triu (np .ones (block_shape , dtype = np .int8 ), window_size ),
1173- ])
1167+ expected_partial_mask_blocks = np .stack (
1168+ [
1169+ np .tri (* block_shape , dtype = np .int8 ),
1170+ np .triu (
1171+ np .tri (* block_shape , window_size , dtype = np .int8 ),
1172+ - window_size ,
1173+ ),
1174+ np .tri (* block_shape , - window_size , dtype = np .int8 ),
1175+ np .triu (np .ones (block_shape , dtype = np .int8 ), window_size ),
1176+ ]
1177+ )
11741178
11751179 expected_mask_info = mask_info_lib .MaskInfo (
11761180 expected_mask_next ,
@@ -1341,18 +1345,20 @@ def test_two_shards_local_wide_local_narrow_stacked(self, q_seq_shards, kv_seq_s
13411345
13421346 expected_active_rows_dkv = np .concatenate (
13431347 [
1344- np .array ([
1345- 0 ,
1346- 0 ,
1347- 1 ,
1348- 1 ,
1349- 1 ,
1350- 2 ,
1351- 2 ,
1352- 2 ,
1353- 3 ,
1354- 3 ,
1355- ]),
1348+ np .array (
1349+ [
1350+ 0 ,
1351+ 0 ,
1352+ 1 ,
1353+ 1 ,
1354+ 1 ,
1355+ 2 ,
1356+ 2 ,
1357+ 2 ,
1358+ 3 ,
1359+ 3 ,
1360+ ]
1361+ ),
13561362 np .array ([0 , 0 , 1 , 1 , 2 , 2 , 3 , - 1 , - 1 , - 1 ]),
13571363 ],
13581364 axis = 0 ,
0 commit comments