diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 3f79ec925419..59d11e6b5a1a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -700,6 +700,7 @@ def forward( control_context = self.patchify(control_context, patch_size, f_patch_size) control_context = torch.cat(control_context, dim=0) + self.control_all_x_embedder = self.control_all_x_embedder.to(control_context.device) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token @@ -739,6 +740,7 @@ def forward( layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input ) else: + layer = layer.to(control_context.device) control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input) hints = torch.unbind(control_context)[:-1] @@ -818,6 +820,7 @@ def forward( ) else: for layer in self.control_noise_refiner: + layer = layer.to(control_context.device) control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) # unified @@ -834,6 +837,7 @@ def forward( layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) else: + layer = layer.to(control_context_unified.device) control_context_unified = layer( control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input )