File size: 7,572 Bytes
3d558ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# concept_steerer.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional, Tuple
import numpy as np
class ConceptSteerer:
def __init__(
self,
model_name: str = "unsloth/Llama-3.2-1B-Instruct",
device: str = "auto"
):
"""
A robust class for performing activation steering on LLMs.
Args:
model_name: The Hugging Face model name.
device: The device to load the model on ("auto", "cuda", "cpu").
"""
print(f"Loading model {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch.float16 if device != "cpu" else torch.float32,
attn_implementation="sdpa", # Use optimized attention
trust_remote_code=False
)
self.model.eval()
self.num_layers = len(self.model.model.layers)
self.concepts = {} # name -> steering vector
def _format_prompt_for_model(self, prompt: str) -> str:
"""Format the prompt according to the model's chat template if available."""
if hasattr(self.tokenizer, 'apply_chat_template'):
messages = [{"role": "user", "content": prompt}]
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
def _get_mean_activation(self, prompts: List[str], layer: int, token_pos: int = -1) -> torch.Tensor:
"""Get the mean activation for a set of prompts at a specific layer and token position."""
acts = []
for prompt in prompts:
formatted_prompt = self._format_prompt_for_model(prompt)
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Resolve token index
seq_len = inputs.input_ids.shape[1]
if token_pos >= 0:
idx = min(token_pos, seq_len - 1)
else:
idx = seq_len + token_pos
act = outputs.hidden_states[layer][0, idx, :].float().cpu()
acts.append(act)
return torch.stack(acts).mean(dim=0)
def register_concept(
self,
name: str,
positive_prompts: List[str],
negative_prompts: List[str],
layer: int = -1, # Default to last layer
token_pos: int = -1
):
"""Create and register a steering vector from contrastive examples."""
if layer < 0:
layer = self.num_layers + layer
pos_acts = self._get_mean_activation(positive_prompts, layer, token_pos)
neg_acts = self._get_mean_activation(negative_prompts, layer, token_pos)
steering_vec = (pos_acts - neg_acts)
# Normalize to unit vector for consistent scaling
self.concepts[name] = steering_vec / steering_vec.norm()
def steer_by_relation(
self,
name: str,
A: str, B: str, C: str, D: str,
layer: int = -1,
token_pos: int = -1,
num_examples: int = 5
):
"""
Create a composite concept using the relation (A is to B) as (C is to D).
Generates examples on-the-fly using the model itself.
"""
if layer < 0:
layer = self.num_layers + layer
def generate_examples(seed_prompt: str, num: int) -> List[str]:
examples = []
for _ in range(num):
inputs = self.tokenizer(
self._format_prompt_for_model(seed_prompt),
return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=self.tokenizer.pad_token_id
)
full_text = self.tokenizer.decode(out[0], skip_special_tokens=True)
# Extract just the generated part
generated = full_text[len(seed_prompt):].strip()
examples.append(generated)
return examples
# Generate examples for each concept
pos_examples = generate_examples(f"{A} is to {B} as {C} is to", num_examples)
neg_examples = generate_examples(f"{A} is to {B} as {D} is to", num_examples)
# Create the composite vector: (A-B) + (C-D)
AB_vec = self._get_mean_activation([A], layer, -1) - self._get_mean_activation([B], layer, -1)
CD_vec = self._get_mean_activation([C], layer, -1) - self._get_mean_activation([D], layer, -1)
composite_vec = AB_vec + CD_vec
self.concepts[name] = composite_vec / composite_vec.norm()
def generate(
self,
prompt: str,
steering_config: Optional[Dict[str, float]] = None,
layer: int = -1,
token_pos: int = -1,
max_new_tokens: int = 100,
**gen_kwargs
) -> str:
"""Generate text with optional activation steering."""
if layer < 0:
layer = self.num_layers + layer
if steering_config is None:
steering_config = {}
inputs = self.tokenizer(
self._format_prompt_for_model(prompt),
return_tensors="pt"
).to(self.model.device)
# Resolve token index for the hook
seq_len = inputs.input_ids.shape[1]
if token_pos >= 0:
hook_token_idx = min(token_pos, seq_len - 1)
else:
hook_token_idx = seq_len + token_pos
def hook_fn(module, input, output):
total_steer = torch.zeros_like(output[0][0, hook_token_idx, :])
for concept_name, strength in steering_config.items():
if concept_name in self.concepts:
vec = self.concepts[concept_name].to(output[0].device, dtype=output[0].dtype)
total_steer += vec * strength
output[0][0, hook_token_idx, :] += total_steer
return output
handle = self.model.model.layers[layer].register_forward_hook(hook_fn)
try:
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9,
**gen_kwargs
)
result = self.tokenizer.decode(out[0], skip_special_tokens=True)
# Remove the prompt from the result if it's a chat model
if result.startswith(prompt):
result = result[len(prompt):].strip()
return result
finally:
handle.remove()
def get_concept_names(self) -> List[str]:
return list(self.concepts.keys())
|