Skip to content

Commit 39bcb5b

Browse files
Add annotations retention tests
1 parent 50d90b3 commit 39bcb5b

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

src/pynwb/tests/test_probe.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import datetime
33
import numpy as np
4+
import json
45

56
import probeinterface as pi
67

@@ -25,14 +26,14 @@ def set_up_nwbfile():
2526

2627
def create_single_shank_probe():
2728
probe = pi.generate_linear_probe()
28-
probe.annotate(name="Single-shank")
29+
probe.annotate(name="Single-shank", custom_key="custom annotation")
2930
probe.set_contact_ids([f"c{i}" for i in range(probe.get_contact_count())])
3031
return probe
3132

3233

3334
def create_multi_shank_probe():
3435
probe = pi.generate_multi_shank()
35-
probe.annotate(name="Multi-shank")
36+
probe.annotate(name="Multi-shank", custom_key="custom annotation")
3637
probe.set_contact_ids([f"cm{i}" for i in range(probe.get_contact_count())])
3738
return probe
3839

@@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self):
6970
probe_array = probe.to_numpy()
7071
np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions)
7172
np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"])
73+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
74+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
75+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
7276

7377
# set channel indices
7478
device_channel_indices = np.arange(probe.get_contact_count())
@@ -105,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self):
105109
np.testing.assert_array_equal(
106110
contact_table["shank_id"][:], probe.shank_ids
107111
)
112+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
113+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
114+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
108115

109116
def test_constructor_from_probegroup(self):
110117
"""Test that the constructor from probegroup sets values as expected."""
@@ -139,6 +146,10 @@ def test_constructor_from_probegroup(self):
139146
contact_table["device_channel_index_pi"][:], device_channel_indices
140147
)
141148

149+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
150+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
151+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
152+
142153

143154
class TestProbeRoundtrip(TestCase):
144155
"""Simple roundtrip test for Probe device."""
@@ -171,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
171182
with NWBHDF5IO(self.path0, mode="r", load_namespaces=True) as io:
172183
read_nwbfile = io.read()
173184
devices = read_nwbfile.devices
174-
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
185+
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
186+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
187+
filtered_annotations = {
188+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
189+
if key not in keys_to_filter
190+
}
191+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
175192

176193
def test_roundtrip_nwb_from_probe_multi_shank(self):
177194
devices = Probe.from_probeinterface(self.probe1)
@@ -185,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
185202
read_nwbfile = io.read()
186203
devices = read_nwbfile.devices
187204
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
205+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
206+
filtered_annotations = {
207+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
208+
if key not in keys_to_filter
209+
}
210+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
188211

189212
def test_roundtrip_nwb_from_probegroup(self):
190213
devices = Probe.from_probeinterface(self.probegroup)
@@ -198,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self):
198221
read_nwbfile = io.read()
199222
for device in devices:
200223
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
201-
224+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
225+
filtered_annotations = {
226+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
227+
if key not in keys_to_filter
228+
}
229+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
202230
def test_roundtrip_pi_from_probe_single_shank(self):
203231
probe_arr = self.probe0.to_numpy()
204232
devices = Probe.from_probeinterface(self.probe0)
205233
device = devices[0]
206234
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
235+
self.assertDictEqual(self.probe0.annotations, device.to_probeinterface().annotations)
207236

208237
def test_roundtrip_pi_from_probe_multi_shank(self):
209238
probe_arr = self.probe1.to_numpy()
210239
devices = Probe.from_probeinterface(self.probe1)
211240
device = devices[0]
212241
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
242+
self.assertDictEqual(self.probe1.annotations, device.to_probeinterface().annotations)
213243

214244

215245
if __name__ == "__main__":

0 commit comments

Comments
 (0)