@@ -445,7 +445,7 @@ class StableDiffusionGGML {
445445 }
446446 }
447447 if (is_chroma) {
448- if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask ) {
448+ if (( sd_ctx_params->flash_attn || sd_ctx_params-> diffusion_flash_attn ) && sd_ctx_params->chroma_use_dit_mask ) {
449449 LOG_WARN (
450450 " !!!It looks like you are using Chroma with flash attention. "
451451 " This is currently unsupported. "
@@ -571,14 +571,6 @@ class StableDiffusionGGML {
571571 }
572572 }
573573
574- if (sd_ctx_params->diffusion_flash_attn ) {
575- LOG_INFO (" Using flash attention in the diffusion model" );
576- diffusion_model->set_flash_attn_enabled (true );
577- if (high_noise_diffusion_model) {
578- high_noise_diffusion_model->set_flash_attn_enabled (true );
579- }
580- }
581-
582574 cond_stage_model->alloc_params_buffer ();
583575 cond_stage_model->get_param_tensors (tensors);
584576
@@ -725,6 +717,28 @@ class StableDiffusionGGML {
725717 pmid_model->get_param_tensors (tensors, " pmid" );
726718 }
727719
720+ if (sd_ctx_params->flash_attn ) {
721+ LOG_INFO (" Using flash attention" );
722+ cond_stage_model->set_flash_attention_enabled (true );
723+ if (clip_vision) {
724+ clip_vision->set_flash_attention_enabled (true );
725+ }
726+ if (first_stage_model) {
727+ first_stage_model->set_flash_attention_enabled (true );
728+ }
729+ if (tae_first_stage) {
730+ tae_first_stage->set_flash_attention_enabled (true );
731+ }
732+ }
733+
734+ if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn ) {
735+ LOG_INFO (" Using flash attention in the diffusion model" );
736+ diffusion_model->set_flash_attention_enabled (true );
737+ if (high_noise_diffusion_model) {
738+ high_noise_diffusion_model->set_flash_attention_enabled (true );
739+ }
740+ }
741+
728742 diffusion_model->set_circular_axes (sd_ctx_params->circular_x , sd_ctx_params->circular_y );
729743 if (high_noise_diffusion_model) {
730744 high_noise_diffusion_model->set_circular_axes (sd_ctx_params->circular_x , sd_ctx_params->circular_y );
@@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
29422956 " keep_clip_on_cpu: %s\n "
29432957 " keep_control_net_on_cpu: %s\n "
29442958 " keep_vae_on_cpu: %s\n "
2959+ " flash_attn: %s\n "
29452960 " diffusion_flash_attn: %s\n "
29462961 " circular_x: %s\n "
29472962 " circular_y: %s\n "
@@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
29732988 BOOL_STR (sd_ctx_params->keep_clip_on_cpu ),
29742989 BOOL_STR (sd_ctx_params->keep_control_net_on_cpu ),
29752990 BOOL_STR (sd_ctx_params->keep_vae_on_cpu ),
2991+ BOOL_STR (sd_ctx_params->flash_attn ),
29762992 BOOL_STR (sd_ctx_params->diffusion_flash_attn ),
29772993 BOOL_STR (sd_ctx_params->circular_x ),
29782994 BOOL_STR (sd_ctx_params->circular_y ),
0 commit comments