Skip to content

Commit 45e0b10

Browse files
committed
Add LoRA path sanitizer (leejet#1156)
1 parent 4f33e6f commit 45e0b10

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

build.bat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pushd %~dp0
33
echo Compiling wait patiently
44
clang -Iggml -Ithirdparty entry.cpp -w --shared -DGGML_MAX_NAME=256 -O2 -flto=thin -ffunction-sections -fdata-sections -Lincludes -lstdc++ -lggml -lggml-base -lggml-cpu-x64 -lggml-cuda -DSD_USE_CUDA -DSD_BUILD_DLL -o stable-diffusion.dll -Wl,-s
55
clang -Iggml -Ithirdparty -I. -Iexamples -w examples\cli\main.cpp -lstdc++ -lstable-diffusion -Lincludes -o sd-cli.exe -Wl,-s
6+
clang -Iggml -Ithirdparty -I. -Iexamples -w examples\server\main.cpp -lstdc++ -lstable-diffusion -Lincludes -DSD_SAFE_LORA -o sd-server.exe -Wl,-s -D_WIN32_WINNT=0xA00 -lws2_32

examples/common/common.hpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ namespace fs = std::filesystem;
3737
#define SAFE_STR(s) ((s) ? (s) : "")
3838
#define BOOL_STR(b) ((b) ? "true" : "false")
3939

40+
static bool is_path_safe(const fs::path& base_dir, const fs::path& target_path) {
41+
try {
42+
// 获取目标路径的绝对路径(weakly_canonical 允许文件尚不存在,但能处理 .. 和 .)
43+
fs::path abs_target = fs::weakly_canonical(target_path);
44+
45+
auto base_str = base_dir.string();
46+
auto target_str = abs_target.string();
47+
return target_str.rfind(base_str, 0) == 0;
48+
} catch (...) {
49+
return false;
50+
}
51+
}
52+
4053
const char* modes_str[] = {
4154
"img_gen",
4255
"vid_gen",
@@ -1632,15 +1645,16 @@ struct SDGenerationParams {
16321645
}
16331646

16341647
void extract_and_remove_lora(const std::string& lora_model_dir) {
1635-
if (lora_model_dir.empty()) {
1636-
return;
1637-
}
1648+
if (lora_model_dir.empty()) return;
1649+
16381650
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
16391651
static const std::vector<std::string> valid_ext = {".gguf", ".safetensors", ".pt"};
16401652
std::smatch m;
16411653

16421654
std::string tmp = prompt;
16431655

1656+
fs::path base_path = fs::absolute(lora_model_dir);
1657+
16441658
while (std::regex_search(tmp, m, re)) {
16451659
std::string raw_path = m[1].str();
16461660
const std::string raw_mul = m[2].str();
@@ -1663,10 +1677,22 @@ struct SDGenerationParams {
16631677

16641678
fs::path final_path;
16651679
if (is_absolute_path(raw_path)) {
1680+
#ifdef SD_SAFE_LORA
1681+
LOG_WARN("This LoRA path is forbidden [SD_SAFE_LORA]: %s", raw_path.c_str());
1682+
goto next_match;
1683+
#else
16661684
final_path = raw_path;
1685+
#endif
16671686
} else {
1668-
final_path = fs::path(lora_model_dir) / raw_path;
1687+
final_path = base_path / raw_path;
1688+
#ifdef SD_SAFE_LORA
1689+
if (!is_path_safe(base_path, final_path)) {
1690+
LOG_WARN("This LoRA path is forbidden [SD_SAFE_LORA]: %s", raw_path.c_str());
1691+
goto next_match;
1692+
}
1693+
#endif
16691694
}
1695+
16701696
if (!fs::exists(final_path)) {
16711697
bool found = false;
16721698
for (const auto& ext : valid_ext) {
@@ -1686,13 +1712,17 @@ struct SDGenerationParams {
16861712
}
16871713
}
16881714

1715+
{
16891716
const std::string key = final_path.lexically_normal().string();
16901717

16911718
if (is_high_noise)
16921719
high_noise_lora_map[key] += mul;
16931720
else
16941721
lora_map[key] += mul;
1722+
}
1723+
16951724

1725+
next_match:
16961726
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
16971727

16981728
tmp = m.suffix().str();
@@ -1999,7 +2029,7 @@ uint8_t* load_image_common(bool from_memory,
19992029
int c = 0;
20002030
const char* image_path;
20012031
uint8_t* image_buffer = nullptr;
2002-
2032+
20032033
bool is_qoi = false;
20042034
if (from_memory) {
20052035
// magic bytes
@@ -2039,7 +2069,7 @@ uint8_t* load_image_common(bool from_memory,
20392069
image_path_or_bytes, &width, &height, &c, expected_channel);
20402070
}
20412071
}
2042-
2072+
20432073
if (image_buffer == nullptr) {
20442074
LOG_ERROR("load image from '%s' failed", image_path);
20452075
return nullptr;
@@ -2180,6 +2210,6 @@ bool save_image_as_qoi(const char* filename, int width, int height, int channels
21802210
desc.height = height;
21812211
desc.channels = channels;
21822212
desc.colorspace = QOI_SRGB;
2183-
2213+
21842214
return qoi_write(filename, data, &desc) > 0;
21852215
}

examples/server/main.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ std::vector<uint8_t> write_image_to_vector(
254254
desc.height = height;
255255
desc.channels = channels;
256256
desc.colorspace = QOI_SRGB;
257-
257+
258258
int out_len = 0;
259259
void* qoi_data = qoi_encode(image, &desc, &out_len);
260260
if (qoi_data) {
@@ -289,6 +289,7 @@ int main(int argc, const char** argv) {
289289
SDContextParams ctx_params;
290290
SDGenerationParams default_gen_params;
291291
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
292+
ctx_params.lora_apply_mode = LORA_APPLY_AT_RUNTIME;
292293

293294
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
294295
log_verbose = svr_params.verbose;
@@ -425,7 +426,7 @@ int main(int argc, const char** argv) {
425426
if (gen_params.sample_params.sample_steps > 100)
426427
gen_params.sample_params.sample_steps = 100;
427428

428-
if (!gen_params.process_and_check(IMG_GEN, "")) {
429+
if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) {
429430
res.status = 400;
430431
res.set_content(R"({"error":"invalid params"})", "application/json");
431432
return;
@@ -482,8 +483,8 @@ int main(int argc, const char** argv) {
482483
continue;
483484
}
484485
auto image_bytes = write_image_to_vector(
485-
output_format == "jpeg" ? ImageFormat::JPEG :
486-
output_format == "qoi" ? ImageFormat::QOI :
486+
output_format == "jpeg" ? ImageFormat::JPEG :
487+
output_format == "qoi" ? ImageFormat::QOI :
487488
ImageFormat::PNG,
488489
results[i].data,
489490
results[i].width,
@@ -609,7 +610,7 @@ int main(int argc, const char** argv) {
609610
if (gen_params.sample_params.sample_steps > 100)
610611
gen_params.sample_params.sample_steps = 100;
611612

612-
if (!gen_params.process_and_check(IMG_GEN, "")) {
613+
if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) {
613614
res.status = 400;
614615
res.set_content(R"({"error":"invalid params"})", "application/json");
615616
return;
@@ -705,8 +706,8 @@ int main(int argc, const char** argv) {
705706
if (results[i].data == nullptr)
706707
continue;
707708
auto image_bytes = write_image_to_vector(
708-
output_format == "jpeg" ? ImageFormat::JPEG :
709-
output_format == "qoi" ? ImageFormat::QOI :
709+
output_format == "jpeg" ? ImageFormat::JPEG :
710+
output_format == "qoi" ? ImageFormat::QOI :
710711
ImageFormat::PNG,
711712
results[i].data,
712713
results[i].width,

0 commit comments

Comments
 (0)