Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

- Refactor DatasetRNN to be subclassable for different target types and
introduce subclasses for the existing target types. DatasetRNN can no longer
be used directly.

## [0.1.4] - 2026-01-22

- Modify default behavior of DatasetRNN to use random batching, and allow
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/checkpoint_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/disrnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/disrnn_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/example_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
110 changes: 52 additions & 58 deletions disentangled_rnns/library/get_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,10 @@

"""Load datasets."""

import copy
import json
import os
from typing import Literal, cast
from typing import TypeVar
import urllib.request

from disentangled_rnns.library import pclicks
Expand All @@ -31,7 +32,7 @@ def find(s, ch):
return [i for i, ltr in enumerate(s) if ltr == ch]


def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNNCategorical:
"""Downloads and packages rat two-armed bandit datasets.

Dataset is from the following paper:
Expand Down Expand Up @@ -146,12 +147,12 @@ def get_rat_bandit_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
ys = np.concatenate((free_choices, -1*np.ones((1, n_sess, 1))), axis=0)

# Pack into a DatasetRNN object
dataset_rat = rnn_utils.DatasetRNN(ys=ys, xs=xs, y_type='categorical')
dataset_rat = rnn_utils.DatasetRNNCategorical(ys=ys, xs=xs)

return dataset_rat


def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNNCategorical:
"""Packages up rat poisson clicks datasets.

Dataset is from the following paper:
Expand Down Expand Up @@ -243,7 +244,7 @@ def get_pclicks_dataset(rat_i: int = 0) -> rnn_utils.DatasetRNN:
ys = -1*np.ones((101, n_trials, 1))
ys[-1,:, 0] = choices

dataset_rat = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
dataset_rat = rnn_utils.DatasetRNNCategorical(xs, ys)

return dataset_rat

Expand All @@ -255,16 +256,17 @@ def get_q_learning_dataset(
n_trials: int = 500,
n_sessions: int = 20000,
np_rng_seed: float = 0
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Generates synthetic dataset from Q-Learning agent, using standard parameters."""
np.random.seed(np_rng_seed)
rng = np.random.default_rng(np_rng_seed)
agent = two_armed_bandits.AgentQ(alpha=alpha, beta=beta)
environment = two_armed_bandits.EnvironmentBanditsDrift(sigma=sigma)
dataset = two_armed_bandits.create_dataset(
agent,
environment,
n_steps_per_session=n_trials,
n_sessions=n_sessions,
rng=rng,
)
return dataset

Expand All @@ -277,7 +279,7 @@ def get_actor_critic_dataset(
n_trials: int = 500,
n_sessions: int = 20000,
np_rng_seed: float = 0,
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Generates synthetic dataset from Actor-Critic agent, using standard parameters."""
np.random.seed(np_rng_seed)
agent = two_armed_bandits.AgentLeakyActorCritic(
Expand Down Expand Up @@ -307,7 +309,7 @@ def get_bounded_accumulator_dataset(
depression_tau: float = 8.0,
bound: float = 2.9,
lapse: float = 0.0,
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Generates synthetic dataset from Bounded Accumulator."""
xs, _ = pclicks.generate_clicktrains(
n_trials=n_trials,
Expand All @@ -326,24 +328,30 @@ def get_bounded_accumulator_dataset(
bound=bound,
lapse=lapse,
)
ys = -1 * np.ones((stim_duration_max + 1, n_trials, 1))
ys = -1 * np.ones((stim_duration_max + 1, n_trials, 1), dtype=int)
ys[-1, :, 0] = decisions
dataset = rnn_utils.DatasetRNN(xs, ys, y_type='categorical')
dataset = rnn_utils.DatasetRNNCategorical(xs, ys)
return dataset


T = TypeVar('T', bound=rnn_utils.DatasetRNN)


def dataset_list_to_multisubject(
dataset_list: list[rnn_utils.DatasetRNN],
dataset_list: list[T],
add_subj_id: bool = True,
) -> rnn_utils.DatasetRNN:
) -> T:
"""Turn a list of single-subject datasets into a multisubject dataset.

Multisubject dataset has a new first column containing an integer subject ID.
DisRNN in multisubject mode will convert this first to a one-hot then to a
subject embedding.

Args:
dataset_list: List of single-subject datasets
dataset_list: List of single-subject datasets. Datasets must be compatible,
i.e. have the same number of trials, timesteps, and features, and be
instances of the same class -- this is necessary for them to be mergable.
They must also have the same batching and rng object.
add_subj_id: Whether to add a subject ID column to the xs. If True, dataset
is suitable for multisubject mode. If False, dataset is suitable for
single-subject mode, treating all data as if from a single subject.
Expand All @@ -353,19 +361,6 @@ def dataset_list_to_multisubject(
"""
data = dataset_list[0].get_all()
xs_dataset, ys_dataset = data['xs'], data['ys']
x_names = dataset_list[0].x_names
y_names = dataset_list[0].y_names
y_type_str = dataset_list[0].y_type
n_classes = dataset_list[0].n_classes

# Runtime check for y_type_str before casting
allowed_y_types = ('categorical', 'scalar', 'mixed')
if y_type_str not in allowed_y_types:
raise ValueError(
f'Invalid y_type "{y_type_str}" found in dataset_list. '
f'Expected one of {allowed_y_types}.')
# Cast for pytype
y_type = cast(Literal['categorical', 'scalar', 'mixed'], y_type_str)

# If we're adding a subject ID, we'll add a feature to the xs
if add_subj_id:
Expand All @@ -383,23 +378,22 @@ def dataset_list_to_multisubject(
# multisubject dataset
for dataset_i in range(len(dataset_list)):
# Check datasets are compatible
assert x_names == dataset_list[dataset_i].x_names, (
f'x_names do not match across datasets. Expected {x_names}, got'
f' {dataset_list[dataset_i].x_names}'
)
assert y_names == dataset_list[dataset_i].y_names, (
f'y_names do not match across datasets. Expected {y_names}, got'
f' {dataset_list[dataset_i].y_names}'
)
assert y_type == dataset_list[dataset_i].y_type, (
f'y_type does not match across datasets. Expected {y_type}, got'
f' {dataset_list[dataset_i].y_type}'
)
assert n_classes == dataset_list[dataset_i].n_classes, (
f'n_classes does not match across datasets. Expected {n_classes}, got'
f' {dataset_list[dataset_i].n_classes}'
)

if not rnn_utils.datasets_are_compatible(
dataset_list[0], dataset_list[dataset_i]
):
raise ValueError(
f'Dataset {dataset_i} is not compatible with dataset 0.'
)
# Check datasets have the same batching. This is a bit paranoid, but
# if they don't match the merged dataset could violate user expectations.
if (
dataset_list[0].batch_mode != dataset_list[dataset_i].batch_mode
or dataset_list[0].batch_size != dataset_list[dataset_i].batch_size
):
raise ValueError(
f'Dataset {dataset_i} has a different batch mode or batch size than'
f' dataset 0.'
)
data = dataset_list[dataset_i].get_all()
xs_dataset, ys_dataset = data['xs'], data['ys']
n_sessions = np.shape(xs_dataset)[1]
Expand Down Expand Up @@ -458,16 +452,16 @@ def dataset_list_to_multisubject(
ys = np.concatenate((ys, ys_dataset), axis=1)

if add_subj_id:
x_names = ['Subject ID'] + x_names

dataset = rnn_utils.DatasetRNN(
xs=xs,
ys=ys,
x_names=x_names,
y_names=y_names,
y_type=y_type,
n_classes=n_classes,
)
x_names = ['Subject ID'] + dataset_list[0].x_names
else:
x_names = dataset_list[0].x_names

dataset = copy.deepcopy(dataset_list[0])
dataset._xs = xs # pylint: disable=protected-access
dataset._ys = ys # pylint: disable=protected-access
dataset._n_episodes = np.shape(xs)[1] # pylint: disable=protected-access
dataset._n_timesteps = np.shape(xs)[0] # pylint: disable=protected-access
dataset.x_names = x_names

return dataset

Expand All @@ -477,7 +471,7 @@ def get_q_learning_multisubject_dataset(
n_sessions: int = 300,
alphas: list[float] | None = None,
np_rng_seed: float = 0,
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Returns a multisubject dataset for the Q-learning task."""
if alphas is None:
alphas = [0.1, 0.2, 0.3, 0.5, 0.5, 0.6, 0.7, 0.8, 0.9]
Expand All @@ -496,7 +490,7 @@ def get_q_learning_multisubject_dataset(

def get_rat_bandit_multisubject_dataset(
n_rats: int = 20,
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Returns a multisubject dataset for the rat bandit task."""
dataset_list = []
for rat_i in range(n_rats):
Expand All @@ -507,7 +501,7 @@ def get_rat_bandit_multisubject_dataset(

def get_pclick_multisubject_dataset(
n_rats: int = 19,
) -> rnn_utils.DatasetRNN:
) -> rnn_utils.DatasetRNNCategorical:
"""Returns a multisubject dataset for the pClick task."""
dataset_list = []
for rat_i in range(n_rats):
Expand Down
9 changes: 8 additions & 1 deletion disentangled_rnns/library/get_datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,6 +51,13 @@ def test_dataset_list_to_multisubject(self):
[dataset1, dataset2]
)
self.assertIsInstance(multisubject_dataset, rnn_utils.DatasetRNN)
data_dict = multisubject_dataset.get_all()
xs = data_dict["xs"]
ys = data_dict["ys"]
self.assertEqual(xs.shape[0], 12)
self.assertEqual(ys.shape[0], 12)
self.assertEqual(xs.shape[1], 20)
self.assertEqual(ys.shape[1], 20)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/multisubject_disrnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/multisubject_disrnn_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/neuro_disrnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/neuro_disrnn_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions disentangled_rnns/library/pclicks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -132,7 +132,7 @@ def drift_diffusion_model(
)
first_bound_crossing = np.argmax(crossed_bound, axis=0)

decisions = np.zeros(n_trials)
decisions = np.zeros(n_trials, dtype=int)
for trial_i in range(n_trials):
if first_bound_crossing[trial_i] > 0:
decision_variable[first_bound_crossing[trial_i] :, trial_i] = (
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/pclicks_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion disentangled_rnns/library/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 DeepMind Technologies Limited.
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading