Skip to content

Commit fb4fa40

Browse files
committed
nn wrapper; v4; test
1 parent 0665dad commit fb4fa40

13 files changed

+360
-112
lines changed

spf/dataset/spf_dataset.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def __setitem__(self, key, value):
421421
"ground_truth_theta",
422422
"ground_truth_phi",
423423
"craft_ground_truth_theta",
424+
"absolute_theta",
424425
"y_rad",
425426
"y_phi",
426427
"craft_y_rad",
@@ -429,7 +430,7 @@ def __setitem__(self, key, value):
429430

430431
segmentation_based_keys = [
431432
"weighted_beamformer",
432-
"all_windows_stats",
433+
# "all_windows_stats",
433434
"weighted_windows_stats",
434435
"downsampled_segmentation_mask",
435436
"simple_segmentations",
@@ -439,8 +440,8 @@ def __setitem__(self, key, value):
439440
v5_raw_keys = v5rx_f64_keys + v5rx_2xf64_keys + ["signal_matrix"]
440441

441442

442-
def data_single_radio_to_raw(d):
443-
return {k: d[k] for k in v5_raw_keys}
443+
def data_single_radio_to_raw(d, ds):
444+
return {k: d[k] for k in list(set(v5_raw_keys) - set(ds.skip_fields))}
444445

445446

446447
class v5inferencedataset(Dataset):
@@ -470,13 +471,16 @@ def __init__(
470471
skip_segmentation: bool = True,
471472
vehicle_type: str = "",
472473
max_in_memory: int = 10,
474+
realtime: bool = True,
473475
):
474476
# Store configuration parameters
475477
self.yaml_fn = yaml_fn
476478
self.n_parallel = n_parallel
477479
self.nthetas = nthetas # Number of angles to discretize space for beamforming
478480
self.target_ntheta = self.nthetas if target_ntheta is None else target_ntheta
479481

482+
self.realtime = realtime
483+
480484
self.max_in_memory = max_in_memory
481485
self.min_idx = 0
482486
self.condition = multiprocessing.Condition()
@@ -589,6 +593,9 @@ def __init__(
589593
self.empirical_data_fn = None
590594
self.empirical_data = None
591595

596+
def __len__(self):
597+
return self.serving_idx
598+
592599
def __iter__(self):
593600
self.serving_idx = 0
594601
return self
@@ -606,9 +613,11 @@ def __next__(self):
606613
def __getitem__(self, idx, timeout=10.0):
607614
start_time = time.time()
608615
with self.condition:
616+
print("waitinf to get get", idx, time.time() - start_time)
609617
while idx not in self.store or self.store[idx]["count"] != 2:
610618
self.condition.wait(0.01)
611619
if (time.time() - start_time) > timeout:
620+
print("ret waitinf to get get", idx, time.time() - start_time)
612621
return None
613622
return self.store[idx]["data"]
614623

@@ -637,6 +646,7 @@ def render_session(self, idx, ridx, data):
637646
data["receiver_idx"] = torch.tensor([[ridx]], dtype=torch.int)
638647

639648
data["ground_truth_theta"] = torch.tensor([torch.inf]) # unknown
649+
data["absolute_theta"] = torch.tensor([torch.inf])
640650
data["y_rad"] = data["ground_truth_theta"] # torch.inf
641651

642652
data["ground_truth_phi"] = torch.tensor([torch.inf]) # unkown
@@ -774,6 +784,7 @@ def __init__(
774784
skip_detrend: bool = False,
775785
vehicle_type: str = "",
776786
v4: bool = False,
787+
realtime: bool = False,
777788
):
778789
logging.debug(f"loading... {prefix}")
779790
# Store configuration parameters
@@ -786,6 +797,8 @@ def __init__(
786797
self.valid_entries = None
787798
self.temp_file = temp_file
788799

800+
self.realtime = realtime
801+
789802
# Segmentation parameters control how raw signal is processed into windows
790803
# and how phase difference is computed between antenna elements
791804
self.segmentation_version = (
@@ -1286,6 +1299,7 @@ def render_session(self, receiver_idx, session_idx, double_flip=False):
12861299
data["ground_truth_theta"] = self.ground_truth_thetas[receiver_idx][
12871300
snapshot_idxs
12881301
]
1302+
data["absolute_theta"] = self.absolute_thetas[receiver_idx][snapshot_idxs]
12891303
data["ground_truth_phi"] = self.ground_truth_phis[receiver_idx][snapshot_idxs]
12901304
data["craft_ground_truth_theta"] = self.craft_ground_truth_thetas[snapshot_idxs]
12911305
data["vehicle_type"] = torch.tensor(
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from functools import cache
2+
3+
import torch
4+
from torch.utils.data import Dataset
5+
6+
from spf.dataset.spf_dataset import v5_collate_keys_fast
7+
from spf.filters.particle_dual_radio_nn_filter import (
8+
cached_model_inference_to_absolute_north,
9+
)
10+
from spf.model_training_and_inference.models.single_point_networks_inference import (
11+
get_nn_inference_on_ds_and_cache,
12+
load_model_and_config_from_config_fn_and_checkpoint,
13+
)
14+
from spf.rf import rotate_dist, torch_pi_norm_pi
15+
from spf.scripts.train_utils import global_config_to_keys_used
16+
17+
18+
# attach nn inference attributes to dataset entries
19+
class v5spfdataset_nn_wrapper(Dataset):
20+
def __init__(
21+
self,
22+
ds,
23+
checkpoint_config_fn,
24+
checkpoint_fn,
25+
inference_cache,
26+
device="cpu",
27+
v4=None,
28+
absolute=False,
29+
):
30+
self.ds = ds
31+
assert self.ds.paired
32+
self.checkpoint_config_fn = checkpoint_config_fn
33+
self.checkpoint_fn = checkpoint_fn
34+
self.inference_cache = inference_cache
35+
self.absolute = absolute
36+
37+
if v4 is None:
38+
v4 = self.ds.v4
39+
40+
if not ds.realtime:
41+
self.cached_model_inference = {
42+
k: torch.as_tensor(v)
43+
for k, v in get_nn_inference_on_ds_and_cache(
44+
ds_fn=ds.zarr_fn,
45+
config_fn=self.checkpoint_config_fn,
46+
checkpoint_fn=self.checkpoint_fn,
47+
device=device,
48+
inference_cache=inference_cache,
49+
batch_size=64,
50+
workers=0,
51+
precompute_cache=ds.precompute_cache,
52+
crash_if_not_cached=False,
53+
segmentation_version=ds.segmentation_version,
54+
v4=v4,
55+
).items()
56+
}
57+
if self.absolute:
58+
self.cached_model_inference = {
59+
k: cached_model_inference_to_absolute_north(ds, v)
60+
for k, v in self.cached_model_inference.items()
61+
}
62+
else:
63+
self.model, self.model_config = (
64+
load_model_and_config_from_config_fn_and_checkpoint(
65+
self.checkpoint_config_fn, self.checkpoint_fn, device=device
66+
)
67+
)
68+
self.model.eval()
69+
self.keys_to_get = global_config_to_keys_used(
70+
global_config=self.model_config["global"]
71+
)
72+
73+
def to_absolute_north(self, sample):
74+
for ridx in range(2):
75+
ntheta = sample[ridx]["paired"].shape[-1]
76+
paired_nn_inference = sample[ridx]["paired"].reshape(-1, ntheta)
77+
paired_nn_inference_rotated = rotate_dist(
78+
paired_nn_inference,
79+
rotations=torch_pi_norm_pi(
80+
sample[ridx]["rx_heading_in_pis"][:, None] * torch.pi
81+
),
82+
).reshape(paired_nn_inference.shape)
83+
sample[ridx]["paired"] = paired_nn_inference_rotated
84+
return sample
85+
86+
@cache
87+
def get_inference_for_idx(self, idx):
88+
if not self.ds.realtime:
89+
return [
90+
{k: v[idx][ridx] for k, v in self.cached_model_inference.items()}
91+
for ridx in range(2)
92+
]
93+
return self.get_and_annotate_entry_at_idx(idx)
94+
95+
@cache
96+
def get_and_annotate_entry_at_idx(self, idx):
97+
sample = self.ds[idx]
98+
if not self.ds.realtime:
99+
for ridx in range(2):
100+
sample[ridx].update(
101+
{k: v[idx][ridx] for k, v in self.cached_model_inference.items()}
102+
)
103+
return sample
104+
else:
105+
single_example = v5_collate_keys_fast(self.keys_to_get, [sample]).to(
106+
self.model_config["optim"]["device"]
107+
)
108+
with torch.no_grad():
109+
nn_inference = self.model(single_example)
110+
for ridx in range(2):
111+
sample[ridx].update({k: v[ridx] for k, v in nn_inference.items()})
112+
if self.absolute:
113+
sample = self.to_absolute_north(sample)
114+
return sample
115+
116+
def __iter__(self):
117+
self.serving_idx = 0
118+
return self
119+
120+
def __next__(self):
121+
sample = self.get_and_annotate_entry_at_idx(self.serving_idx)
122+
self.serving_idx += 1
123+
return sample
124+
125+
@cache
126+
def __getitem__(self, idx):
127+
return self.get_and_annotate_entry_at_idx(idx)
128+
129+
def __len__(self):
130+
return len(self.ds)
131+
132+
@property
133+
def mean_phase(self):
134+
return self.ds.mean_phase
135+
136+
@property
137+
def ground_truth_phis(self):
138+
return self.ds.ground_truth_phis
139+
140+
@property
141+
def craft_ground_truth_thetas(self):
142+
return self.ds.craft_ground_truth_thetas
143+
144+
@property
145+
def absolute_thetas(self):
146+
return self.ds.absolute_thetas

spf/filters/particle_dual_radio_nn_filter.py

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
11
import torch
22

3-
from spf.dataset.spf_dataset import v5_collate_keys_fast
43
from spf.filters.filters import (
54
ParticleFilter,
65
add_noise,
76
dual_radio_mse_theta_metrics,
87
theta_phi_to_bins,
98
)
10-
from spf.model_training_and_inference.models.single_point_networks_inference import (
11-
convert_datasets_config_to_inference,
12-
get_inference_on_ds,
13-
load_model_and_config_from_config_fn_and_checkpoint,
14-
)
159
from spf.rf import rotate_dist, torch_pi_norm_pi
16-
from spf.scripts.train_single_point import (
17-
global_config_to_keys_used,
18-
load_config_from_fn,
19-
)
2010

2111

2212
def cached_model_inference_to_absolute_north(ds, cached_model_inference):
@@ -40,80 +30,15 @@ def cached_model_inference_to_absolute_north(ds, cached_model_inference):
4030
class PFSingleThetaDualRadioNN(ParticleFilter):
4131
def __init__(
4232
self,
43-
ds,
44-
checkpoint_fn,
45-
config_fn,
46-
inference_cache=None,
47-
device="cpu",
48-
absolute=False,
33+
nn_ds,
4934
):
50-
self.ds = ds
51-
self.absolute = absolute
35+
self.ds = nn_ds
36+
self.absolute = nn_ds.absolute
5237
self.generator = torch.Generator()
5338
self.generator.manual_seed(0)
5439

55-
# checkpoint_config = load_config_from_fn(config_fn)
56-
# assert (
57-
# self.ds.empirical_data_fn
58-
# == checkpoint_config["datasets"]["empirical_data_fn"]
59-
# )
60-
61-
if not self.ds.temp_file:
62-
# cache model results
63-
self.cached_model_inference = torch.as_tensor(
64-
get_inference_on_ds(
65-
ds_fn=ds.zarr_fn,
66-
config_fn=config_fn,
67-
checkpoint_fn=checkpoint_fn,
68-
device=device,
69-
inference_cache=inference_cache,
70-
batch_size=64,
71-
workers=0,
72-
precompute_cache=ds.precompute_cache,
73-
crash_if_not_cached=False,
74-
segmentation_version=ds.segmentation_version,
75-
)["paired"]
76-
)
77-
if self.absolute:
78-
self.cached_model_inference = cached_model_inference_to_absolute_north(
79-
ds, self.cached_model_inference
80-
)
81-
else:
82-
# load the model and such
83-
self.model, self.model_config = (
84-
load_model_and_config_from_config_fn_and_checkpoint(
85-
config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device
86-
)
87-
)
88-
self.model.eval()
89-
90-
self.model_datasets_config = convert_datasets_config_to_inference(
91-
self.model_config["datasets"],
92-
ds_fn=ds.zarr_fn,
93-
precompute_cache=self.ds.precompute_cache,
94-
)
95-
96-
self.model_optim_config = {"device": device, "dtype": torch.float32}
97-
98-
self.model_keys_to_get = global_config_to_keys_used(
99-
global_config=self.model_config["global"]
100-
)
101-
assert not self.absolute # this needs to be implemented
102-
103-
def model_inference_at_observation_idx(self, idx):
104-
if not self.ds.temp_file:
105-
return self.cached_model_inference[idx]
106-
107-
z = v5_collate_keys_fast(self.model_keys_to_get, [self.ds[idx]]).to(
108-
self.model_optim_config["device"]
109-
)
110-
with torch.no_grad():
111-
return self.model(z)["paired"].cpu()
112-
11340
def observation(self, idx):
114-
# even though the model outputs one paired dist for each reciever
115-
# they should be identical
116-
return self.model_inference_at_observation_idx(idx)[0, 0]
41+
return self.ds.get_inference_for_idx(idx)[0]["paired"][0]
11742

11843
def fix_particles(self):
11944
self.particles[:, 0] = torch_pi_norm_pi(self.particles[:, 0])
@@ -125,7 +50,6 @@ def predict(self, our_state, dt, noise_std):
12550
add_noise(self.particles, noise_std=noise_std, generator=self.generator)
12651

12752
def update(self, z):
128-
#
12953
# z is not the raw observation, but the processed model output
13054
theta_bin = theta_phi_to_bins(self.particles[:, 0], nbins=z.shape[0])
13155
prob_theta_given_observation = torch.take(z, theta_bin)

spf/filters/particle_single_radio_nn_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from spf.model_training_and_inference.models.single_point_networks_inference import (
1212
convert_datasets_config_to_inference,
13-
get_inference_on_ds,
13+
get_nn_inference_on_ds_and_cache,
1414
load_model_and_config_from_config_fn_and_checkpoint,
1515
)
1616
from spf.scripts.train_single_point import (
@@ -41,7 +41,7 @@ def __init__(
4141
if not self.ds.temp_file:
4242
# cache model results
4343
self.cached_model_inference = torch.as_tensor(
44-
get_inference_on_ds(
44+
get_nn_inference_on_ds_and_cache(
4545
ds_fn=ds.zarr_fn,
4646
config_fn=config_fn,
4747
checkpoint_fn=checkpoint_fn,

0 commit comments

Comments
 (0)