-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
19 lines (16 loc) · 743 Bytes
/
utils.py
File metadata and controls
19 lines (16 loc) · 743 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from loguru import logger
def adjust_prefix_and_load_state_dict(model, ckpt_path, ckpt_to_model_prefix=None):
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
if ckpt_to_model_prefix:
new_state_dict = {}
for k, v in state_dict.items():
for prefix, replacement in ckpt_to_model_prefix.items():
if k.startswith(prefix):
k = k.replace(prefix, replacement, 1)
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded state dict from {ckpt_path}")
return model