Skip to content

Commit f1fd515

Browse files
authored
[tests] fix lora logging tests for models. (#13318)
* fix lora logging tests for models. * make style
1 parent afdda57 commit f1fd515

2 files changed

Lines changed: 43 additions & 17 deletions

File tree

tests/models/testing_utils/lora.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -481,28 +481,40 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
481481
# ensure that enable_lora_hotswap is called before loading the first adapter
482482
import logging
483483

484+
from diffusers.utils import logging as diffusers_logging
485+
484486
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
485487
init_dict = self.get_init_dict()
486488
model = self.model_class(**init_dict).to(torch_device)
487489
model.add_adapter(lora_config)
488490
msg = (
489491
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
490492
)
491-
with caplog.at_level(logging.WARNING):
492-
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
493-
assert any(msg in record.message for record in caplog.records)
493+
diffusers_logging.enable_propagation()
494+
try:
495+
with caplog.at_level(logging.WARNING):
496+
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
497+
assert any(msg in record.message for record in caplog.records)
498+
finally:
499+
diffusers_logging.disable_propagation()
494500

495501
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
496502
# check possibility to ignore the error/warning
497503
import logging
498504

505+
from diffusers.utils import logging as diffusers_logging
506+
499507
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
500508
init_dict = self.get_init_dict()
501509
model = self.model_class(**init_dict).to(torch_device)
502510
model.add_adapter(lora_config)
503-
with caplog.at_level(logging.WARNING):
504-
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
505-
assert len(caplog.records) == 0
511+
diffusers_logging.enable_propagation()
512+
try:
513+
with caplog.at_level(logging.WARNING):
514+
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
515+
assert len(caplog.records) == 0
516+
finally:
517+
diffusers_logging.disable_propagation()
506518

507519
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
508520
# check that wrong argument value raises an error
@@ -518,20 +530,26 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplo
518530
# check the error and log
519531
import logging
520532

533+
from diffusers.utils import logging as diffusers_logging
534+
521535
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
522536
target_modules0 = ["to_q"]
523537
target_modules1 = ["to_q", "to_k"]
524-
with pytest.raises(RuntimeError): # peft raises RuntimeError
525-
with caplog.at_level(logging.ERROR):
526-
self._check_model_hotswap(
527-
tmp_path,
528-
do_compile=True,
529-
rank0=8,
530-
rank1=8,
531-
target_modules0=target_modules0,
532-
target_modules1=target_modules1,
533-
)
534-
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
538+
diffusers_logging.enable_propagation()
539+
try:
540+
with pytest.raises(RuntimeError): # peft raises RuntimeError
541+
with caplog.at_level(logging.ERROR):
542+
self._check_model_hotswap(
543+
tmp_path,
544+
do_compile=True,
545+
rank0=8,
546+
rank1=8,
547+
target_modules0=target_modules0,
548+
target_modules1=target_modules1,
549+
)
550+
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
551+
finally:
552+
diffusers_logging.disable_propagation()
535553

536554
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
537555
@require_torch_version_greater("2.7.1")

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
286286
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
287287
"""LoRA hot-swapping tests for QwenImage Transformer."""
288288

289+
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
290+
def test_hotswapping_compiled_model_linear(self):
291+
super().test_hotswapping_compiled_model_linear()
292+
293+
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
294+
def test_hotswapping_compiled_model_both_linear_and_other(self):
295+
super().test_hotswapping_compiled_model_both_linear_and_other()
296+
289297
@property
290298
def different_shapes_for_compilation(self):
291299
return [(4, 4), (4, 8), (8, 8)]

0 commit comments

Comments
 (0)