1- # Copyright 2025 DeepMind Technologies Limited.
1+ # Copyright 2026 DeepMind Technologies Limited.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1414
1515"""Load datasets."""
1616
17+ import copy
1718import json
1819import os
19- from typing import Literal , cast
20+ from typing import TypeVar
2021import urllib .request
2122
2223from disentangled_rnns .library import pclicks
@@ -31,7 +32,7 @@ def find(s, ch):
3132 return [i for i , ltr in enumerate (s ) if ltr == ch ]
3233
3334
34- def get_rat_bandit_dataset (rat_i : int = 0 ) -> rnn_utils .DatasetRNN :
35+ def get_rat_bandit_dataset (rat_i : int = 0 ) -> rnn_utils .DatasetRNNCategorical :
3536 """Downloads and packages rat two-armed bandit datasets.
3637
3738 Dataset is from the following paper:
@@ -146,12 +147,12 @@ def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
146147 ys = np .concatenate ((free_choices , - 1 * np .ones ((1 , n_sess , 1 ))), axis = 0 )
147148
148149 # Pack into a DatasetRNN object
149- dataset_rat = rnn_utils .DatasetRNN (ys = ys , xs = xs , y_type = 'categorical' )
150+ dataset_rat = rnn_utils .DatasetRNNCategorical (ys = ys , xs = xs )
150151
151152 return dataset_rat
152153
153154
154- def get_pclicks_dataset (rat_i : int = 0 ) -> rnn_utils .DatasetRNN :
155+ def get_pclicks_dataset (rat_i : int = 0 ) -> rnn_utils .DatasetRNNCategorical :
155156 """Packages up rat poisson clicks datasets.
156157
157158 Dataset is from the following paper:
@@ -243,7 +244,7 @@ def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
243244 ys = - 1 * np .ones ((101 , n_trials , 1 ))
244245 ys [- 1 ,:, 0 ] = choices
245246
246- dataset_rat = rnn_utils .DatasetRNN (xs , ys , y_type = 'categorical' )
247+ dataset_rat = rnn_utils .DatasetRNNCategorical (xs , ys )
247248
248249 return dataset_rat
249250
@@ -255,16 +256,17 @@ def get_q_learning_dataset(
255256 n_trials : int = 500 ,
256257 n_sessions : int = 20000 ,
257258 np_rng_seed : float = 0
258- ) -> rnn_utils .DatasetRNN :
259+ ) -> rnn_utils .DatasetRNNCategorical :
259260 """Generates synthetic dataset from Q-Learning agent, using standard parameters."""
260- np .random .seed (np_rng_seed )
261+ rng = np .random .default_rng (np_rng_seed )
261262 agent = two_armed_bandits .AgentQ (alpha = alpha , beta = beta )
262263 environment = two_armed_bandits .EnvironmentBanditsDrift (sigma = sigma )
263264 dataset = two_armed_bandits .create_dataset (
264265 agent ,
265266 environment ,
266267 n_steps_per_session = n_trials ,
267268 n_sessions = n_sessions ,
269+ rng = rng ,
268270 )
269271 return dataset
270272
@@ -277,7 +279,7 @@ def get_actor_critic_dataset(
277279 n_trials : int = 500 ,
278280 n_sessions : int = 20000 ,
279281 np_rng_seed : float = 0 ,
280- ) -> rnn_utils .DatasetRNN :
282+ ) -> rnn_utils .DatasetRNNCategorical :
281283 """Generates synthetic dataset from Actor-Critic agent, using standard parameters."""
282284 np .random .seed (np_rng_seed )
283285 agent = two_armed_bandits .AgentLeakyActorCritic (
@@ -307,7 +309,7 @@ def get_bounded_accumulator_dataset(
307309 depression_tau : float = 8.0 ,
308310 bound : float = 2.9 ,
309311 lapse : float = 0.0 ,
310- ) -> rnn_utils .DatasetRNN :
312+ ) -> rnn_utils .DatasetRNNCategorical :
311313 """Generates synthetic dataset from Bounded Accumulator."""
312314 xs , _ = pclicks .generate_clicktrains (
313315 n_trials = n_trials ,
@@ -326,24 +328,30 @@ def get_bounded_accumulator_dataset(
326328 bound = bound ,
327329 lapse = lapse ,
328330 )
329- ys = - 1 * np .ones ((stim_duration_max + 1 , n_trials , 1 ))
331+ ys = - 1 * np .ones ((stim_duration_max + 1 , n_trials , 1 ), dtype = int )
330332 ys [- 1 , :, 0 ] = decisions
331- dataset = rnn_utils .DatasetRNN (xs , ys , y_type = 'categorical' )
333+ dataset = rnn_utils .DatasetRNNCategorical (xs , ys )
332334 return dataset
333335
334336
337+ T = TypeVar ('T' , bound = rnn_utils .DatasetRNN )
338+
339+
335340def dataset_list_to_multisubject (
336- dataset_list : list [rnn_utils . DatasetRNN ],
341+ dataset_list : list [T ],
337342 add_subj_id : bool = True ,
338- ) -> rnn_utils . DatasetRNN :
343+ ) -> T :
339344 """Turn a list of single-subject datasets into a multisubject dataset.
340345
341346 Multisubject dataset has a new first column containing an integer subject ID.
342347 DisRNN in multisubject mode will convert this first to a one-hot then to a
343348 subject embedding.
344349
345350 Args:
346- dataset_list: List of single-subject datasets
351+ dataset_list: List of single-subject datasets. Datasets must be compatible,
352+ i.e. have the same number of trials, timesteps, and features, and be
353+ instances of the same class -- this is necessary for them to be mergable.
354+ They must also have the same batching and rng object.
347355 add_subj_id: Whether to add a subject ID column to the xs. If True, dataset
348356 is suitable for multisubject mode. If False, dataset is suitable for
349357 single-subject mode, treating all data as if from a single subject.
@@ -353,19 +361,6 @@ def dataset_list_to_multisubject(
353361 """
354362 data = dataset_list [0 ].get_all ()
355363 xs_dataset , ys_dataset = data ['xs' ], data ['ys' ]
356- x_names = dataset_list [0 ].x_names
357- y_names = dataset_list [0 ].y_names
358- y_type_str = dataset_list [0 ].y_type
359- n_classes = dataset_list [0 ].n_classes
360-
361- # Runtime check for y_type_str before casting
362- allowed_y_types = ('categorical' , 'scalar' , 'mixed' )
363- if y_type_str not in allowed_y_types :
364- raise ValueError (
365- f'Invalid y_type "{ y_type_str } " found in dataset_list. '
366- f'Expected one of { allowed_y_types } .' )
367- # Cast for pytype
368- y_type = cast (Literal ['categorical' , 'scalar' , 'mixed' ], y_type_str )
369364
370365 # If we're adding a subject ID, we'll add a feature to the xs
371366 if add_subj_id :
@@ -383,23 +378,22 @@ def dataset_list_to_multisubject(
383378 # multisubject dataset
384379 for dataset_i in range (len (dataset_list )):
385380 # Check datasets are compatible
386- assert x_names == dataset_list [dataset_i ].x_names , (
387- f'x_names do not match across datasets. Expected { x_names } , got'
388- f' { dataset_list [dataset_i ].x_names } '
389- )
390- assert y_names == dataset_list [dataset_i ].y_names , (
391- f'y_names do not match across datasets. Expected { y_names } , got'
392- f' { dataset_list [dataset_i ].y_names } '
393- )
394- assert y_type == dataset_list [dataset_i ].y_type , (
395- f'y_type does not match across datasets. Expected { y_type } , got'
396- f' { dataset_list [dataset_i ].y_type } '
397- )
398- assert n_classes == dataset_list [dataset_i ].n_classes , (
399- f'n_classes does not match across datasets. Expected { n_classes } , got'
400- f' { dataset_list [dataset_i ].n_classes } '
401- )
402-
381+ if not rnn_utils .datasets_are_compatible (
382+ dataset_list [0 ], dataset_list [dataset_i ]
383+ ):
384+ raise ValueError (
385+ f'Dataset { dataset_i } is not compatible with dataset 0.'
386+ )
387+ # Check datasets have the same batching. This is a bit paranoid, but
388+ # if they don't match the merged dataset could violate user expectations.
389+ if (
390+ dataset_list [0 ].batch_mode != dataset_list [dataset_i ].batch_mode
391+ or dataset_list [0 ].batch_size != dataset_list [dataset_i ].batch_size
392+ ):
393+ raise ValueError (
394+ f'Dataset { dataset_i } has a different batch mode or batch size than'
395+ f' dataset 0.'
396+ )
403397 data = dataset_list [dataset_i ].get_all ()
404398 xs_dataset , ys_dataset = data ['xs' ], data ['ys' ]
405399 n_sessions = np .shape (xs_dataset )[1 ]
@@ -458,16 +452,16 @@ def dataset_list_to_multisubject(
458452 ys = np .concatenate ((ys , ys_dataset ), axis = 1 )
459453
460454 if add_subj_id :
461- x_names = ['Subject ID' ] + x_names
462-
463- dataset = rnn_utils . DatasetRNN (
464- xs = xs ,
465- ys = ys ,
466- x_names = x_names ,
467- y_names = y_names ,
468- y_type = y_type ,
469- n_classes = n_classes ,
470- )
455+ x_names = ['Subject ID' ] + dataset_list [ 0 ]. x_names
456+ else :
457+ x_names = dataset_list [ 0 ]. x_names
458+
459+ dataset = copy . deepcopy ( dataset_list [ 0 ])
460+ dataset . _xs = xs # pylint: disable=protected-access
461+ dataset . _ys = ys # pylint: disable=protected-access
462+ dataset . _n_episodes = np . shape ( xs )[ 1 ] # pylint: disable=protected-access
463+ dataset . _n_timesteps = np . shape ( xs )[ 0 ] # pylint: disable=protected-access
464+ dataset . x_names = x_names
471465
472466 return dataset
473467
@@ -477,7 +471,7 @@ def get_q_learning_multisubject_dataset(
477471 n_sessions : int = 300 ,
478472 alphas : list [float ] | None = None ,
479473 np_rng_seed : float = 0 ,
480- ) -> rnn_utils .DatasetRNN :
474+ ) -> rnn_utils .DatasetRNNCategorical :
481475 """Returns a multisubject dataset for the Q-learning task."""
482476 if alphas is None :
483477 alphas = [0.1 , 0.2 , 0.3 , 0.5 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 ]
@@ -496,7 +490,7 @@ def get_q_learning_multisubject_dataset(
496490
497491def get_rat_bandit_multisubject_dataset (
498492 n_rats : int = 20 ,
499- ) -> rnn_utils .DatasetRNN :
493+ ) -> rnn_utils .DatasetRNNCategorical :
500494 """Returns a multisubject dataset for the rat bandit task."""
501495 dataset_list = []
502496 for rat_i in range (n_rats ):
@@ -507,7 +501,7 @@ def get_rat_bandit_multisubject_dataset(
507501
508502def get_pclick_multisubject_dataset (
509503 n_rats : int = 19 ,
510- ) -> rnn_utils .DatasetRNN :
504+ ) -> rnn_utils .DatasetRNNCategorical :
511505 """Returns a multisubject dataset for the pClick task."""
512506 dataset_list = []
513507 for rat_i in range (n_rats ):
0 commit comments