Skip to content

perf: merge LoRA weights into model for inference#599

Open
whiteswordLI wants to merge 2 commits intojingyaogong:masterfrom
whiteswordLI:perf/merge-lora-weights
Open

perf: merge LoRA weights into model for inference#599
whiteswordLI wants to merge 2 commits intojingyaogong:masterfrom
whiteswordLI:perf/merge-lora-weights

Conversation

@whiteswordLI
Copy link
Contributor

This PR introduces the merge_lora function, which merges LoRA weights directly into the model's linear layers during the loading phase. This eliminates the need for additional LoRA calculations during each inference step, thereby reducing computational overhead and latency, and accelerating inference.

Copilot AI review requested due to automatic review settings December 23, 2025 13:02
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a merge_lora function to merge LoRA adapter weights directly into the base model's linear layers during inference initialization. This optimization eliminates the runtime overhead of separate LoRA forward passes by baking the low-rank adaptations into the model weights upfront.

Key changes:

  • Added merge_lora function that merges LoRA weights (B @ A) directly into linear layer weights
  • Updated eval_llm.py to use merge_lora instead of the apply_lora + load_lora pattern for inference

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
model/model_lora.py Introduces merge_lora function to perform weight merging via W_new = W_old + B @ A
eval_llm.py Replaces apply_lora + load_lora calls with single merge_lora call for inference optimization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

module.lora.load_state_dict(lora_state)

def merge_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling for file loading. If the path doesn't exist or the file is corrupted, torch.load will raise an exception without a helpful error message. Consider adding a try-except block with a clear error message indicating the LoRA weight file could not be loaded.

Copilot uses AI. Check for mistakes.

def merge_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
# 移除可能的module前缀,确保key与模型层级名称一致
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. While the comment correctly describes removing the 'module.' prefix, consider translating inline code comments to English for broader accessibility, especially since the PR title and description are in English.

Suggested change
# 移除可能的module前缀,确保key与模型层级名称一致
# Remove potential 'module.' prefix to ensure keys align with model layer naming

Copilot uses AI. Check for mistakes.
key_A = f"{name}.lora.A.weight"
key_B = f"{name}.lora.B.weight"
if key_A in state_dict and key_B in state_dict:
# 直接合并权重: W_new = W_old + B @ A
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. While the comment correctly describes the weight merging formula, consider translating inline code comments to English for broader accessibility, especially since the PR title and description are in English.

Copilot uses AI. Check for mistakes.
if key_A in state_dict and key_B in state_dict:
# 直接合并权重: W_new = W_old + B @ A
module.weight.data += state_dict[key_B] @ state_dict[key_A]

Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace on this blank line. This is a minor style issue that some linters may flag.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines +44 to +57
def merge_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
# 移除可能的module前缀,确保key与模型层级名称一致
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}

for name, module in model.named_modules():
if isinstance(module, nn.Linear):
key_A = f"{name}.lora.A.weight"
key_B = f"{name}.lora.B.weight"
if key_A in state_dict and key_B in state_dict:
# 直接合并权重: W_new = W_old + B @ A
module.weight.data += state_dict[key_B] @ state_dict[key_A]

return model
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function silently succeeds even if no LoRA weights are found in the state_dict. This could lead to confusing behavior where users think LoRA weights have been merged when they haven't. Consider adding a warning or raising an error if no LoRA weights are merged, or at least logging how many layers were successfully merged.

Copilot uses AI. Check for mistakes.
fix some bugs in merge_lora

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants