irodkin's picture
Upload ARMT+Llama-3.2-1B demo (single-file)
00399d1 verified
# === Inlined ARMT for HF Hub (single-file) ===
# ---- act_utils.py ----
from torch import nn
import torch
import numpy as np
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
def gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""
Generates a [1, length, channels] timing signal consisting of sinusoids
Adapted from:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
"""
position = np.arange(length)
num_timescales = channels // 2
log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1))
inv_timescales = min_timescale * np.exp(np.arange(num_timescales).astype(float) * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
'constant', constant_values=[0.0, 0.0])
signal = signal.reshape([1, length, channels])
return torch.from_numpy(signal).type(torch.FloatTensor)
class ACT_basic(nn.Module):
def __init__(self,hidden_size):
super(ACT_basic, self).__init__()
self.sigma = nn.Sigmoid()
self.p = nn.Linear(hidden_size,1)
self.p.bias.data.fill_(1)
self.threshold = 1 - 0.1
self.eps = 0.1
def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs):
# init_hdd
## [B, S]
noisy_halting = False
if 'noisy_halting' in kwargs:
noisy_halting = kwargs['noisy_halting']
kwargs.pop('noisy_halting')
halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S, HDD]
previous_state = torch.zeros_like(inputs).cuda()
step = 0
# for l in range(self.num_layers):
rest = None
while( ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any()):
# Add timing signal
# state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data)
# state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
p = self.sigma(self.p(state)).squeeze(-1)
if noisy_halting and self.training:
p = p + torch.randn_like(p) * self.eps
# Mask for inputs which have not halted yet
still_running = (halting_probability < 1.0).float()
# Mask of inputs which halted at this step
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running
# Mask of inputs which haven't halted, and didn't halt this step
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running
# Add the halting probability for this step to the halting
# probabilities for those input which haven't halted yet
halting_probability = halting_probability + p * still_running
# Compute remainders for the inputs which halted at this step
remainders = remainders + new_halted * (1 - halting_probability)
# Add the remainders to those inputs which halted at this step
halting_probability = halting_probability + new_halted * remainders
# Increment n_updates for all inputs which are still running
n_updates = n_updates + still_running + new_halted
# Compute the weight to be applied to the new state and output
# 0 when the input has already halted
# p when the input hasn't halted yet
# the remainders when it halted this step
update_weights = p * still_running + new_halted * remainders
if(encoder_output):
state, _ = fn((state,encoder_output))
else:
# apply transformation on the state
state = fn(state, *args, **kwargs)
if isinstance(state, tuple):
rest = state[1:]
state = state[0]
# update running part in the weighted state and keep the rest
previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1))))
## previous_state is actually the new_state at end of hte loop
## to save a line I assigned to previous_state so in the next
## iteration is correct. Notice that indeed we return previous_state
step+=1
if rest is None:
return previous_state, (remainders,n_updates)
else:
return (previous_state, *rest), (remainders, n_updates)
class ACT_constant_depth():
def __init__(self):
super(ACT_constant_depth, self).__init__()
def __call__(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs):
# init_hdd
## [B, S]
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S, HDD]
previous_state = torch.zeros_like(inputs).cuda()
step = 0
# for l in range(self.num_layers):
rest = None
while(step < max_hop):
print('constsant depth TRUE')
# Add timing signal
# state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data)
# state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
if(encoder_output):
state, _ = fn((state,encoder_output))
else:
# apply transformation on the state
state = fn(state, *args, **kwargs)
if isinstance(state, tuple):
rest = state[1:]
state = state[0]
# update running part in the weighted state and keep the rest
# print(state.shape, previous_state.shape, update_weights.shape)
# print(state.dtype, previous_state.dtype, update_weights.dtype)
previous_state = state
## previous_state is actually the new_state at end of hte loop
## to save a line I assigned to previous_state so in the next
## iteration is correct. Notice that indeed we return previous_state
step+=1
if rest is None:
return previous_state, (remainders,n_updates)
else:
return (previous_state, *rest), (remainders, n_updates)
class ACTForWholeARMT(nn.Module):
def __init__(self,hidden_size):
super(ACTForWholeARMT, self).__init__()
self.sigma = nn.Sigmoid()
self.p = nn.Linear(hidden_size,1)
self.p.bias.data.fill_(1)
self.threshold = 1 - 0.1
def forward(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs):
# init_hdd
## [B, S]
halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S, HDD]
previous_state = torch.zeros_like(inputs).cuda()
step = 0
# for l in range(self.num_layers):
rest = None
while( ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any()):
# Add timing signal
# state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data)
# state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
p = self.sigma(self.p(state)).squeeze(-1)
# Mask for inputs which have not halted yet
still_running = (halting_probability < 1.0).float()
# Mask of inputs which halted at this step
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running
# Mask of inputs which haven't halted, and didn't halt this step
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running
# Add the halting probability for this step to the halting
# probabilities for those input which haven't halted yet
halting_probability = halting_probability + p * still_running
# Compute remainders for the inputs which halted at this step
remainders = remainders + new_halted * (1 - halting_probability)
# Add the remainders to those inputs which halted at this step
halting_probability = halting_probability + new_halted * remainders
# Increment n_updates for all inputs which are still running
n_updates = n_updates + still_running + new_halted
# Compute the weight to be applied to the new state and output
# 0 when the input has already halted
# p when the input hasn't halted yet
# the remainders when it halted this step
update_weights = p * still_running + new_halted * remainders
if(encoder_output):
if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any():
state, _ = fn_no_update((state,encoder_output))
else:
state, _ = fn_update((state, encoder_output))
else:
# apply transformation on the state
if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any():
state = fn_no_update(state, *args, **kwargs)
else:
state = fn_update(state, *args, **kwargs)
if isinstance(state, tuple):
rest = state[1:]
state = state[0]
# update running part in the weighted state and keep the rest
previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1))))
## previous_state is actually the new_state at end of hte loop
## to save a line I assigned to previous_state so in the next
## iteration is correct. Notice that indeed we return previous_state
step+=1
if rest is None:
return previous_state, (remainders,n_updates)
else:
return (previous_state, *rest), (remainders, n_updates)
class ACTForWholeARMT_constant_depth():
def __init__(self):
super(ACTForWholeARMT_constant_depth, self).__init__()
def __call__(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs):
print("\n\n\n\n\n\n\n\n\n\nCONSTANT DEPTH TRUE")
# init_hdd
## [B, S]
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda()
## [B, S]
n_updates = torch.full((inputs.shape[0],inputs.shape[1]), max_hop).cuda()
## [B, S, HDD]
previous_state = torch.zeros_like(inputs).cuda()
step = 0
# for l in range(self.num_layers):
rest = None
while(step < max_hop):
# Add timing signal
# state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data)
# state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data)
if(encoder_output):
if (step < max_hop):
state, _ = fn_no_update((state,encoder_output))
else:
state, _ = fn_update((state, encoder_output))
else:
# apply transformation on the state
if (step < max_hop):
state = fn_no_update(state, *args, **kwargs)
else:
state = fn_update(state, *args, **kwargs)
if isinstance(state, tuple):
rest = state[1:]
state = state[0]
# update running part in the weighted state and keep the rest
previous_state = state
## previous_state is actually the new_state at end of hte loop
## to save a line I assigned to previous_state so in the next
## iteration is correct. Notice that indeed we return previous_state
step+=1
if rest is None:
return previous_state, (remainders,n_updates)
else:
return (previous_state, *rest), (remainders, n_updates)
class ACT_transformer(nn.Module):
def __init__(self, hidden_size, num_heads=4, num_transformer_layers=1, dropout=0.1):
super(ACT_transformer, self).__init__()
# Transformer encoder
transformer_layer = TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size,
dropout=dropout,
norm_first=True
)
self.transformer = TransformerEncoder(transformer_layer,
num_layers=num_transformer_layers)
# Feedforward layer for logits
self.logit_ff = nn.Linear(hidden_size, 1)
self.logit_ff.bias.data.fill_(1)
# Halting threshold
self.sigma = nn.Sigmoid()
self.threshold = 1 - 0.1
def generate_causal_mask(self, seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs):
batch_size, seq_len, hidden_size = inputs.shape
halting_probability = torch.zeros(batch_size, seq_len).cuda()
remainders = torch.zeros(batch_size, seq_len).cuda()
n_updates = torch.zeros(batch_size, seq_len).cuda()
previous_state = torch.zeros_like(inputs).cuda()
step = 0
rest = None
causal_mask = self.generate_causal_mask(seq_len).cuda()
while ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any():
state_transformed = self.transformer(
state.permute(1, 0, 2), # [S, B, H]
mask=causal_mask
) # [S, B, H]
state_transformed = state_transformed.permute(1, 0, 2) # [B, S, H]
# Pass through linear layer and sigmoid
p = self.sigma(self.logit_ff(state_transformed)).squeeze(-1) # [B, S]
# Update halting logic
still_running = (halting_probability < 1.0).float()
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running
halting_probability = halting_probability + p * still_running
remainders = remainders + new_halted * (1 - halting_probability)
halting_probability = halting_probability + new_halted * remainders
n_updates = n_updates + still_running + new_halted
update_weights = p * still_running + new_halted * remainders
if encoder_output is not None:
state, _ = fn((state, encoder_output))
else:
state = fn(state, *args, **kwargs)
if isinstance(state, tuple):
rest = state[1:]
state = state[0]
previous_state = (
(state * update_weights.unsqueeze(-1)) +
(previous_state * (1 - update_weights.unsqueeze(-1)))
)
step += 1
if rest is None:
return previous_state, (remainders, n_updates)
else:
return (previous_state, *rest), (remainders, n_updates)
# ---- language_modeling.py ----
import math
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.cache_utils import Cache, DynamicCache
from torch.nn.functional import relu as r
import torch.nn.functional as F
import wandb
from munch import Munch
import os
# inlined act_utils: removed import ACT_basic, gen_timing_signal, ACTForWholeARMT, ACT_transformer, ACT_constant_depth, ACTForWholeARMT_constant_depth
try:
from baselines.rwkv.language_modeling import RWKVModel
RWKV_imported = True
except ImportError:
print("*** Can't import RWKV model ***")
RWKV_imported = False
def dpfp(x, nu=1):
x = torch.cat([r(x), r(-x)], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1)
for j in range(1,nu+1)], dim=-1)
x_repeat = torch.cat([x] * nu, dim=-1)
return x_repeat * x_rolled
class DPFP:
def __init__(self, nu):
self.nu = nu
def __call__(self, x):
nu = self.nu
x = torch.cat([r(x), r(-x)], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1,nu+1)], dim=-1)
x_repeat = torch.cat([x] * nu, dim=-1)
return x_repeat * x_rolled
def attn_mask_to_4d(attn_mask, upper, query_len):
if attn_mask is None:
return None
seg_len = attn_mask.size(-1)
if upper:
tri = torch.triu(torch.ones(query_len, seg_len))
else:
tri = torch.tril(torch.ones(query_len, seg_len))
mask = torch.einsum('bj,ij->bij', attn_mask, tri.to(attn_mask.device))
mask = mask.unsqueeze(1)
return mask
def invert_attn_mask(attn_mask, dtype):
min_dtype = torch.finfo(dtype).min
new_mask = (1.0 - attn_mask) * min_dtype
return new_mask
class AssociativeLayerWrapper(torch.nn.Module):
def __init__(self, layer, d_model, num_mem_tokens, d_mem, n_heads=1, correction=True, info=None, use_denom=True, gating=False) -> None:
super().__init__()
self.info = info
self.seg_num = 0
self.d_model = d_model
self.num_mem_tokens = num_mem_tokens
self.d_mem = d_mem
self.n_heads = n_heads
self.gating = gating
nu = 3
self.d_key = 2 * nu * d_mem
assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0
self.phi = DPFP(nu)
# self.d_key = d_mem
# self.phi = torch.nn.Identity()
self.use_denom = use_denom
# Get the proper dtype from the layer
layer_dtype = next(layer.parameters()).dtype
self.W_mq = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype)
# torch.nn.init.zeros_(self.W_mq.weight)
self.W_mk = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype)
self.W_mv = torch.nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype)
if gating:
self.W_mb = torch.nn.Linear(d_model, d_model, dtype=layer_dtype)
else:
self.W_mb = torch.nn.Linear(d_model, n_heads, dtype=layer_dtype)
torch.nn.init.zeros_(self.W_mv.weight)
s = 1/math.sqrt(d_model)
# torch.nn.init.uniform_(self.W_mq.weight, -s, s)
# torch.nn.init.uniform_(self.W_mk.weight, -s, s)
# torch.nn.init.uniform_(self.W_mb.weight, -s, s)
# self.ln = torch.nn.LayerNorm(d_model)
self.layer = layer
self.generate_mode = False
self.first_seg = True
self.correction = correction
self.zero_mem()
def _to_heads(self, x):
bsz, seq_len, d_model = x.shape
x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads)
x = x.permute(0, 2, 1, 3)
return x
def _from_heads(self, x):
bsz, n_heads, seq_len, d_head = x.shape
x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head)
return x
def associate(self, hidden_states):
bsz, seq_len, d_model = hidden_states.shape
self.W_mem = self.W_mem.to(hidden_states.device)
if self.use_denom:
self.z = self.z.to(hidden_states.device)
q = self._to_heads(self.W_mq(hidden_states))
mq = self.phi(q) # (bsz, n_heads, seq_len, 2 * d_head * nu)
mq = F.normalize(mq, dim=-1, p=2.0)
# crutch for dataparallel
# mq += 0 * self.W_mb(hidden_states).sum() * self.W_mk(hidden_states).sum() * self.W_mv(hidden_states).sum()
num = torch.einsum('ihjk,ihkt->ihjt', mq, self.W_mem)
if self.use_denom:
denom = torch.einsum("ihk,ihjk->ihj", self.z, mq)[..., None] + 1e-5
hidden_states = num / denom # (bsz, n_heads, seq_len, d_model // n_heads)
else:
hidden_states = num
hidden_states = self._from_heads(hidden_states)
return hidden_states
def forward(self, hidden_states, *args, **kwargs):
if not self.first_seg:
hidden_states = self.associate(
# self.ln(
hidden_states
# )
) + hidden_states
out = self.layer(hidden_states, *args, **kwargs)
if not self.generate_mode:
# The layer output contains hidden states, not logits
# For transformer layers, the output is typically the hidden states
if isinstance(out, tuple):
mem_tokens = out[0][:, -self.num_mem_tokens:]
else:
mem_tokens = out[:, -self.num_mem_tokens:]
self.update_mem(mem_tokens)
return out
def forward_no_update(self, hidden_states, *args, **kwargs):
if not self.first_seg:
hidden_states = self.associate(
# self.ln(
hidden_states
# )
)+ hidden_states
out = self.layer(hidden_states, *args, **kwargs)
return out
def forward_no_update(self, hidden_states, *args, **kwargs):
if not self.first_seg:
hidden_states = self.associate(
# self.ln(
hidden_states
# )
) + hidden_states
out = self.layer(hidden_states, *args, **kwargs)
return out
def update_mem(self, mem_tokens):
self.W_mem = self.W_mem.to(mem_tokens.device)
if self.use_denom:
self.z = self.z.to(mem_tokens.device)
k = self._to_heads(self.W_mk(mem_tokens))
mk = self.phi(k)
mk = F.normalize(mk, dim=-1, p=2.0)
new_mv = self._to_heads(self.W_mv(mem_tokens)) # (bsz, n_heads, num_mem_tokens, d_model)
if not self.first_seg:
num = torch.einsum('ihjk,ihkt->ihjt', mk, self.W_mem)
if self.use_denom:
denom = torch.einsum("ihj,ihkj->ihk", self.z, mk)[..., None] + 1e-5
prev_mv = num / denom
if self.correction:
new_info_coef = (1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None])
new_info_coef = torch.clip(new_info_coef, 0, 1).detach()
else:
new_info_coef = 1
else:
prev_mv = num
else:
prev_mv = torch.zeros_like(new_mv, device=new_mv.device)
new_info_coef = 1
# wandb.log({f"gamma_{self.info['layer']}": new_info_coef.mean(dim=1).item() if isinstance(new_info_coef, torch.Tensor) else 1}, step=self.seg_num)
mv = new_mv - prev_mv
# new_norm = torch.linalg.norm(new_mv, dim=-1)
# old_norm = torch.linalg.norm(prev_mv, dim=-1)
# new_info_coef = torch.clip(1 - old_norm / (new_norm + 1e-5), -10, 10)[..., None].detach()
# new_info_coef = 1 - denom
mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens)))
einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt"
associations = torch.einsum(einop, mk, mv, mb) # (bsz, n_heads, d_mem, d_model)
self.W_mem = self.W_mem + associations
if self.use_denom:
self.z = self.z + (new_info_coef*mk).sum(dim=-2)
# self.z = self.z + (new_info_coef*mb[..., None]*mk).sum(dim=1)
self.seg_num += 1
self.first_seg = False
def freeze_mem(self):
self.W_mb.weight.requires_grad = False
self.W_mb.bias.requires_grad = False
self.W_mq.weight.requires_grad = False
self.W_mk.weight.requires_grad = False
self.W_mv.weight.requires_grad = False
def zero_mem(self):
self.first_seg = True
# Get the proper dtype from the layer parameters
layer_dtype = next(self.layer.parameters()).dtype
self.W_mem = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, self.d_model // self.n_heads, dtype=layer_dtype)
self.W_mem.requires_grad_(False)
if self.use_denom:
self.z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, dtype=layer_dtype)
self.z.requires_grad_(False)
self.seg_num = 0
def detach_mem(self):
self.W_mem = self.W_mem.detach()
if self.use_denom:
self.z = self.z.detach()
class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper):
def __init__(self,
layer,
d_model,
num_mem_tokens,
d_mem,
max_hop,
n_heads=1,
correction=True,
info=None,
use_denom=True,
gating=False,
constant_depth=False,
) -> None:
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating)
self.act = ACT_basic(d_model) if not constant_depth else ACT_constant_depth()
self.depth = max_hop
self.max_length = 1024
self.timing_signal = gen_timing_signal(self.max_length, d_model)
## for t
self.position_signal = gen_timing_signal(self.depth, d_model)
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
def associate(self, hidden_states):
self.remainders = self.remainders.to(hidden_states.device)
self.n_updates = self.n_updates.to(hidden_states.device)
self.segments_passed = self.segments_passed.to(hidden_states.device)
out, (remainders, n_updates) = self.act(
state=hidden_states,
inputs=hidden_states,
fn=super().associate,
time_enc=self.timing_signal,
pos_enc=self.position_signal,
max_hop=self.depth
)
self.remainders = self.remainders + remainders.mean() # 1 - \sum(h_i); L' = L + tau * mean(remainders)
self.n_updates = self.n_updates + n_updates.mean()
self.segments_passed = self.segments_passed + 1
return out
def zero_mem(self):
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
return super().zero_mem()
def detach_mem(self):
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
return super().detach_mem()
class AdaptiveAssociativeLayerWrapper2(AssociativeLayerWrapper):
def __init__(self,
layer,
d_model,
num_mem_tokens,
d_mem,
max_hop,
n_heads=1,
correction=True,
info=None,
use_denom=True,
gating=False,
act_format='linear',
noisy_halting=False,
constant_depth=False,
) -> None:
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating)
if act_format=='transformer':
self.act = ACT_transformer(d_model)
elif constant_depth:
self.act = ACT_constant_depth()
elif act_format == 'linear':
self.act = ACT_basic(d_model)
else:
raise NotImplemetedError
self.depth = max_hop
self.max_length = 1024
self.noisy_halting = noisy_halting
self.timing_signal = gen_timing_signal(self.max_length, d_model)
## for t
self.position_signal = gen_timing_signal(self.depth, d_model)
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
def forward(self, hidden_states, *args, **kwargs):
self.remainders = self.remainders.to(hidden_states.device)
self.n_updates = self.n_updates.to(hidden_states.device)
self.segments_passed = self.segments_passed.to(hidden_states.device)
if self.noisy_halting:
kwargs['noisy_halting'] = self.noisy_halting
fwd = super().forward_no_update
out, (remainders, n_updates) = self.act(
*args,
state=hidden_states,
inputs=hidden_states,
fn=fwd,
time_enc=self.timing_signal,
pos_enc=self.position_signal,
max_hop=self.depth,
**kwargs
)
if not self.generate_mode:
mem_tokens = out[0][:, -self.num_mem_tokens:]
# mem_tokens = out[0]
self.update_mem(mem_tokens)
self.first_seg = False
self.remainders = self.remainders + remainders.mean() # 1 - \sum(h_i); L' = L + tau * mean(remainders)
self.n_updates = self.n_updates + n_updates.mean()
self.segments_passed = self.segments_passed + 1
return out
def zero_mem(self):
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
return super().zero_mem()
def detach_mem(self):
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
return super().detach_mem()
class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper):
def __init__(self,
layer,
d_model,
num_mem_tokens,
d_mem,
max_hop,
n_heads=1,
correction=True,
info=None,
use_denom=True,
gating=False,
) -> None:
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating)
self.act = ACT_basic(d_model)
self.depth = max_hop
self.max_length = 1024
self.timing_signal = gen_timing_signal(self.max_length, d_model)
## for t
self.position_signal = gen_timing_signal(self.depth, d_model)
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
def associate(self, hidden_states):
self.remainders = self.remainders.to(hidden_states.device)
self.n_updates = self.n_updates.to(hidden_states.device)
self.segments_passed = self.segments_passed.to(hidden_states.device)
out, (remainders, n_updates) = self.act(
state=hidden_states,
inputs=hidden_states,
fn=super().associate,
time_enc=self.timing_signal,
pos_enc=self.position_signal,
max_hop=self.depth
)
self.remainders = self.remainders + remainders # 1 - \sum(h_i); L' = L + tau * mean(remainders)
self.n_updates = self.n_updates + n_updates
self.segments_passed = self.segments_passed + 1
return out
def zero_mem(self):
self.remainders = torch.zeros(1,)
self.n_updates = torch.zeros(1,)
self.segments_passed = torch.zeros(1,)
return super().zero_mem()
class AssociativeMemoryCell(torch.nn.Module):
def __init__(self,
base_model,
num_mem_tokens,
d_mem,
layers_attr: str = 'model.layers',
wrap_pos=False,
correction=True,
n_heads=1,
use_denom=True,
gating=False,
freeze_mem=False,
act_on=False,
max_hop=4,
act_type='layer',
act_format='linear',
noisy_halting=False,
constant_depth=False,
attend_to_previous_input=False,
use_sink=False,
**rmt_config
):
super().__init__()
self.model = base_model
self.attend_to_previous_input = attend_to_previous_input
self.previous_input = None
self.use_sink = use_sink
self.RWKV_ARMT = isinstance(self.model, RWKVModel) if RWKV_imported else False
self.num_mem_tokens = num_mem_tokens
self.d_mem = d_mem
self.d_model = base_model.get_input_embeddings().embedding_dim
self.W_mem = []
self.constant_depth = constant_depth
self.layers_attrs = layers_attr.split('.')
def _get_layers_from_model(model_root):
layers_obj = model_root
for attr in self.layers_attrs:
layers_obj = getattr(layers_obj, attr)
return layers_obj
layers = _get_layers_from_model(self.model)
for i in range(len(layers)):
kw = dict(
layer=layers[i],
d_model=self.d_model,
num_mem_tokens=self.num_mem_tokens,
d_mem=self.d_mem,
correction=correction,
info={'layer': i},
n_heads=n_heads,
use_denom=use_denom,
gating=gating,
)
if act_on and act_type != 'model':
kw['act_format'] = act_format
if act_on and act_type == 'model' and act_format != 'linear':
raise NotImplementedError
if act_on and (act_type != 'model'):
kw['max_hop'] = max_hop
kw['constant_depth'] = self.constant_depth
kw['act_format'] = act_format
if act_on and noisy_halting:
kw['noisy_halting'] = noisy_halting
if not act_on:
layers[i] = AssociativeLayerWrapper(**kw)
elif act_type == 'associative':
layers[i] = AdaptiveAssociativeLayerWrapper(**kw)
elif act_type == 'layer':
layers[i] = AdaptiveAssociativeLayerWrapper2(**kw)
elif act_type == 'model':
layers[i] = AssociativeLayerWrapper(**kw)
else:
raise f'Unknown ACT type: {act_type}'
if act_type == 'model':
self.act = ACTForWholeARMT(self.d_model) if not self.constant_depth else ACTForWholeARMT_constant_depth()
self.depth = max_hop
self.max_length = 1024
self.timing_signal = gen_timing_signal(self.max_length, self.d_model)
self.position_signal = gen_timing_signal(self.depth, self.d_model)
self.act_type = act_type
self.create_memory(num_mem_tokens)
self.wrap_pos = wrap_pos
self.act_on = act_on
if wrap_pos:
self.wrap_positional_embeddings(num_mem_tokens)
if freeze_mem:
for layer in _get_layers_from_model(self.model):
layer.freeze_mem()
# Expose a resolver without registering layers as a submodule to avoid shared tensor aliases
self.get_layers = lambda: _get_layers_from_model(self.model)
def generate_mode(self, is_on):
for layer in self.get_layers():
layer.generate_mode = is_on
def create_memory(self, num_mem_tokens):
self.num_mem_tokens = num_mem_tokens
embeddings = self.model.get_input_embeddings()
memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size)
memory_weights = torch.randn((num_mem_tokens, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype) * embeddings.weight.data.std()
self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))
if self.use_sink:
self.sink = torch.nn.Parameter(torch.randn((1, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype), requires_grad=True)
def wrap_positional_embeddings(self, num_mem_tokens):
num_pos_embs, emb_dim = self.model.transformer.wpe.weight.shape
prev_embs = self.model.transformer.wpe.weight.detach()
self.model.transformer.wpe = torch.nn.Embedding(num_mem_tokens + num_pos_embs, emb_dim)
new_num_pos = num_pos_embs + num_mem_tokens
with torch.no_grad():
self.model.transformer.wpe.weight[:len(self.model.transformer.wpe.weight)-num_mem_tokens] = prev_embs
for layer in self.model.transformer.h:
layer.layer.attn.bias = torch.tril(torch.ones((new_num_pos, new_num_pos), dtype=torch.uint8)).view(
1, 1, new_num_pos, new_num_pos
)
def set_memory(self, input_shape):
memory = self.memory.repeat(input_shape[0], 1, 1)
if self.use_sink:
sink = self.sink.repeat(input_shape[0], 1, 1)
else:
sink = None
return memory, sink
def zero_mem(self):
for layer in self.get_layers():
layer.zero_mem()
self.previous_input = None
def detach_mem(self):
for layer in self.get_layers():
layer.detach_mem()
pass
def forward(self, input_ids, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs):
if self.act_type != 'model':
out = self.forward_with_update(input_ids, labels, labels_mask, zero_mem, attention_mask=attention_mask, **kwargs)
else:
seg_kwargs = self.process_input(input_ids=input_ids,
labels=labels,
labels_mask=labels_mask,
zero_mem=zero_mem,
attention_mask=attention_mask,
**kwargs
)
out = self.gptneox_forward_act(**seg_kwargs)
out = self.process_output(out, labels=labels, labels_mask=labels_mask)
return out
def forward_with_update(self, input_ids, labels=None, labels_mask=None, zero_mem=False, **kwargs):
current_input_ids = input_ids.clone()
if self.attend_to_previous_input and self.previous_input is not None:
input_ids = torch.cat([self.previous_input, input_ids], dim=1)
if zero_mem:
self.zero_mem()
seg_kwargs = self.process_input(input_ids, **kwargs)
layers = self.get_layers()
if self.RWKV_ARMT and not layers[0].generate_mode:
input1 = dict()
input2 = dict()
for item in seg_kwargs:
if isinstance(seg_kwargs[item], torch.Tensor):
# if False:
input1[item] = seg_kwargs[item][:, :-self.num_mem_tokens]
input2[item] = seg_kwargs[item][:, -self.num_mem_tokens:]
else:
input1[item] = seg_kwargs[item]
input2[item] = seg_kwargs[item]
self.generate_mode(True)
out = self.model(**input1)
self.generate_mode(False)
state_tmp = tuple([torch.clone(state) for state in out['state']])
out = Munch({k: torch.clone(t) if isinstance(t, torch.Tensor) else t for k, t in out.items()})
input2['state'] = out['state']
_ = self.model(**input2)
out['state'] = state_tmp
# out['state'] = out2['state']
# out = self.model(**seg_kwargs)
# out['logits'] = out['logits'][:, :-self.num_mem_tokens]
else:
out = self.model(**seg_kwargs)
if self.attend_to_previous_input and self.previous_input is not None:
out['logits'] = out['logits'][:, self.previous_input.size(1):]
out = self.process_output(out, labels, labels_mask, **kwargs)
self.previous_input = current_input_ids
return out
def process_input(self, input_ids, **kwargs):
memory_state, sink = self.set_memory(input_ids.shape)
seg_kwargs = dict(**kwargs)
inputs_embeds = kwargs.get('inputs_embeds')
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.use_sink:
inputs_embeds = torch.cat([sink, inputs_embeds, memory_state], dim=1)
else:
inputs_embeds = torch.cat([inputs_embeds, memory_state], dim=1)
seg_kwargs['input_ids'] = None
seg_kwargs['inputs_embeds'] = inputs_embeds
if kwargs.get('attention_mask') is not None:
seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], dtype=inputs_embeds.dtype)
if kwargs.get('prev_attn_mask') is not None:
prev_seg_attn_mask = self.pad_prev_seg_attn_mask(kwargs['prev_attn_mask'], dtype=inputs_embeds.dtype)
seg_kwargs['attention_mask'] = torch.cat([prev_seg_attn_mask, seg_kwargs['attention_mask']], dim=-1)
if 'prev_attn_mask' in seg_kwargs:
seg_kwargs.pop('prev_attn_mask')
seg_kwargs['output_hidden_states'] = True
if self.wrap_pos:
num_pos_embs = self.model.transformer.wpe.weight.shape[0]
ordinary_pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device)
write_pos = torch.arange(num_pos_embs - self.num_mem_tokens, num_pos_embs, dtype=torch.long, device=input_ids.device)
seg_kwargs['position_ids'] = torch.cat([
ordinary_pos,
write_pos
]).long().unsqueeze(0)
return seg_kwargs
def pad_attention_mask(self, attention_mask, dtype=float):
if self.num_mem_tokens in {0, None}:
return attention_mask
else:
shape = list(attention_mask.shape)
if len(shape) == 4:
shape[-1] += self.num_mem_tokens + self.use_sink
shape[-2] += self.num_mem_tokens + self.use_sink
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device)
mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask
if self.use_sink:
mask[..., 0, 1:] = 0
mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0
# mask = torch.tril(mask)
if not os.environ.get("NOT_INVERT_ATTN_MASK"):
mask = invert_attn_mask(mask, dtype)
else:
shape[-1] += self.num_mem_tokens + self.use_sink
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device)
mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask
return mask.to(dtype)
def pad_prev_seg_attn_mask(self, prev_seg_attn_mask, dtype=float):
if self.num_mem_tokens in {0, None}:
return prev_seg_attn_mask
else:
shape = list(prev_seg_attn_mask.shape)
if len(shape) == 4:
shape[-2] += self.num_mem_tokens + self.use_sink
mask = torch.ones(*shape, dtype=dtype).to(prev_seg_attn_mask.device)
mask[..., int(self.use_sink):-self.num_mem_tokens, :] = prev_seg_attn_mask
if self.use_sink:
mask[..., 0, :] = 0
if not os.environ.get("NOT_INVERT_ATTN_MASK"):
mask = invert_attn_mask(mask, dtype)
else:
mask = prev_seg_attn_mask
return mask.to(dtype)
def process_output(self, model_outputs, labels, labels_mask, **kwargs):
if (self.num_mem_tokens not in {0, None}) and not self.RWKV_ARMT:
out = CausalLMOutputWithCrossAttentions()
out['logits'] = model_outputs.logits[:, int(self.use_sink):-self.num_mem_tokens]
if kwargs.get('output_hidden_states'):
out['hidden_states'] = [lh[:, int(self.use_sink):-self.num_mem_tokens] for lh in model_outputs.hidden_states]
if kwargs.get('output_attentions'):
out['attentions'] = model_outputs['attentions']
else:
out = model_outputs
if labels is not None:
logits = out['logits'][..., :-1, :].contiguous()
flat_logits = logits.view(-1, logits.size(-1))
labels = labels[..., 1:].contiguous()
flat_labels = labels.view(-1)
if labels_mask is not None:
flat_mask = labels_mask[..., :-1].contiguous().view(-1)
flat_logits = flat_logits[flat_mask]
flat_labels = flat_labels[flat_mask]
# Average by number of valid tokens
ce_loss_fn = CrossEntropyLoss(reduction='sum')
ce_loss = ce_loss_fn(flat_logits, flat_labels)
if labels_mask is not None:
denom = labels_mask[..., :-1].contiguous().view(-1).sum()
else:
denom = (flat_labels != -100).sum()
denom = torch.clamp(denom, min=1)
out['ce_loss'] = ce_loss / denom
if kwargs.get('use_cache', False):
out['past_key_values'] = model_outputs.past_key_values
if self.act_on and self.act_type == 'model':
out['remainders'] = model_outputs['remainders']
out['n_updates'] = model_outputs['n_updates']
return out
def generate(self, input_ids, attention_mask, zero_mem=False, **generate_kwargs):
if zero_mem:
self.zero_mem()
self.generate_mode(True)
seg_kwargs = self.process_input(input_ids, attention_mask=attention_mask)
out = self.model.generate(
inputs_embeds=seg_kwargs['inputs_embeds'][:, :-self.num_mem_tokens],
attention_mask=seg_kwargs['attention_mask'][:, :-self.num_mem_tokens],
**generate_kwargs
)
self.generate_mode(False)
return out
def update_past_key_values_sw(self, past_key_values, window_size):
past_key_values = past_key_values.to_legacy_cache()
past_key_values = [
[
k_or_v[..., -(window_size+self.use_sink):, :]
for k_or_v in seg_kv
]
for seg_kv in past_key_values
]
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return past_key_values
def greedy_generate_sw(self, input_ids, attention_mask, prev_attn_mask, **generate_kwargs):
self.generate_mode(True)
window_size = generate_kwargs['window_size']
max_new_tokens = generate_kwargs['max_new_tokens']
past_key_values = self.update_past_key_values_sw(generate_kwargs['past_key_values'], window_size)
eos_token_id = generate_kwargs['eos_token_id']
prev_attn_mask_2d = prev_attn_mask.clone()
attention_mask_2d = attention_mask.clone()
attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=attention_mask.size(-1))
prev_attn_mask = attn_mask_to_4d(prev_attn_mask, upper=True, query_len=attention_mask.size(-1))
seg_kwargs = self.process_input(input_ids=input_ids, attention_mask=attention_mask, prev_attn_mask=prev_attn_mask, past_key_values=past_key_values)
seg_kwargs['inputs_embeds'] = seg_kwargs['inputs_embeds'][..., :-self.num_mem_tokens, :]
seg_kwargs['attention_mask'] = seg_kwargs['attention_mask'][..., :-self.num_mem_tokens, :-self.num_mem_tokens]
outputs = self.model(**seg_kwargs, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
past_key_values = self.update_past_key_values_sw(past_key_values, window_size)
generated_ids = None
sw_attention_mask = torch.cat([prev_attn_mask_2d, torch.ones(attention_mask_2d.size(0), 1).to(prev_attn_mask_2d.device), attention_mask_2d], dim=-1)
for i in range(max_new_tokens):
# print(next_token_logits[..., :5])
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
if generated_ids is not None:
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
else:
generated_ids = next_token_id
next_input = next_token_id
sw_attention_mask = torch.cat([sw_attention_mask, torch.ones_like(next_token_id).to(sw_attention_mask.device)], dim=-1)[..., -window_size-1-self.use_sink:]
with torch.no_grad():
outputs = self.model(
input_ids=next_input,
attention_mask=sw_attention_mask,
past_key_values=past_key_values,
use_cache=True,
cache_position=torch.full((1,), window_size + i + input_ids.size(-1) + self.use_sink).to(input_ids.device)
)
past_key_values = self.update_past_key_values_sw(outputs.past_key_values, window_size)
next_token_logits = outputs.logits[:, -1, :]
if (next_token_id[:, 0] == eos_token_id).all():
break
self.generate_mode(False)
return generated_ids
def apply_layers(self, hidden_states, causal_mask, position_ids, cache_position, position_embeddings, update_mem=True):
if not update_mem:
tmp = []
for i in range(len(self.layers)):
tmp.append(self.layers[i].forward)
self.layers[i].forward = self.layers[i].forward_no_update
for layer in self.get_layers():
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
cache_position=cache_position,
position_embeddings=position_embeddings,
)[0]
if not update_mem:
for i, layer in enumerate(self.get_layers()):
layer.forward = tmp[i]
return hidden_states
def gptneox_forward_act(self, inputs_embeds, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs):
drop = self.model.gpt_neox.emb_dropout
hidden_states = drop(inputs_embeds)
seq_length = hidden_states.shape[1]
cache_position = torch.arange(0, seq_length, device=hidden_states.device)
position_ids = cache_position.unsqueeze(0)
position_embeddings = self.model.gpt_neox.rotary_emb(hidden_states, position_ids)
causal_mask = self.model.gpt_neox._update_causal_mask(
attention_mask, hidden_states, cache_position, None, False
)
out, (remainders, n_updates) = self.act(
state=hidden_states,
inputs=hidden_states,
fn_no_update=lambda *args, **kwargs: self.apply_layers(*args, **kwargs, update_mem=False),
fn_update=self.apply_layers,
time_enc=self.timing_signal,
pos_enc=self.position_signal,
max_hop=self.depth,
causal_mask=causal_mask,
position_ids=position_ids,
cache_position=cache_position,
position_embeddings=position_embeddings
)
hidden_states = self.model.gpt_neox.final_layer_norm(out)
lm_logits = self.model.embed_out(hidden_states)
return Munch(logits=lm_logits, n_updates=n_updates, remainders=remainders)
class AssociativeRecurrentWrapper(torch.nn.Module):
def __init__(self, memory_cell, **rmt_kwargs):
super().__init__()
self.memory_cell = memory_cell
self.rmt_config = rmt_kwargs
self.last_state = None
def gradient_checkpointing_enable(self, *args, **kwargs):
self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs)
def process_segment(self, segment_kwargs, next_seg_len=None):
sliding_window = self.rmt_config['sliding_window'] if 'sliding_window' in self.rmt_config else False
attend_to_previous_input = self.rmt_config['attend_to_previous_input'] if 'attend_to_previous_input' in self.rmt_config else False
attn_mask = segment_kwargs['attention_mask']
seg_len = segment_kwargs['input_ids'].size(-1)
segment_kwargs['use_cache'] = sliding_window
if segment_kwargs.get('past_key_values') is None:
segment_kwargs['past_key_values'] = None
if segment_kwargs.get('prev_attn_mask') is None:
segment_kwargs['prev_attn_mask'] = None
segment_kwargs['zero_mem'] = False
if sliding_window or attend_to_previous_input:
segment_kwargs['attention_mask'] = attn_mask_to_4d(attn_mask, upper=False, query_len=seg_len)
if 'state' in segment_kwargs and segment_kwargs['state'] is None:
segment_kwargs.pop('state')
num_mem_tokens = self.memory_cell.num_mem_tokens
cell_out = self.memory_cell(**segment_kwargs)
state = cell_out.get('state')
if (sliding_window or attend_to_previous_input) and next_seg_len is not None:
prev_attn_mask = attn_mask_to_4d(attn_mask, upper=True, query_len=next_seg_len)
else:
prev_attn_mask = None
if sliding_window:
past_key_values = [
[
k_or_v[..., -(num_mem_tokens+seg_len):k_or_v.size(-2)-num_mem_tokens, :].detach()
for k_or_v in seg_kv
]
for seg_kv in cell_out['past_key_values']
]
if not isinstance(cell_out['past_key_values'], tuple) and not isinstance(cell_out['past_key_values'], list):
past_key_values = cell_out['past_key_values'].from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
else:
past_key_values = None
next_segment_kwargs = dict()
next_segment_kwargs['use_cache'] = sliding_window
next_segment_kwargs['past_key_values'] = past_key_values
next_segment_kwargs['prev_attn_mask'] = prev_attn_mask
next_segment_kwargs['zero_mem'] = False
if state is not None:
next_segment_kwargs['state'] = state
return cell_out, next_segment_kwargs
def forward(self,
input_ids,
labels=None,
labels_mask=None,
inputs_embeds=None,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
input_segmented=False,
output_only_last_segment=False,
use_previous_batch_state=torch.zeros(1),
num_items_in_batch=None, # Added to handle HF Trainer compatibility
**kwargs # Added to handle any other unexpected kwargs
):
if input_segmented:
n_segs = input_ids.shape[1] if not (input_ids is None) else inputs_embeds.shape[1]
segmented = [dict(
input_ids=input_ids[:, i] if not (input_ids is None) else None,
inputs_embeds=inputs_embeds[:, i] if not (inputs_embeds is None) else None,
attention_mask=attention_mask[:, i],
labels=labels[:, i] if not (labels is None) else None,
labels_mask=labels_mask[:, i] if not (labels_mask is None) else None,
) for i in range(n_segs)]
labels = torch.cat([labels[:, i] for i in range(n_segs)], dim=1)
if labels_mask is not None:
labels_mask = torch.cat([labels_mask[:, i] for i in range(n_segs)], dim=1)
else:
segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, labels_mask=labels_mask)
cell_outputs = []
if not use_previous_batch_state.all() or self.last_state is None:
self.memory_cell.zero_mem()
state = None
else:
self.memory_cell.detach_mem()
state = self.last_state
next_seg_kwargs = dict(state=state)
for seg_num, segment in enumerate(segmented):
if seg_num != len(segmented) - 1:
next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1)
else:
next_seg_len = None
# Pass num_items_in_batch to segment processing
segment_with_kwargs = dict(**segment, **next_seg_kwargs)
if kwargs.get('num_items_in_batch') is not None:
segment_with_kwargs['num_items_in_batch'] = kwargs['num_items_in_batch']
cell_out, next_seg_kwargs = self.process_segment(segment_with_kwargs, next_seg_len=next_seg_len)
if (not output_only_last_segment) or (seg_num == len(segmented) - 1):
cell_outputs.append(cell_out)
out = self.process_outputs(cell_outputs, labels=labels,
labels_mask=labels_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
num_items_in_batch=kwargs.get('num_items_in_batch'))
if not self.training:
self.memory_cell.zero_mem()
self.last_state = None
return out
def segment(self, **kwargs):
segments = []
for k, tensor in kwargs.items():
if tensor is not None:
k_segments = self.split_tensor(tensor)
for s, k_seg in enumerate(k_segments):
if s < len(segments):
segments[s][k] = k_seg
else:
segments.append({k: k_seg})
return segments
def split_tensor(self, tensor):
align = self.rmt_config.get('segment_alignment')
segment_size = self.rmt_config.get('segment_size')
if align in {'left', None}:
split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]]
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
elif align in {'right', None}:
split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1]
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
elif align == 'center':
n_seg = math.ceil(tensor.shape[1] / segment_size)
segments = torch.chunk(tensor, n_seg, dim=1)
else:
raise NotImplementedError
return segments
def process_outputs(self, cell_outputs, **kwargs):
out = CausalLMOutputWithCrossAttentions()
full_logits = torch.cat([o.logits for o in cell_outputs], dim=1)
labels = kwargs.get('labels')
if labels is not None:
labels = labels[:, -full_logits.size(1):]
shift_labels = labels[..., 1:].contiguous()
shift_logits = full_logits[..., :-1, :].contiguous()
flat_labels = shift_labels.view(-1)
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
labels_mask = kwargs.get('labels_mask')
if labels_mask is not None:
labels_mask = labels_mask[:, -full_logits.size(1):]
shift_mask = labels_mask[..., :-1].contiguous()
flat_labels = flat_labels[shift_mask.view(-1)]
flat_logits = flat_logits[shift_mask.view(-1)]
# Average by number of valid tokens
loss_fct = CrossEntropyLoss(reduction='sum')
loss = loss_fct(flat_logits, flat_labels)
if labels_mask is not None:
# Use the same mask used to filter flat logits/labels
denom = labels_mask[..., :-1].contiguous().view(-1).sum()
else:
denom = (flat_labels != -100).sum()
denom = torch.clamp(denom, min=1)
out['loss'] = loss / denom
else:
out['loss'] = 0
if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']:
out['ce_loss'] = out['loss']
out['logits'] = full_logits
segment_keys = ['loss', 'logits']
if kwargs.get('output_attentions'):
segment_keys.append('attentions')
if kwargs.get('output_hidden_states'):
full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])])
segment_keys.append('hidden_states')
out['hidden_states'] = full_hidden_states
if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']:
for seg_num, o in enumerate(cell_outputs):
for key, value in o.items():
if any([sk in key for sk in segment_keys]):
out[f'{key}_{seg_num}'] = value
remainders = []
n_updates = []
act_on = self.rmt_config['act_on'] if 'act_on' in self.rmt_config else False
if act_on:
if self.memory_cell.act_type != 'model':
for layer in self.memory_cell.get_layers():
remainders.append(layer.remainders / layer.segments_passed)
n_updates.append(layer.n_updates / layer.segments_passed)
remainders = torch.mean(torch.stack(remainders, dim=0))
n_updates = torch.mean(torch.stack(n_updates, dim=0))
else:
remainders = torch.mean(torch.stack([o['remainders'] for o in cell_outputs], dim=0))
n_updates = torch.mean(torch.stack([o['n_updates'] for o in cell_outputs], dim=0))
out['n_updates'] = n_updates.detach().cpu()
out['remainders'] = remainders.detach().cpu()
time_penalty = self.rmt_config['time_penalty']
out['loss'] = out['loss'] + time_penalty * remainders
return out
def generate(self, input_ids, attention_mask, **generate_kwargs):
self.memory_cell.zero_mem()
segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask)
next_seg_kwargs = dict()
for seg_num, segment in enumerate(segmented[:-1]):
next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1)
_, next_seg_kwargs = self.process_segment(dict(**segment, **next_seg_kwargs), next_seg_len=next_seg_len)
final_segment = segmented[-1]
assert next_seg_kwargs.get('past_key_values') is None or isinstance(next_seg_kwargs.get('past_key_values'), Cache), "Sliding Window generation is not implemented for legacy cache"
if next_seg_kwargs.get('past_key_values') is not None:
prev_attn_mask = segmented[-2]['attention_mask']
legacy_cache = next_seg_kwargs['past_key_values'].to_legacy_cache()
seg_len = segmented[-2]['input_ids'].size(-1)
cache = DynamicCache().from_legacy_cache(legacy_cache)
generate_kwargs['past_key_values'] = cache
generate_kwargs['window_size'] = seg_len
final_segment['prev_attn_mask'] = prev_attn_mask
out = self.memory_cell.greedy_generate_sw(**final_segment, **generate_kwargs)
return out
else:
out = self.memory_cell.generate(**final_segment, **generate_kwargs)
return out
# ---- model.py ----
import math
import torch
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.cache_utils import Cache, DynamicCache
from torch.nn.functional import relu as r
import torch.nn.functional as F
import os
# inlined language_modeling: removed import AssociativeMemoryCell, AssociativeRecurrentWrapper, attn_mask_to_4d, invert_attn_mask
class ARMTConfig(PretrainedConfig):
model_type = "armt"
def __init__(self,
base_model_name=None,
base_model_config=None,
num_mem_tokens=16,
d_mem=512,
segment_size=512,
segment_alignment="left",
sliding_window=False,
attend_to_previous_input=False,
use_sink=False,
layers_attr="model.layers",
wrap_pos=False,
correction=True,
n_heads=1,
use_denom=True,
gating=False,
freeze_mem=False,
act_on=False,
max_hop=4,
act_type="associative",
act_format="linear",
noisy_halting=False,
constant_depth=False,
time_penalty=0.0,
**kwargs):
super().__init__(**kwargs)
# Validate mutual exclusivity
if (base_model_name is not None) and (base_model_config is not None):
raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided. Set the other to None.")
self.base_model_name = base_model_name
# Optional alternative to base_model_name: a config (dict/PretrainedConfig/name-or-path)
self.base_model_config = base_model_config
self.num_mem_tokens = num_mem_tokens
self.d_mem = d_mem
self.segment_size = segment_size
self.segment_alignment = segment_alignment
self.sliding_window = sliding_window
self.attend_to_previous_input = attend_to_previous_input
self.use_sink = use_sink
self.layers_attr = layers_attr
self.wrap_pos = wrap_pos
self.correction = correction
self.n_heads = n_heads
self.use_denom = use_denom
self.gating = gating
self.freeze_mem = freeze_mem
self.act_on = act_on
self.max_hop = max_hop
self.act_type = act_type
self.act_format = act_format
self.noisy_halting = noisy_halting
self.constant_depth = constant_depth
self.time_penalty = time_penalty
def get(self, attr: str, default=None):
if hasattr(self, attr):
return getattr(self, attr)
else:
return default
class ARMTForCausalLM(PreTrainedModel):
config_class = ARMTConfig
def __init__(self, config: ARMTConfig, **kwargs):
super().__init__(config, **kwargs)
from transformers import AutoConfig, AutoModelForCausalLM
# Build base model either from name (pretrained weights) or from provided config
base_model = None
if getattr(config, 'base_model_name', None) is not None and getattr(config, 'base_model_config', None) is not None:
raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in ARMTConfig.")
bm_cfg = getattr(config, 'base_model_config', None)
if bm_cfg is not None:
# Prefer explicit config when provided
if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, 'model_type', None) != ARMTConfig.model_type:
resolved_cfg = bm_cfg
elif isinstance(bm_cfg, dict):
if 'model_type' not in bm_cfg:
raise ValueError("`base_model_config` dict must include a 'model_type' key (e.g., 'gpt_neox', 'llama').")
config_cls_or_instance = AutoConfig.for_model(bm_cfg['model_type'])
# If an instance was returned, update it; if a class was returned, construct from dict
if isinstance(config_cls_or_instance, PretrainedConfig):
resolved_cfg = config_cls_or_instance
for k, v in bm_cfg.items():
setattr(resolved_cfg, k, v)
else:
resolved_cfg = config_cls_or_instance.from_dict(bm_cfg)
elif isinstance(bm_cfg, str):
# Treat as a name or path to load a config
resolved_cfg = AutoConfig.from_pretrained(bm_cfg)
else:
raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str (name/path)")
base_model = AutoModelForCausalLM.from_config(resolved_cfg)
elif getattr(config, 'base_model_name', None):
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name)
else:
raise ValueError("ARMTForCausalLM requires either `base_model_config` or `base_model_name` in ARMTConfig.")
self.armt_config = config
# Create the associative memory cell
memory_cell = AssociativeMemoryCell(
base_model=base_model,
num_mem_tokens=config.num_mem_tokens,
d_mem=config.d_mem,
layers_attr=config.layers_attr,
wrap_pos=config.wrap_pos,
correction=config.correction,
n_heads=config.n_heads,
use_denom=config.use_denom,
gating=config.gating,
freeze_mem=config.freeze_mem,
act_on=config.act_on,
max_hop=config.max_hop,
act_type=config.act_type,
# Optional extras
constant_depth=config.get('constant_depth', False),
act_format=config.get('act_format', 'linear'),
noisy_halting=config.get('noisy_halting', False),
attend_to_previous_input=config.attend_to_previous_input,
use_sink=config.use_sink
)
# Create the associative recurrent wrapper
self.armt = AssociativeRecurrentWrapper(
memory_cell,
segment_size=config.segment_size,
segment_alignment=config.segment_alignment,
sliding_window=config.sliding_window,
attend_to_previous_input=config.attend_to_previous_input,
act_on=config.act_on,
time_penalty=config.time_penalty
)
def forward(
self,
input_ids=None,
labels=None,
labels_mask=None,
inputs_embeds=None,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
input_segmented=False,
output_only_last_segment=False,
num_items_in_batch=None,
):
return self.armt(
input_ids=input_ids,
labels=labels,
labels_mask=labels_mask,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
input_segmented=input_segmented,
output_only_last_segment=output_only_last_segment,
num_items_in_batch=num_items_in_batch,
)
def generate(self, *args, **kwargs):
return self.armt.generate(*args, **kwargs)
def load_state_dict(self, state_dict, strict=True, assign=False):
try:
return super().load_state_dict(state_dict, strict, assign)
except RuntimeError:
print("Failed to load state, retrying with ARMT loader.")
self.armt.load_state_dict(state_dict, strict=True, assign=assign)
print("Success!")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs):
# Delegate to the base class to benefit from full shard/format support
return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs)
def gradient_checkpointing_enable(self, *args, **kwargs):
self.armt.gradient_checkpointing_enable(*args, **kwargs)