Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/cache_dit/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from ..logger import init_logger
from ..platforms import current_platform
from ..compile.utils import set_compile_configs, _maybe_apply_mindiesd_compile
from ..compile.utils import (
set_compile_configs,
mindiesd_compile_available,
mindiesd_compile,
)
from ..distributed import ParallelismBackend, ParallelismConfig
from ..caching import enable_cache, steps_mask
from ..attention import set_attn_backend
Expand Down Expand Up @@ -1257,10 +1261,11 @@ def _compile_transformer_module(args, pipe, transformer, name):
return transformer

# Auto-enable MindieSDBackend on NPU
compiled = _maybe_apply_mindiesd_compile(transformer, name, transformer_cls_name)
if compiled is not None:
setattr(pipe, name, compiled)
return compiled
if mindiesd_compile_available():
compiled = mindiesd_compile(transformer, name, transformer_cls_name)
if compiled is not None:
setattr(pipe, name, compiled)
return compiled

use_regional_compile = not args.disable_compile_repeated_blocks and hasattr(
transformer, "compile_repeated_blocks")
Expand Down
2 changes: 2 additions & 0 deletions src/cache_dit/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_init_comm_metadata,
)
from ._diffusers_bridge import _register_cache_dit_attn_backends_to_diffusers
from ._backend_selector import AttnBackendSelector

logger = init_logger(__name__)

Expand All @@ -20,6 +21,7 @@
"_dispatch_attention_fn",
"set_attn_backend",
"_maybe_register_custom_attn_backends",
"AttnBackendSelector",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
logger = init_logger(__name__)


class BackendSelector:
class AttnBackendSelector:
_attn_backend: str | None = None
_selected: bool = False

@classmethod
def auto_select(cls, pipe_or_adapter) -> str | None:
if cls._selected:
try:
if cls._selected:
return cls._attn_backend
device = cls._detect_device(pipe_or_adapter)
if device.type == "npu":
cls._attn_backend = "_native_npu"
cls._selected = True
return cls._attn_backend
device = cls._detect_device(pipe_or_adapter)
if device.type == "npu":
cls._attn_backend = "_native_npu"
cls._selected = True
return cls._attn_backend
except Exception as e:
logger.warning(f"Failed to auto-select attention backend: {e}")
return None

@classmethod
def auto_select_kernel_backend(cls) -> str | None:
Expand Down
13 changes: 2 additions & 11 deletions src/cache_dit/caching/cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,11 @@
from ..utils import parse_extra_modules
from ..logger import init_logger
from ..attention import set_attn_backend
from ..attention import AttnBackendSelector

logger = init_logger(__name__)


def _auto_select_attention_backend(pipe_or_adapter) -> Optional[str]:
"""Try to auto-select an optimal attention backend when none was specified."""
try:
from cache_dit.attention.backend_selector import BackendSelector

return BackendSelector.auto_select(pipe_or_adapter)
except Exception:
return None


def enable_cache(
pipe_or_adapter: Union[
DiffusionPipeline,
Expand Down Expand Up @@ -364,7 +355,7 @@ def _enable_cache_impl(

# Auto-select attention backend when none specified
if attention_backend is None and parallelism_config is None:
attention_backend = _auto_select_attention_backend(pipe_or_adapter)
attention_backend = AttnBackendSelector.auto_select(pipe_or_adapter)

# Set custom attention backend for non-parallelism case
if attention_backend is not None:
Expand Down
15 changes: 14 additions & 1 deletion src/cache_dit/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,20 @@ def set_compile_configs(
pass


def _maybe_apply_mindiesd_compile(module, module_name, module_cls_name):
def mindiesd_compile_available():
try:
if not hasattr(torch, 'npu') or not torch.npu.is_available():
return False

import mindiesd # noqa F401
from mindiesd.compilation import MindieSDBackend # noqa F401

return True
except ImportError:
return False


def mindiesd_compile(module, module_name, module_cls_name):
# Auto-apply MindieSDBackend compile on NPU when mindiesd is available.
# Returns the compiled module if compiled, None if MindIE-SD not applicable.
try:
Expand Down
Loading