import random
import time
import numpy as np
import sae_lens
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
"cuda")
torch.set_default_device(# While we will use gradients later, we don't need them for most operations and
# will explicitly enable gradients when needed.
False) torch.set_grad_enabled(
Dreaming with sparse autoencoder features
I get a lot of questions about dreaming/feature visualization applied to sparse autoencoders (SAEs) 1. When we were writing Fluent dreaming for language model and the companion post for that paper, we thought that a natural application of feature visualization would be sparse autoencoder features because the features should be fairly monosemantic. But there weren’t any open source SAE features yet and we didn’t want to put the effort in to train our own SAE features.
Since then, several open source SAEs have been released [1], [2]. This post will demonstrate dreaming applied to SAE features. Technically, this is an extremely simple modification because an SAE encoder or decoder feature is just a direction in activation space.
In addition, since writing the dreaming paper, we have worked on state of the art token optimization methods [3]. I use these new token optimization methods here.
The rest of this post is a self-contained notebook for feature visualization of SAE features 2.
Setup
In this first section, we setup our environment and define a few useful functions:
add_fwd_hooks
is a context manager for adding forward hooks to a model so that we can store and later access intermediate activations.load_sae
loads a Gemma Scope SAE for the specified layer.calc_xe
calculates a batched cross-entropy loss for the purpose of fluency scoring.
import contextlib
from typing import Callable, List, Tuple
@contextlib.contextmanager
def add_fwd_hooks(module_hooks: List[Tuple[torch.nn.Module, Callable]]):
"""
Context manager for temporarily adding forward hooks to a model.
Parameters
----------
module_hooks
A list of pairs: (module, fnc) The function will be registered as a
forward hook on the module
"""
try:
= []
handles for mod, hk in module_hooks:
handles.append(mod.register_forward_hook(hk))yield
finally:
for h in handles:
h.remove()
= {}
sae_cache
def load_sae(layer):
= f"layer_{layer}/width_16k/canonical"
sae_id if sae_id not in sae_cache:
= sae_lens.SAE.from_pretrained(
sae, _, _ ="gemma-scope-2b-pt-res-canonical",
release=sae_id,
sae_id="cuda",
device
)= sae
sae_cache[sae_id] else:
= sae_cache[sae_id]
sae return sae
def calc_xe(logits, input_ids):
return (
torch.nn.functional.cross_entropy(-1].reshape(-1, logits.shape[-1]),
logits[:, :1:].reshape(-1),
input_ids[:, ="none",
reduction
)0], -1))
.reshape((logits.shape[=-1)
.mean(dim )
Dreaming
This section defines a dream
function for optimizing token sequences to maximize a provided feature.
Broadly, the algorithm works as follows:
- Mutate the current token sequence into
explore
new sequences. - Evaluate the feature activation and fluency of the new sequences.
- Retain the best sequences.
- Repeat.
The most important parameter is get_feature_and_logits
. This function should accept a batch of sequences and return the target feature as well as the full logits tensor from calling the model. In the next section, we will provide a few examples of how to use this function with SAEs.
Follow along with the code comments for more detail.
def dream(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,# This function accepts a batch of sequences and returns the target feature
# as well as the full logits tensor from calling the model.
get_feature_and_logits: Callable,str = "help! are you a purple bobcat?",
init_prompt: bool = True, # Print status updates.
verbose: int = 500, # Number of iterations
iters: int = 0,
seed: #
# There are two approaches to fluency control:
# 1. `xe_max` sets an absolute limit on the cross-entropy loss.
# 2. `xe_regularization` adds a cross-entropy fluency penalty to the loss
# function.
float = None,
xe_max: float = 0.0,
xe_regularization: int = 4, # Number of prompts to keep in the buffer
buffer_size: #
# Parameters for choosing the type of mutation operation at each iteration.
int = 8, # Minimum allowed number of tokens
min_tokens: int = 16, # Maximum allowed number of tokens
max_tokens: float = 0.25, # prob of an iteration being a gcg swap
p_gcg_swap: float = 0.25, # prob of an iteration being a sampled insert
p_sample_insert: float = 0.25, # prob of an iteration being a sampled swap
p_sample_swap: float = 0.25, # prob of an iteration being a delete
p_delete: #
# Mutation parameters:
# - `sample_k2`: For sampled mutations, we sample without
# replacement`sample_k2` candidate tokens per token position.
# - `gcg_topk`: For GCG, we select swap tokens from the top `gcg_topk`
# tokens according to the loss gradient.
# after sampling, we have a (n_tokens, k2 or gcg_topk) matrix of candidate tokens
# then, for each child candidate, we sample a random entry from this matrix
# and perform the corresponding mutation.
int = 16,
sample_k2: int = 512,
gcg_topk: int = 128, # Number of child candidates per parent
explore:
):# Set random seeds.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# We track the best prompts in a buffer. At each step, we remove the best
# prompt and mutate it to produce new candidates. Then, we merge those
# candidates back into the buffer.
= [init_prompt] * buffer_size
buffer_prompts = torch.tensor([float("inf")] * buffer_size)
buffer_losses
# Track the history of the best prompts in run.
= []
history
for IT in range(iters):
= time.time()
start
# At each step, we choose randomly between the four types of
# operations. See below for precise descriptions of these operations.
= torch.multinomial(
operation_idx 1
torch.tensor([p_gcg_swap, p_sample_insert, p_sample_swap, p_delete]),
).item()= ["gcg_swap", "sample_insert", "sample_swap", "delete"][
operation
operation_idx
]
= tokenizer.encode(
input_ids 0], return_tensors="pt", add_special_tokens=False
buffer_prompts[
)
# If the prompt is too short or too long, we force a different
# operation.
= input_ids.shape[1]
n_tokens if n_tokens < min_tokens:
= "sample_insert"
operation elif n_tokens > max_tokens:
= "delete"
operation
if operation == "gcg_swap":
#########################
# GCG SWAP
#########################
# A GCG swap proceeds as in Zou et al 2023:
# - we backpropagate the loss to get the gradient of the loss with
# respect to each token in each token position
# - we select the top K tokens in each position according to the loss gradient
# - we sample uniformly at random between the token positions
# - we sample uniformly among the top K tokens in each position
with torch.enable_grad():
= model.model.embed_tokens
embed = torch.nn.functional.one_hot(
one_hot =embed.num_embeddings
input_ids.clone(), num_classes
).to(embed.weight.dtype)= True
one_hot.requires_grad = torch.matmul(one_hot, embed.weight)
embeds = get_feature_and_logits(inputs_embeds=embeds)
feature, logits -feature).backward()
(
= (-one_hot.grad).topk(k=gcg_topk, dim=-1)
topk_grad = torch.randint(0, input_ids.shape[1], (explore,))
token_pos = torch.randint(0, gcg_topk, (explore,))
topk_idx = input_ids[0, None, :].repeat(explore, 1)
candidate_ids = topk_grad.indices[
candidate_ids[torch.arange(explore), token_pos] 0, token_pos, topk_idx
]elif operation == "sample_insert":
##################
# SAMPLE INSERT
##################
# A sample insert proceeds similarly to the mutation operation from
# the BEAST paper (Sadasivan et al 2024) but incorporates features from GCG:
# - we produce the next token probability distribution for each token position.
# - we sample K tokens without replacement from the probability
# distribution for each token position.
# - we select uniformly at random from the token positions.
# - then we sample uniformly at random from those K tokens.
= model(input_ids=input_ids).logits
logits = torch.softmax(logits[0], dim=-1)
probs = torch.empty((explore, n_tokens + 1), dtype=torch.long)
candidate_ids = torch.randint(1, n_tokens + 1, (explore,))
insert_position = probs[insert_position - 1]
insert_probs = torch.multinomial(insert_probs, num_samples=sample_k2)
sampled_ids = torch.randint(0, sample_k2, (explore,))
sample_idx = sampled_ids[torch.arange(explore), sample_idx]
insert_ids for j in range(explore):
= input_ids[
candidate_ids[j, : insert_position[j]] 0, : insert_position[j]
]= insert_ids[j]
candidate_ids[j, insert_position[j]] + 1 :] = input_ids[
candidate_ids[j, insert_position[j] 0, insert_position[j] :
]elif operation == "sample_swap":
#################
# SAMPLE SWAP
#################
# A sample swap proceeds similarly to the sample insert operation
# except that we swap a token instead of inserting a new token.
= model(input_ids).logits
logits = torch.softmax(logits[0], dim=-1)
probs = input_ids[0, None, :].repeat(explore, 1)
candidate_ids = torch.randint(1, n_tokens, (explore,))
swap_position = probs[swap_position]
swap_probs = torch.multinomial(swap_probs, num_samples=sample_k2)
sampled_ids = torch.randint(0, sample_k2, (explore,))
sample_idx = sampled_ids[torch.arange(explore), sample_idx]
swap_ids = swap_ids
candidate_ids[torch.arange(explore), swap_position] elif operation == "delete":
#################
# DELETE
#################
# A delete operation removes a token from the prompt.
# The set of candidates is the set of all possible deletions.
if explore > n_tokens:
= n_tokens
n_candidates = torch.arange(n_tokens)
delete_indices else:
= explore
n_candidates = torch.randperm(n_tokens)[:n_candidates]
delete_indices
= torch.empty((n_candidates, n_tokens - 1), dtype=torch.long)
candidate_ids for i in range(n_candidates):
= input_ids[
candidate_ids[i, : delete_indices[i]] 0, : delete_indices[i]
]= input_ids[
candidate_ids[i, delete_indices[i] :] 0, delete_indices[i] + 1 :
]
# To avoid issues with special tokens, we decode and re-encode.
= tokenizer.batch_decode(candidate_ids, skip_special_tokens=True)
candidates = tokenizer(candidates, padding=True, return_tensors="pt")
candidates_tokenized
################
# Calculate loss
################
= get_feature_and_logits(**candidates_tokenized)
feature, logits = calc_xe(logits, candidates_tokenized["input_ids"])
xe # We maximize activation, so negate activation.
= -feature + xe_regularization * xe
candidate_losses if xe_max is not None:
= torch.where(xe > xe_max, float("inf"), candidate_losses)
candidate_losses = torch.argmin(candidate_losses)
best_idx
#########################
# Update the top-N buffer
#########################
# Update the top-N buffer. We remove the first element because we used it
# up this iteration. Then, we select the top-N best sequences between the
# existing buffer entries and the new candidates.
= buffer_prompts[1:] + candidates
combined_prompts = torch.cat([buffer_losses[1:], candidate_losses])
combined_losses = torch.argsort(combined_losses)[:buffer_size]
keep_idxs = [combined_prompts[i] for i in keep_idxs]
buffer_prompts = combined_losses[keep_idxs]
buffer_losses
# Report on the step and record the history.
= time.time() - start
runtime if verbose:
print("\n\n")
print(f"Iteration {IT} | Operation: {operation}")
print(
f"Loss={candidate_losses[best_idx].item():.2f}"
f" | Activation={feature[best_idx].item():.2f}"
f" | XE={xe[best_idx].item():.2f}"
)print(f"Runtime={runtime:.2f}s")
print(f"Best candidate: {buffer_prompts[0]}")
history.append(dict(
=operation,
operation=runtime,
runtime=feature[best_idx].item(),
activation=xe[best_idx].item(),
xe=buffer_prompts[0],
seq
)
)return history
Dreaming encoder directions
In this section, we will run feature visualization on a particular SAE encoder direction.
First, we’ll load up Gemma 2 2B.
= AutoTokenizer.from_pretrained(
tokenizer "google/gemma-2-2b", clean_up_tokenization_spaces=False
)= AutoModelForCausalLM.from_pretrained(
gemma2 "google/gemma-2-2b",
=torch.bfloat16,
torch_dtype=True,
low_cpu_mem_usage="cuda",
device_map="flash_attention_2",
attn_implementation
)= gemma2.requires_grad_(False) gemma2
Then we’ll define a function for getting the activation of the particular SAE latent. Importantly, we’ll grab the latent before it is passed through the ReLU activation function or any other thresholding operation. That will make optimization easier. The generated f
function will be passed as the get_feature_and_logits
argument to the dream
function.
def gemma_sae_encoder(gemma2, layer, feature_idx):
def f(input_ids=None, inputs_embeds=None, attention_mask=None):
= {}
out
def get_res(module, input, output):
"res"] = input[0]
out[
= [
model_hooks
(gemma2.model.layers[layer], get_res),
]with add_fwd_hooks(model_hooks):
if inputs_embeds is not None:
= gemma2(
logits =inputs_embeds, attention_mask=attention_mask
inputs_embeds
).logitselse:
= gemma2(
logits =input_ids, attention_mask=attention_mask
input_ids
).logits
def get_sae_pre(module, input, output):
"sae_pre"] = input[0]
out[
= load_sae(layer)
sae = [(sae.hook_sae_acts_pre, get_sae_pre)]
sae_hooks with add_fwd_hooks(sae_hooks):
"res"])
sae.encode(out[
return out["sae_pre"][:, -1, feature_idx], logits
return f
Let’s look at layer 12, feature 0. Looking at neuronpedia, this feature seems to respond strongly to variants of the word “label”. We run a bunch of examples through the SAE and see how they score:
= 12
layer = 0
feature_idx for text in [
"label",
"labl",
"LBL",
"lbl",
"LABEL",
"labal",
"L",
"B",
"bel",
"la",
"abel",
]:= tokenizer([text], padding=True, return_tensors="pt")
inputs = gemma_sae_encoder(gemma2, layer, feature_idx)(
feature, logits =inputs["input_ids"]
input_ids
)print(f"Text = {text!r:8} | Activation = {feature.item():.2f}")
Text = 'label' | Activation = 55.03
Text = 'labl' | Activation = 16.46
Text = 'LBL' | Activation = 15.48
Text = 'lbl' | Activation = 26.45
Text = 'LABEL' | Activation = 49.56
Text = 'labal' | Activation = 14.07
Text = 'L' | Activation = -3.95
Text = 'B' | Activation = -6.75
Text = 'bel' | Activation = 3.60
Text = 'la' | Activation = -0.55
Text = 'abel' | Activation = 15.67
And now let’s run our feature visualization on this encoder direction. After about 20 iterations, the optimization finds the “labl” phrase and then after about 100 iterations, it finds “label”.
The output of the dream
function itself is hidden below. Download the original notebook to see it.
= dream(
encoder_history
gemma2,
tokenizer,=gemma_sae_encoder(gemma2, layer, feature_idx),
get_feature_and_logits=150,
iters="help",
init_prompt=10.0,
xe_max )
%config InlineBackend.figure_format = 'retina'
=(9, 4))
plt.figure(figsize= [h["activation"] for h in encoder_history]
activations
plt.plot(activations)
plt.annotate(repr(encoder_history[20]["seq"]),
=(20, activations[20]),
xy=(40, activations[20] - 30),
xytext=dict(facecolor="black", shrink=0.05),
arrowprops="center",
horizontalalignment="bottom",
verticalalignment
)
plt.annotate(repr(encoder_history[110]["seq"]),
=(110, activations[110]),
xy=(110, activations[110] - 40),
xytext=dict(facecolor="black", shrink=0.05),
arrowprops="center",
horizontalalignment="bottom",
verticalalignment
)"Iteration")
plt.xlabel("Activation")
plt.ylabel("Activation History with Annotations")
plt.title( plt.show()
Dreaming decoder directions
Out of curiosity, do we get similar results if we apply feature visualization to the corresponding decoder direction?
def gemma_sae_decoder(gemma2, layer, feature_idx):
def f(input_ids=None, inputs_embeds=None, attention_mask=None):
= {}
out
def get_res(module, input, output):
"res"] = input[0]
out[
= [
model_hooks
(gemma2.model.layers[layer], get_res),
]with add_fwd_hooks(model_hooks):
if inputs_embeds is not None:
= gemma2(
logits =inputs_embeds, attention_mask=attention_mask
inputs_embeds
).logitselse:
= gemma2(
logits =input_ids, attention_mask=attention_mask
input_ids
).logits
def get_sae_pre(module, input, output):
"sae_pre"] = input[0]
out[
= load_sae(layer)
sae = out["res"][:, -1].to(sae.dtype) @ sae.W_dec[feature_idx]
activation
return activation, logits
return f
Based on the correlation of 0.99, the answer appears to be an emphatic yes for this particular feature. However, it might be different for other features.
= 12
layer = 0
feature_idx = [
texts "label",
"labl",
"LBL",
"lbl",
"LABEL",
"labal",
"L",
"B",
"bel",
"la",
"abel",
]
= []
encoder_activations = []
decoder_activations
for text in texts:
= tokenizer([text], padding=True, return_tensors="pt")
inputs = inputs["input_ids"]
input_ids = gemma_sae_encoder(gemma2, layer, feature_idx)(
encoder_activation, _ =input_ids
input_ids
)= gemma_sae_decoder(gemma2, layer, feature_idx)(
decoder_activation, _ =input_ids
input_ids
)
encoder_activations.append(encoder_activation.item())
decoder_activations.append(decoder_activation.item())
=(10, 6))
plt.figure(figsize
plt.scatter(encoder_activations, decoder_activations)"Encoder Activations")
plt.xlabel("Decoder Activations")
plt.ylabel("Correlation between Encoder and Decoder Activations")
plt.title(
# Add text labels for each point
for i, text in enumerate(texts):
plt.annotate(
text,
(encoder_activations[i], decoder_activations[i]),=(5, 5),
xytext="offset points",
textcoords
)
# Calculate and display correlation coefficient
= np.corrcoef(encoder_activations, decoder_activations)[0, 1]
correlation
plt.text(0.05,
0.95,
f"Correlation: {correlation:.2f}",
=plt.gca().transAxes,
transform="top",
verticalalignment
)
plt.tight_layout() plt.show()
Feature visualization on the decoder direction also acquires the “label” phrase:
= dream(
decoder_history
gemma2,
tokenizer,=gemma_sae_decoder(gemma2, layer, feature_idx),
get_feature_and_logits=150,
iters="help",
init_prompt=10.0,
xe_max )
References
Footnotes
As a specific motivation for this post, Joseph Bloom recently got in touch to ask about applying feature visualization techniques to SAEs. I wrote this with him in mind.↩︎
As a disclaimer, the code below has not been heavily used and was written just for this post. I expect inefficiences and possibly some bugs.↩︎
Citation
@online{thompson2024,
author = {Thompson, T. Ben},
title = {Dreaming with Sparse Autoencoder Features},
date = {2024-10-14},
url = {https://confirmlabs.org/posts/sae_dream.html},
langid = {en}
}