Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions disentangled_rnns/library/disrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

"""Disentangled RNN and plotting functions."""

import dataclasses
from typing import Optional, Callable, Any, Sequence
from typing import Any, Callable, Optional, Sequence

import haiku as hk
import jax
Expand Down Expand Up @@ -42,8 +43,8 @@ def information_bottleneck(
inputs: The inputs to the bottleneck. Shape is (batch_size, bottleneck_dims)
sigmas: The standard deviations of the sampling distribution (diagonal of
the sqrt of the covariance matrix). Shape is (bottleneck_dims).
multipliers: The multipliers to apply to the inputs. Shape is
(batch_size, bottleneck_dims)
multipliers: The multipliers to apply to the inputs. Shape is (batch_size,
bottleneck_dims)
noiseless_mode: If True, no noise is added and no penalty is computed.

Returns:
Expand Down Expand Up @@ -92,6 +93,7 @@ def reparameterize_sigma(
hk_param: The haiku parameter corresponding to a bottleneck sigma. Range
from -inf to +inf
min_sigma: The minimum value of the standard deviation.

Returns:
sigma: The bottleneck standard deviation. Range from min_sigma to inf.
"""
Expand All @@ -104,8 +106,8 @@ class DisRnnConfig:

Attributes:
obs_size: Number of dimensions in the observation vector
output_size: Number of dimensions the disRNN will output
(logits or predicted targets)
output_size: Number of dimensions the disRNN will output (logits or
predicted targets)
latent_size: Number of recurrent variables
update_net_n_units_per_layer: Number of units in each layer of the update
networks
Expand All @@ -117,8 +119,8 @@ class DisRnnConfig:
latent_penalty: Multiplier for KL cost on the latent bottlenecks
choice_net_latent_penalty: Multiplier for bottleneck cost on latent inputs
to the choice network
update_net_obs_penalty: Multiplier for bottleneck cost on observation
inputs to the update network
update_net_obs_penalty: Multiplier for bottleneck cost on observation inputs
to the update network
update_net_latent_penalty: Multiplier for latent inputs to the update
networks
l2_scale: Multiplier for L2 penalty on hidden layer weights in both update
Expand All @@ -128,6 +130,9 @@ class DisRnnConfig:
prevent runaway latents resulting in NaNs
x_names: Names of the observation vector elements. Must have length obs_size
y_names: Names of the target vector elements. Must have length target_size
enable_aux_outputs: if enabled, supported classes will also give the
auxiliary outputs. Can be used for getting internal model states such as
subject embeddings etc.
"""

obs_size: int = 2
Expand All @@ -149,11 +154,13 @@ class DisRnnConfig:

l2_scale: float = 0.01

max_latent_value: float = 2.
max_latent_value: float = 2.0

x_names: Optional[list[str]] = None
y_names: Optional[list[str]] = None

enable_aux_outputs: bool = False

def __post_init__(self):
"""Checks that the configuration is valid."""

Expand Down Expand Up @@ -190,13 +197,15 @@ class ResMLP(hk.Module):
name: Optional name, which affects the names of the haiku parameters
"""

def __init__(self,
input_size: int,
output_size: int,
n_layers: int = 5,
n_units_per_layer: int = 5,
activation_fn: Callable[[Any], Any] = jax.nn.relu,
name=None):
def __init__(
self,
input_size: int,
output_size: int,
n_layers: int = 5,
n_units_per_layer: int = 5,
activation_fn: Callable[[Any], Any] = jax.nn.relu,
name=None,
):
super().__init__(name=name)

self.n_layers = n_layers
Expand Down Expand Up @@ -258,9 +267,7 @@ def __init__(self,

# Compute sum of squares of all hidden layer weights. This will be passed on
# and can be used to compute an L2 (ridge) penalty.
self.l2 = (
jnp.sum(jnp.square(jnp.array(self._hidden_layer_weights)))
)
self.l2 = jnp.sum(jnp.square(jnp.array(self._hidden_layer_weights)))

def __call__(self, inputs):

Expand All @@ -284,7 +291,8 @@ def __call__(self, inputs):


def get_initial_bottleneck_params(
shape: Sequence[int], name: str,
shape: Sequence[int],
name: str,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Defines a bottleneck with a sigma and a multiplier."""
# At init the bottlenecks should all be open: sigmas small and multipliers 1
Expand Down Expand Up @@ -328,6 +336,7 @@ def __init__(
self._choice_net_latent_penalty = config.choice_net_latent_penalty
self._activation = getattr(jax.nn, config.activation)
self._max_latent_value = config.max_latent_value
self._auxiliary_outputs = config.enable_aux_outputs

# Get Haiku parameters. IMPORTANT: if you are subclassing HkDisentangledRNN,
# you must override _get_haiku_parameters to add any new parameters that you
Expand Down Expand Up @@ -369,11 +378,9 @@ def _build_latent_bottlenecks(self):
"""Initializes parameters for the latent bottlenecks."""
# Latents will also go through a bottleneck after being updated. These
# bottlenecks do not need multipliers, the network output can rescale them
self._latent_sigmas, _ = (
get_initial_bottleneck_params(
shape=(self._latent_size,),
name='latent',
)
self._latent_sigmas, _ = get_initial_bottleneck_params(
shape=(self._latent_size,),
name='latent',
)

def _build_choice_bottlenecks(self):
Expand Down Expand Up @@ -404,6 +411,7 @@ def update_latents(
Args:
update_net_inputs: Additional inputs for the update rules.
prev_latent_values: The latents from the previous time step.

Returns:
new_latent_values: The updated latents.
penalty_increment: A penalty associated with the update.
Expand Down Expand Up @@ -468,8 +476,8 @@ def predict_targets(
n_units_per_layer=self._choice_net_n_units_per_layer,
n_layers=self._choice_net_n_layers,
activation_fn=self._activation,
name='choice_net'
)(choice_net_inputs)
name='choice_net',
)(choice_net_inputs)
penalty_increment += self._l2_scale * choice_net_l2

return predicted_targets, penalty_increment
Expand Down Expand Up @@ -545,10 +553,12 @@ def __call__(self, observations: jnp.ndarray, prev_latents: jnp.ndarray):
return output, new_latents


def log_bottlenecks(params,
open_thresh: float = 0.1,
partially_open_thresh: float = 0.25,
closed_thresh: float = 0.9) -> dict[str, int]:
def log_bottlenecks(
params,
open_thresh: float = 0.1,
partially_open_thresh: float = 0.25,
closed_thresh: float = 0.9,
) -> dict[str, int]:
"""Computes info about bottlenecks for the base DisRNN."""

params_disrnn = params['hk_disentangled_rnn']
Expand Down Expand Up @@ -612,7 +622,7 @@ def log_bottlenecks(params,
'update_bottlenecks_open': int(update_bottlenecks_open),
'update_bottlenecks_partial': int(update_bottlenecks_partial),
'update_bottlenecks_closed': int(update_bottlenecks_closed),
}
}
return bottleneck_dict


Expand All @@ -622,13 +632,17 @@ def get_total_sigma(params):
params_disrnn = params['hk_disentangled_rnn']

latent_bottlenecks = reparameterize_sigma(
params_disrnn['latent_sigma_params'])
params_disrnn['latent_sigma_params']
)
update_obs_bottlenecks = reparameterize_sigma(
params_disrnn['update_net_obs_sigma_params'])
params_disrnn['update_net_obs_sigma_params']
)
update_latent_bottlenecks = reparameterize_sigma(
params_disrnn['update_net_latent_sigma_params'])
params_disrnn['update_net_latent_sigma_params']
)
choice_bottlenecks = reparameterize_sigma(
params_disrnn['choice_net_sigma_params'])
params_disrnn['choice_net_sigma_params']
)

return float(
jnp.sum(latent_bottlenecks)
Expand Down
7 changes: 6 additions & 1 deletion disentangled_rnns/library/multisubject_disrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,12 @@ def __call__(self, inputs: jnp.ndarray, prev_latents: jnp.ndarray):
output = output.at[:, :-1].set(predicted_targets)
output = output.at[:, -1].set(penalty)

return output, new_latents
final_outputs = output, new_latents
if self._auxiliary_outputs:
aux_output = {'subject_embeddings': subject_embeddings}
final_outputs = (*final_outputs, aux_output)

return final_outputs


def get_auxiliary_metrics(
Expand Down
38 changes: 6 additions & 32 deletions disentangled_rnns/library/neuro_disrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,39 +198,13 @@ def plot_bottlenecks(
sort_latents,
)

plt.close()

# Regenerate the figure with the neural_activity bottlenecks added to the end.
base_axes = fig.axes
fig, axes = plt.subplots(1, 4, figsize=(25, 5))

im = axes[-1].imshow(
np.swapaxes([1 - neural_activity_sigmas], 0, 1), cmap='Oranges'
fig = plotting.append_bottleneck(
fig=fig,
bottleneck_values=neural_activity_sigmas,
bottleneck_names=list(latent_names),
title='Neural Activity Bottlenecks',
sort_latents=sort_latents,
)
im.set_clim(vmin=0, vmax=1)
axes[-1].set_title('Neural Activity Bottlenecks')
axes[-1].set_ylabel('Latent # (Sorted)' if sort_latents else 'Latent #')
axes[-1].set_xticks([]) # Remove x-axis ticks as it's a 1D representation
axes[-1].set_yticks(ticks=range(len(latent_names)), labels=latent_names)

for i, ax in enumerate(base_axes):
if len(ax.images) < 1:
continue
image = ax.images[0].get_array()
im = axes[i].imshow(image, cmap='Oranges')
im.set_clim(vmin=0, vmax=1)
axes[i].set_title(ax.get_title())
axes[i].set_ylabel(ax.get_ylabel())
axes[i].set_xticks(
ticks=ax.get_xticks(),
labels=ax.get_xticklabels(),
rotation='vertical',
)
axes[i].set_yticks(ticks=ax.get_yticks(), labels=ax.get_yticklabels())
axes[i].set_ylim(ax.get_ylim())
axes[i].set_xlim(ax.get_xlim())

# fig.tight_layout()
return fig


Expand Down
77 changes: 77 additions & 0 deletions disentangled_rnns/library/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,83 @@ def plot_bottlenecks(
return fig


def append_bottleneck(
fig: plt.Figure,
bottleneck_values: np.ndarray,
bottleneck_names: list,
title: str,
sort_latents: bool,
) -> plt.Figure:
"""Appends a bottleneck plot to an existing figure of bottleneck plots."""
old_axes_props = []
subplot_axes = [ax for ax in fig.axes if ax.get_subplotspec() is not None]

for ax in subplot_axes:
# For some reason, some subplots don't have any images, so we skip them.
if len(ax.images) < 1:
continue
xtickrotation = 0
if ax.get_xticklabels():
xtickrotation = ax.get_xticklabels()[0].get_rotation()

props = {
'title': ax.get_title(),
'ylabel': ax.get_ylabel(),
'xlabel': ax.get_xlabel(),
'xticks': ax.get_xticks(),
'xticklabels': [l.get_text() for l in ax.get_xticklabels()],
'xtickrotation': xtickrotation,
'yticks': ax.get_yticks(),
'yticklabels': [l.get_text() for l in ax.get_yticklabels()],
'ylim': ax.get_ylim(),
'xlim': ax.get_xlim(),
'images': [],
}
for im in ax.images:
props['images'].append({
'data': im.get_array(),
'cmap': im.get_cmap(),
'clim': im.get_clim(),
})
old_axes_props.append(props)

n_axes_old = len(old_axes_props)

fig, axes = plt.subplots(
1, n_axes_old + 1, figsize=(5 * (n_axes_old + 1), 5)
)

ax = axes[-1]
im = ax.imshow(np.swapaxes([1 - bottleneck_values], 0, 1), cmap='Oranges')
im.set_clim(vmin=0, vmax=1)
ax.set_title(title)
ylabel = 'Latent # (Sorted)' if sort_latents else 'Latent #'
ax.set_ylabel(ylabel)
ax.set_xticks([])
ax.set_yticks(ticks=range(len(bottleneck_names)), labels=bottleneck_names)


for i, props in enumerate(old_axes_props):
ax = axes[i]
if props['images']:
im_props = props['images'][0]
im = ax.imshow(im_props['data'], cmap=im_props['cmap'])
im.set_clim(im_props['clim'])

ax.set_title(props['title'])
ax.set_xlabel(props['xlabel'])
ax.set_ylabel(props['ylabel'])
ax.set_xticks(props['xticks'])
ax.set_xticklabels(props['xticklabels'], rotation=props['xtickrotation'])
ax.set_yticks(props['yticks'])
ax.set_yticklabels(props['yticklabels'])
ax.set_ylim(props['ylim'])
ax.set_xlim(props['xlim'])

#fig.tight_layout()
return fig


def plot_update_rules(
params: hk.Params,
disrnn_config: disrnn.DisRnnConfig,
Expand Down