perf: merge LoRA weights into model for inference#599
perf: merge LoRA weights into model for inference#599whiteswordLI wants to merge 2 commits intojingyaogong:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a merge_lora function to merge LoRA adapter weights directly into the base model's linear layers during inference initialization. This optimization eliminates the runtime overhead of separate LoRA forward passes by baking the low-rank adaptations into the model weights upfront.
Key changes:
- Added
merge_lorafunction that merges LoRA weights (B @ A) directly into linear layer weights - Updated
eval_llm.pyto usemerge_lorainstead of theapply_lora+load_lorapattern for inference
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| model/model_lora.py | Introduces merge_lora function to perform weight merging via W_new = W_old + B @ A |
| eval_llm.py | Replaces apply_lora + load_lora calls with single merge_lora call for inference optimization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| module.lora.load_state_dict(lora_state) | ||
|
|
||
| def merge_lora(model, path): | ||
| state_dict = torch.load(path, map_location=model.device) |
There was a problem hiding this comment.
Missing error handling for file loading. If the path doesn't exist or the file is corrupted, torch.load will raise an exception without a helpful error message. Consider adding a try-except block with a clear error message indicating the LoRA weight file could not be loaded.
|
|
||
| def merge_lora(model, path): | ||
| state_dict = torch.load(path, map_location=model.device) | ||
| # 移除可能的module前缀,确保key与模型层级名称一致 |
There was a problem hiding this comment.
Comment is in Chinese. While the comment correctly describes removing the 'module.' prefix, consider translating inline code comments to English for broader accessibility, especially since the PR title and description are in English.
| # 移除可能的module前缀,确保key与模型层级名称一致 | |
| # Remove potential 'module.' prefix to ensure keys align with model layer naming |
| key_A = f"{name}.lora.A.weight" | ||
| key_B = f"{name}.lora.B.weight" | ||
| if key_A in state_dict and key_B in state_dict: | ||
| # 直接合并权重: W_new = W_old + B @ A |
There was a problem hiding this comment.
Comment is in Chinese. While the comment correctly describes the weight merging formula, consider translating inline code comments to English for broader accessibility, especially since the PR title and description are in English.
| if key_A in state_dict and key_B in state_dict: | ||
| # 直接合并权重: W_new = W_old + B @ A | ||
| module.weight.data += state_dict[key_B] @ state_dict[key_A] | ||
|
|
There was a problem hiding this comment.
Trailing whitespace on this blank line. This is a minor style issue that some linters may flag.
| def merge_lora(model, path): | ||
| state_dict = torch.load(path, map_location=model.device) | ||
| # 移除可能的module前缀,确保key与模型层级名称一致 | ||
| state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} | ||
|
|
||
| for name, module in model.named_modules(): | ||
| if isinstance(module, nn.Linear): | ||
| key_A = f"{name}.lora.A.weight" | ||
| key_B = f"{name}.lora.B.weight" | ||
| if key_A in state_dict and key_B in state_dict: | ||
| # 直接合并权重: W_new = W_old + B @ A | ||
| module.weight.data += state_dict[key_B] @ state_dict[key_A] | ||
|
|
||
| return model |
There was a problem hiding this comment.
The function silently succeeds even if no LoRA weights are found in the state_dict. This could lead to confusing behavior where users think LoRA weights have been merged when they haven't. Consider adding a warning or raising an error if no LoRA weights are merged, or at least logging how many layers were successfully merged.
fix some bugs in merge_lora Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This PR introduces the merge_lora function, which merges LoRA weights directly into the model's linear layers during the loading phase. This eliminates the need for additional LoRA calculations during each inference step, thereby reducing computational overhead and latency, and accelerating inference.