11import pytest
22import datetime
33import numpy as np
4+ import json
45
56import probeinterface as pi
67
@@ -25,14 +26,14 @@ def set_up_nwbfile():
2526
2627def create_single_shank_probe ():
2728 probe = pi .generate_linear_probe ()
28- probe .annotate (name = "Single-shank" )
29+ probe .annotate (name = "Single-shank" , custom_key = "custom annotation" )
2930 probe .set_contact_ids ([f"c{ i } " for i in range (probe .get_contact_count ())])
3031 return probe
3132
3233
3334def create_multi_shank_probe ():
3435 probe = pi .generate_multi_shank ()
35- probe .annotate (name = "Multi-shank" )
36+ probe .annotate (name = "Multi-shank" , custom_key = "custom annotation" )
3637 probe .set_contact_ids ([f"cm{ i } " for i in range (probe .get_contact_count ())])
3738 return probe
3839
@@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self):
6970 probe_array = probe .to_numpy ()
7071 np .testing .assert_array_equal (contact_table ["contact_position" ][:], probe .contact_positions )
7172 np .testing .assert_array_equal (contact_table ["contact_shape" ][:], probe_array ["contact_shapes" ])
73+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
74+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
75+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
7276
7377 # set channel indices
7478 device_channel_indices = np .arange (probe .get_contact_count ())
@@ -78,9 +82,6 @@ def test_constructor_from_probe_single_shank(self):
7882 contact_table = device_w_indices .contact_table
7983 np .testing .assert_array_equal (contact_table ["device_channel_index_pi" ][:], device_channel_indices )
8084
81- devices_w_names = Probe .from_probeinterface (probe , name = "Test Probe" )
82- assert devices_w_names [0 ].name == "Test Probe"
83-
8485 def test_constructor_from_probe_multi_shank (self ):
8586 """Test that the constructor from Probe sets values as expected for multi-shank."""
8687
@@ -108,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self):
108109 np .testing .assert_array_equal (
109110 contact_table ["shank_id" ][:], probe .shank_ids
110111 )
112+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
113+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
114+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
111115
112116 def test_constructor_from_probegroup (self ):
113117 """Test that the constructor from probegroup sets values as expected."""
@@ -142,6 +146,10 @@ def test_constructor_from_probegroup(self):
142146 contact_table ["device_channel_index_pi" ][:], device_channel_indices
143147 )
144148
149+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
150+ filtered_annotations = {key : value for key , value in probe .annotations .items () if key not in keys_to_filter }
151+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
152+
145153
146154class TestProbeRoundtrip (TestCase ):
147155 """Simple roundtrip test for Probe device."""
@@ -174,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
174182 with NWBHDF5IO (self .path0 , mode = "r" , load_namespaces = True ) as io :
175183 read_nwbfile = io .read ()
176184 devices = read_nwbfile .devices
177- self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
185+ self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
186+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
187+ filtered_annotations = {
188+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
189+ if key not in keys_to_filter
190+ }
191+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
178192
179193 def test_roundtrip_nwb_from_probe_multi_shank (self ):
180194 devices = Probe .from_probeinterface (self .probe1 )
@@ -188,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
188202 read_nwbfile = io .read ()
189203 devices = read_nwbfile .devices
190204 self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
205+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
206+ filtered_annotations = {
207+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
208+ if key not in keys_to_filter
209+ }
210+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
191211
192212 def test_roundtrip_nwb_from_probegroup (self ):
193213 devices = Probe .from_probeinterface (self .probegroup )
@@ -201,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self):
201221 read_nwbfile = io .read ()
202222 for device in devices :
203223 self .assertContainerEqual (device , read_nwbfile .devices [device .name ])
204-
224+ keys_to_filter = ["name" , "manufacturer" , "model_name" , "serial_number" ]
225+ filtered_annotations = {
226+ key : value for key , value in read_nwbfile .devices [device .name ].to_probeinterface ().annotations .items ()
227+ if key not in keys_to_filter
228+ }
229+ self .assertDictEqual (json .loads (device .annotations ), filtered_annotations )
205230 def test_roundtrip_pi_from_probe_single_shank (self ):
206231 probe_arr = self .probe0 .to_numpy ()
207232 devices = Probe .from_probeinterface (self .probe0 )
208233 device = devices [0 ]
209234 np .testing .assert_array_equal (probe_arr , device .to_probeinterface ().to_numpy ())
235+ self .assertDictEqual (self .probe0 .annotations , device .to_probeinterface ().annotations )
210236
211237 def test_roundtrip_pi_from_probe_multi_shank (self ):
212238 probe_arr = self .probe1 .to_numpy ()
213239 devices = Probe .from_probeinterface (self .probe1 )
214240 device = devices [0 ]
215241 np .testing .assert_array_equal (probe_arr , device .to_probeinterface ().to_numpy ())
242+ self .assertDictEqual (self .probe1 .annotations , device .to_probeinterface ().annotations )
216243
217244
218245if __name__ == "__main__" :
0 commit comments