Skip to content

Commit b88e6e1

Browse files
author
lipengbo
committed
重构训练格式,yaml->py
1 parent 32eab28 commit b88e6e1

5 files changed

Lines changed: 69 additions & 78 deletions

File tree

Config/config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class Config(object):
2+
# ======数据集============
3+
size= [224,224] # 图像尺寸
4+
sampler="batch_balance" # 采样策略
5+
txt="./Config/dataset.txt" # 数据集路径
6+
7+
# ======模型============
8+
optimizer="sgd" # 优化器
9+
10+
# 常规分类
11+
backbone="mynet" # 主干网络
12+
loss="cross_entropy" # 损失函数
13+
14+
# 度量学习
15+
# backbone="mynet_metric" # 主干网络
16+
# loss="arcface" # 损失函数
17+
18+
# ======训练============
19+
lr=0.01
20+
batch=64
21+
epochs=80
22+
scheduler="cosine" # 学习率调度器
23+
24+
cfg=Config()

Config/train.yaml

Lines changed: 0 additions & 23 deletions
This file was deleted.

Package/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ numpy
44
pyyaml
55
seaborn
66
matplotlib
7-
torch
8-
torchvision
7+
torch>=1.11.0
8+
torchvision>=0.12.0
99
tensorboard>1.15.0
1010
torchinfo #模型统计
1111
timm #模型库

Utils/tools.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
cur_path = os.path.abspath(os.path.dirname(__file__))
2828

29+
2930
def analysis_dataset(txt):
3031
"""
3132
解析dataset.txt
@@ -37,23 +38,34 @@ def analysis_dataset(txt):
3738
"val": {"imgs": [], "labels": []},
3839
"test": {"imgs": [], "labels": []},
3940
}
40-
labels=set()
41+
labels = set()
4142
for path, label, mode in imgs_list:
4243
assert mode in ["train", "val", "test"]
4344
labels.add(label)
4445
dataset[mode]["imgs"].append(path)
4546
dataset[mode]["labels"].append(label)
46-
labels=list(labels)
47+
labels = list(labels)
4748
labels.sort()
4849

4950
index = list(range(0, len(labels)))
50-
labels_dict= dict(zip(index,labels))
51-
52-
dataset["labels"]=labels
53-
dataset["labels_dict"]=labels_dict
51+
labels_dict = dict(zip(index, labels))
52+
53+
dataset["labels"] = labels
54+
dataset["labels_dict"] = labels_dict
5455

5556
return dataset
5657

58+
def object2dict(object):
59+
'''
60+
类对象->字典
61+
'''
62+
dict={}
63+
64+
for key in dir(object):
65+
if not key.startswith('__'):
66+
dict[key]=getattr(object, key)
67+
return dict
68+
5769
def init_env(cfg):
5870
"""
5971
初始化训练环境
@@ -83,10 +95,10 @@ def init_env(cfg):
8395

8496
# 初始化TensorBoard
8597
tb_writer = SummaryWriter(tb_path)
86-
tb_writer.add_text("Config", str(cfg))
98+
tb_writer.add_text("Config", str(object2dict(cfg)))
8799
print("*" * 28)
88100
print("TensorBoard | Checkpoint save to ", exp_path, "\n")
89-
return tb_writer, checkpoint_path + cfg["Models"]["backbone"]
101+
return tb_writer, checkpoint_path + cfg.backbone
90102

91103

92104
@torch.no_grad()

train.py

Lines changed: 23 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import torch
44
import argparse
5-
import yaml
65
import copy
76
from DataSets.preprocess import PreProcess
87
from DataSets import create_datasets, create_dataloader
@@ -19,71 +18,54 @@
1918
cur_path = os.path.abspath(os.path.dirname(__file__))
2019

2120
if __name__ == "__main__":
22-
parser = argparse.ArgumentParser(description="Train")
23-
parser.add_argument("--yaml", help="训练配置", default=cur_path + "/Config/train.yaml")
24-
args = parser.parse_args()
21+
22+
# !导入训练配置!
23+
from Config.config import cfg
2524

2625
# 初始化环境
2726
device = "cuda" if torch.cuda.is_available() else "cpu"
28-
cfg = yaml.load(open(args.yaml, "r"), Loader=yaml.FullLoader)
29-
labels_list = analysis_dataset(cfg["DataSet"]["txt"])["labels"]
27+
labels_list = analysis_dataset(cfg.txt)["labels"]
3028

3129
tb_writer, checkpoint_path = init_env(cfg)
3230

3331
# 模型
34-
model = create_backbone(cfg["Models"]["backbone"], num_classes=len(labels_list))
32+
model = create_backbone(cfg.backbone, num_classes=len(labels_list))
3533
vis_model = copy.deepcopy(model)
3634
TASK = "metric" if hasattr(model, "embedding_size") else "class"
3735
# 区分任务
3836
if TASK == "metric":
3937
# 数据集
40-
train_set = create_datasets(
41-
txt=cfg["DataSet"]["txt"],
42-
mode="train",
43-
size=cfg["DataSet"]["size"],
44-
use_augment=True,
45-
)
46-
train_dataloader = create_dataloader(
47-
batch_size=cfg["DataSet"]["batch"],
48-
dataset=train_set,
49-
sampler_name=cfg["DataSet"]["sampler"],
50-
)
51-
val_set = create_datasets(
52-
txt=cfg["DataSet"]["txt"], mode="val", size=cfg["DataSet"]["size"]
53-
)
38+
train_set = create_datasets(txt=cfg.txt, mode="train", size=cfg.size, use_augment=True)
39+
train_dataloader = create_dataloader(batch_size=cfg.batch, dataset=train_set, sampler_name=cfg.sampler)
40+
val_set = create_datasets(txt=cfg.txt, mode="val", size=cfg.size)
5441

5542
# 难样例挖掘
5643
mining_func = miners.MultiSimilarityMiner()
5744

5845
# 损失函数(分类器)
5946
loss_func = create_metric_loss(
60-
name=cfg["Models"]["loss"],
47+
name=cfg.loss,
6148
num_classes=len(labels_list),
6249
embedding_size=model.embedding_size,
6350
).to(device)
64-
params = [{"params": loss_func.parameters(), "lr": cfg["Train"]["lr"]}]
51+
params = [{"params": loss_func.parameters(), "lr": cfg.lr}]
6552

6653
else:
6754
# 数据集
68-
train_set = create_datasets(
69-
txt=cfg["DataSet"]["txt"],
70-
mode="train",
71-
size=cfg["DataSet"]["size"],
72-
use_augment=True,
73-
)
74-
val_set = create_datasets(txt=cfg["DataSet"]["txt"], mode="val", size=cfg["DataSet"]["size"])
55+
train_set = create_datasets(txt=cfg.txt, mode="train", size=cfg.size, use_augment=True,)
56+
val_set = create_datasets(txt=cfg.txt, mode="val", size=cfg.size)
7557

7658
# 数据集加载器
7759
train_dataloader = create_dataloader(
78-
batch_size=cfg["DataSet"]["batch"],
60+
batch_size=cfg.batch,
7961
dataset=train_set,
80-
sampler_name=cfg["DataSet"]["sampler"],
62+
sampler_name=cfg.sampler,
8163
)
8264

83-
val_dataloader = create_dataloader(batch_size=cfg["DataSet"]["batch"], dataset=val_set)
65+
val_dataloader = create_dataloader(batch_size=cfg.batch, dataset=val_set)
8466

8567
# 损失函数
86-
loss_func = create_class_loss(cfg["Models"]["loss"]).to(device)
68+
loss_func = create_class_loss(cfg.loss).to(device)
8769
params = []
8870

8971
# 模型转为GPU
@@ -95,18 +77,18 @@
9577
# 优化器
9678
params.append({"params": model.parameters()})
9779
optimizer = create_optimizer(
98-
params, cfg["Models"]["optimizer"], lr=cfg["Train"]["lr"]
80+
params, cfg.optimizer, lr=cfg.lr
9981
)
10082

10183
# 学习率调度器
10284
lr_scheduler = create_scheduler(
103-
sched_name=cfg["Train"]["scheduler"],
104-
epochs=cfg["Train"]["epochs"],
85+
sched_name=cfg.scheduler,
86+
epochs=cfg.epochs,
10587
optimizer=optimizer,
10688
)
10789
best_score = 0.0
108-
for epoch in range(cfg["Train"]["epochs"]):
109-
print("start epoch {}/{}...".format(epoch, cfg["Train"]["epochs"]))
90+
for epoch in range(cfg.epochs):
91+
print("start epoch {}/{}...".format(epoch, cfg.epochs))
11092
tb_writer.add_scalar("Train/lr", optimizer.param_groups[-1]["lr"], epoch)
11193
optimizer.zero_grad()
11294

@@ -140,9 +122,7 @@
140122

141123
ema_model.update(model)
142124

143-
lr_scheduler.step_update(
144-
num_updates=epoch * len(train_dataloader) + batch_idx
145-
)
125+
lr_scheduler.step_update(num_updates=epoch * len(train_dataloader) + batch_idx)
146126

147127
if batch_idx % 100 == 0:
148128
iter_num = int(batch_idx + epoch * len(train_dataloader))
@@ -159,9 +139,7 @@
159139
elif TASK == "metric": # 度量学习
160140
score = eval_metric_model(model, train_set, val_set)
161141
ema_score = eval_metric_model(ema_model.module, train_set, val_set)
162-
tb_writer.add_scalars(
163-
"Eval", {"precision": score, "ema_precision": ema_score}, epoch
164-
)
142+
tb_writer.add_scalars("Eval", {"precision": score, "ema_precision": ema_score}, epoch)
165143
model.train()
166144

167145
# 保存最优模型

0 commit comments

Comments
 (0)