File size: 7,572 Bytes
3d558ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# concept_steerer.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional, Tuple
import numpy as np

class ConceptSteerer:
    def __init__(
        self,
        model_name: str = "unsloth/Llama-3.2-1B-Instruct",
        device: str = "auto"
    ):
        """
        A robust class for performing activation steering on LLMs.
        
        Args:
            model_name: The Hugging Face model name.
            device: The device to load the model on ("auto", "cuda", "cpu").
        """
        print(f"Loading model {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map=device,
            torch_dtype=torch.float16 if device != "cpu" else torch.float32,
            attn_implementation="sdpa", # Use optimized attention
            trust_remote_code=False
        )
        self.model.eval()
        self.num_layers = len(self.model.model.layers)
        self.concepts = {}  # name -> steering vector

    def _format_prompt_for_model(self, prompt: str) -> str:
        """Format the prompt according to the model's chat template if available."""
        if hasattr(self.tokenizer, 'apply_chat_template'):
            messages = [{"role": "user", "content": prompt}]
            return self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        return prompt

    def _get_mean_activation(self, prompts: List[str], layer: int, token_pos: int = -1) -> torch.Tensor:
        """Get the mean activation for a set of prompts at a specific layer and token position."""
        acts = []
        for prompt in prompts:
            formatted_prompt = self._format_prompt_for_model(prompt)
            inputs = self.tokenizer(
                formatted_prompt, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=512
            ).to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
            
            # Resolve token index
            seq_len = inputs.input_ids.shape[1]
            if token_pos >= 0:
                idx = min(token_pos, seq_len - 1)
            else:
                idx = seq_len + token_pos
            
            act = outputs.hidden_states[layer][0, idx, :].float().cpu()
            acts.append(act)
            
        return torch.stack(acts).mean(dim=0)

    def register_concept(
        self,
        name: str,
        positive_prompts: List[str],
        negative_prompts: List[str],
        layer: int = -1, # Default to last layer
        token_pos: int = -1
    ):
        """Create and register a steering vector from contrastive examples."""
        if layer < 0:
            layer = self.num_layers + layer
        
        pos_acts = self._get_mean_activation(positive_prompts, layer, token_pos)
        neg_acts = self._get_mean_activation(negative_prompts, layer, token_pos)
        steering_vec = (pos_acts - neg_acts)
        # Normalize to unit vector for consistent scaling
        self.concepts[name] = steering_vec / steering_vec.norm()

    def steer_by_relation(
        self,
        name: str,
        A: str, B: str, C: str, D: str,
        layer: int = -1,
        token_pos: int = -1,
        num_examples: int = 5
    ):
        """
        Create a composite concept using the relation (A is to B) as (C is to D).
        Generates examples on-the-fly using the model itself.
        """
        if layer < 0:
            layer = self.num_layers + layer

        def generate_examples(seed_prompt: str, num: int) -> List[str]:
            examples = []
            for _ in range(num):
                inputs = self.tokenizer(
                    self._format_prompt_for_model(seed_prompt),
                    return_tensors="pt"
                ).to(self.model.device)
                with torch.no_grad():
                    out = self.model.generate(
                        **inputs,
                        max_new_tokens=20,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.9,
                        pad_token_id=self.tokenizer.pad_token_id
                    )
                full_text = self.tokenizer.decode(out[0], skip_special_tokens=True)
                # Extract just the generated part
                generated = full_text[len(seed_prompt):].strip()
                examples.append(generated)
            return examples

        # Generate examples for each concept
        pos_examples = generate_examples(f"{A} is to {B} as {C} is to", num_examples)
        neg_examples = generate_examples(f"{A} is to {B} as {D} is to", num_examples)

        # Create the composite vector: (A-B) + (C-D)
        AB_vec = self._get_mean_activation([A], layer, -1) - self._get_mean_activation([B], layer, -1)
        CD_vec = self._get_mean_activation([C], layer, -1) - self._get_mean_activation([D], layer, -1)
        composite_vec = AB_vec + CD_vec
        self.concepts[name] = composite_vec / composite_vec.norm()

    def generate(
        self,
        prompt: str,
        steering_config: Optional[Dict[str, float]] = None,
        layer: int = -1,
        token_pos: int = -1,
        max_new_tokens: int = 100,
        **gen_kwargs
    ) -> str:
        """Generate text with optional activation steering."""
        if layer < 0:
            layer = self.num_layers + layer

        if steering_config is None:
            steering_config = {}

        inputs = self.tokenizer(
            self._format_prompt_for_model(prompt),
            return_tensors="pt"
        ).to(self.model.device)
        
        # Resolve token index for the hook
        seq_len = inputs.input_ids.shape[1]
        if token_pos >= 0:
            hook_token_idx = min(token_pos, seq_len - 1)
        else:
            hook_token_idx = seq_len + token_pos

        def hook_fn(module, input, output):
            total_steer = torch.zeros_like(output[0][0, hook_token_idx, :])
            for concept_name, strength in steering_config.items():
                if concept_name in self.concepts:
                    vec = self.concepts[concept_name].to(output[0].device, dtype=output[0].dtype)
                    total_steer += vec * strength
            output[0][0, hook_token_idx, :] += total_steer
            return output

        handle = self.model.model.layers[layer].register_forward_hook(hook_fn)
        try:
            with torch.no_grad():
                out = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    pad_token_id=self.tokenizer.pad_token_id,
                    do_sample=True,
                    temperature=0.6,
                    top_p=0.9,
                    **gen_kwargs
                )
            result = self.tokenizer.decode(out[0], skip_special_tokens=True)
            # Remove the prompt from the result if it's a chat model
            if result.startswith(prompt):
                result = result[len(prompt):].strip()
            return result
        finally:
            handle.remove()

    def get_concept_names(self) -> List[str]:
        return list(self.concepts.keys())