Skip to content

Commit f957fa3

Browse files
authored
feat: add --fa option (#1242)
1 parent c252e03 commit f957fa3

File tree

10 files changed

+88
-23
lines changed

10 files changed

+88
-23
lines changed

conditioner.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct Conditioner {
3434
virtual void free_params_buffer() = 0;
3535
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
3636
virtual size_t get_params_buffer_size() = 0;
37+
virtual void set_flash_attention_enabled(bool enabled) = 0;
3738
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
3839
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
3940
int n_threads,
@@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
115116
return buffer_size;
116117
}
117118

119+
void set_flash_attention_enabled(bool enabled) override {
120+
text_model->set_flash_attention_enabled(enabled);
121+
if (sd_version_is_sdxl(version)) {
122+
text_model2->set_flash_attention_enabled(enabled);
123+
}
124+
}
125+
118126
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
119127
text_model->set_weight_adapter(adapter);
120128
if (sd_version_is_sdxl(version)) {
@@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
783791
return buffer_size;
784792
}
785793

794+
void set_flash_attention_enabled(bool enabled) override {
795+
if (clip_l) {
796+
clip_l->set_flash_attention_enabled(enabled);
797+
}
798+
if (clip_g) {
799+
clip_g->set_flash_attention_enabled(enabled);
800+
}
801+
if (t5) {
802+
t5->set_flash_attention_enabled(enabled);
803+
}
804+
}
805+
786806
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
787807
if (clip_l) {
788808
clip_l->set_weight_adapter(adapter);
@@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
11911211
return buffer_size;
11921212
}
11931213

1214+
void set_flash_attention_enabled(bool enabled) override {
1215+
if (clip_l) {
1216+
clip_l->set_flash_attention_enabled(enabled);
1217+
}
1218+
if (t5) {
1219+
t5->set_flash_attention_enabled(enabled);
1220+
}
1221+
}
1222+
11941223
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
11951224
if (clip_l) {
11961225
clip_l->set_weight_adapter(adapter);
@@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
14401469
return buffer_size;
14411470
}
14421471

1472+
void set_flash_attention_enabled(bool enabled) override {
1473+
if (t5) {
1474+
t5->set_flash_attention_enabled(enabled);
1475+
}
1476+
}
1477+
14431478
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
14441479
if (t5) {
14451480
t5->set_weight_adapter(adapter);
@@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
16501685
return buffer_size;
16511686
}
16521687

1688+
void set_flash_attention_enabled(bool enabled) override {
1689+
llm->set_flash_attention_enabled(enabled);
1690+
}
1691+
16531692
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
16541693
if (llm) {
16551694
llm->set_weight_adapter(adapter);

diffusion_model.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct DiffusionModel {
3838
virtual size_t get_params_buffer_size() = 0;
3939
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
4040
virtual int64_t get_adm_in_channels() = 0;
41-
virtual void set_flash_attn_enabled(bool enabled) = 0;
41+
virtual void set_flash_attention_enabled(bool enabled) = 0;
4242
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
4343
};
4444

@@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
8484
return unet.unet.adm_in_channels;
8585
}
8686

87-
void set_flash_attn_enabled(bool enabled) {
87+
void set_flash_attention_enabled(bool enabled) {
8888
unet.set_flash_attention_enabled(enabled);
8989
}
9090

@@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
149149
return 768 + 1280;
150150
}
151151

152-
void set_flash_attn_enabled(bool enabled) {
152+
void set_flash_attention_enabled(bool enabled) {
153153
mmdit.set_flash_attention_enabled(enabled);
154154
}
155155

@@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
215215
return 768;
216216
}
217217

218-
void set_flash_attn_enabled(bool enabled) {
218+
void set_flash_attention_enabled(bool enabled) {
219219
flux.set_flash_attention_enabled(enabled);
220220
}
221221

@@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
286286
return 768;
287287
}
288288

289-
void set_flash_attn_enabled(bool enabled) {
289+
void set_flash_attention_enabled(bool enabled) {
290290
wan.set_flash_attention_enabled(enabled);
291291
}
292292

@@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
357357
return 768;
358358
}
359359

360-
void set_flash_attn_enabled(bool enabled) {
360+
void set_flash_attention_enabled(bool enabled) {
361361
qwen_image.set_flash_attention_enabled(enabled);
362362
}
363363

@@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
424424
return 768;
425425
}
426426

427-
void set_flash_attn_enabled(bool enabled) {
427+
void set_flash_attention_enabled(bool enabled) {
428428
z_image.set_flash_attention_enabled(enabled);
429429
}
430430

examples/cli/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ Context Options:
5252
--control-net-cpu keep controlnet in cpu (for low vram)
5353
--clip-on-cpu keep clip in cpu (for low vram)
5454
--vae-on-cpu keep vae in cpu (for low vram)
55-
--diffusion-fa use flash attention in the diffusion model
55+
--fa use flash attention
56+
--diffusion-fa use flash attention in the diffusion model only
5657
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
5758
--vae-conv-direct use ggml_conv2d_direct in the vae model
5859
--circular enable circular padding for convolutions

examples/common/common.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ struct SDContextParams {
457457
bool control_net_cpu = false;
458458
bool clip_on_cpu = false;
459459
bool vae_on_cpu = false;
460+
bool flash_attn = false;
460461
bool diffusion_flash_attn = false;
461462
bool diffusion_conv_direct = false;
462463
bool vae_conv_direct = false;
@@ -615,9 +616,13 @@ struct SDContextParams {
615616
"--vae-on-cpu",
616617
"keep vae in cpu (for low vram)",
617618
true, &vae_on_cpu},
619+
{"",
620+
"--fa",
621+
"use flash attention",
622+
true, &flash_attn},
618623
{"",
619624
"--diffusion-fa",
620-
"use flash attention in the diffusion model",
625+
"use flash attention in the diffusion model only",
621626
true, &diffusion_flash_attn},
622627
{"",
623628
"--diffusion-conv-direct",
@@ -904,6 +909,7 @@ struct SDContextParams {
904909
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
905910
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
906911
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
912+
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
907913
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
908914
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
909915
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@@ -968,6 +974,7 @@ struct SDContextParams {
968974
clip_on_cpu,
969975
control_net_cpu,
970976
vae_on_cpu,
977+
flash_attn,
971978
diffusion_flash_attn,
972979
taesd_preview,
973980
diffusion_conv_direct,

examples/server/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ Context Options:
4444
--clip-on-cpu keep clip in cpu (for low vram)
4545
--vae-on-cpu keep vae in cpu (for low vram)
4646
--mmap whether to memory-map model
47-
--diffusion-fa use flash attention in the diffusion model
47+
--fa use flash attention
48+
--diffusion-fa use flash attention in the diffusion model only
4849
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
4950
--vae-conv-direct use ggml_conv2d_direct in the vae model
5051
--circular enable circular padding for convolutions

ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2623,7 +2623,7 @@ class MultiheadAttention : public GGMLBlock {
26232623
v = v_proj->forward(ctx, x);
26242624
}
26252625

2626-
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
2626+
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
26272627

26282628
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
26292629
return x;

stable-diffusion.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class StableDiffusionGGML {
445445
}
446446
}
447447
if (is_chroma) {
448-
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
448+
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
449449
LOG_WARN(
450450
"!!!It looks like you are using Chroma with flash attention. "
451451
"This is currently unsupported. "
@@ -571,14 +571,6 @@ class StableDiffusionGGML {
571571
}
572572
}
573573

574-
if (sd_ctx_params->diffusion_flash_attn) {
575-
LOG_INFO("Using flash attention in the diffusion model");
576-
diffusion_model->set_flash_attn_enabled(true);
577-
if (high_noise_diffusion_model) {
578-
high_noise_diffusion_model->set_flash_attn_enabled(true);
579-
}
580-
}
581-
582574
cond_stage_model->alloc_params_buffer();
583575
cond_stage_model->get_param_tensors(tensors);
584576

@@ -725,6 +717,28 @@ class StableDiffusionGGML {
725717
pmid_model->get_param_tensors(tensors, "pmid");
726718
}
727719

720+
if (sd_ctx_params->flash_attn) {
721+
LOG_INFO("Using flash attention");
722+
cond_stage_model->set_flash_attention_enabled(true);
723+
if (clip_vision) {
724+
clip_vision->set_flash_attention_enabled(true);
725+
}
726+
if (first_stage_model) {
727+
first_stage_model->set_flash_attention_enabled(true);
728+
}
729+
if (tae_first_stage) {
730+
tae_first_stage->set_flash_attention_enabled(true);
731+
}
732+
}
733+
734+
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
735+
LOG_INFO("Using flash attention in the diffusion model");
736+
diffusion_model->set_flash_attention_enabled(true);
737+
if (high_noise_diffusion_model) {
738+
high_noise_diffusion_model->set_flash_attention_enabled(true);
739+
}
740+
}
741+
728742
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
729743
if (high_noise_diffusion_model) {
730744
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
29422956
"keep_clip_on_cpu: %s\n"
29432957
"keep_control_net_on_cpu: %s\n"
29442958
"keep_vae_on_cpu: %s\n"
2959+
"flash_attn: %s\n"
29452960
"diffusion_flash_attn: %s\n"
29462961
"circular_x: %s\n"
29472962
"circular_y: %s\n"
@@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
29732988
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
29742989
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
29752990
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
2991+
BOOL_STR(sd_ctx_params->flash_attn),
29762992
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
29772993
BOOL_STR(sd_ctx_params->circular_x),
29782994
BOOL_STR(sd_ctx_params->circular_y),

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ typedef struct {
189189
bool keep_clip_on_cpu;
190190
bool keep_control_net_on_cpu;
191191
bool keep_vae_on_cpu;
192+
bool flash_attn;
192193
bool diffusion_flash_attn;
193194
bool tae_preview_only;
194195
bool diffusion_conv_direct;

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock {
141141
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
142142
}
143143

144-
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
144+
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
145145

146146
if (use_linear) {
147147
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

wan.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ namespace WAN {
572572
auto v = qkv_vec[2];
573573
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
574574

575-
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
576-
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
575+
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
576+
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
577577

578578
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
579579
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]

0 commit comments

Comments
 (0)