@@ -236,7 +236,7 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
236236 super ().__init__ ()
237237 self .dim = dim
238238 self .mode = mode
239-
239+
240240 # default to dim //2
241241 if upsample_out_dim is None :
242242 upsample_out_dim = dim // 2
@@ -524,7 +524,7 @@ class WanEncoder3d(nn.Module):
524524
525525 def __init__ (
526526 self ,
527- in_channels : int = 3 ,
527+ in_channels : int = 3 ,
528528 dim = 128 ,
529529 z_dim = 4 ,
530530 dim_mult = [1 , 2 , 4 , 4 ],
@@ -558,10 +558,10 @@ def __init__(
558558 if is_residual :
559559 self .down_blocks .append (
560560 WanResidualDownBlock (
561- in_dim ,
562- out_dim ,
563- dropout ,
564- num_res_blocks ,
561+ in_dim ,
562+ out_dim ,
563+ dropout ,
564+ num_res_blocks ,
565565 temperal_downsample = temperal_downsample [i ] if i != len (dim_mult ) - 1 else False ,
566566 down_flag = i != len (dim_mult ) - 1 ,
567567 )
@@ -708,10 +708,10 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
708708 x = self .upsampler (x , feat_cache , feat_idx )
709709 else :
710710 x = self .upsampler (x )
711-
711+
712712 if self .avg_shortcut is not None :
713713 x = x + self .avg_shortcut (x_copy , first_chunk = first_chunk )
714-
714+
715715 return x
716716
717717class WanUpBlock (nn .Module ):
@@ -912,10 +912,9 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
912912 return x
913913
914914
915- # YiYi TODO: refactor this
916- from einops import rearrange
917-
918915def patchify (x , patch_size ):
916+ # YiYi TODO: refactor this
917+ from einops import rearrange
919918 if patch_size == 1 :
920919 return x
921920 if x .dim () == 4 :
@@ -935,6 +934,8 @@ def patchify(x, patch_size):
935934
936935
937936def unpatchify (x , patch_size ):
937+ # YiYi TODO: refactor this
938+ from einops import rearrange
938939 if patch_size == 1 :
939940 return x
940941
0 commit comments