Skip to content

Commit 0685e96

Browse files
authored
Merge pull request #28 from FrancescoNegri/main
Fix probe metadata and annotations retention
2 parents 1e24932 + 39bcb5b commit 0685e96

File tree

4 files changed

+78
-20
lines changed

4 files changed

+78
-20
lines changed

spec/ndx-probeinterface.extensions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ groups:
2121
default_value: micrometer
2222
doc: SI unit used to define the probe; e.g. 'meter'.
2323
required: false
24+
- name: annotations
25+
dtype: text
26+
doc: annotations attached to the probe
27+
required: false
2428
datasets:
2529
- name: planar_contour
2630
dtype: float

src/pynwb/ndx_probeinterface/io.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Union, List, Optional
22
import numpy as np
3+
import json
34
from probeinterface import Probe, ProbeGroup
45
from pynwb.file import Device
56

@@ -11,8 +12,7 @@
1112
inverted_unit_map = {v: k for k, v in unit_map.items()}
1213

1314

14-
def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup],
15-
name: Optional[str] = None) -> List[Device]:
15+
def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[Device]:
1616
"""
1717
Construct ndx-probeinterface Probe devices from a probeinterface.Probe
1818
@@ -33,7 +33,7 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup],
3333
probes = probe_or_probegroup.probes
3434
devices = []
3535
for probe in probes:
36-
devices.append(_single_probe_to_nwb_device(probe, name=name))
36+
devices.append(_single_probe_to_nwb_device(probe))
3737
return devices
3838

3939

@@ -53,6 +53,11 @@ def to_probeinterface(ndx_probe) -> Probe:
5353
"""
5454
ndim = ndx_probe.ndim
5555
unit = inverted_unit_map[ndx_probe.unit]
56+
name = ndx_probe.name
57+
serial_number = ndx_probe.serial_number
58+
model_name = ndx_probe.model_name
59+
manufacturer = ndx_probe.manufacturer
60+
5661
polygon = ndx_probe.planar_contour
5762

5863
positions = []
@@ -105,19 +110,27 @@ def to_probeinterface(ndx_probe) -> Probe:
105110
if device_channel_indices is not None:
106111
device_channel_indices = [item for sublist in device_channel_indices for item in sublist]
107112

108-
probeinterface_probe = Probe(ndim=ndim, si_units=unit)
113+
probeinterface_probe = Probe(
114+
ndim=ndim,
115+
si_units=unit,
116+
name=name,
117+
serial_number=serial_number,
118+
manufacturer=manufacturer,
119+
model_name=model_name
120+
)
109121
probeinterface_probe.set_contacts(
110122
positions=positions, shapes=shapes, shape_params=shape_params, plane_axes=plane_axes, shank_ids=shank_ids
111123
)
112124
probeinterface_probe.set_contact_ids(contact_ids=contact_ids)
113125
if device_channel_indices is not None:
114126
probeinterface_probe.set_device_channel_indices(channel_indices=device_channel_indices)
115127
probeinterface_probe.set_planar_contour(polygon)
128+
probeinterface_probe.annotate(**json.loads(ndx_probe.annotations))
116129

117130
return probeinterface_probe
118131

119132

120-
def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
133+
def _single_probe_to_nwb_device(probe: Probe):
121134
from pynwb import get_class
122135

123136
Probe = get_class("Probe", "ndx-probeinterface")
@@ -156,10 +169,11 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
156169
kwargs["shank_id"] = probe.shank_ids[index]
157170
contact_table.add_row(kwargs)
158171

159-
serial_number = probe.serial_number
160-
model_name = probe.model_name
161-
manufacturer = probe.manufacturer
162-
name = name if name is not None else probe.name
172+
annotations = probe.annotations.copy()
173+
name = annotations.pop("name") if "name" in annotations else None
174+
serial_number = annotations.pop("serial_number") if "serial_number" in annotations else None
175+
model_name = annotations.pop("model_name") if "model_name" in annotations else None
176+
manufacturer = annotations.pop("manufacturer") if "manufacturer" in annotations else None
163177

164178
probe_device = Probe(
165179
name=name,
@@ -169,7 +183,8 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
169183
ndim=probe.ndim,
170184
unit=unit_map[probe.si_units],
171185
planar_contour=planar_contour,
172-
contact_table=contact_table
186+
contact_table=contact_table,
187+
annotations=json.dumps(annotations)
173188
)
174189

175-
return probe_device
190+
return probe_device

src/pynwb/tests/test_probe.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import datetime
33
import numpy as np
4+
import json
45

56
import probeinterface as pi
67

@@ -25,14 +26,14 @@ def set_up_nwbfile():
2526

2627
def create_single_shank_probe():
2728
probe = pi.generate_linear_probe()
28-
probe.annotate(name="Single-shank")
29+
probe.annotate(name="Single-shank", custom_key="custom annotation")
2930
probe.set_contact_ids([f"c{i}" for i in range(probe.get_contact_count())])
3031
return probe
3132

3233

3334
def create_multi_shank_probe():
3435
probe = pi.generate_multi_shank()
35-
probe.annotate(name="Multi-shank")
36+
probe.annotate(name="Multi-shank", custom_key="custom annotation")
3637
probe.set_contact_ids([f"cm{i}" for i in range(probe.get_contact_count())])
3738
return probe
3839

@@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self):
6970
probe_array = probe.to_numpy()
7071
np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions)
7172
np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"])
73+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
74+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
75+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
7276

7377
# set channel indices
7478
device_channel_indices = np.arange(probe.get_contact_count())
@@ -78,9 +82,6 @@ def test_constructor_from_probe_single_shank(self):
7882
contact_table = device_w_indices.contact_table
7983
np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices)
8084

81-
devices_w_names = Probe.from_probeinterface(probe, name="Test Probe")
82-
assert devices_w_names[0].name == "Test Probe"
83-
8485
def test_constructor_from_probe_multi_shank(self):
8586
"""Test that the constructor from Probe sets values as expected for multi-shank."""
8687

@@ -108,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self):
108109
np.testing.assert_array_equal(
109110
contact_table["shank_id"][:], probe.shank_ids
110111
)
112+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
113+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
114+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
111115

112116
def test_constructor_from_probegroup(self):
113117
"""Test that the constructor from probegroup sets values as expected."""
@@ -142,6 +146,10 @@ def test_constructor_from_probegroup(self):
142146
contact_table["device_channel_index_pi"][:], device_channel_indices
143147
)
144148

149+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
150+
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
151+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
152+
145153

146154
class TestProbeRoundtrip(TestCase):
147155
"""Simple roundtrip test for Probe device."""
@@ -174,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
174182
with NWBHDF5IO(self.path0, mode="r", load_namespaces=True) as io:
175183
read_nwbfile = io.read()
176184
devices = read_nwbfile.devices
177-
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
185+
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
186+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
187+
filtered_annotations = {
188+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
189+
if key not in keys_to_filter
190+
}
191+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
178192

179193
def test_roundtrip_nwb_from_probe_multi_shank(self):
180194
devices = Probe.from_probeinterface(self.probe1)
@@ -188,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
188202
read_nwbfile = io.read()
189203
devices = read_nwbfile.devices
190204
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
205+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
206+
filtered_annotations = {
207+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
208+
if key not in keys_to_filter
209+
}
210+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
191211

192212
def test_roundtrip_nwb_from_probegroup(self):
193213
devices = Probe.from_probeinterface(self.probegroup)
@@ -201,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self):
201221
read_nwbfile = io.read()
202222
for device in devices:
203223
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
204-
224+
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
225+
filtered_annotations = {
226+
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
227+
if key not in keys_to_filter
228+
}
229+
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
205230
def test_roundtrip_pi_from_probe_single_shank(self):
206231
probe_arr = self.probe0.to_numpy()
207232
devices = Probe.from_probeinterface(self.probe0)
208233
device = devices[0]
209234
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
235+
self.assertDictEqual(self.probe0.annotations, device.to_probeinterface().annotations)
210236

211237
def test_roundtrip_pi_from_probe_multi_shank(self):
212238
probe_arr = self.probe1.to_numpy()
213239
devices = Probe.from_probeinterface(self.probe1)
214240
device = devices[0]
215241
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
242+
self.assertDictEqual(self.probe1.annotations, device.to_probeinterface().annotations)
216243

217244

218245
if __name__ == "__main__":

src/spec/create_extension_spec.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ def main():
100100
probe = NWBGroupSpec(
101101
doc="Neural probe object according to probeinterface specification",
102102
attributes=[
103-
NWBAttributeSpec(name="ndim", doc="dimension of the probe", dtype="int", required=True, default_value=2),
103+
NWBAttributeSpec(
104+
name="ndim",
105+
doc="dimension of the probe",
106+
dtype="int",
107+
required=True,
108+
default_value=2
109+
),
104110
NWBAttributeSpec(
105111
name="model_name",
106112
doc="model of the probe; e.g. 'Neuropixels 1.0'",
@@ -120,6 +126,12 @@ def main():
120126
required=True,
121127
default_value="micrometer",
122128
),
129+
NWBAttributeSpec(
130+
name="annotations",
131+
doc="annotations attached to the probe",
132+
dtype="text",
133+
required=False
134+
),
123135
],
124136
neurodata_type_inc="Device",
125137
neurodata_type_def="Probe",
@@ -151,4 +163,4 @@ def main():
151163

152164
if __name__ == "__main__":
153165
# usage: python create_extension_spec.py
154-
main()
166+
main()

0 commit comments

Comments
 (0)