diff --git a/src/cache_dit/_utils/utils.py b/src/cache_dit/_utils/utils.py index 0377b6c5..0ebd5426 100644 --- a/src/cache_dit/_utils/utils.py +++ b/src/cache_dit/_utils/utils.py @@ -7,7 +7,11 @@ from ..logger import init_logger from ..platforms import current_platform -from ..compile.utils import set_compile_configs, _maybe_apply_mindiesd_compile +from ..compile.utils import ( + set_compile_configs, + mindiesd_compile_available, + mindiesd_compile, +) from ..distributed import ParallelismBackend, ParallelismConfig from ..caching import enable_cache, steps_mask from ..attention import set_attn_backend @@ -1257,10 +1261,11 @@ def _compile_transformer_module(args, pipe, transformer, name): return transformer # Auto-enable MindieSDBackend on NPU - compiled = _maybe_apply_mindiesd_compile(transformer, name, transformer_cls_name) - if compiled is not None: - setattr(pipe, name, compiled) - return compiled + if mindiesd_compile_available(): + compiled = mindiesd_compile(transformer, name, transformer_cls_name) + if compiled is not None: + setattr(pipe, name, compiled) + return compiled use_regional_compile = not args.disable_compile_repeated_blocks and hasattr( transformer, "compile_repeated_blocks") diff --git a/src/cache_dit/attention/__init__.py b/src/cache_dit/attention/__init__.py index 166fbc69..d2e2b9ad 100644 --- a/src/cache_dit/attention/__init__.py +++ b/src/cache_dit/attention/__init__.py @@ -11,6 +11,7 @@ _init_comm_metadata, ) from ._diffusers_bridge import _register_cache_dit_attn_backends_to_diffusers +from ._backend_selector import AttnBackendSelector logger = init_logger(__name__) @@ -20,6 +21,7 @@ "_dispatch_attention_fn", "set_attn_backend", "_maybe_register_custom_attn_backends", + "AttnBackendSelector", ] diff --git a/src/cache_dit/attention/backend_selector.py b/src/cache_dit/attention/_backend_selector.py similarity index 63% rename from src/cache_dit/attention/backend_selector.py rename to src/cache_dit/attention/_backend_selector.py index 7d6aefa6..d8751103 100644 --- a/src/cache_dit/attention/backend_selector.py +++ b/src/cache_dit/attention/_backend_selector.py @@ -4,19 +4,23 @@ logger = init_logger(__name__) -class BackendSelector: +class AttnBackendSelector: _attn_backend: str | None = None _selected: bool = False @classmethod def auto_select(cls, pipe_or_adapter) -> str | None: - if cls._selected: + try: + if cls._selected: + return cls._attn_backend + device = cls._detect_device(pipe_or_adapter) + if device.type == "npu": + cls._attn_backend = "_native_npu" + cls._selected = True return cls._attn_backend - device = cls._detect_device(pipe_or_adapter) - if device.type == "npu": - cls._attn_backend = "_native_npu" - cls._selected = True - return cls._attn_backend + except Exception as e: + logger.warning(f"Failed to auto-select attention backend: {e}") + return None @classmethod def auto_select_kernel_backend(cls) -> str | None: diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index cea8f049..f44cca5a 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -18,20 +18,11 @@ from ..utils import parse_extra_modules from ..logger import init_logger from ..attention import set_attn_backend +from ..attention import AttnBackendSelector logger = init_logger(__name__) -def _auto_select_attention_backend(pipe_or_adapter) -> Optional[str]: - """Try to auto-select an optimal attention backend when none was specified.""" - try: - from cache_dit.attention.backend_selector import BackendSelector - - return BackendSelector.auto_select(pipe_or_adapter) - except Exception: - return None - - def enable_cache( pipe_or_adapter: Union[ DiffusionPipeline, @@ -364,7 +355,7 @@ def _enable_cache_impl( # Auto-select attention backend when none specified if attention_backend is None and parallelism_config is None: - attention_backend = _auto_select_attention_backend(pipe_or_adapter) + attention_backend = AttnBackendSelector.auto_select(pipe_or_adapter) # Set custom attention backend for non-parallelism case if attention_backend is not None: diff --git a/src/cache_dit/compile/utils.py b/src/cache_dit/compile/utils.py index a1838daa..128e36ba 100644 --- a/src/cache_dit/compile/utils.py +++ b/src/cache_dit/compile/utils.py @@ -161,7 +161,20 @@ def set_compile_configs( pass -def _maybe_apply_mindiesd_compile(module, module_name, module_cls_name): +def mindiesd_compile_available(): + try: + if not hasattr(torch, 'npu') or not torch.npu.is_available(): + return False + + import mindiesd # noqa F401 + from mindiesd.compilation import MindieSDBackend # noqa F401 + + return True + except ImportError: + return False + + +def mindiesd_compile(module, module_name, module_cls_name): # Auto-apply MindieSDBackend compile on NPU when mindiesd is available. # Returns the compiled module if compiled, None if MindIE-SD not applicable. try: