1414
1515"""Load datasets."""
1616
17+ import copy
1718import json
1819import os
19- from typing import Literal , cast
2020import urllib .request
2121
2222from disentangled_rnns .library import pclicks
@@ -146,7 +146,7 @@ def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
146146 ys = np .concatenate ((free_choices , - 1 * np .ones ((1 , n_sess , 1 ))), axis = 0 )
147147
148148 # Pack into a DatasetRNN object
149- dataset_rat = rnn_utils .DatasetRNN (ys = ys , xs = xs , y_type = 'categorical' )
149+ dataset_rat = rnn_utils .DatasetRNNCategorical (ys = ys , xs = xs )
150150
151151 return dataset_rat
152152
@@ -243,7 +243,7 @@ def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
243243 ys = - 1 * np .ones ((101 , n_trials , 1 ))
244244 ys [- 1 ,:, 0 ] = choices
245245
246- dataset_rat = rnn_utils .DatasetRNN (xs , ys , y_type = 'categorical' )
246+ dataset_rat = rnn_utils .DatasetRNNCategorical (xs , ys )
247247
248248 return dataset_rat
249249
@@ -328,7 +328,7 @@ def get_bounded_accumulator_dataset(
328328 )
329329 ys = - 1 * np .ones ((stim_duration_max + 1 , n_trials , 1 ))
330330 ys [- 1 , :, 0 ] = decisions
331- dataset = rnn_utils .DatasetRNN (xs , ys , y_type = 'categorical' )
331+ dataset = rnn_utils .DatasetRNNCategorical (xs , ys )
332332 return dataset
333333
334334
@@ -340,7 +340,8 @@ def dataset_list_to_multisubject(
340340
341341 Multisubject dataset has a new first column containing an integer subject ID.
342342 DisRNN in multisubject mode will convert this first to a one-hot then to a
343- subject embedding.
343+ subject embedding. The returned DatasetRNN will inherit properties like
344+ `batch_mode` and `batch_size` from the first dataset in `dataset_list`.
344345
345346 Args:
346347 dataset_list: List of single-subject datasets
@@ -353,19 +354,6 @@ def dataset_list_to_multisubject(
353354 """
354355 data = dataset_list [0 ].get_all ()
355356 xs_dataset , ys_dataset = data ['xs' ], data ['ys' ]
356- x_names = dataset_list [0 ].x_names
357- y_names = dataset_list [0 ].y_names
358- y_type_str = dataset_list [0 ].y_type
359- n_classes = dataset_list [0 ].n_classes
360-
361- # Runtime check for y_type_str before casting
362- allowed_y_types = ('categorical' , 'scalar' , 'mixed' )
363- if y_type_str not in allowed_y_types :
364- raise ValueError (
365- f'Invalid y_type "{ y_type_str } " found in dataset_list. '
366- f'Expected one of { allowed_y_types } .' )
367- # Cast for pytype
368- y_type = cast (Literal ['categorical' , 'scalar' , 'mixed' ], y_type_str )
369357
370358 # If we're adding a subject ID, we'll add a feature to the xs
371359 if add_subj_id :
@@ -383,22 +371,12 @@ def dataset_list_to_multisubject(
383371 # multisubject dataset
384372 for dataset_i in range (len (dataset_list )):
385373 # Check datasets are compatible
386- assert x_names == dataset_list [dataset_i ].x_names , (
387- f'x_names do not match across datasets. Expected { x_names } , got'
388- f' { dataset_list [dataset_i ].x_names } '
389- )
390- assert y_names == dataset_list [dataset_i ].y_names , (
391- f'y_names do not match across datasets. Expected { y_names } , got'
392- f' { dataset_list [dataset_i ].y_names } '
393- )
394- assert y_type == dataset_list [dataset_i ].y_type , (
395- f'y_type does not match across datasets. Expected { y_type } , got'
396- f' { dataset_list [dataset_i ].y_type } '
397- )
398- assert n_classes == dataset_list [dataset_i ].n_classes , (
399- f'n_classes does not match across datasets. Expected { n_classes } , got'
400- f' { dataset_list [dataset_i ].n_classes } '
401- )
374+ if not rnn_utils .datasets_are_compatible (
375+ dataset_list [0 ], dataset_list [dataset_i ]
376+ ):
377+ raise ValueError (
378+ f'Dataset { dataset_i } is not compatible with dataset 0.'
379+ )
402380
403381 data = dataset_list [dataset_i ].get_all ()
404382 xs_dataset , ys_dataset = data ['xs' ], data ['ys' ]
@@ -458,16 +436,16 @@ def dataset_list_to_multisubject(
458436 ys = np .concatenate ((ys , ys_dataset ), axis = 1 )
459437
460438 if add_subj_id :
461- x_names = ['Subject ID' ] + x_names
462-
463- dataset = rnn_utils . DatasetRNN (
464- xs = xs ,
465- ys = ys ,
466- x_names = x_names ,
467- y_names = y_names ,
468- y_type = y_type ,
469- n_classes = n_classes ,
470- )
439+ x_names = ['Subject ID' ] + dataset_list [ 0 ]. x_names
440+ else :
441+ x_names = dataset_list [ 0 ]. x_names
442+
443+ dataset = copy . deepcopy ( dataset_list [ 0 ])
444+ dataset . _xs = xs # pylint: disable=protected-access
445+ dataset . _ys = ys # pylint: disable=protected-access
446+ dataset . _n_episodes = np . shape ( xs )[ 1 ] # pylint: disable=protected-access
447+ dataset . _n_timesteps = np . shape ( xs )[ 0 ] # pylint: disable=protected-access
448+ dataset . x_names = x_names
471449
472450 return dataset
473451
0 commit comments