Skip to content

Commit 52558b4

Browse files
DN6sayakpaul
andauthored
[CI] Flux2 Model Test Refactor (#13071)
* update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent c02c17c commit 52558b4

2 files changed

Lines changed: 542 additions & 80 deletions

File tree

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
ModelOptCompileTesterMixin,
4242
ModelOptTesterMixin,
4343
ModelTesterMixin,
44-
PyramidAttentionBroadcastTesterMixin,
4544
QuantoCompileTesterMixin,
4645
QuantoTesterMixin,
4746
SingleFileTesterMixin,
@@ -219,6 +218,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
219218
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
220219
"""Training tests for Flux Transformer."""
221220

221+
def test_gradient_checkpointing_is_applied(self):
222+
expected_set = {"FluxTransformer2DModel"}
223+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
224+
222225

223226
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
224227
"""Attention processor tests for Flux Transformer."""
@@ -412,10 +415,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
412415
"""BitsAndBytes + compile tests for Flux Transformer."""
413416

414417

415-
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
416-
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
417-
418-
419418
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
420419
"""FirstBlockCache tests for Flux Transformer."""
421420

0 commit comments

Comments
 (0)