Skip to content

Commit 42bbd13

Browse files
kevin-j-millercopybara-github
authored andcommitted
Refactor DatasetRNN to be subclassable for different target types.
Initial classes are categorical, continuous, and mixed. These reproduce the functionality of using the y_type argument to the previous version. DatasetRNNMixed keeps the logic of the current "hybrid" loss -- the intention is to refactor this in a future CL to have separate fields for the continuous and categorical portions of the loss. This also required a refactor of dataset_list_to_multisubject, to be more careful about merging only lists where the datasets are compatible with each other. PiperOrigin-RevId: 864300977
1 parent 93c497f commit 42bbd13

24 files changed

+376
-173
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424

2525
## [Unreleased]
2626

27+
- Refactor DatasetRNN to be subclassable for different target types and
28+
introduce subclasses for the existing target types. DatasetRNN can no longer
29+
be used directly.
30+
2731
## [0.1.4] - 2026-01-22
2832

2933
- Modify default behavior of DatasetRNN to use random batching, and allow

disentangled_rnns/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/checkpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/checkpoint_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/disrnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/disrnn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/example_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

disentangled_rnns/library/get_datasets.py

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 DeepMind Technologies Limited.
1+
# Copyright 2026 DeepMind Technologies Limited.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,9 +14,10 @@
1414

1515
"""Load datasets."""
1616

17+
import copy
1718
import json
1819
import os
19-
from typing import Literal, cast
20+
from typing import TypeVar
2021
import urllib.request
2122

2223
from disentangled_rnns.library import pclicks
@@ -31,7 +32,7 @@ def find(s, ch):
3132
return [i for i, ltr in enumerate(s) if ltr == ch]
3233

3334

34-
def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
35+
def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNNCategorical:
3536
"""Downloads and packages rat two-armed bandit datasets.
3637
3738
Dataset is from the following paper:
@@ -146,12 +147,12 @@ def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
146147
ys = np.concatenate((free_choices, -1*np.ones((1, n_sess, 1))), axis=0)
147148

148149
# Pack into a DatasetRNN object
149-
dataset_rat = rnn_utils.DatasetRNN(ys=ys, xs=xs, y_type='categorical')
150+
dataset_rat = rnn_utils.DatasetRNNCategorical(ys=ys, xs=xs)
150151

151152
return dataset_rat
152153

153154

154-
def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
155+
def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNNCategorical:
155156
"""Packages up rat poisson clicks datasets.
156157
157158
Dataset is from the following paper:
@@ -243,7 +244,7 @@ def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
243244
ys = -1*np.ones((101, n_trials, 1))
244245
ys[-1,:, 0] = choices
245246

246-
dataset_rat = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
247+
dataset_rat = rnn_utils.DatasetRNNCategorical(xs, ys)
247248

248249
return dataset_rat
249250

@@ -255,16 +256,17 @@ def get_q_learning_dataset(
255256
n_trials: int = 500,
256257
n_sessions: int = 20000,
257258
np_rng_seed: float = 0
258-
) -> rnn_utils.DatasetRNN:
259+
) -> rnn_utils.DatasetRNNCategorical:
259260
"""Generates synthetic dataset from Q-Learning agent, using standard parameters."""
260-
np.random.seed(np_rng_seed)
261+
rng = np.random.default_rng(np_rng_seed)
261262
agent = two_armed_bandits.AgentQ(alpha=alpha, beta=beta)
262263
environment = two_armed_bandits.EnvironmentBanditsDrift(sigma=sigma)
263264
dataset = two_armed_bandits.create_dataset(
264265
agent,
265266
environment,
266267
n_steps_per_session=n_trials,
267268
n_sessions=n_sessions,
269+
rng=rng,
268270
)
269271
return dataset
270272

@@ -277,7 +279,7 @@ def get_actor_critic_dataset(
277279
n_trials: int = 500,
278280
n_sessions: int = 20000,
279281
np_rng_seed: float = 0,
280-
) -> rnn_utils.DatasetRNN:
282+
) -> rnn_utils.DatasetRNNCategorical:
281283
"""Generates synthetic dataset from Actor-Critic agent, using standard parameters."""
282284
np.random.seed(np_rng_seed)
283285
agent = two_armed_bandits.AgentLeakyActorCritic(
@@ -307,7 +309,7 @@ def get_bounded_accumulator_dataset(
307309
depression_tau: float = 8.0,
308310
bound: float = 2.9,
309311
lapse: float = 0.0,
310-
) -> rnn_utils.DatasetRNN:
312+
) -> rnn_utils.DatasetRNNCategorical:
311313
"""Generates synthetic dataset from Bounded Accumulator."""
312314
xs, _ = pclicks.generate_clicktrains(
313315
n_trials=n_trials,
@@ -326,24 +328,30 @@ def get_bounded_accumulator_dataset(
326328
bound=bound,
327329
lapse=lapse,
328330
)
329-
ys = -1 * np.ones((stim_duration_max + 1, n_trials, 1))
331+
ys = -1 * np.ones((stim_duration_max + 1, n_trials, 1), dtype=int)
330332
ys[-1, :, 0] = decisions
331-
dataset = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
333+
dataset = rnn_utils.DatasetRNNCategorical(xs, ys)
332334
return dataset
333335

334336

337+
T = TypeVar('T', bound=rnn_utils.DatasetRNN)
338+
339+
335340
def dataset_list_to_multisubject(
336-
dataset_list: list[rnn_utils.DatasetRNN],
341+
dataset_list: list[T],
337342
add_subj_id: bool = True,
338-
) -> rnn_utils.DatasetRNN:
343+
) -> T:
339344
"""Turn a list of single-subject datasets into a multisubject dataset.
340345
341346
Multisubject dataset has a new first column containing an integer subject ID.
342347
DisRNN in multisubject mode will convert this first to a one-hot then to a
343348
subject embedding.
344349
345350
Args:
346-
dataset_list: List of single-subject datasets
351+
dataset_list: List of single-subject datasets. Datasets must be compatible,
352+
i.e. have the same number of trials, timesteps, and features, and be
353+
instances of the same class -- this is necessary for them to be mergable.
354+
They must also have the same batching and rng object.
347355
add_subj_id: Whether to add a subject ID column to the xs. If True, dataset
348356
is suitable for multisubject mode. If False, dataset is suitable for
349357
single-subject mode, treating all data as if from a single subject.
@@ -353,19 +361,6 @@ def dataset_list_to_multisubject(
353361
"""
354362
data = dataset_list[0].get_all()
355363
xs_dataset, ys_dataset = data['xs'], data['ys']
356-
x_names = dataset_list[0].x_names
357-
y_names = dataset_list[0].y_names
358-
y_type_str = dataset_list[0].y_type
359-
n_classes = dataset_list[0].n_classes
360-
361-
# Runtime check for y_type_str before casting
362-
allowed_y_types = ('categorical', 'scalar', 'mixed')
363-
if y_type_str not in allowed_y_types:
364-
raise ValueError(
365-
f'Invalid y_type "{y_type_str}" found in dataset_list. '
366-
f'Expected one of {allowed_y_types}.')
367-
# Cast for pytype
368-
y_type = cast(Literal['categorical', 'scalar', 'mixed'], y_type_str)
369364

370365
# If we're adding a subject ID, we'll add a feature to the xs
371366
if add_subj_id:
@@ -383,23 +378,22 @@ def dataset_list_to_multisubject(
383378
# multisubject dataset
384379
for dataset_i in range(len(dataset_list)):
385380
# Check datasets are compatible
386-
assert x_names == dataset_list[dataset_i].x_names, (
387-
f'x_names do not match across datasets. Expected {x_names}, got'
388-
f' {dataset_list[dataset_i].x_names}'
389-
)
390-
assert y_names == dataset_list[dataset_i].y_names, (
391-
f'y_names do not match across datasets. Expected {y_names}, got'
392-
f' {dataset_list[dataset_i].y_names}'
393-
)
394-
assert y_type == dataset_list[dataset_i].y_type, (
395-
f'y_type does not match across datasets. Expected {y_type}, got'
396-
f' {dataset_list[dataset_i].y_type}'
397-
)
398-
assert n_classes == dataset_list[dataset_i].n_classes, (
399-
f'n_classes does not match across datasets. Expected {n_classes}, got'
400-
f' {dataset_list[dataset_i].n_classes}'
401-
)
402-
381+
if not rnn_utils.datasets_are_compatible(
382+
dataset_list[0], dataset_list[dataset_i]
383+
):
384+
raise ValueError(
385+
f'Dataset {dataset_i} is not compatible with dataset 0.'
386+
)
387+
# Check datasets have the same batching. This is a bit paranoid, but
388+
# if they don't match the merged dataset could violate user expectations.
389+
if (
390+
dataset_list[0].batch_mode != dataset_list[dataset_i].batch_mode
391+
or dataset_list[0].batch_size != dataset_list[dataset_i].batch_size
392+
):
393+
raise ValueError(
394+
f'Dataset {dataset_i} has a different batch mode or batch size than'
395+
f' dataset 0.'
396+
)
403397
data = dataset_list[dataset_i].get_all()
404398
xs_dataset, ys_dataset = data['xs'], data['ys']
405399
n_sessions = np.shape(xs_dataset)[1]
@@ -458,16 +452,16 @@ def dataset_list_to_multisubject(
458452
ys = np.concatenate((ys, ys_dataset), axis=1)
459453

460454
if add_subj_id:
461-
x_names = ['Subject ID'] + x_names
462-
463-
dataset = rnn_utils.DatasetRNN(
464-
xs=xs,
465-
ys=ys,
466-
x_names=x_names,
467-
y_names=y_names,
468-
y_type=y_type,
469-
n_classes=n_classes,
470-
)
455+
x_names = ['Subject ID'] + dataset_list[0].x_names
456+
else:
457+
x_names = dataset_list[0].x_names
458+
459+
dataset = copy.deepcopy(dataset_list[0])
460+
dataset._xs = xs # pylint: disable=protected-access
461+
dataset._ys = ys # pylint: disable=protected-access
462+
dataset._n_episodes = np.shape(xs)[1] # pylint: disable=protected-access
463+
dataset._n_timesteps = np.shape(xs)[0] # pylint: disable=protected-access
464+
dataset.x_names = x_names
471465

472466
return dataset
473467

@@ -477,7 +471,7 @@ def get_q_learning_multisubject_dataset(
477471
n_sessions: int = 300,
478472
alphas: list[float] | None = None,
479473
np_rng_seed: float = 0,
480-
) -> rnn_utils.DatasetRNN:
474+
) -> rnn_utils.DatasetRNNCategorical:
481475
"""Returns a multisubject dataset for the Q-learning task."""
482476
if alphas is None:
483477
alphas = [0.1, 0.2, 0.3, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9]
@@ -496,7 +490,7 @@ def get_q_learning_multisubject_dataset(
496490

497491
def get_rat_bandit_multisubject_dataset(
498492
n_rats: int = 20,
499-
) -> rnn_utils.DatasetRNN:
493+
) -> rnn_utils.DatasetRNNCategorical:
500494
"""Returns a multisubject dataset for the rat bandit task."""
501495
dataset_list = []
502496
for rat_i in range(n_rats):
@@ -507,7 +501,7 @@ def get_rat_bandit_multisubject_dataset(
507501

508502
def get_pclick_multisubject_dataset(
509503
n_rats: int = 19,
510-
) -> rnn_utils.DatasetRNN:
504+
) -> rnn_utils.DatasetRNNCategorical:
511505
"""Returns a multisubject dataset for the pClick task."""
512506
dataset_list = []
513507
for rat_i in range(n_rats):

0 commit comments

Comments
 (0)