Skip to content

ege-erdogan/equivariant-sae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Equivariant Sparse Autoencoders: Mechanistic Interpretability of Neural Networks on Symmetric Data

📝 Abstract

Machine learning (ML) models achieve remarkable performance but remain hard to interpret due to their scale and complexity. This is an important obstacle in trusting such models and potentially deriving novel insights on phenomena ML models operate on. Mechanistic interpretability methods such as sparse autoencoders (SAEs) can identify interpretable components inside ML models, but suffer from unidentifiability: different explanations can fit equally well without necessarily being faithful to the underlying model.

We show that aligning explanations with data symmetries is a promising solution. We extend the Linear Representation Hypothesis, the theory behind SAEs, to account for symmetries. This motivates our design of Equivariant SAEs with symmetry-aware priors that can (1) avoid the pitfalls of existing SAEs on symmetric data and (2) discover features more useful for downstream tasks across different architectures.

⚙️ Installation

  1. Create and activate a virtual environment:
python -m venv equivariant-sae
source equivariant-sae/bin/activate  # On Windows: equivariant-sae\Scripts\activate
  1. Install the required packages:
pip install torch torchvision torchaudio numpy pandas matplotlib opencv-python scikit-learn xgboost scipy tqdm overcomplete galaxy_datasets
  1. Unzip the Shapes probing datasets:
unzip probe_datasets.zip
unzip probe_datasets_224.zip

🔬 Reproducing Results

Toy Model

See 01_toy_model.ipynb for the toy model experiments.


Shapes Dataset

Datasets

The image dataset to train SAEs on is created when the training script is run. The probe datasets for the 180 tasks are placed in the probe_datasets folder and are automatically loaded by the probing script.

Training SAEs

Set --ae-architecture to one of cnn, mlp, or vit.

Set --sae-latent-dim to 32 for narrow baselines and 128 for wide baselines.

Vanilla SAE

python 02_shapes_train.py \
  --ae-architecture mlp \
  --orbit-size 4 \
  --train-setup normal \
  --lambda-sym 0.0 \
  --sae-loss-type NA \
  --sae-activation topk \
  --sae-topk 8 \
  --ae-epochs 100 \
  --ae-augment \
  --sae-hidden-dim 128 \
  --ae-latent-dim 256 \
  --sae-latent-dim 32 \
  --batch-size 128 \
  --learning-rate 0.0005 \
  --epochs 500 \
  --sae-enc-type nonlinear \
  --sae-unit-decoder \
  --pre-encoder-bias \
  --sae-enc-activation relu \
  --seed 42 \
  --n-samples 10000 \
  --save-path ./models/shapes

Equivariant SAE

Set --sae-loss-type to one of latent_inv, canonical, or output_inv.

python 02_shapes_train.py \
  --ae-architecture mlp \
  --orbit-size 4 \
  --train-setup equivariant \
  --lambda-mse 0.0 \
  --sae-loss-type latent_inv \
  --sae-activation topk \
  --sae-topk 4 \
  --ae-epochs 100 \
  --ae-augment \
  --sae-hidden-dim 128 \
  --ae-latent-dim 256 \
  --sae-latent-dim 32 \
  --batch-size 128 \
  --learning-rate 0.0005 \
  --epochs 500 \
  --sae-enc-type nonlinear \
  --sae-unit-decoder \
  --pre-encoder-bias \
  --sae-enc-activation relu \
  --seed 42 \
  --n-samples 10000 \
  --M-epochs 150 \
  --M-train-on-all-k \
  --save-path ./models/shapes

Archetypal SAE

python 02_shapes_train_archetypal.py \
  --ae-architecture mlp \
  --orbit-size 4 \
  --train-setup archetypal \
  --sae-loss-type NA \
  --sae-activation topk \
  --sae-topk 8 \
  --ae-epochs 100 \
  --ae-augment \
  --sae-hidden-dim 128 \
  --ae-latent-dim 256 \
  --sae-latent-dim 32 \
  --n-arch-clusters 2048 \
  --batch-size 128 \
  --learning-rate 0.0005 \
  --epochs 125 \
  --sae-enc-type nonlinear \
  --seed 42 \
  --n-samples 10000 \
  --save-path ./models/shapes

Probes

python 04_shapes_probes.py \
  --save-path models/shapes \
  --probe-trunc-lengths 8 \
  --ae-architecture cnn \
  --n-samples 10000 \
  --probe-sae-topks 4 8

Results are saved in probe_results/results_<n>.csv.

Analysis

See analyze_results.ipynb for probing performance visualizations.


GalaxyMNIST

Training SAEs

Base model activations are already saved in the data folder.

Vanilla SAE

python 03_galaxy_train.py \
  --dataset galaxy-mnist \
  --orbit-size 4 \
  --train-setup normal \
  --sae-expansion 2 \
  --sae-loss-type mse \
  --sae-topk 16 \
  --ae-architecture zoobot \
  --sae-hidden-dim 128 \
  --hook 1 \
  --input-dim 160 \
  --batch-size 64 \
  --learning-rate 0.0005 \
  --epochs 400 \
  --sae-unit-decoder \
  --pre-encoder-bias \
  --sae-enc-activation relu \
  --seed 42 \
  --save-path ./models/galaxy

Equivariant SAE

python 03_galaxy_train.py \
  --dataset galaxy-mnist \
  --orbit-size 4 \
  --train-setup equivariant \
  --sae-expansion 2 \
  --sae-loss-type latent_inv \
  --sae-topk 8 \
  --ae-architecture zoobot \
  --sae-hidden-dim 128 \
  --hook 1 \
  --input-dim 160 \
  --batch-size 64 \
  --learning-rate 0.0005 \
  --epochs 400 \
  --sae-unit-decoder \
  --pre-encoder-bias \
  --sae-enc-activation relu \
  --seed 42 \
  --M-epochs 150 \
  --M-learning-rate 0.001 \
  --save-path ./models/galaxy

Archetypal SAE

python 03_galaxy_train_archetypal.py \
  --dataset galaxy-mnist \
  --orbit-size 4 \
  --train-setup archetypal \
  --sae-topk 16 \
  --ae-architecture zoobot \
  --sae-hidden-dim 128 \
  --sae-loss-type mse \
  --hook 1 \
  --input-dim 160 \
  --sae-expansion 2 \
  --n-arch-clusters 1024 \
  --batch-size 64 \
  --learning-rate 0.0005 \
  --epochs 100 \
  --sae-enc-type nonlinear \
  --seed 42 \
  --save-path ./models/galaxy

Probes

Run the following command to train XGBoost and logistic probes. Then the results can be loaded and analyzed in analyze_results.ipynb.

python 05_galaxy_probes.py --probe-trunc-length 8 16

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors