@@ -764,7 +764,7 @@ def convert_vae_22():
764764 "conv2.weight" : "post_quant_conv.weight" ,
765765 "conv2.bias" : "post_quant_conv.bias" ,
766766 }
767-
767+
768768 # Process each key in the state dict
769769 for key , value in old_state_dict .items ():
770770 # Handle middle block keys using the mapping
@@ -797,12 +797,12 @@ def convert_vae_22():
797797 elif key .startswith ("encoder.downsamples." ):
798798 # Change encoder.downsamples to encoder.down_blocks
799799 new_key = key .replace ("encoder.downsamples." , "encoder.down_blocks." )
800-
800+
801801 # Handle residual blocks - change downsamples to resnets and rename components
802802 if "residual" in new_key or "shortcut" in new_key :
803803 # Change the second downsamples to resnets
804804 new_key = new_key .replace (".downsamples." , ".resnets." )
805-
805+
806806 # Rename residual components
807807 if ".residual.0.gamma" in new_key :
808808 new_key = new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
@@ -820,7 +820,7 @@ def convert_vae_22():
820820 new_key = new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
821821 elif ".shortcut.bias" in new_key :
822822 new_key = new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
823-
823+
824824 # Handle resample blocks - change downsamples to downsampler and remove index
825825 elif "resample" in new_key or "time_conv" in new_key :
826826 # Change the second downsamples to downsampler and remove the index
@@ -831,19 +831,19 @@ def convert_vae_22():
831831 # Remove the index (parts[4]) and change downsamples to downsampler
832832 new_parts = parts [:3 ] + ["downsampler" ] + parts [5 :]
833833 new_key = "." .join (new_parts )
834-
834+
835835 new_state_dict [new_key ] = value
836836
837837 # Handle decoder upsamples
838838 elif key .startswith ("decoder.upsamples." ):
839839 # Change decoder.upsamples to decoder.up_blocks
840840 new_key = key .replace ("decoder.upsamples." , "decoder.up_blocks." )
841-
841+
842842 # Handle residual blocks - change upsamples to resnets and rename components
843843 if "residual" in new_key or "shortcut" in new_key :
844844 # Change the second upsamples to resnets
845845 new_key = new_key .replace (".upsamples." , ".resnets." )
846-
846+
847847 # Rename residual components
848848 if ".residual.0.gamma" in new_key :
849849 new_key = new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
@@ -861,7 +861,7 @@ def convert_vae_22():
861861 new_key = new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
862862 elif ".shortcut.bias" in new_key :
863863 new_key = new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
864-
864+
865865 # Handle resample blocks - change upsamples to upsampler and remove index
866866 elif "resample" in new_key or "time_conv" in new_key :
867867 # Change the second upsamples to upsampler and remove the index
@@ -872,7 +872,7 @@ def convert_vae_22():
872872 # Remove the index (parts[4]) and change upsamples to upsampler
873873 new_parts = parts [:3 ] + ["upsampler" ] + parts [5 :]
874874 new_key = "." .join (new_parts )
875-
875+
876876 new_state_dict [new_key ] = value
877877 else :
878878 # Keep other keys unchanged
0 commit comments