Skip to content

Commit e96f638

Browse files
committed
make style, make quality, make fix-copies
1 parent f0a0d21 commit e96f638

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/diffusers/models/transformers/transformer_sana_video.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def __init__(
189189

190190
h_dim = w_dim = 2 * (attention_head_dim // 6)
191191
t_dim = attention_head_dim - h_dim - w_dim
192+
193+
self.t_dim = t_dim
194+
self.h_dim = h_dim
195+
self.w_dim = w_dim
196+
192197
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
193198

194199
freqs_cos = []
@@ -214,11 +219,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
214219
p_t, p_h, p_w = self.patch_size
215220
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
216221

217-
split_sizes = [
218-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
219-
self.attention_head_dim // 3,
220-
self.attention_head_dim // 3,
221-
]
222+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
222223

223224
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
224225
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,11 @@ def __init__(
831831

832832
h_dim = w_dim = 2 * (attention_head_dim // 6)
833833
t_dim = attention_head_dim - h_dim - w_dim
834+
835+
self.t_dim = t_dim
836+
self.h_dim = h_dim
837+
self.w_dim = w_dim
838+
834839
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
835840

836841
freqs_cos = []
@@ -856,11 +861,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
856861
p_t, p_h, p_w = self.patch_size
857862
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
858863

859-
split_sizes = [
860-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
861-
self.attention_head_dim // 3,
862-
self.attention_head_dim // 3,
863-
]
864+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
864865

865866
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
866867
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

0 commit comments

Comments
 (0)