-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathimle_sdd.py
More file actions
328 lines (240 loc) · 11.5 KB
/
imle_sdd.py
File metadata and controls
328 lines (240 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import os
import torch
import argparse
import copy
from torch.utils.data import DataLoader, random_split
from tensorboardX import SummaryWriter
from data.dataloader_sdd import SDDDataset, seq_collate_sdd, seq_collate_imle_train
from utils.config import Config
from utils.utils import back_up_code_git, set_random_seed, log_config_to_file
from models.flow_matching import FlowMatcher
from models.imle import IMLE
from models.backbone_eth_ucy import ETHIMLETransformer
from trainer.imle_trainers import IMLETrainer
def parse_config():
"""
Parse the command line arguments and return the configuration options.
"""
parser = argparse.ArgumentParser()
# Basic configuration
parser.add_argument('--cfg', default='cfg/sdd/imle.yml', type=str, help="Config file path")
parser.add_argument('--exp', default='', type=str, help='Experiment description for each run, name of the saving folder.')
parser.add_argument('--eval', default=False, action='store_true', help='Evaluate the model using the ckpt, default is to train the IMLE model.')
parser.add_argument('--eval_on_train', default=False, action='store_true', help='Evaluate the model on the training set.')
parser.add_argument('--save_samples', default=False, action='store_true', help='Save the samples during evaluation.')
# Data configuration
parser.add_argument('--epochs', default=None, type=int, help='Override the number of epochs in the config file.')
parser.add_argument('--batch_size', default=None, type=int, help='Override the batch size in the config file.')
parser.add_argument('--data_dir', type=str, default='./data/sdd', help='Directory where the data is stored.')
parser.add_argument('--n_train', type=int, default=32500, help='Number training scenes used.')
parser.add_argument('--n_test', type=int, default=12500, help='Number testing scenes used.')
parser.add_argument('--checkpt_freq', default=5, type=int, help='Override the checkpt_freq in the config file.')
parser.add_argument('--max_num_ckpts', default=5, type=int, help='Override the max_num_ckpts in the config file.')
parser.add_argument('--data_norm', default='min_max', choices=['min_max', 'sqrt'], help='Normalization method for the data.')
parser.add_argument('--rotate', default=False, action='store_true', help="Whether to rotate the trajectories in the dataset")
parser.add_argument('--rotate_time_frame', type=int, default=0, help='Index of time frames to rotate the trajectories.')
parser.add_argument('--rotate_aug', default=False, action='store_true', help='Whether to use rotation as data augmentation.')
# Reproducibility configuration
parser.add_argument('--fix_random_seed', action='store_true', default=False, help='fix random seed for reproducibility')
parser.add_argument('--seed', type=int, default=42, help='Set the random seed.')
### IMLE parameters ###
parser.add_argument('--load_pretrained', default=False, action='store_true', help='Whether to load a pretrained encoder.')
parser.add_argument('--ckpt_path', type=str, default=None, help='Path to the pretrained encoder checkpoint.')
parser.add_argument('--latent_tau', type=int, default=0, help='The remaining denoising steps needed for the distillation latents')
parser.add_argument('--objective', type=str, default='set', choices=['set', 'single'], help='Whether to generate a set of latents or a single latent for one noise vector.')
parser.add_argument('--num_to_gen', type=int, default=30, help='Number of samples to generate to compute IMLE loss.')
### IMLE parameters ###
### Architecture configuration ###
parser.add_argument("--use_pre_norm", default=False, action='store_true', help='Where to normalize the input trajectories in the Transformer Encoders.')
### Architecture configuration ###
### Regression loss configuration ###
parser.add_argument('--loss_reg_gt_weight', default=0.0, type=float, help='Weight for the ground truth supervision loss.')
parser.add_argument('--loss_reg_chamfer_weight', default=1.0, type=float, help='Weight for the chamfer distance loss.')
parser.add_argument('--perturb_ctx', default=0.0, type=float, help='The scale of perturbation applied to the contextual input to the network.')
parser.add_argument('--loss_reg_reduction', type=str, default='sum', choices=['mean', 'sum'], help='Reduction method for the regression loss.')
parser.add_argument('--loss_reg_squared', default=False, action='store_true', help='Whether to use the squared regression loss.')
### Regression loss configuration ###
### Optimization configuration ###
parser.add_argument('--init_lr', type=float, default=None, help='Override the peak learning rate in the config file.')
parser.add_argument('--weight_decay', type=float, default=None, help='Override the weight decay in the config file.')
### Optimization configuration ###
return parser.parse_args()
def init_basics(args):
"""
Init the basic configurations for the experiment.
"""
"""Load the config file"""
cfg = Config(args.cfg, f'{args.exp}')
tag = '_'
### Architecture configuration ###
def _update_architecture_params(args, cfg, tag):
cfg.MODEL.USE_PRE_NORM = args.use_pre_norm
return cfg, tag
cfg, tag = _update_architecture_params(args, cfg, tag)
### General denoising objective configuration ###
def _update_general_imle_params(args, cfg, tag):
cfg.latent_tau = args.latent_tau
cfg.objective = args.objective
cfg.num_to_gen = args.num_to_gen
cfg.load_pretrained = args.load_pretrained
tag += f'_IMLE_gen_{cfg.objective}_M_{cfg.num_to_gen}'
if cfg.load_pretrained:
cfg.ckpt_path = args.ckpt_path
tag += '_load_enc'
return cfg, tag
cfg, tag = _update_general_imle_params(args, cfg, tag)
### Regression loss configuration ###
def _update_regression_loss_params(args, cfg, tag):
cfg.loss_reg_gt_weight = args.loss_reg_gt_weight
cfg.loss_reg_chamfer_weight = args.loss_reg_chamfer_weight
tag += '_GT_{:.2f}_Chamfer_{:.2f}'.format(cfg.loss_reg_gt_weight, cfg.loss_reg_chamfer_weight)
tag += '_PC_{:.2f}'.format(args.perturb_ctx)
cfg.loss_reg_reduction = args.loss_reg_reduction
cfg.loss_reg_squared = args.loss_reg_squared
tag += f'_REG_{cfg.loss_reg_reduction[:1].upper()}'
if args.loss_reg_squared:
tag += '_SQ'
return cfg, tag
cfg, tag = _update_regression_loss_params(args, cfg, tag)
### Update data configuration ###
def _update_data_params(args, cfg, tag):
cfg.rotate = args.rotate
if args.rotate:
cfg.rotate_time_frame = args.rotate_time_frame
cfg.rotate_aug = args.rotate_aug
tag += f'_rot_{cfg.rotate_time_frame}'
if cfg.rotate_aug:
tag += '_aug'
if args.n_train != 32500:
tag += f'_subset{args.n_train}'
cfg.data_norm = args.data_norm
tag += f'_{args.data_norm}'
return cfg, tag
cfg, tag = _update_data_params(args, cfg, tag)
### Update optimization configs ###
def _update_optimization_params(args, cfg, tag):
if args.init_lr is not None:
cfg.OPTIMIZATION.LR = args.init_lr
if args.weight_decay is not None:
cfg.OPTIMIZATION.WEIGHT_DECAY = args.weight_decay
tag += f'_LR{cfg.OPTIMIZATION.LR}_WD{cfg.OPTIMIZATION.WEIGHT_DECAY}'
if args.epochs is not None:
# override the number of epochs
cfg.OPTIMIZATION.NUM_EPOCHS = args.epochs
if args.batch_size is not None:
# override the batch size
cfg.train_batch_size = args.batch_size
cfg.test_batch_size = args.batch_size
if args.checkpt_freq is not None:
# override the checkpt_freq
cfg.checkpt_freq = args.checkpt_freq
cfg.max_num_ckpts = args.max_num_ckpts
tag += f'_BS{cfg.train_batch_size}_EP{cfg.OPTIMIZATION.NUM_EPOCHS}'
return cfg, tag
cfg, tag = _update_optimization_params(args, cfg, tag)
### voila, create the saving directory ###
tag = tag.replace('__', '_')
cfg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger = cfg.create_dirs(tag_suffix=tag)
"""fix random seed"""
if args.fix_random_seed:
set_random_seed(args.seed)
"""set up tensorboard and text log"""
tb_dir = os.path.abspath(os.path.join(cfg.log_dir, '../tb'))
os.makedirs(tb_dir, exist_ok=True)
tb_log = SummaryWriter(log_dir=tb_dir)
"""back up the code"""
back_up_code_git(cfg, logger=logger)
"""print the config file"""
log_config_to_file(cfg.yml_dict, logger=logger)
return cfg, logger, tb_log
def build_data_loader(cfg, args):
"""
Build the data loader for the NBA dataset.
"""
train_dset = SDDDataset(
cfg=cfg,
training=True,
data_dir=args.data_dir,
rotate_time_frame=args.rotate_time_frame,
imle=True)
train_loader = DataLoader(
train_dset,
batch_size=cfg.train_batch_size,
shuffle=True,
num_workers=4,
collate_fn=seq_collate_imle_train,
pin_memory=True)
test_dset = SDDDataset(
cfg=cfg,
training=False,
data_dir=args.data_dir,
rotate_time_frame=args.rotate_time_frame,
imle=False)
test_loader = DataLoader(
test_dset,
batch_size=cfg.test_batch_size, ### change it from 500
shuffle=False,
num_workers=4,
collate_fn=seq_collate_sdd,
pin_memory=True)
return train_loader, test_loader
def build_network(cfg, args, logger):
"""
Build the network for the denoising model.
"""
model = ETHIMLETransformer(
model_config=cfg.MODEL,
logger=logger,
config=cfg,
)
if args.load_pretrained:
state_dict = torch.load(args.ckpt_path, map_location='cpu')['model']
for key, val in state_dict.items():
if 'context_encoder' in key:
if key.startswith('model.'):
key = key.replace('model.', '')
assert key in model.state_dict(), f"Key {key} not in model state dict."
model.state_dict()[key].copy_(val)
if 'motion_decoder' in key:
if key.startswith('model.'):
key = key.replace('model.', '')
if 'adaLN' in key:
continue
assert key in model.state_dict(), f"Key {key} not in model state dict."
model.state_dict()[key].copy_(val)
logger.info("The pretrained encoder and decoder have been loaded.")
imle_model = IMLE(
cfg=cfg,
model=model,
logger=logger,
)
return imle_model
def main():
"""
Main function to train the model.
"""
"""Init everything"""
args = parse_config()
cfg, logger, tb_log = init_basics(args)
train_loader, test_loader = build_data_loader(cfg, args)
imle_model = build_network(cfg, args, logger)
"""Train the model"""
trainer = IMLETrainer(
cfg,
imle_model,
train_loader,
test_loader,
tb_log=tb_log,
logger=logger,
gradient_accumulate_every=1,
ema_decay = 0.995,
ema_update_every = 1,
save_samples=args.save_samples
) ### grid search
if args.eval:
trainer.test(mode='best', eval_on_train=args.eval_on_train)
else:
trainer.train()
if __name__ == "__main__":
main()