forked from TimSong412/OmniTrackFast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup_trainer.py
More file actions
32 lines (29 loc) · 1.02 KB
/
setup_trainer.py
File metadata and controls
32 lines (29 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from trainer_triplanedep import TriplaneDepTrainer
from trainer_combo import ComboTrainer
import torch
import random
import numpy as np
from trainer import BaseTrainer
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def setup_trainer(args, eval=False):
if args.trainer == 'triplanedep':
trainer = TriplaneDepTrainer(args)
elif args.trainer == "combo":
trainer = ComboTrainer(args)
else:
trainer = BaseTrainer(args)
if eval:
if hasattr(trainer, "color_mlp") and trainer.color_mlp is not None:
trainer.color_mlp.eval()
if hasattr(trainer, "deform_mlp") and trainer.deform_mlp is not None:
trainer.deform_mlp.eval()
if hasattr(trainer, "feature_mlp") and trainer.feature_mlp is not None:
trainer.feature_mlp.eval()
if hasattr(trainer, "depthmem") and trainer.depthmem is not None:
trainer.depthmem.eval()
return trainer