|
2 | 2 | import os |
3 | 3 | import torch |
4 | 4 | import argparse |
5 | | -import yaml |
6 | 5 | import copy |
7 | 6 | from DataSets.preprocess import PreProcess |
8 | 7 | from DataSets import create_datasets, create_dataloader |
|
19 | 18 | cur_path = os.path.abspath(os.path.dirname(__file__)) |
20 | 19 |
|
21 | 20 | if __name__ == "__main__": |
22 | | - parser = argparse.ArgumentParser(description="Train") |
23 | | - parser.add_argument("--yaml", help="训练配置", default=cur_path + "/Config/train.yaml") |
24 | | - args = parser.parse_args() |
| 21 | + |
| 22 | + # !导入训练配置! |
| 23 | + from Config.config import cfg |
25 | 24 |
|
26 | 25 | # 初始化环境 |
27 | 26 | device = "cuda" if torch.cuda.is_available() else "cpu" |
28 | | - cfg = yaml.load(open(args.yaml, "r"), Loader=yaml.FullLoader) |
29 | | - labels_list = analysis_dataset(cfg["DataSet"]["txt"])["labels"] |
| 27 | + labels_list = analysis_dataset(cfg.txt)["labels"] |
30 | 28 |
|
31 | 29 | tb_writer, checkpoint_path = init_env(cfg) |
32 | 30 |
|
33 | 31 | # 模型 |
34 | | - model = create_backbone(cfg["Models"]["backbone"], num_classes=len(labels_list)) |
| 32 | + model = create_backbone(cfg.backbone, num_classes=len(labels_list)) |
35 | 33 | vis_model = copy.deepcopy(model) |
36 | 34 | TASK = "metric" if hasattr(model, "embedding_size") else "class" |
37 | 35 | # 区分任务 |
38 | 36 | if TASK == "metric": |
39 | 37 | # 数据集 |
40 | | - train_set = create_datasets( |
41 | | - txt=cfg["DataSet"]["txt"], |
42 | | - mode="train", |
43 | | - size=cfg["DataSet"]["size"], |
44 | | - use_augment=True, |
45 | | - ) |
46 | | - train_dataloader = create_dataloader( |
47 | | - batch_size=cfg["DataSet"]["batch"], |
48 | | - dataset=train_set, |
49 | | - sampler_name=cfg["DataSet"]["sampler"], |
50 | | - ) |
51 | | - val_set = create_datasets( |
52 | | - txt=cfg["DataSet"]["txt"], mode="val", size=cfg["DataSet"]["size"] |
53 | | - ) |
| 38 | + train_set = create_datasets(txt=cfg.txt, mode="train", size=cfg.size, use_augment=True) |
| 39 | + train_dataloader = create_dataloader(batch_size=cfg.batch, dataset=train_set, sampler_name=cfg.sampler) |
| 40 | + val_set = create_datasets(txt=cfg.txt, mode="val", size=cfg.size) |
54 | 41 |
|
55 | 42 | # 难样例挖掘 |
56 | 43 | mining_func = miners.MultiSimilarityMiner() |
57 | 44 |
|
58 | 45 | # 损失函数(分类器) |
59 | 46 | loss_func = create_metric_loss( |
60 | | - name=cfg["Models"]["loss"], |
| 47 | + name=cfg.loss, |
61 | 48 | num_classes=len(labels_list), |
62 | 49 | embedding_size=model.embedding_size, |
63 | 50 | ).to(device) |
64 | | - params = [{"params": loss_func.parameters(), "lr": cfg["Train"]["lr"]}] |
| 51 | + params = [{"params": loss_func.parameters(), "lr": cfg.lr}] |
65 | 52 |
|
66 | 53 | else: |
67 | 54 | # 数据集 |
68 | | - train_set = create_datasets( |
69 | | - txt=cfg["DataSet"]["txt"], |
70 | | - mode="train", |
71 | | - size=cfg["DataSet"]["size"], |
72 | | - use_augment=True, |
73 | | - ) |
74 | | - val_set = create_datasets(txt=cfg["DataSet"]["txt"], mode="val", size=cfg["DataSet"]["size"]) |
| 55 | + train_set = create_datasets(txt=cfg.txt, mode="train", size=cfg.size, use_augment=True,) |
| 56 | + val_set = create_datasets(txt=cfg.txt, mode="val", size=cfg.size) |
75 | 57 |
|
76 | 58 | # 数据集加载器 |
77 | 59 | train_dataloader = create_dataloader( |
78 | | - batch_size=cfg["DataSet"]["batch"], |
| 60 | + batch_size=cfg.batch, |
79 | 61 | dataset=train_set, |
80 | | - sampler_name=cfg["DataSet"]["sampler"], |
| 62 | + sampler_name=cfg.sampler, |
81 | 63 | ) |
82 | 64 |
|
83 | | - val_dataloader = create_dataloader(batch_size=cfg["DataSet"]["batch"], dataset=val_set) |
| 65 | + val_dataloader = create_dataloader(batch_size=cfg.batch, dataset=val_set) |
84 | 66 |
|
85 | 67 | # 损失函数 |
86 | | - loss_func = create_class_loss(cfg["Models"]["loss"]).to(device) |
| 68 | + loss_func = create_class_loss(cfg.loss).to(device) |
87 | 69 | params = [] |
88 | 70 |
|
89 | 71 | # 模型转为GPU |
|
95 | 77 | # 优化器 |
96 | 78 | params.append({"params": model.parameters()}) |
97 | 79 | optimizer = create_optimizer( |
98 | | - params, cfg["Models"]["optimizer"], lr=cfg["Train"]["lr"] |
| 80 | + params, cfg.optimizer, lr=cfg.lr |
99 | 81 | ) |
100 | 82 |
|
101 | 83 | # 学习率调度器 |
102 | 84 | lr_scheduler = create_scheduler( |
103 | | - sched_name=cfg["Train"]["scheduler"], |
104 | | - epochs=cfg["Train"]["epochs"], |
| 85 | + sched_name=cfg.scheduler, |
| 86 | + epochs=cfg.epochs, |
105 | 87 | optimizer=optimizer, |
106 | 88 | ) |
107 | 89 | best_score = 0.0 |
108 | | - for epoch in range(cfg["Train"]["epochs"]): |
109 | | - print("start epoch {}/{}...".format(epoch, cfg["Train"]["epochs"])) |
| 90 | + for epoch in range(cfg.epochs): |
| 91 | + print("start epoch {}/{}...".format(epoch, cfg.epochs)) |
110 | 92 | tb_writer.add_scalar("Train/lr", optimizer.param_groups[-1]["lr"], epoch) |
111 | 93 | optimizer.zero_grad() |
112 | 94 |
|
|
140 | 122 |
|
141 | 123 | ema_model.update(model) |
142 | 124 |
|
143 | | - lr_scheduler.step_update( |
144 | | - num_updates=epoch * len(train_dataloader) + batch_idx |
145 | | - ) |
| 125 | + lr_scheduler.step_update(num_updates=epoch * len(train_dataloader) + batch_idx) |
146 | 126 |
|
147 | 127 | if batch_idx % 100 == 0: |
148 | 128 | iter_num = int(batch_idx + epoch * len(train_dataloader)) |
|
159 | 139 | elif TASK == "metric": # 度量学习 |
160 | 140 | score = eval_metric_model(model, train_set, val_set) |
161 | 141 | ema_score = eval_metric_model(ema_model.module, train_set, val_set) |
162 | | - tb_writer.add_scalars( |
163 | | - "Eval", {"precision": score, "ema_precision": ema_score}, epoch |
164 | | - ) |
| 142 | + tb_writer.add_scalars("Eval", {"precision": score, "ema_precision": ema_score}, epoch) |
165 | 143 | model.train() |
166 | 144 |
|
167 | 145 | # 保存最优模型 |
|
0 commit comments