Skip to content

Commit 2b04a03

Browse files
kevin-j-millercopybara-github
authored andcommitted
Refactor DatasetRNN to be subclassable for different target types.
Initial classes are categorical, continuous, and hybrid. These reproduce the functionality of using the y_type argument to the previous version. Hybrid keeps the logic of the current "hybrid" loss -- the intention is to refactor it in a future CL to have separate fields for the continuous and categorical portions of the loss. This also required a refactor of dataset_list_to_multisubject, to be more careful about merging only lists where the datasets are compatible with each other. PiperOrigin-RevId: 864300977
1 parent 93c497f commit 2b04a03

File tree

8 files changed

+293
-119
lines changed

8 files changed

+293
-119
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424

2525
## [Unreleased]
2626

27+
- Refactor DatasetRNN to be subclassable for different target types and
28+
introduce subclasses for the existing target types. DatasetRNN can no longer
29+
be used directly.
30+
2731
## [0.1.4] - 2026-01-22
2832

2933
- Modify default behavior of DatasetRNN to use random batching, and allow

disentangled_rnns/library/get_datasets.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
"""Load datasets."""
1616

17+
import copy
1718
import json
1819
import os
19-
from typing import Literal, cast
2020
import urllib.request
2121

2222
from disentangled_rnns.library import pclicks
@@ -146,7 +146,7 @@ def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
146146
ys = np.concatenate((free_choices, -1*np.ones((1, n_sess, 1))), axis=0)
147147

148148
# Pack into a DatasetRNN object
149-
dataset_rat = rnn_utils.DatasetRNN(ys=ys, xs=xs, y_type='categorical')
149+
dataset_rat = rnn_utils.DatasetRNNCategorical(ys=ys, xs=xs)
150150

151151
return dataset_rat
152152

@@ -243,7 +243,7 @@ def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
243243
ys = -1*np.ones((101, n_trials, 1))
244244
ys[-1,:, 0] = choices
245245

246-
dataset_rat = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
246+
dataset_rat = rnn_utils.DatasetRNNCategorical(xs, ys)
247247

248248
return dataset_rat
249249

@@ -328,7 +328,7 @@ def get_bounded_accumulator_dataset(
328328
)
329329
ys = -1 * np.ones((stim_duration_max + 1, n_trials, 1))
330330
ys[-1, :, 0] = decisions
331-
dataset = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
331+
dataset = rnn_utils.DatasetRNNCategorical(xs, ys)
332332
return dataset
333333

334334

@@ -340,7 +340,8 @@ def dataset_list_to_multisubject(
340340
341341
Multisubject dataset has a new first column containing an integer subject ID.
342342
DisRNN in multisubject mode will convert this first to a one-hot then to a
343-
subject embedding.
343+
subject embedding. The returned DatasetRNN will inherit properties like
344+
`batch_mode` and `batch_size` from the first dataset in `dataset_list`.
344345
345346
Args:
346347
dataset_list: List of single-subject datasets
@@ -353,19 +354,6 @@ def dataset_list_to_multisubject(
353354
"""
354355
data = dataset_list[0].get_all()
355356
xs_dataset, ys_dataset = data['xs'], data['ys']
356-
x_names = dataset_list[0].x_names
357-
y_names = dataset_list[0].y_names
358-
y_type_str = dataset_list[0].y_type
359-
n_classes = dataset_list[0].n_classes
360-
361-
# Runtime check for y_type_str before casting
362-
allowed_y_types = ('categorical', 'scalar', 'mixed')
363-
if y_type_str not in allowed_y_types:
364-
raise ValueError(
365-
f'Invalid y_type "{y_type_str}" found in dataset_list. '
366-
f'Expected one of {allowed_y_types}.')
367-
# Cast for pytype
368-
y_type = cast(Literal['categorical', 'scalar', 'mixed'], y_type_str)
369357

370358
# If we're adding a subject ID, we'll add a feature to the xs
371359
if add_subj_id:
@@ -383,22 +371,12 @@ def dataset_list_to_multisubject(
383371
# multisubject dataset
384372
for dataset_i in range(len(dataset_list)):
385373
# Check datasets are compatible
386-
assert x_names == dataset_list[dataset_i].x_names, (
387-
f'x_names do not match across datasets. Expected {x_names}, got'
388-
f' {dataset_list[dataset_i].x_names}'
389-
)
390-
assert y_names == dataset_list[dataset_i].y_names, (
391-
f'y_names do not match across datasets. Expected {y_names}, got'
392-
f' {dataset_list[dataset_i].y_names}'
393-
)
394-
assert y_type == dataset_list[dataset_i].y_type, (
395-
f'y_type does not match across datasets. Expected {y_type}, got'
396-
f' {dataset_list[dataset_i].y_type}'
397-
)
398-
assert n_classes == dataset_list[dataset_i].n_classes, (
399-
f'n_classes does not match across datasets. Expected {n_classes}, got'
400-
f' {dataset_list[dataset_i].n_classes}'
401-
)
374+
if not rnn_utils.datasets_are_compatible(
375+
dataset_list[0], dataset_list[dataset_i]
376+
):
377+
raise ValueError(
378+
f'Dataset {dataset_i} is not compatible with dataset 0.'
379+
)
402380

403381
data = dataset_list[dataset_i].get_all()
404382
xs_dataset, ys_dataset = data['xs'], data['ys']
@@ -458,16 +436,16 @@ def dataset_list_to_multisubject(
458436
ys = np.concatenate((ys, ys_dataset), axis=1)
459437

460438
if add_subj_id:
461-
x_names = ['Subject ID'] + x_names
462-
463-
dataset = rnn_utils.DatasetRNN(
464-
xs=xs,
465-
ys=ys,
466-
x_names=x_names,
467-
y_names=y_names,
468-
y_type=y_type,
469-
n_classes=n_classes,
470-
)
439+
x_names = ['Subject ID'] + dataset_list[0].x_names
440+
else:
441+
x_names = dataset_list[0].x_names
442+
443+
dataset = copy.deepcopy(dataset_list[0])
444+
dataset._xs = xs # pylint: disable=protected-access
445+
dataset._ys = ys # pylint: disable=protected-access
446+
dataset._n_episodes = np.shape(xs)[1] # pylint: disable=protected-access
447+
dataset._n_timesteps = np.shape(xs)[0] # pylint: disable=protected-access
448+
dataset.x_names = x_names
471449

472450
return dataset
473451

disentangled_rnns/library/get_datasets_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def test_dataset_list_to_multisubject(self):
5151
[dataset1, dataset2]
5252
)
5353
self.assertIsInstance(multisubject_dataset, rnn_utils.DatasetRNN)
54+
data_dict = multisubject_dataset.get_all()
55+
xs = data_dict["xs"]
56+
ys = data_dict["ys"]
57+
self.assertEqual(xs.shape[0], 12)
58+
self.assertEqual(ys.shape[0], 12)
59+
self.assertEqual(xs.shape[1], 20)
60+
self.assertEqual(ys.shape[1], 20)
5461

5562

5663
if __name__ == "__main__":

0 commit comments

Comments
 (0)