Skip to content

Commit ebf891a

Browse files
sayakpaulDN6
andauthored
[core] gracefully error out when attn-backend x cp combo isn't supported. (#12832)
* gracefully error out when attn-backend x cp combo isn't supported. * Revert "gracefully error out when attn-backend x cp combo isn't supported." This reverts commit c8abb5d. * gracefully error out when attn-backend x cp combo isn't supported. * up * address PR feedback. * up * Update src/diffusers/models/modeling_utils.py Co-authored-by: Dhruv Nair <[email protected]> * dot. --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 8af8e86 commit ebf891a

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ def decorator(func):
235235
def get_active_backend(cls):
236236
return cls._active_backend, cls._backends[cls._active_backend]
237237

238+
@classmethod
239+
def set_active_backend(cls, backend: str):
240+
cls._active_backend = backend
241+
238242
@classmethod
239243
def list_backends(cls):
240244
return list(cls._backends.keys())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
294298
_maybe_download_kernel_for_backend(backend)
295299

296300
old_backend = _AttentionBackendRegistry._active_backend
297-
_AttentionBackendRegistry._active_backend = backend
301+
_AttentionBackendRegistry.set_active_backend(backend)
298302

299303
try:
300304
yield
301305
finally:
302-
_AttentionBackendRegistry._active_backend = old_backend
306+
_AttentionBackendRegistry.set_active_backend(old_backend)
303307

304308

305309
def dispatch_attention_fn(
@@ -348,6 +352,7 @@ def dispatch_attention_fn(
348352
check(**kwargs)
349353

350354
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
355+
351356
return backend_fn(**kwargs)
352357

353358

src/diffusers/models/modeling_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def set_attention_backend(self, backend: str) -> None:
599599
from .attention import AttentionModuleMixin
600600
from .attention_dispatch import (
601601
AttentionBackendName,
602+
_AttentionBackendRegistry,
602603
_check_attention_backend_requirements,
603604
_maybe_download_kernel_for_backend,
604605
)
@@ -607,17 +608,34 @@ def set_attention_backend(self, backend: str) -> None:
607608
from .attention_processor import Attention, MochiAttention
608609

609610
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
611+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
612+
613+
parallel_config_set = False
614+
for module in self.modules():
615+
if not isinstance(module, attention_classes):
616+
continue
617+
processor = module.processor
618+
if getattr(processor, "_parallel_config", None) is not None:
619+
parallel_config_set = True
620+
break
610621

611622
backend = backend.lower()
612623
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
613624
if backend not in available_backends:
614625
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
615626

616627
backend = AttentionBackendName(backend)
628+
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
629+
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
630+
raise ValueError(
631+
f"Context parallelism is enabled but current attention backend '{backend.value}' "
632+
f"does not support context parallelism. "
633+
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
634+
)
635+
617636
_check_attention_backend_requirements(backend)
618637
_maybe_download_kernel_for_backend(backend)
619638

620-
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
621639
for module in self.modules():
622640
if not isinstance(module, attention_classes):
623641
continue
@@ -626,6 +644,9 @@ def set_attention_backend(self, backend: str) -> None:
626644
continue
627645
processor._attention_backend = backend
628646

647+
# Important to set the active backend so that it propagates gracefully throughout.
648+
_AttentionBackendRegistry.set_active_backend(backend)
649+
629650
def reset_attention_backend(self) -> None:
630651
"""
631652
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
@@ -1538,7 +1559,7 @@ def enable_parallelism(
15381559
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
15391560
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
15401561
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
1541-
f"calling `enable_parallelism()`."
1562+
f"calling `model.enable_parallelism()`."
15421563
)
15431564

15441565
# All modules use the same attention processor and backend. We don't need to

0 commit comments

Comments
 (0)