| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import Phi3Config, Phi3ForCausalLM |
| | from typing import Optional, Dict |
| |
|
| | |
| | class VectorMemoryHead(nn.Module): |
| | def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None): |
| | super().__init__() |
| | self.hidden_dim = hidden_dim |
| | self.num_memory_slots = num_memory_slots |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True, |
| | device=device, dtype=dtype |
| | ) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) |
| | self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype)) |
| | self.memory_attention = nn.MultiheadAttention( |
| | embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, |
| | device=device, dtype=dtype |
| | ) |
| | self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) |
| | self.decoder_attention = nn.MultiheadAttention( |
| | embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, |
| | device=device, dtype=dtype |
| | ) |
| | self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype) |
| | self.decoder_ffn = nn.Sequential( |
| | nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), |
| | nn.ReLU(), |
| | nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype) |
| | ) |
| |
|
| | def forward(self, memory_input_sequence: torch.Tensor): |
| | batch_size = memory_input_sequence.shape[0] |
| | encoded_vectors = self.encoder(memory_input_sequence) |
| | queries = self.memory_queries.expand(batch_size, -1, -1) |
| | compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors) |
| | compressed_memory = self.memory_layernorm(compressed_memory + queries) |
| | reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=compressed_memory, value=compressed_memory) |
| | reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors) |
| | reconstructed_vectors = self.decoder_ffn(reconstructed_vectors) |
| | return compressed_memory, reconstructed_vectors |
| |
|
| | |
| | class GCVectorMemoryLayer(nn.Module): |
| | def __init__(self, original_layer: nn.Linear, global_input_dim: int, |
| | memory_dim: int, num_memory_slots: int, memory_num_heads: int, |
| | global_state_storage: Dict): |
| | super().__init__() |
| | self.input_dim = original_layer.in_features |
| | self.output_dim = original_layer.out_features |
| | self.memory_dim = memory_dim |
| | self.global_state_storage = global_state_storage |
| | self.linear = original_layer |
| |
|
| | device, dtype = self.linear.weight.device, self.linear.weight.dtype |
| |
|
| | |
| | self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype) |
| | self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype) |
| | self.memory_head = VectorMemoryHead( |
| | hidden_dim=memory_dim, num_memory_slots=num_memory_slots, |
| | num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device, dtype=dtype |
| | ) |
| | self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype) |
| |
|
| | |
| | |
| | |
| | self.num_correction_passes: int = 1 |
| |
|
| | self.last_corrected_activation: Optional[torch.Tensor] = None |
| | self.last_additive_correction: Optional[torch.Tensor] = None |
| | self.last_memory_input: Optional[torch.Tensor] = None |
| | self.last_reconstructed_from_memory: Optional[torch.Tensor] = None |
| |
|
| | def forward(self, x: torch.Tensor): |
| | base_output = self.linear(x) |
| |
|
| | |
| | if 'embeds' not in self.global_state_storage or self.num_correction_passes < 1: |
| | return base_output |
| |
|
| | global_embeds = self.global_state_storage['embeds'] |
| | if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :] |
| | B, S, _ = x.shape |
| |
|
| | with torch.enable_grad(): |
| | |
| | proj_local = self.local_state_proj(x) |
| | proj_global = self.global_state_proj(global_embeds) |
| |
|
| | memory_input = torch.stack([proj_global, proj_local], dim=2) |
| | memory_input_flat = memory_input.view(B * S, 2, self.memory_dim) |
| | compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat) |
| | aggregated_thought_flat = compressed_mem_flat.mean(dim=1) |
| | aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim) |
| | raw_correction = self.correction_head(aggregated_thought) |
| | gate, value = torch.chunk(raw_correction, 2, dim=-1) |
| |
|
| | |
| | corrected_activation = base_output |
| | for _ in range(self.num_correction_passes): |
| | corrected_activation = corrected_activation * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype) |
| |
|
| | |
| | |
| | if self.training: |
| | self.last_corrected_activation = corrected_activation |
| | self.last_additive_correction = value |
| | self.last_memory_input = memory_input_flat |
| | self.last_reconstructed_from_memory = recon_flat |
| |
|
| | return corrected_activation |
| |
|
| | |
| | class Phi3WithVectorMemoryForCausalLM(Phi3ForCausalLM): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.global_state_storage = {} |
| | self.target_layer_path = "model.layers.15.mlp.gate_up_proj" |
| |
|
| | self.model.embed_tokens.register_forward_hook( |
| | lambda module, input, output: self.global_state_storage.update({'embeds': output.detach()}) |
| | ) |
| |
|
| | try: |
| | original_layer = self.get_submodule(self.target_layer_path) |
| | custom_layer = GCVectorMemoryLayer( |
| | original_layer=original_layer, global_input_dim=config.hidden_size, |
| | memory_dim=64, num_memory_slots=8, memory_num_heads=4, |
| | global_state_storage=self.global_state_storage |
| | ) |
| | parent_path = ".".join(self.target_layer_path.split('.')[:-1]) |
| | child_name = self.target_layer_path.split('.')[-1] |
| | setattr(self.get_submodule(parent_path), child_name, custom_layer) |
| | print(f"Successfully replaced '{self.target_layer_path}' with GCVectorMemoryLayer.") |
| | except AttributeError: |
| | print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.") |
| | |
| |
|