@@ -481,28 +481,40 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
481481 # ensure that enable_lora_hotswap is called before loading the first adapter
482482 import logging
483483
484+ from diffusers .utils import logging as diffusers_logging
485+
484486 lora_config = self ._get_lora_config (8 , 8 , target_modules = ["to_q" ])
485487 init_dict = self .get_init_dict ()
486488 model = self .model_class (** init_dict ).to (torch_device )
487489 model .add_adapter (lora_config )
488490 msg = (
489491 "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
490492 )
491- with caplog .at_level (logging .WARNING ):
492- model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
493- assert any (msg in record .message for record in caplog .records )
493+ diffusers_logging .enable_propagation ()
494+ try :
495+ with caplog .at_level (logging .WARNING ):
496+ model .enable_lora_hotswap (target_rank = 32 , check_compiled = "warn" )
497+ assert any (msg in record .message for record in caplog .records )
498+ finally :
499+ diffusers_logging .disable_propagation ()
494500
495501 def test_enable_lora_hotswap_called_after_adapter_added_ignore (self , caplog ):
496502 # check possibility to ignore the error/warning
497503 import logging
498504
505+ from diffusers .utils import logging as diffusers_logging
506+
499507 lora_config = self ._get_lora_config (8 , 8 , target_modules = ["to_q" ])
500508 init_dict = self .get_init_dict ()
501509 model = self .model_class (** init_dict ).to (torch_device )
502510 model .add_adapter (lora_config )
503- with caplog .at_level (logging .WARNING ):
504- model .enable_lora_hotswap (target_rank = 32 , check_compiled = "ignore" )
505- assert len (caplog .records ) == 0
511+ diffusers_logging .enable_propagation ()
512+ try :
513+ with caplog .at_level (logging .WARNING ):
514+ model .enable_lora_hotswap (target_rank = 32 , check_compiled = "ignore" )
515+ assert len (caplog .records ) == 0
516+ finally :
517+ diffusers_logging .disable_propagation ()
506518
507519 def test_enable_lora_hotswap_wrong_check_compiled_argument_raises (self ):
508520 # check that wrong argument value raises an error
@@ -518,20 +530,26 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplo
518530 # check the error and log
519531 import logging
520532
533+ from diffusers .utils import logging as diffusers_logging
534+
521535 # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
522536 target_modules0 = ["to_q" ]
523537 target_modules1 = ["to_q" , "to_k" ]
524- with pytest .raises (RuntimeError ): # peft raises RuntimeError
525- with caplog .at_level (logging .ERROR ):
526- self ._check_model_hotswap (
527- tmp_path ,
528- do_compile = True ,
529- rank0 = 8 ,
530- rank1 = 8 ,
531- target_modules0 = target_modules0 ,
532- target_modules1 = target_modules1 ,
533- )
534- assert any ("Hotswapping adapter0 was unsuccessful" in record .message for record in caplog .records )
538+ diffusers_logging .enable_propagation ()
539+ try :
540+ with pytest .raises (RuntimeError ): # peft raises RuntimeError
541+ with caplog .at_level (logging .ERROR ):
542+ self ._check_model_hotswap (
543+ tmp_path ,
544+ do_compile = True ,
545+ rank0 = 8 ,
546+ rank1 = 8 ,
547+ target_modules0 = target_modules0 ,
548+ target_modules1 = target_modules1 ,
549+ )
550+ assert any ("Hotswapping adapter0 was unsuccessful" in record .message for record in caplog .records )
551+ finally :
552+ diffusers_logging .disable_propagation ()
535553
536554 @pytest .mark .parametrize ("rank0,rank1" , [(11 , 11 ), (7 , 13 ), (13 , 7 )])
537555 @require_torch_version_greater ("2.7.1" )
0 commit comments