@@ -170,8 +170,11 @@ def forward(
170170 timestep : torch .Tensor ,
171171 encoder_hidden_states : torch .Tensor ,
172172 encoder_hidden_states_image : Optional [torch .Tensor ] = None ,
173- ):
173+ timestep_seq_len : Optional [int ] = None ,
174+ ):
174175 timestep = self .timesteps_proj (timestep )
176+ if timestep_seq_len is not None :
177+ timestep = timestep .unflatten (0 , (1 , timestep_seq_len ))
175178
176179 time_embedder_dtype = next (iter (self .time_embedder .parameters ())).dtype
177180 if timestep .dtype != time_embedder_dtype and time_embedder_dtype != torch .int8 :
@@ -309,9 +312,24 @@ def forward(
309312 temb : torch .Tensor ,
310313 rotary_emb : torch .Tensor ,
311314 ) -> torch .Tensor :
312- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = (
313- self .scale_shift_table + temb .float ()
314- ).chunk (6 , dim = 1 )
315+
316+ if temb .ndim == 4 :
317+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
318+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = (
319+ self .scale_shift_table .unsqueeze (0 ) + temb .float ()
320+ ).chunk (6 , dim = 2 )
321+ # batch_size, seq_len, 1, inner_dim
322+ shift_msa = shift_msa .squeeze (2 )
323+ scale_msa = scale_msa .squeeze (2 )
324+ gate_msa = gate_msa .squeeze (2 )
325+ c_shift_msa = c_shift_msa .squeeze (2 )
326+ c_scale_msa = c_scale_msa .squeeze (2 )
327+ c_gate_msa = c_gate_msa .squeeze (2 )
328+ else :
329+ # temb: batch_size, 6, inner_dim
330+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = (
331+ self .scale_shift_table + temb .float ()
332+ ).chunk (6 , dim = 1 )
315333
316334 # 1. Self-attention
317335 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
@@ -469,10 +487,22 @@ def forward(
469487 hidden_states = self .patch_embedding (hidden_states )
470488 hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 )
471489
490+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
491+ if timestep .ndim == 2 :
492+ ts_seq_len = timestep .shape [1 ]
493+ timestep = timestep .flatten () # batch_size * seq_len
494+ else :
495+ ts_seq_len = None
496+
472497 temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
473- timestep , encoder_hidden_states , encoder_hidden_states_image
498+ timestep , encoder_hidden_states , encoder_hidden_states_image , timestep_seq_len = ts_seq_len
474499 )
475- timestep_proj = timestep_proj .unflatten (1 , (6 , - 1 ))
500+ if ts_seq_len is not None :
501+ # batch_size, seq_len, 6, inner_dim
502+ timestep_proj = timestep_proj .unflatten (2 , (6 , - 1 ))
503+ else :
504+ # batch_size, 6, inner_dim
505+ timestep_proj = timestep_proj .unflatten (1 , (6 , - 1 ))
476506
477507 if encoder_hidden_states_image is not None :
478508 encoder_hidden_states = torch .concat ([encoder_hidden_states_image , encoder_hidden_states ], dim = 1 )
@@ -488,7 +518,14 @@ def forward(
488518 hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
489519
490520 # 5. Output norm, projection & unpatchify
491- shift , scale = (self .scale_shift_table + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
521+ if temb .ndim == 3 :
522+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
523+ shift , scale = (self .scale_shift_table .unsqueeze (0 ) + temb .unsqueeze (2 )).chunk (2 , dim = 2 )
524+ shift = shift .squeeze (2 )
525+ scale = scale .squeeze (2 )
526+ else :
527+ # batch_size, inner_dim
528+ shift , scale = (self .scale_shift_table + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
492529
493530 # Move the shift and scale tensors to the same device as hidden_states.
494531 # When using multi-GPU inference via accelerate these will be on the
0 commit comments