@@ -374,39 +374,37 @@ def test_lazy_causal_mask_chunking(self, block_size: tuple[int, int], shape: tup
374374 block_size ,
375375 )
376376
377- @parameterized .parameters (
378- [
379- ((256 , 256 ), (1024 , 1024 ), (128 , None ), 0 ),
380- ((256 , 128 ), (1024 , 1024 ), (128 , None ), 16 ),
381- ((128 , 256 ), (1024 , 1024 ), (128 , None ), 16 ),
382- ((256 , 256 ), (1024 , 1024 ), (128 , 256 ), 0 ),
383- ((256 , 128 ), (1024 , 1024 ), (128 , 256 ), 0 ),
384- ((128 , 256 ), (1024 , 1024 ), (128 , 256 ), 16 ),
385- ((256 , 256 ), (1024 , 1024 ), (None , 256 ), 0 ),
386- ((256 , 128 ), (1024 , 1024 ), (None , 256 ), 32 ),
387- ((128 , 256 ), (1024 , 1024 ), (None , 256 ), 32 ),
388- #
389- ((256 , 256 ), (1024 , 2048 ), (128 , None ), 0 ),
390- ((256 , 128 ), (1024 , 2048 ), (128 , None ), 16 ),
391- ((128 , 256 ), (1024 , 2048 ), (128 , None ), 16 ),
392- ((256 , 256 ), (1024 , 2048 ), (128 , 256 ), 0 ),
393- ((256 , 128 ), (1024 , 2048 ), (128 , 256 ), 0 ),
394- ((128 , 256 ), (1024 , 2048 ), (128 , 256 ), 16 ),
395- ((256 , 256 ), (1024 , 2048 ), (None , 256 ), 0 ),
396- ((256 , 128 ), (1024 , 2048 ), (None , 256 ), 32 ),
397- ((128 , 256 ), (1024 , 2048 ), (None , 256 ), 32 ),
398- #
399- ((256 , 256 ), (2048 , 1024 ), (128 , None ), 0 ),
400- ((256 , 128 ), (2048 , 1024 ), (128 , None ), 16 ),
401- ((128 , 256 ), (2048 , 1024 ), (128 , None ), 16 ),
402- ((256 , 256 ), (2048 , 1024 ), (128 , 256 ), 0 ),
403- ((256 , 128 ), (2048 , 1024 ), (128 , 256 ), 0 ),
404- ((128 , 256 ), (2048 , 1024 ), (128 , 256 ), 16 ),
405- ((256 , 256 ), (2048 , 1024 ), (None , 256 ), 0 ),
406- ((256 , 128 ), (2048 , 1024 ), (None , 256 ), 32 ),
407- ((128 , 256 ), (2048 , 1024 ), (None , 256 ), 32 ),
408- ]
409- )
377+ @parameterized .parameters ([
378+ ((256 , 256 ), (1024 , 1024 ), (128 , None ), 0 ),
379+ ((256 , 128 ), (1024 , 1024 ), (128 , None ), 16 ),
380+ ((128 , 256 ), (1024 , 1024 ), (128 , None ), 16 ),
381+ ((256 , 256 ), (1024 , 1024 ), (128 , 256 ), 0 ),
382+ ((256 , 128 ), (1024 , 1024 ), (128 , 256 ), 0 ),
383+ ((128 , 256 ), (1024 , 1024 ), (128 , 256 ), 16 ),
384+ ((256 , 256 ), (1024 , 1024 ), (None , 256 ), 0 ),
385+ ((256 , 128 ), (1024 , 1024 ), (None , 256 ), 32 ),
386+ ((128 , 256 ), (1024 , 1024 ), (None , 256 ), 32 ),
387+ #
388+ ((256 , 256 ), (1024 , 2048 ), (128 , None ), 0 ),
389+ ((256 , 128 ), (1024 , 2048 ), (128 , None ), 16 ),
390+ ((128 , 256 ), (1024 , 2048 ), (128 , None ), 16 ),
391+ ((256 , 256 ), (1024 , 2048 ), (128 , 256 ), 0 ),
392+ ((256 , 128 ), (1024 , 2048 ), (128 , 256 ), 0 ),
393+ ((128 , 256 ), (1024 , 2048 ), (128 , 256 ), 16 ),
394+ ((256 , 256 ), (1024 , 2048 ), (None , 256 ), 0 ),
395+ ((256 , 128 ), (1024 , 2048 ), (None , 256 ), 32 ),
396+ ((128 , 256 ), (1024 , 2048 ), (None , 256 ), 32 ),
397+ #
398+ ((256 , 256 ), (2048 , 1024 ), (128 , None ), 0 ),
399+ ((256 , 128 ), (2048 , 1024 ), (128 , None ), 16 ),
400+ ((128 , 256 ), (2048 , 1024 ), (128 , None ), 16 ),
401+ ((256 , 256 ), (2048 , 1024 ), (128 , 256 ), 0 ),
402+ ((256 , 128 ), (2048 , 1024 ), (128 , 256 ), 0 ),
403+ ((128 , 256 ), (2048 , 1024 ), (128 , 256 ), 16 ),
404+ ((256 , 256 ), (2048 , 1024 ), (None , 256 ), 0 ),
405+ ((256 , 128 ), (2048 , 1024 ), (None , 256 ), 32 ),
406+ ((128 , 256 ), (2048 , 1024 ), (None , 256 ), 32 ),
407+ ])
410408 def test_lazy_local_mask_chunking (
411409 self ,
412410 block_size : tuple [int , int ],
@@ -1164,17 +1162,15 @@ def test_two_qseq_shards_causal_local_stacked(self):
11641162
11651163 expected_num_active_blocks = np .array ([10 , 10 ], dtype = np .int32 )
11661164
1167- expected_partial_mask_blocks = np .stack (
1168- [
1169- np .tri (* block_shape , dtype = np .int8 ),
1170- np .triu (
1171- np .tri (* block_shape , window_size , dtype = np .int8 ),
1172- - window_size ,
1173- ),
1174- np .tri (* block_shape , - window_size , dtype = np .int8 ),
1175- np .triu (np .ones (block_shape , dtype = np .int8 ), window_size ),
1176- ]
1177- )
1165+ expected_partial_mask_blocks = np .stack ([
1166+ np .tri (* block_shape , dtype = np .int8 ),
1167+ np .triu (
1168+ np .tri (* block_shape , window_size , dtype = np .int8 ),
1169+ - window_size ,
1170+ ),
1171+ np .tri (* block_shape , - window_size , dtype = np .int8 ),
1172+ np .triu (np .ones (block_shape , dtype = np .int8 ), window_size ),
1173+ ])
11781174
11791175 expected_mask_info = mask_info_lib .MaskInfo (
11801176 expected_mask_next ,
@@ -1345,20 +1341,18 @@ def test_two_shards_local_wide_local_narrow_stacked(self, q_seq_shards, kv_seq_s
13451341
13461342 expected_active_rows_dkv = np .concatenate (
13471343 [
1348- np .array (
1349- [
1350- 0 ,
1351- 0 ,
1352- 1 ,
1353- 1 ,
1354- 1 ,
1355- 2 ,
1356- 2 ,
1357- 2 ,
1358- 3 ,
1359- 3 ,
1360- ]
1361- ),
1344+ np .array ([
1345+ 0 ,
1346+ 0 ,
1347+ 1 ,
1348+ 1 ,
1349+ 1 ,
1350+ 2 ,
1351+ 2 ,
1352+ 2 ,
1353+ 3 ,
1354+ 3 ,
1355+ ]),
13621356 np .array ([0 , 0 , 1 , 1 , 2 , 2 , 3 , - 1 , - 1 , - 1 ]),
13631357 ],
13641358 axis = 0 ,
0 commit comments