11import pytest
22import datetime
33import numpy as np
4+ import json
45
56import probeinterface as pi
67
@@ -25,14 +26,14 @@ def set_up_nwbfile():
2526
2627def create_single_shank_probe ():
2728 probe = pi .generate_linear_probe ()
28- probe .annotate (name = "Single-shank" )
29+ probe .annotate (name = "Single-shank" , custom_key = "custom annotation" )
2930 probe .set_contact_ids ([f"c{ i } " for i in range (probe .get_contact_count ())])
3031 return probe
3132
3233
3334def create_multi_shank_probe ():
3435 probe = pi .generate_multi_shank ()
35- probe .annotate (name = "Multi-shank" )
36+ probe .annotate (name = "Multi-shank" , custom_key = "custom annotation" )
3637 probe .set_contact_ids ([f"cm{ i } " for i in range (probe .get_contact_count ())])
3738 return probe
3839
@@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self):
6970 probe_array = probe .to_numpy ()
7071 np .testing .assert_array_equal (contact_table ["contact_position" ][:], probe .contact_positions )
7172 np .testing .assert_array_equal (contact_table ["contact_shape" ][:], probe_array ["contact_shapes" ])
73+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
74+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
75+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
7276
7377 # set channel indices
7478 device_channel_indices = np .arange (probe .get_contact_count ())
@@ -105,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self):
105109 np .testing .assert_array_equal (
106110 contact_table ["shank_id" ][:], probe .shank_ids
107111 )
112+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
113+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
114+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
108115
109116 def test_constructor_from_probegroup (self ):
110117 """Test that the constructor from probegroup sets values as expected."""
@@ -139,6 +146,10 @@ def test_constructor_from_probegroup(self):
139146 contact_table ["device_channel_index_pi" ][:], device_channel_indices
140147 )
141148
149+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
150+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
151+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
152+
142153
143154class TestProbeRoundtrip (TestCase ):
144155 """Simple roundtrip test for Probe device."""
@@ -171,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
171182 with NWBHDF5IO (self .path0 , mode = "r" , load_namespaces = True ) as io :
172183 read_nwbfile = io .read ()
173184 devices = read_nwbfile .devices
174- self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
185+ self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
186+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
187+ filtered_annotations = {
188+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
189+ if key not in keys_to_filter
190+ }
191+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
175192
176193 def test_roundtrip_nwb_from_probe_multi_shank (self ):
177194 devices = Probe .from_probeinterface (self .probe1 )
@@ -185,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
185202 read_nwbfile = io .read ()
186203 devices = read_nwbfile .devices
187204 self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
205+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
206+ filtered_annotations = {
207+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
208+ if key not in keys_to_filter
209+ }
210+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
188211
189212 def test_roundtrip_nwb_from_probegroup (self ):
190213 devices = Probe .from_probeinterface (self .probegroup )
@@ -198,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self):
198221 read_nwbfile = io .read ()
199222 for device in devices :
200223 self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
201-
224+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
225+ filtered_annotations = {
226+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
227+ if key not in keys_to_filter
228+ }
229+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
202230 def test_roundtrip_pi_from_probe_single_shank (self ):
203231 probe_arr = self .probe0 .to_numpy ()
204232 devices = Probe .from_probeinterface (self .probe0 )
205233 device = devices [0 ]
206234 np .testing .assert_array_equal (probe_arr , device .to_probeinterface ().to_numpy ())
235+ self .assertDictEqual (self .probe0 .annotations , device .to_probeinterface ().annotations )
207236
208237 def test_roundtrip_pi_from_probe_multi_shank (self ):
209238 probe_arr = self .probe1 .to_numpy ()
210239 devices = Probe .from_probeinterface (self .probe1 )
211240 device = devices [0 ]
212241 np .testing .assert_array_equal (probe_arr , device .to_probeinterface ().to_numpy ())
242+ self .assertDictEqual (self .probe1 .annotations , device .to_probeinterface ().annotations )
213243
214244
215245if __name__ == "__main__" :
0 commit comments