Skip to content

Conversation

@vladmandic
Copy link
Contributor

@vladmandic vladmandic commented Dec 24, 2025

z-image-turbo controlnet implementation in #12792 includes some manually created torch modules and dicts which by default are created on cpu while context used in the controlnet loop is on gpu

depending on offloading strategy and/or use of accelerate, this this results in typical error:

/app/venv/lib/python3.12/site-packages/diffusers/models/controlnets/controlnet_z_image.py:818 in forward                                                                                              │
│                                                                                                                                                                                                      │
│  817 │   │   │   else:                                                                                                                                                                               │
│❱ 818 │   │   │   │   control_context_unified = layer(                                                                                                                                                │
│  819 │   │   │   │   │   control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input                                                                                         │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775 in _wrapped_call_impl                                                                                                          │
│                                                                                                                                                                                                      │
│  1774 │   │   else:                                                                                                                                                                                  │
│❱ 1775 │   │   │   return self._call_impl(*args, **kwargs)                                                                                                                                            │
│  1776                                                                                                                                                                                                │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786 in _call_impl                                                                                                                  │
│                                                                                                                                                                                                      │
│  1785 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                                        │
│❱ 1786 │   │   │   return forward_call(*args, **kwargs)                                                                                                                                               │
│  1787                                                                                                                                                                                                │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/diffusers/models/controlnets/controlnet_z_image.py:356 in forward                                                                                              │
│                                                                                                                                                                                                      │
│  355 │   │   if self.block_id == 0:                                                                                                                                                                  │
│❱ 356 │   │   │   c = self.before_proj(c) + x                                                                                                                                                         │
│  357 │   │   │   all_c = []                                                                                                                                                                          │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775 in _wrapped_call_impl                                                                                                          │
│                                                                                                                                                                                                      │
│  1774 │   │   else:                                                                                                                                                                                  │
│❱ 1775 │   │   │   return self._call_impl(*args, **kwargs)                                                                                                                                            │
│  1776                                                                                                                                                                                                │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786 in _call_impl                                                                                                                  │
│                                                                                                                                                                                                      │
│  1785 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                                        │
│❱ 1786 │   │   │   return forward_call(*args, **kwargs)                                                                                                                                               │
│  1787                                                                                                                                                                                                │
│                                                                                                                                                                                                      │
│/app/venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward                                                                                                                      │
│                                                                                                                                                                                                      │
│  133 │   │   """                                                                                                                                                                                     │
│❱ 134 │   │   return F.linear(input, self.weight, self.bias)                                                                                                                                          │
│  135                                                                                                                                                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_addmm)

i've traced device placement to couple of critical execution paths which when corrected allow execution of z-image-turbo controlnet without issues.

cc: @hlky @yiyixuxu @DN6 @sayakpaul

@vladmandic vladmandic changed the title z-image-controlnet z-image-turbo controlnet correct device placement Dec 24, 2025
@hlky
Copy link
Contributor

hlky commented Dec 24, 2025

@vladmandic Thanks, I didn't test with offloading, but it seems unusual for this to happen, the ControlNet layers are created in the normal way and I don't recall this being needed in other ControlNet models which do work with offloading afaik. Do you have code for a reproduction?

@vladmandic
Copy link
Contributor Author

not easy reproduction as behavior differs in different environments, torch version, etc.
but to me, this is logical behavior. some torch variants may handle it similar to autocast, but if they don't, this is imo expected failure.

@hlky
Copy link
Contributor

hlky commented Dec 24, 2025

@vladmandic I've tracked this to an issue with accelerate, specifically accelerate >= 1.10.0. Offloading works with 1.9.0.

Code

import torch
from diffusers import ZImageControlNetPipeline, ZImagePipeline
from diffusers import ZImageControlNetModel
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download

controlnet = ZImageControlNetModel.from_single_file(
    hf_hub_download(
        "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union",
        filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors",
    ),
    torch_dtype=torch.bfloat16,
)
pipe = ZImageControlNetPipeline.from_pretrained(
    "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16, variant="bf16",
    revision="refs/pr/102",
)
pipe.enable_model_cpu_offload()
control_image = load_image(
    "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true"
)
prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
image = pipe(
    prompt,
    control_image=control_image,
    controlnet_conditioning_scale=0.75,
    height=1728,
    width=992,
    num_inference_steps=9,
    guidance_scale=0.0,
    generator=torch.Generator("cuda").manual_seed(43),
).images[0]

Interestingly in my tests (on main) the error (with accelerate >= 1.10.0) occurs after controlnet in transformer:

Traceback

  File "/workspace/diffusers/../test.py", line 23, in <module>
    image = pipe(
            ^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py", line 655, in __call__
    model_out_list = self.transformer(
                     ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/transformers/transformer_z_image.py", line 648, in forward
    unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/transformers/transformer_z_image.py", line 221, in forward
    scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/container.py", line 244, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)

So I'm not sure this PR would help in that case.

After reviewing accelerate changelog in 1.10.0 I've tracked it further to this specific commit huggingface/accelerate#3695 which was also discussed on diffusers #11872 - Given that it provided a negligible benefit and has seemingly broken offloading I think it should be raised with accelerate team and reverted.

I'm surprised this issue has not been noticed before, and I wonder if other controlnet models are affected. Have your users reported any issues with other models?

In the meantime I would suggest your users install accelerate==1.9.0.

cc @SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants