Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions disentangled_rnns/library/disrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Disentangled RNN and plotting functions."""
import dataclasses
from typing import Optional, Callable, Any, Sequence
from typing import Callable, Any, Sequence

from disentangled_rnns.library import rnn_utils
import haiku as hk
Expand All @@ -26,7 +26,7 @@
def information_bottleneck(
inputs: jnp.ndarray,
sigmas: jnp.ndarray,
multipliers: Optional[jnp.ndarray] = None,
multipliers: jnp.ndarray | None = None,
noiseless_mode: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
r"""Output from an information bottleneck given a vector of means and std devs.
Expand Down Expand Up @@ -152,8 +152,8 @@ class DisRnnConfig:

max_latent_value: float = 2.

x_names: Optional[list[str]] = None
y_names: Optional[list[str]] = None
x_names: list[str] | None = None
y_names: list[str] | None = None

def __post_init__(self):
"""Checks that the configuration is valid."""
Expand Down Expand Up @@ -387,7 +387,7 @@ def _build_choice_bottlenecks(self):
)
)

def initial_state(self, batch_size: Optional[int]) -> Any:
def initial_state(self, batch_size: int | None) -> Any:
# (batch_size, latent_size)
latents = jnp.ones([batch_size, self._latent_size]) * self._latent_inits
return latents
Expand Down
4 changes: 2 additions & 2 deletions disentangled_rnns/library/get_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import json
import os
from typing import Optional, Literal, cast
from typing import Literal, cast
import urllib.request

from disentangled_rnns.library import pclicks
Expand Down Expand Up @@ -475,7 +475,7 @@ def dataset_list_to_multisubject(
def get_q_learning_multisubject_dataset(
n_trials: int = 200,
n_sessions: int = 300,
alphas: Optional[list[float]] = None,
alphas: list[float] | None = None,
np_rng_seed: float = 0,
) -> rnn_utils.DatasetRNN:
"""Returns a multisubject dataset for the Q-learning task."""
Expand Down
7 changes: 3 additions & 4 deletions disentangled_rnns/library/neuro_disrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections.abc import Callable
import copy
import dataclasses
from typing import Optional

from disentangled_rnns.library import disrnn
from disentangled_rnns.library import plotting
Expand Down Expand Up @@ -238,7 +237,7 @@ def plot_neural_activity_rules(
params: rnn_utils.RnnParams,
disrnn_config: DisRnnWNeuralActivityConfig,
axis_lim: float = 2.1,
) -> Optional[plt.Figure]:
) -> plt.Figure | None:
"""Plots the neural_activity rule of a DisRNN with neural_activity prediction.
This function visualizes how the predicted neural_activity level changes based
Expand Down Expand Up @@ -440,7 +439,7 @@ def plot_choice_rule(
params: rnn_utils.RnnParams,
disrnn_config: DisRnnWNeuralActivityConfig,
axis_lim: float = 2.1,
) -> Optional[plt.Figure]:
) -> plt.Figure | None:
"""Plots the choice rule of a DisRNN with neural_activity prediction."""

params = {
Expand All @@ -456,7 +455,7 @@ def plot_update_rules(
params: rnn_utils.RnnParams,
disrnn_config: DisRnnWNeuralActivityConfig,
axis_lim: float = 2.1,
) -> Optional[plt.Figure]:
) -> plt.Figure | None:
"""Plots the update rules of a DisRNN with neural_activity prediction."""
params = {
key.replace('hk_neuro_disentangled_rnn', 'hk_disentangled_rnn'): value
Expand Down
7 changes: 3 additions & 4 deletions disentangled_rnns/library/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Plotting functions for inspecting Disentangled RNNs."""

import copy
from typing import Optional

from disentangled_rnns.library import disrnn
from disentangled_rnns.library import multisubject_disrnn
Expand Down Expand Up @@ -204,7 +203,7 @@ def plot_bottlenecks(
def plot_update_rules(
params: rnn_utils.RnnParams,
disrnn_config: disrnn.DisRnnConfig,
subj_ind: Optional[int] = None,
subj_ind: int | None = None,
axis_lim: float = 2.1,
) -> list[plt.Figure]:
"""Generates visualizations of the update rules of a HkDisentangledRNN."""
Expand Down Expand Up @@ -439,9 +438,9 @@ def plot_update_2d(params, unit_i, unit_input, observations, titles):
def plot_choice_rule(
params: rnn_utils.RnnParams,
disrnn_config: disrnn.DisRnnConfig,
subj_embedding: Optional[np.ndarray] = None,
subj_embedding: np.ndarray | None = None,
axis_lim: float = 2.1,
) -> Optional[plt.Figure]:
) -> plt.Figure | None:
"""Plots the choice rule of a DisRNN.
Args:
Expand Down
16 changes: 8 additions & 8 deletions disentangled_rnns/library/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Callable
import json
import sys
from typing import Any, Literal, Mapping, Optional
from typing import Any, Literal, Mapping
import warnings

from absl import logging
Expand Down Expand Up @@ -624,11 +624,11 @@ def compute_penalty(
def train_network(
make_network: Callable[[], hk.RNNCore],
training_dataset: DatasetRNN,
validation_dataset: Optional[DatasetRNN],
validation_dataset: DatasetRNN | None,
opt: optax.GradientTransformation = optax.adam(1e-3),
random_key: Optional[chex.PRNGKey] = None,
opt_state: Optional[optax.OptState] = None,
params: Optional[RnnParams] = None,
random_key: chex.PRNGKey | None = None,
opt_state: optax.OptState | None = None,
params: RnnParams | None = None,
n_steps: int = 1000,
max_grad_norm: float = 1,
loss_param: dict[str, float] | float = 1.0,
Expand All @@ -643,7 +643,7 @@ def train_network(
log_losses_every: int = 10,
do_plot: bool = False,
report_progress_by: Literal['print', 'log', 'wandb', 'none'] = 'print',
wandb_run: Optional[Any] = None,
wandb_run: Any | None = None,
wandb_step_offset: int = 0,
) -> tuple[RnnParams, optax.OptState, dict[str, np.ndarray]]:
"""Trains a Haiku recurrent neural network.
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def step_network(

def get_initial_state(
make_network: Callable[[], hk.RNNCore],
params: Optional[RnnParams] = None,
params: RnnParams | None = None,
batch_size: int = 1,
seed: int = 0,
) -> Any:
Expand Down Expand Up @@ -1112,7 +1112,7 @@ def unroll_network():
def get_new_params(
make_network: Callable[..., hk.RNNCore],
input_size: int,
random_key: Optional[jax.Array] = None,
random_key: jax.Array | None = None,
) -> RnnParams:
"""Get a new set of random parameters for a network architecture.
Expand Down
8 changes: 4 additions & 4 deletions disentangled_rnns/library/two_armed_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import abc
from collections.abc import Callable
from typing import Literal, NamedTuple, Optional, Union
from typing import Literal, NamedTuple, Union
import warnings

from disentangled_rnns.library import rnn_utils
Expand Down Expand Up @@ -45,7 +45,7 @@ class BaseEnvironment(abc.ABC):
n_arms: The number of arms in the environment.
"""

def __init__(self, seed: Optional[int] = None, n_arms: int = 2):
def __init__(self, seed: int | None = None, n_arms: int = 2):
self._random_state = np.random.RandomState(seed)
self._n_arms = n_arms

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
self,
sigma: float,
p_instructed: float = 0.0,
seed: Optional[int] = None,
seed: int | None = None,
n_arms: int = 2,
):
super().__init__(seed=seed, n_arms=n_arms)
Expand Down Expand Up @@ -183,7 +183,7 @@ class EnvironmentPayoutMatrix(BaseEnvironment):
def __init__(
self,
payout_matrix: np.ndarray,
instructed_matrix: Optional[np.ndarray] = None,
instructed_matrix: np.ndarray | None = None,
):
"""Initialize the environment.

Expand Down