Skip to content

Commit a2d83dd

Browse files
authored
refactor: move pmid condition logic into get_pmid_condition (#1148)
1 parent cc10771 commit a2d83dd

File tree

1 file changed

+92
-109
lines changed

1 file changed

+92
-109
lines changed

stable-diffusion.cpp

Lines changed: 92 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class StableDiffusionGGML {
129129
bool use_tiny_autoencoder = false;
130130
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
131131
bool offload_params_to_cpu = false;
132-
bool stacked_id = false;
132+
bool use_pmid = false;
133133

134134
bool is_using_v_parameterization = false;
135135
bool is_using_edm_v_parameterization = false;
@@ -701,10 +701,10 @@ class StableDiffusionGGML {
701701
if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) {
702702
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
703703
} else {
704-
stacked_id = true;
704+
use_pmid = true;
705705
}
706706
}
707-
if (stacked_id) {
707+
if (use_pmid) {
708708
if (!pmid_model->alloc_params_buffer()) {
709709
LOG_ERROR(" pmid model params buffer allocation failed");
710710
return false;
@@ -745,7 +745,7 @@ class StableDiffusionGGML {
745745
if (use_tiny_autoencoder) {
746746
ignore_tensors.insert("first_stage_model.");
747747
}
748-
if (stacked_id) {
748+
if (use_pmid) {
749749
ignore_tensors.insert("pmid.unet.");
750750
}
751751
ignore_tensors.insert("model.diffusion_model.__x0__");
@@ -799,7 +799,7 @@ class StableDiffusionGGML {
799799
control_net_params_mem_size = control_net->get_params_buffer_size();
800800
}
801801
size_t pmid_params_mem_size = 0;
802-
if (stacked_id) {
802+
if (use_pmid) {
803803
pmid_params_mem_size = pmid_model->get_params_buffer_size();
804804
}
805805

@@ -1211,14 +1211,89 @@ class StableDiffusionGGML {
12111211
}
12121212
}
12131213

1214-
ggml_tensor* id_encoder(ggml_context* work_ctx,
1215-
ggml_tensor* init_img,
1216-
ggml_tensor* prompts_embeds,
1217-
ggml_tensor* id_embeds,
1218-
std::vector<bool>& class_tokens_mask) {
1219-
ggml_tensor* res = nullptr;
1220-
pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx);
1221-
return res;
1214+
SDCondition get_pmid_conditon(ggml_context* work_ctx,
1215+
sd_pm_params_t pm_params,
1216+
ConditionerParams& condition_params) {
1217+
SDCondition id_cond;
1218+
if (use_pmid) {
1219+
if (!pmid_lora->applied) {
1220+
int64_t t0 = ggml_time_ms();
1221+
pmid_lora->apply(tensors, version, n_threads);
1222+
int64_t t1 = ggml_time_ms();
1223+
pmid_lora->applied = true;
1224+
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1225+
if (free_params_immediately) {
1226+
pmid_lora->free_params_buffer();
1227+
}
1228+
}
1229+
// preprocess input id images
1230+
bool pmv2 = pmid_model->get_version() == PM_VERSION_2;
1231+
if (pm_params.id_images_count > 0) {
1232+
int clip_image_size = 224;
1233+
pmid_model->style_strength = pm_params.style_strength;
1234+
1235+
auto id_image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
1236+
1237+
std::vector<sd_image_f32_t> processed_id_images;
1238+
for (int i = 0; i < pm_params.id_images_count; i++) {
1239+
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
1240+
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
1241+
free(id_image.data);
1242+
id_image.data = nullptr;
1243+
processed_id_images.push_back(processed_id_image);
1244+
}
1245+
1246+
ggml_ext_tensor_iter(id_image_tensor, [&](ggml_tensor* id_image_tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
1247+
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false);
1248+
ggml_ext_tensor_set_f32(id_image_tensor, value, i0, i1, i2, i3);
1249+
});
1250+
1251+
for (auto& image : processed_id_images) {
1252+
free(image.data);
1253+
image.data = nullptr;
1254+
}
1255+
processed_id_images.clear();
1256+
1257+
int64_t t0 = ggml_time_ms();
1258+
condition_params.num_input_imgs = pm_params.id_images_count;
1259+
auto cond_tup = cond_stage_model->get_learned_condition_with_trigger(work_ctx,
1260+
n_threads,
1261+
condition_params);
1262+
id_cond = std::get<0>(cond_tup);
1263+
auto class_tokens_mask = std::get<1>(cond_tup);
1264+
struct ggml_tensor* id_embeds = nullptr;
1265+
if (pmv2 && pm_params.id_embed_path != nullptr) {
1266+
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
1267+
}
1268+
if (pmv2 && id_embeds == nullptr) {
1269+
LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2");
1270+
LOG_WARN("Turn off PhotoMaker");
1271+
use_pmid = false;
1272+
} else {
1273+
if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) {
1274+
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
1275+
LOG_WARN("Turn off PhotoMaker");
1276+
use_pmid = false;
1277+
} else {
1278+
ggml_tensor* res = nullptr;
1279+
pmid_model->compute(n_threads, id_image_tensor, id_cond.c_crossattn, id_embeds, class_tokens_mask, &res, work_ctx);
1280+
id_cond.c_crossattn = res;
1281+
int64_t t1 = ggml_time_ms();
1282+
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
1283+
if (free_params_immediately) {
1284+
pmid_model->free_params_buffer();
1285+
}
1286+
// Encode input prompt without the trigger word for delayed conditioning
1287+
condition_params.text = cond_stage_model->remove_trigger_from_prompt(work_ctx, condition_params.text);
1288+
}
1289+
}
1290+
} else {
1291+
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
1292+
LOG_WARN("Turn off PhotoMaker");
1293+
use_pmid = false;
1294+
}
1295+
}
1296+
return id_cond;
12221297
}
12231298

12241299
ggml_tensor* get_clip_vision_output(ggml_context* work_ctx,
@@ -3117,114 +3192,22 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
31173192
guidance.img_cfg = guidance.txt_cfg;
31183193
}
31193194

3120-
// for (auto v : sigmas) {
3121-
// std::cout << v << " ";
3122-
// }
3123-
// std::cout << std::endl;
3124-
31253195
int sample_steps = sigmas.size() - 1;
31263196

31273197
int64_t t0 = ggml_time_ms();
31283198

3129-
// Photo Maker
3130-
std::string prompt_text_only;
3131-
ggml_tensor* init_img = nullptr;
3132-
SDCondition id_cond;
3133-
std::vector<bool> class_tokens_mask;
3134-
31353199
ConditionerParams condition_params;
3200+
condition_params.text = prompt;
31363201
condition_params.clip_skip = clip_skip;
31373202
condition_params.width = width;
31383203
condition_params.height = height;
31393204
condition_params.ref_images = ref_images;
31403205
condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels();
31413206

3142-
if (sd_ctx->sd->stacked_id) {
3143-
if (!sd_ctx->sd->pmid_lora->applied) {
3144-
int64_t t0 = ggml_time_ms();
3145-
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
3146-
int64_t t1 = ggml_time_ms();
3147-
sd_ctx->sd->pmid_lora->applied = true;
3148-
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
3149-
if (sd_ctx->sd->free_params_immediately) {
3150-
sd_ctx->sd->pmid_lora->free_params_buffer();
3151-
}
3152-
}
3153-
// preprocess input id images
3154-
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
3155-
if (pm_params.id_images_count > 0) {
3156-
int clip_image_size = 224;
3157-
sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength;
3158-
3159-
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
3160-
3161-
std::vector<sd_image_f32_t> processed_id_images;
3162-
for (int i = 0; i < pm_params.id_images_count; i++) {
3163-
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
3164-
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
3165-
free(id_image.data);
3166-
id_image.data = nullptr;
3167-
processed_id_images.push_back(processed_id_image);
3168-
}
3169-
3170-
ggml_ext_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
3171-
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false);
3172-
ggml_ext_tensor_set_f32(init_img, value, i0, i1, i2, i3);
3173-
});
3174-
3175-
for (auto& image : processed_id_images) {
3176-
free(image.data);
3177-
image.data = nullptr;
3178-
}
3179-
processed_id_images.clear();
3180-
3181-
int64_t t0 = ggml_time_ms();
3182-
condition_params.text = prompt;
3183-
condition_params.num_input_imgs = pm_params.id_images_count;
3184-
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
3185-
sd_ctx->sd->n_threads,
3186-
condition_params);
3187-
id_cond = std::get<0>(cond_tup);
3188-
class_tokens_mask = std::get<1>(cond_tup); //
3189-
struct ggml_tensor* id_embeds = nullptr;
3190-
if (pmv2 && pm_params.id_embed_path != nullptr) {
3191-
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
3192-
// print_ggml_tensor(id_embeds, true, "id_embeds:");
3193-
}
3194-
if (pmv2 && id_embeds == nullptr) {
3195-
LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2");
3196-
LOG_WARN("Turn off PhotoMaker");
3197-
sd_ctx->sd->stacked_id = false;
3198-
} else {
3199-
if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) {
3200-
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
3201-
LOG_WARN("Turn off PhotoMaker");
3202-
sd_ctx->sd->stacked_id = false;
3203-
} else {
3204-
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
3205-
int64_t t1 = ggml_time_ms();
3206-
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
3207-
if (sd_ctx->sd->free_params_immediately) {
3208-
sd_ctx->sd->pmid_model->free_params_buffer();
3209-
}
3210-
// Encode input prompt without the trigger word for delayed conditioning
3211-
prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt);
3212-
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
3213-
prompt = prompt_text_only; //
3214-
if (sample_steps < 50) {
3215-
LOG_WARN("It's recommended to use >= 50 steps for photo maker!");
3216-
}
3217-
}
3218-
}
3219-
} else {
3220-
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
3221-
LOG_WARN("Turn off PhotoMaker");
3222-
sd_ctx->sd->stacked_id = false;
3223-
}
3224-
}
3207+
// Photo Maker
3208+
SDCondition id_cond = sd_ctx->sd->get_pmid_conditon(work_ctx, pm_params, condition_params);
32253209

32263210
// Get learned condition
3227-
condition_params.text = prompt;
32283211
condition_params.zero_out_masked = false;
32293212
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
32303213
sd_ctx->sd->n_threads,
@@ -3364,7 +3347,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
33643347
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
33653348

33663349
int start_merge_step = -1;
3367-
if (sd_ctx->sd->stacked_id) {
3350+
if (sd_ctx->sd->use_pmid) {
33683351
start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps);
33693352
// if (start_merge_step > 30)
33703353
// start_merge_step = 30;

0 commit comments

Comments
 (0)