@@ -134,7 +134,8 @@ def apply_rotary_emb(
134134 dropout_p = 0.0 ,
135135 is_causal = False ,
136136 backend = self ._attention_backend ,
137- parallel_config = self ._parallel_config ,
137+ # Reference: https://github.com/huggingface/diffusers/pull/12909
138+ parallel_config = None ,
138139 )
139140 hidden_states_img = hidden_states_img .flatten (2 , 3 )
140141 hidden_states_img = hidden_states_img .type_as (query )
@@ -147,7 +148,8 @@ def apply_rotary_emb(
147148 dropout_p = 0.0 ,
148149 is_causal = False ,
149150 backend = self ._attention_backend ,
150- parallel_config = self ._parallel_config ,
151+ # Reference: https://github.com/huggingface/diffusers/pull/12909
152+ parallel_config = (self ._parallel_config if encoder_hidden_states is None else None ),
151153 )
152154 hidden_states = hidden_states .flatten (2 , 3 )
153155 hidden_states = hidden_states .type_as (query )
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
552554 "blocks.0" : {
553555 "hidden_states" : ContextParallelInput (split_dim = 1 , expected_dims = 3 , split_output = False ),
554556 },
555- "blocks.*" : {
556- "encoder_hidden_states" : ContextParallelInput (split_dim = 1 , expected_dims = 3 , split_output = False ),
557- },
557+ # Reference: https://github.com/huggingface/diffusers/pull/12909
558+ # We need to disable the splitting of encoder_hidden_states because the image_encoder
559+ # (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
560+ # of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
561+ # —to be indivisible by the number of devices in the CP.
558562 "proj_out" : ContextParallelOutput (gather_dim = 1 , expected_dims = 3 ),
559563 "" : {
560564 "timestep" : ContextParallelInput (split_dim = 1 , expected_dims = 2 , split_output = False ),
0 commit comments