@@ -599,6 +599,7 @@ def set_attention_backend(self, backend: str) -> None:
599599 from .attention import AttentionModuleMixin
600600 from .attention_dispatch import (
601601 AttentionBackendName ,
602+ _AttentionBackendRegistry ,
602603 _check_attention_backend_requirements ,
603604 _maybe_download_kernel_for_backend ,
604605 )
@@ -607,17 +608,34 @@ def set_attention_backend(self, backend: str) -> None:
607608 from .attention_processor import Attention , MochiAttention
608609
609610 logger .warning ("Attention backends are an experimental feature and the API may be subject to change." )
611+ attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
612+
613+ parallel_config_set = False
614+ for module in self .modules ():
615+ if not isinstance (module , attention_classes ):
616+ continue
617+ processor = module .processor
618+ if getattr (processor , "_parallel_config" , None ) is not None :
619+ parallel_config_set = True
620+ break
610621
611622 backend = backend .lower ()
612623 available_backends = {x .value for x in AttentionBackendName .__members__ .values ()}
613624 if backend not in available_backends :
614625 raise ValueError (f"`{ backend = } ` must be one of the following: " + ", " .join (available_backends ))
615626
616627 backend = AttentionBackendName (backend )
628+ if parallel_config_set and not _AttentionBackendRegistry ._is_context_parallel_available (backend ):
629+ compatible_backends = sorted (_AttentionBackendRegistry ._supports_context_parallel )
630+ raise ValueError (
631+ f"Context parallelism is enabled but current attention backend '{ backend .value } ' "
632+ f"does not support context parallelism. "
633+ f"Please set a compatible attention backend: { compatible_backends } using `model.set_attention_backend()`."
634+ )
635+
617636 _check_attention_backend_requirements (backend )
618637 _maybe_download_kernel_for_backend (backend )
619638
620- attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
621639 for module in self .modules ():
622640 if not isinstance (module , attention_classes ):
623641 continue
@@ -626,6 +644,9 @@ def set_attention_backend(self, backend: str) -> None:
626644 continue
627645 processor ._attention_backend = backend
628646
647+ # Important to set the active backend so that it propagates gracefully throughout.
648+ _AttentionBackendRegistry .set_active_backend (backend )
649+
629650 def reset_attention_backend (self ) -> None :
630651 """
631652 Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
@@ -1538,7 +1559,7 @@ def enable_parallelism(
15381559 f"Context parallelism is enabled but the attention processor '{ processor .__class__ .__name__ } ' "
15391560 f"is using backend '{ attention_backend .value } ' which does not support context parallelism. "
15401561 f"Please set a compatible attention backend: { compatible_backends } using `model.set_attention_backend()` before "
1541- f"calling `enable_parallelism()`."
1562+ f"calling `model. enable_parallelism()`."
15421563 )
15431564
15441565 # All modules use the same attention processor and backend. We don't need to
0 commit comments