Skip to content

Commit 3138e37

Browse files
authored
Fix wan 2.1 i2v context parallel (#12909)
* fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * format
1 parent 0da1aa9 commit 3138e37

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def apply_rotary_emb(
134134
dropout_p=0.0,
135135
is_causal=False,
136136
backend=self._attention_backend,
137-
parallel_config=self._parallel_config,
137+
# Reference: https://github.com/huggingface/diffusers/pull/12909
138+
parallel_config=None,
138139
)
139140
hidden_states_img = hidden_states_img.flatten(2, 3)
140141
hidden_states_img = hidden_states_img.type_as(query)
@@ -147,7 +148,8 @@ def apply_rotary_emb(
147148
dropout_p=0.0,
148149
is_causal=False,
149150
backend=self._attention_backend,
150-
parallel_config=self._parallel_config,
151+
# Reference: https://github.com/huggingface/diffusers/pull/12909
152+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
151153
)
152154
hidden_states = hidden_states.flatten(2, 3)
153155
hidden_states = hidden_states.type_as(query)
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
552554
"blocks.0": {
553555
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
554556
},
555-
"blocks.*": {
556-
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
557-
},
557+
# Reference: https://github.com/huggingface/diffusers/pull/12909
558+
# We need to disable the splitting of encoder_hidden_states because the image_encoder
559+
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
560+
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
561+
# —to be indivisible by the number of devices in the CP.
558562
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
559563
"": {
560564
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@ def apply_rotary_emb(
609609
dropout_p=0.0,
610610
is_causal=False,
611611
backend=self._attention_backend,
612-
parallel_config=self._parallel_config,
612+
# Reference: https://github.com/huggingface/diffusers/pull/12909
613+
parallel_config=None,
613614
)
614615
hidden_states_img = hidden_states_img.flatten(2, 3)
615616
hidden_states_img = hidden_states_img.type_as(query)
@@ -622,7 +623,8 @@ def apply_rotary_emb(
622623
dropout_p=0.0,
623624
is_causal=False,
624625
backend=self._attention_backend,
625-
parallel_config=self._parallel_config,
626+
# Reference: https://github.com/huggingface/diffusers/pull/12909
627+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
626628
)
627629
hidden_states = hidden_states.flatten(2, 3)
628630
hidden_states = hidden_states.type_as(query)

0 commit comments

Comments
 (0)