|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import numpy as np |
|
|
|
|
|
class ConceptSteerer: |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "unsloth/Llama-3.2-1B-Instruct", |
|
|
device: str = "auto" |
|
|
): |
|
|
""" |
|
|
A robust class for performing activation steering on LLMs. |
|
|
|
|
|
Args: |
|
|
model_name: The Hugging Face model name. |
|
|
device: The device to load the model on ("auto", "cuda", "cpu"). |
|
|
""" |
|
|
print(f"Loading model {model_name}...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
device_map=device, |
|
|
torch_dtype=torch.float16 if device != "cpu" else torch.float32, |
|
|
attn_implementation="sdpa", |
|
|
trust_remote_code=False |
|
|
) |
|
|
self.model.eval() |
|
|
self.num_layers = len(self.model.model.layers) |
|
|
self.concepts = {} |
|
|
|
|
|
def _format_prompt_for_model(self, prompt: str) -> str: |
|
|
"""Format the prompt according to the model's chat template if available.""" |
|
|
if hasattr(self.tokenizer, 'apply_chat_template'): |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
return self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
return prompt |
|
|
|
|
|
def _get_mean_activation(self, prompts: List[str], layer: int, token_pos: int = -1) -> torch.Tensor: |
|
|
"""Get the mean activation for a set of prompts at a specific layer and token position.""" |
|
|
acts = [] |
|
|
for prompt in prompts: |
|
|
formatted_prompt = self._format_prompt_for_model(prompt) |
|
|
inputs = self.tokenizer( |
|
|
formatted_prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs, output_hidden_states=True) |
|
|
|
|
|
|
|
|
seq_len = inputs.input_ids.shape[1] |
|
|
if token_pos >= 0: |
|
|
idx = min(token_pos, seq_len - 1) |
|
|
else: |
|
|
idx = seq_len + token_pos |
|
|
|
|
|
act = outputs.hidden_states[layer][0, idx, :].float().cpu() |
|
|
acts.append(act) |
|
|
|
|
|
return torch.stack(acts).mean(dim=0) |
|
|
|
|
|
def register_concept( |
|
|
self, |
|
|
name: str, |
|
|
positive_prompts: List[str], |
|
|
negative_prompts: List[str], |
|
|
layer: int = -1, |
|
|
token_pos: int = -1 |
|
|
): |
|
|
"""Create and register a steering vector from contrastive examples.""" |
|
|
if layer < 0: |
|
|
layer = self.num_layers + layer |
|
|
|
|
|
pos_acts = self._get_mean_activation(positive_prompts, layer, token_pos) |
|
|
neg_acts = self._get_mean_activation(negative_prompts, layer, token_pos) |
|
|
steering_vec = (pos_acts - neg_acts) |
|
|
|
|
|
self.concepts[name] = steering_vec / steering_vec.norm() |
|
|
|
|
|
def steer_by_relation( |
|
|
self, |
|
|
name: str, |
|
|
A: str, B: str, C: str, D: str, |
|
|
layer: int = -1, |
|
|
token_pos: int = -1, |
|
|
num_examples: int = 5 |
|
|
): |
|
|
""" |
|
|
Create a composite concept using the relation (A is to B) as (C is to D). |
|
|
Generates examples on-the-fly using the model itself. |
|
|
""" |
|
|
if layer < 0: |
|
|
layer = self.num_layers + layer |
|
|
|
|
|
def generate_examples(seed_prompt: str, num: int) -> List[str]: |
|
|
examples = [] |
|
|
for _ in range(num): |
|
|
inputs = self.tokenizer( |
|
|
self._format_prompt_for_model(seed_prompt), |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
with torch.no_grad(): |
|
|
out = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=20, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
pad_token_id=self.tokenizer.pad_token_id |
|
|
) |
|
|
full_text = self.tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
generated = full_text[len(seed_prompt):].strip() |
|
|
examples.append(generated) |
|
|
return examples |
|
|
|
|
|
|
|
|
pos_examples = generate_examples(f"{A} is to {B} as {C} is to", num_examples) |
|
|
neg_examples = generate_examples(f"{A} is to {B} as {D} is to", num_examples) |
|
|
|
|
|
|
|
|
AB_vec = self._get_mean_activation([A], layer, -1) - self._get_mean_activation([B], layer, -1) |
|
|
CD_vec = self._get_mean_activation([C], layer, -1) - self._get_mean_activation([D], layer, -1) |
|
|
composite_vec = AB_vec + CD_vec |
|
|
self.concepts[name] = composite_vec / composite_vec.norm() |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
steering_config: Optional[Dict[str, float]] = None, |
|
|
layer: int = -1, |
|
|
token_pos: int = -1, |
|
|
max_new_tokens: int = 100, |
|
|
**gen_kwargs |
|
|
) -> str: |
|
|
"""Generate text with optional activation steering.""" |
|
|
if layer < 0: |
|
|
layer = self.num_layers + layer |
|
|
|
|
|
if steering_config is None: |
|
|
steering_config = {} |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
self._format_prompt_for_model(prompt), |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
seq_len = inputs.input_ids.shape[1] |
|
|
if token_pos >= 0: |
|
|
hook_token_idx = min(token_pos, seq_len - 1) |
|
|
else: |
|
|
hook_token_idx = seq_len + token_pos |
|
|
|
|
|
def hook_fn(module, input, output): |
|
|
total_steer = torch.zeros_like(output[0][0, hook_token_idx, :]) |
|
|
for concept_name, strength in steering_config.items(): |
|
|
if concept_name in self.concepts: |
|
|
vec = self.concepts[concept_name].to(output[0].device, dtype=output[0].dtype) |
|
|
total_steer += vec * strength |
|
|
output[0][0, hook_token_idx, :] += total_steer |
|
|
return output |
|
|
|
|
|
handle = self.model.model.layers[layer].register_forward_hook(hook_fn) |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
out = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
do_sample=True, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
**gen_kwargs |
|
|
) |
|
|
result = self.tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
if result.startswith(prompt): |
|
|
result = result[len(prompt):].strip() |
|
|
return result |
|
|
finally: |
|
|
handle.remove() |
|
|
|
|
|
def get_concept_names(self) -> List[str]: |
|
|
return list(self.concepts.keys()) |
|
|
|