# === Inlined ARMT for HF Hub (single-file) === # ---- act_utils.py ---- from torch import nn import torch import numpy as np import math from torch.nn import TransformerEncoder, TransformerEncoderLayer def gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): """ Generates a [1, length, channels] timing signal consisting of sinusoids Adapted from: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py """ position = np.arange(length) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) inv_timescales = min_timescale * np.exp(np.arange(num_timescales).astype(float) * -log_timescale_increment) scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.pad(signal, [[0, 0], [0, channels % 2]], 'constant', constant_values=[0.0, 0.0]) signal = signal.reshape([1, length, channels]) return torch.from_numpy(signal).type(torch.FloatTensor) class ACT_basic(nn.Module): def __init__(self,hidden_size): super(ACT_basic, self).__init__() self.sigma = nn.Sigmoid() self.p = nn.Linear(hidden_size,1) self.p.bias.data.fill_(1) self.threshold = 1 - 0.1 self.eps = 0.1 def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): # init_hdd ## [B, S] noisy_halting = False if 'noisy_halting' in kwargs: noisy_halting = kwargs['noisy_halting'] kwargs.pop('noisy_halting') halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S] remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S] n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S, HDD] previous_state = torch.zeros_like(inputs).cuda() step = 0 # for l in range(self.num_layers): rest = None while( ((halting_probability self.threshold).float() * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability = halting_probability + p * still_running # Compute remainders for the inputs which halted at this step remainders = remainders + new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability = halting_probability + new_halted * remainders # Increment n_updates for all inputs which are still running n_updates = n_updates + still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step update_weights = p * still_running + new_halted * remainders if(encoder_output): state, _ = fn((state,encoder_output)) else: # apply transformation on the state state = fn(state, *args, **kwargs) if isinstance(state, tuple): rest = state[1:] state = state[0] # update running part in the weighted state and keep the rest previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) ## previous_state is actually the new_state at end of hte loop ## to save a line I assigned to previous_state so in the next ## iteration is correct. Notice that indeed we return previous_state step+=1 if rest is None: return previous_state, (remainders,n_updates) else: return (previous_state, *rest), (remainders, n_updates) class ACT_constant_depth(): def __init__(self): super(ACT_constant_depth, self).__init__() def __call__(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): # init_hdd ## [B, S] remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S] n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S, HDD] previous_state = torch.zeros_like(inputs).cuda() step = 0 # for l in range(self.num_layers): rest = None while(step < max_hop): print('constsant depth TRUE') # Add timing signal # state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data) # state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data) if(encoder_output): state, _ = fn((state,encoder_output)) else: # apply transformation on the state state = fn(state, *args, **kwargs) if isinstance(state, tuple): rest = state[1:] state = state[0] # update running part in the weighted state and keep the rest # print(state.shape, previous_state.shape, update_weights.shape) # print(state.dtype, previous_state.dtype, update_weights.dtype) previous_state = state ## previous_state is actually the new_state at end of hte loop ## to save a line I assigned to previous_state so in the next ## iteration is correct. Notice that indeed we return previous_state step+=1 if rest is None: return previous_state, (remainders,n_updates) else: return (previous_state, *rest), (remainders, n_updates) class ACTForWholeARMT(nn.Module): def __init__(self,hidden_size): super(ACTForWholeARMT, self).__init__() self.sigma = nn.Sigmoid() self.p = nn.Linear(hidden_size,1) self.p.bias.data.fill_(1) self.threshold = 1 - 0.1 def forward(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): # init_hdd ## [B, S] halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S] remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S] n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() ## [B, S, HDD] previous_state = torch.zeros_like(inputs).cuda() step = 0 # for l in range(self.num_layers): rest = None while( ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any()): # Add timing signal # state = state + time_enc[:, :inputs.shape[1], :].type_as(inputs.data) # state = state + pos_enc[:, step, :].unsqueeze(1).repeat(1,inputs.shape[1],1).type_as(inputs.data) p = self.sigma(self.p(state)).squeeze(-1) # Mask for inputs which have not halted yet still_running = (halting_probability < 1.0).float() # Mask of inputs which halted at this step new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability = halting_probability + p * still_running # Compute remainders for the inputs which halted at this step remainders = remainders + new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability = halting_probability + new_halted * remainders # Increment n_updates for all inputs which are still running n_updates = n_updates + still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step update_weights = p * still_running + new_halted * remainders if(encoder_output): if ((halting_probability self.threshold).float() * still_running still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running halting_probability = halting_probability + p * still_running remainders = remainders + new_halted * (1 - halting_probability) halting_probability = halting_probability + new_halted * remainders n_updates = n_updates + still_running + new_halted update_weights = p * still_running + new_halted * remainders if encoder_output is not None: state, _ = fn((state, encoder_output)) else: state = fn(state, *args, **kwargs) if isinstance(state, tuple): rest = state[1:] state = state[0] previous_state = ( (state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1))) ) step += 1 if rest is None: return previous_state, (remainders, n_updates) else: return (previous_state, *rest), (remainders, n_updates) # ---- language_modeling.py ---- import math import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.cache_utils import Cache, DynamicCache from torch.nn.functional import relu as r import torch.nn.functional as F import wandb from munch import Munch import os # inlined act_utils: removed import ACT_basic, gen_timing_signal, ACTForWholeARMT, ACT_transformer, ACT_constant_depth, ACTForWholeARMT_constant_depth try: from baselines.rwkv.language_modeling import RWKVModel RWKV_imported = True except ImportError: print("*** Can't import RWKV model ***") RWKV_imported = False def dpfp(x, nu=1): x = torch.cat([r(x), r(-x)], dim=-1) x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1,nu+1)], dim=-1) x_repeat = torch.cat([x] * nu, dim=-1) return x_repeat * x_rolled class DPFP: def __init__(self, nu): self.nu = nu def __call__(self, x): nu = self.nu x = torch.cat([r(x), r(-x)], dim=-1) x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1,nu+1)], dim=-1) x_repeat = torch.cat([x] * nu, dim=-1) return x_repeat * x_rolled def attn_mask_to_4d(attn_mask, upper, query_len): if attn_mask is None: return None seg_len = attn_mask.size(-1) if upper: tri = torch.triu(torch.ones(query_len, seg_len)) else: tri = torch.tril(torch.ones(query_len, seg_len)) mask = torch.einsum('bj,ij->bij', attn_mask, tri.to(attn_mask.device)) mask = mask.unsqueeze(1) return mask def invert_attn_mask(attn_mask, dtype): min_dtype = torch.finfo(dtype).min new_mask = (1.0 - attn_mask) * min_dtype return new_mask class AssociativeLayerWrapper(torch.nn.Module): def __init__(self, layer, d_model, num_mem_tokens, d_mem, n_heads=1, correction=True, info=None, use_denom=True, gating=False) -> None: super().__init__() self.info = info self.seg_num = 0 self.d_model = d_model self.num_mem_tokens = num_mem_tokens self.d_mem = d_mem self.n_heads = n_heads self.gating = gating nu = 3 self.d_key = 2 * nu * d_mem assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0 self.phi = DPFP(nu) # self.d_key = d_mem # self.phi = torch.nn.Identity() self.use_denom = use_denom # Get the proper dtype from the layer layer_dtype = next(layer.parameters()).dtype self.W_mq = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) # torch.nn.init.zeros_(self.W_mq.weight) self.W_mk = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) self.W_mv = torch.nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype) if gating: self.W_mb = torch.nn.Linear(d_model, d_model, dtype=layer_dtype) else: self.W_mb = torch.nn.Linear(d_model, n_heads, dtype=layer_dtype) torch.nn.init.zeros_(self.W_mv.weight) s = 1/math.sqrt(d_model) # torch.nn.init.uniform_(self.W_mq.weight, -s, s) # torch.nn.init.uniform_(self.W_mk.weight, -s, s) # torch.nn.init.uniform_(self.W_mb.weight, -s, s) # self.ln = torch.nn.LayerNorm(d_model) self.layer = layer self.generate_mode = False self.first_seg = True self.correction = correction self.zero_mem() def _to_heads(self, x): bsz, seq_len, d_model = x.shape x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads) x = x.permute(0, 2, 1, 3) return x def _from_heads(self, x): bsz, n_heads, seq_len, d_head = x.shape x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head) return x def associate(self, hidden_states): bsz, seq_len, d_model = hidden_states.shape self.W_mem = self.W_mem.to(hidden_states.device) if self.use_denom: self.z = self.z.to(hidden_states.device) q = self._to_heads(self.W_mq(hidden_states)) mq = self.phi(q) # (bsz, n_heads, seq_len, 2 * d_head * nu) mq = F.normalize(mq, dim=-1, p=2.0) # crutch for dataparallel # mq += 0 * self.W_mb(hidden_states).sum() * self.W_mk(hidden_states).sum() * self.W_mv(hidden_states).sum() num = torch.einsum('ihjk,ihkt->ihjt', mq, self.W_mem) if self.use_denom: denom = torch.einsum("ihk,ihjk->ihj", self.z, mq)[..., None] + 1e-5 hidden_states = num / denom # (bsz, n_heads, seq_len, d_model // n_heads) else: hidden_states = num hidden_states = self._from_heads(hidden_states) return hidden_states def forward(self, hidden_states, *args, **kwargs): if not self.first_seg: hidden_states = self.associate( # self.ln( hidden_states # ) ) + hidden_states out = self.layer(hidden_states, *args, **kwargs) if not self.generate_mode: # The layer output contains hidden states, not logits # For transformer layers, the output is typically the hidden states if isinstance(out, tuple): mem_tokens = out[0][:, -self.num_mem_tokens:] else: mem_tokens = out[:, -self.num_mem_tokens:] self.update_mem(mem_tokens) return out def forward_no_update(self, hidden_states, *args, **kwargs): if not self.first_seg: hidden_states = self.associate( # self.ln( hidden_states # ) )+ hidden_states out = self.layer(hidden_states, *args, **kwargs) return out def forward_no_update(self, hidden_states, *args, **kwargs): if not self.first_seg: hidden_states = self.associate( # self.ln( hidden_states # ) ) + hidden_states out = self.layer(hidden_states, *args, **kwargs) return out def update_mem(self, mem_tokens): self.W_mem = self.W_mem.to(mem_tokens.device) if self.use_denom: self.z = self.z.to(mem_tokens.device) k = self._to_heads(self.W_mk(mem_tokens)) mk = self.phi(k) mk = F.normalize(mk, dim=-1, p=2.0) new_mv = self._to_heads(self.W_mv(mem_tokens)) # (bsz, n_heads, num_mem_tokens, d_model) if not self.first_seg: num = torch.einsum('ihjk,ihkt->ihjt', mk, self.W_mem) if self.use_denom: denom = torch.einsum("ihj,ihkj->ihk", self.z, mk)[..., None] + 1e-5 prev_mv = num / denom if self.correction: new_info_coef = (1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None]) new_info_coef = torch.clip(new_info_coef, 0, 1).detach() else: new_info_coef = 1 else: prev_mv = num else: prev_mv = torch.zeros_like(new_mv, device=new_mv.device) new_info_coef = 1 # wandb.log({f"gamma_{self.info['layer']}": new_info_coef.mean(dim=1).item() if isinstance(new_info_coef, torch.Tensor) else 1}, step=self.seg_num) mv = new_mv - prev_mv # new_norm = torch.linalg.norm(new_mv, dim=-1) # old_norm = torch.linalg.norm(prev_mv, dim=-1) # new_info_coef = torch.clip(1 - old_norm / (new_norm + 1e-5), -10, 10)[..., None].detach() # new_info_coef = 1 - denom mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens))) einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt" associations = torch.einsum(einop, mk, mv, mb) # (bsz, n_heads, d_mem, d_model) self.W_mem = self.W_mem + associations if self.use_denom: self.z = self.z + (new_info_coef*mk).sum(dim=-2) # self.z = self.z + (new_info_coef*mb[..., None]*mk).sum(dim=1) self.seg_num += 1 self.first_seg = False def freeze_mem(self): self.W_mb.weight.requires_grad = False self.W_mb.bias.requires_grad = False self.W_mq.weight.requires_grad = False self.W_mk.weight.requires_grad = False self.W_mv.weight.requires_grad = False def zero_mem(self): self.first_seg = True # Get the proper dtype from the layer parameters layer_dtype = next(self.layer.parameters()).dtype self.W_mem = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, self.d_model // self.n_heads, dtype=layer_dtype) self.W_mem.requires_grad_(False) if self.use_denom: self.z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, dtype=layer_dtype) self.z.requires_grad_(False) self.seg_num = 0 def detach_mem(self): self.W_mem = self.W_mem.detach() if self.use_denom: self.z = self.z.detach() class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): def __init__(self, layer, d_model, num_mem_tokens, d_mem, max_hop, n_heads=1, correction=True, info=None, use_denom=True, gating=False, constant_depth=False, ) -> None: super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) self.act = ACT_basic(d_model) if not constant_depth else ACT_constant_depth() self.depth = max_hop self.max_length = 1024 self.timing_signal = gen_timing_signal(self.max_length, d_model) ## for t self.position_signal = gen_timing_signal(self.depth, d_model) self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) def associate(self, hidden_states): self.remainders = self.remainders.to(hidden_states.device) self.n_updates = self.n_updates.to(hidden_states.device) self.segments_passed = self.segments_passed.to(hidden_states.device) out, (remainders, n_updates) = self.act( state=hidden_states, inputs=hidden_states, fn=super().associate, time_enc=self.timing_signal, pos_enc=self.position_signal, max_hop=self.depth ) self.remainders = self.remainders + remainders.mean() # 1 - \sum(h_i); L' = L + tau * mean(remainders) self.n_updates = self.n_updates + n_updates.mean() self.segments_passed = self.segments_passed + 1 return out def zero_mem(self): self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) return super().zero_mem() def detach_mem(self): self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) return super().detach_mem() class AdaptiveAssociativeLayerWrapper2(AssociativeLayerWrapper): def __init__(self, layer, d_model, num_mem_tokens, d_mem, max_hop, n_heads=1, correction=True, info=None, use_denom=True, gating=False, act_format='linear', noisy_halting=False, constant_depth=False, ) -> None: super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) if act_format=='transformer': self.act = ACT_transformer(d_model) elif constant_depth: self.act = ACT_constant_depth() elif act_format == 'linear': self.act = ACT_basic(d_model) else: raise NotImplemetedError self.depth = max_hop self.max_length = 1024 self.noisy_halting = noisy_halting self.timing_signal = gen_timing_signal(self.max_length, d_model) ## for t self.position_signal = gen_timing_signal(self.depth, d_model) self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) def forward(self, hidden_states, *args, **kwargs): self.remainders = self.remainders.to(hidden_states.device) self.n_updates = self.n_updates.to(hidden_states.device) self.segments_passed = self.segments_passed.to(hidden_states.device) if self.noisy_halting: kwargs['noisy_halting'] = self.noisy_halting fwd = super().forward_no_update out, (remainders, n_updates) = self.act( *args, state=hidden_states, inputs=hidden_states, fn=fwd, time_enc=self.timing_signal, pos_enc=self.position_signal, max_hop=self.depth, **kwargs ) if not self.generate_mode: mem_tokens = out[0][:, -self.num_mem_tokens:] # mem_tokens = out[0] self.update_mem(mem_tokens) self.first_seg = False self.remainders = self.remainders + remainders.mean() # 1 - \sum(h_i); L' = L + tau * mean(remainders) self.n_updates = self.n_updates + n_updates.mean() self.segments_passed = self.segments_passed + 1 return out def zero_mem(self): self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) return super().zero_mem() def detach_mem(self): self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) return super().detach_mem() class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): def __init__(self, layer, d_model, num_mem_tokens, d_mem, max_hop, n_heads=1, correction=True, info=None, use_denom=True, gating=False, ) -> None: super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) self.act = ACT_basic(d_model) self.depth = max_hop self.max_length = 1024 self.timing_signal = gen_timing_signal(self.max_length, d_model) ## for t self.position_signal = gen_timing_signal(self.depth, d_model) self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) def associate(self, hidden_states): self.remainders = self.remainders.to(hidden_states.device) self.n_updates = self.n_updates.to(hidden_states.device) self.segments_passed = self.segments_passed.to(hidden_states.device) out, (remainders, n_updates) = self.act( state=hidden_states, inputs=hidden_states, fn=super().associate, time_enc=self.timing_signal, pos_enc=self.position_signal, max_hop=self.depth ) self.remainders = self.remainders + remainders # 1 - \sum(h_i); L' = L + tau * mean(remainders) self.n_updates = self.n_updates + n_updates self.segments_passed = self.segments_passed + 1 return out def zero_mem(self): self.remainders = torch.zeros(1,) self.n_updates = torch.zeros(1,) self.segments_passed = torch.zeros(1,) return super().zero_mem() class AssociativeMemoryCell(torch.nn.Module): def __init__(self, base_model, num_mem_tokens, d_mem, layers_attr: str = 'model.layers', wrap_pos=False, correction=True, n_heads=1, use_denom=True, gating=False, freeze_mem=False, act_on=False, max_hop=4, act_type='layer', act_format='linear', noisy_halting=False, constant_depth=False, attend_to_previous_input=False, use_sink=False, **rmt_config ): super().__init__() self.model = base_model self.attend_to_previous_input = attend_to_previous_input self.previous_input = None self.use_sink = use_sink self.RWKV_ARMT = isinstance(self.model, RWKVModel) if RWKV_imported else False self.num_mem_tokens = num_mem_tokens self.d_mem = d_mem self.d_model = base_model.get_input_embeddings().embedding_dim self.W_mem = [] self.constant_depth = constant_depth self.layers_attrs = layers_attr.split('.') def _get_layers_from_model(model_root): layers_obj = model_root for attr in self.layers_attrs: layers_obj = getattr(layers_obj, attr) return layers_obj layers = _get_layers_from_model(self.model) for i in range(len(layers)): kw = dict( layer=layers[i], d_model=self.d_model, num_mem_tokens=self.num_mem_tokens, d_mem=self.d_mem, correction=correction, info={'layer': i}, n_heads=n_heads, use_denom=use_denom, gating=gating, ) if act_on and act_type != 'model': kw['act_format'] = act_format if act_on and act_type == 'model' and act_format != 'linear': raise NotImplementedError if act_on and (act_type != 'model'): kw['max_hop'] = max_hop kw['constant_depth'] = self.constant_depth kw['act_format'] = act_format if act_on and noisy_halting: kw['noisy_halting'] = noisy_halting if not act_on: layers[i] = AssociativeLayerWrapper(**kw) elif act_type == 'associative': layers[i] = AdaptiveAssociativeLayerWrapper(**kw) elif act_type == 'layer': layers[i] = AdaptiveAssociativeLayerWrapper2(**kw) elif act_type == 'model': layers[i] = AssociativeLayerWrapper(**kw) else: raise f'Unknown ACT type: {act_type}' if act_type == 'model': self.act = ACTForWholeARMT(self.d_model) if not self.constant_depth else ACTForWholeARMT_constant_depth() self.depth = max_hop self.max_length = 1024 self.timing_signal = gen_timing_signal(self.max_length, self.d_model) self.position_signal = gen_timing_signal(self.depth, self.d_model) self.act_type = act_type self.create_memory(num_mem_tokens) self.wrap_pos = wrap_pos self.act_on = act_on if wrap_pos: self.wrap_positional_embeddings(num_mem_tokens) if freeze_mem: for layer in _get_layers_from_model(self.model): layer.freeze_mem() # Expose a resolver without registering layers as a submodule to avoid shared tensor aliases self.get_layers = lambda: _get_layers_from_model(self.model) def generate_mode(self, is_on): for layer in self.get_layers(): layer.generate_mode = is_on def create_memory(self, num_mem_tokens): self.num_mem_tokens = num_mem_tokens embeddings = self.model.get_input_embeddings() memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size) memory_weights = torch.randn((num_mem_tokens, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype) * embeddings.weight.data.std() self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True)) if self.use_sink: self.sink = torch.nn.Parameter(torch.randn((1, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype), requires_grad=True) def wrap_positional_embeddings(self, num_mem_tokens): num_pos_embs, emb_dim = self.model.transformer.wpe.weight.shape prev_embs = self.model.transformer.wpe.weight.detach() self.model.transformer.wpe = torch.nn.Embedding(num_mem_tokens + num_pos_embs, emb_dim) new_num_pos = num_pos_embs + num_mem_tokens with torch.no_grad(): self.model.transformer.wpe.weight[:len(self.model.transformer.wpe.weight)-num_mem_tokens] = prev_embs for layer in self.model.transformer.h: layer.layer.attn.bias = torch.tril(torch.ones((new_num_pos, new_num_pos), dtype=torch.uint8)).view( 1, 1, new_num_pos, new_num_pos ) def set_memory(self, input_shape): memory = self.memory.repeat(input_shape[0], 1, 1) if self.use_sink: sink = self.sink.repeat(input_shape[0], 1, 1) else: sink = None return memory, sink def zero_mem(self): for layer in self.get_layers(): layer.zero_mem() self.previous_input = None def detach_mem(self): for layer in self.get_layers(): layer.detach_mem() pass def forward(self, input_ids, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): if self.act_type != 'model': out = self.forward_with_update(input_ids, labels, labels_mask, zero_mem, attention_mask=attention_mask, **kwargs) else: seg_kwargs = self.process_input(input_ids=input_ids, labels=labels, labels_mask=labels_mask, zero_mem=zero_mem, attention_mask=attention_mask, **kwargs ) out = self.gptneox_forward_act(**seg_kwargs) out = self.process_output(out, labels=labels, labels_mask=labels_mask) return out def forward_with_update(self, input_ids, labels=None, labels_mask=None, zero_mem=False, **kwargs): current_input_ids = input_ids.clone() if self.attend_to_previous_input and self.previous_input is not None: input_ids = torch.cat([self.previous_input, input_ids], dim=1) if zero_mem: self.zero_mem() seg_kwargs = self.process_input(input_ids, **kwargs) layers = self.get_layers() if self.RWKV_ARMT and not layers[0].generate_mode: input1 = dict() input2 = dict() for item in seg_kwargs: if isinstance(seg_kwargs[item], torch.Tensor): # if False: input1[item] = seg_kwargs[item][:, :-self.num_mem_tokens] input2[item] = seg_kwargs[item][:, -self.num_mem_tokens:] else: input1[item] = seg_kwargs[item] input2[item] = seg_kwargs[item] self.generate_mode(True) out = self.model(**input1) self.generate_mode(False) state_tmp = tuple([torch.clone(state) for state in out['state']]) out = Munch({k: torch.clone(t) if isinstance(t, torch.Tensor) else t for k, t in out.items()}) input2['state'] = out['state'] _ = self.model(**input2) out['state'] = state_tmp # out['state'] = out2['state'] # out = self.model(**seg_kwargs) # out['logits'] = out['logits'][:, :-self.num_mem_tokens] else: out = self.model(**seg_kwargs) if self.attend_to_previous_input and self.previous_input is not None: out['logits'] = out['logits'][:, self.previous_input.size(1):] out = self.process_output(out, labels, labels_mask, **kwargs) self.previous_input = current_input_ids return out def process_input(self, input_ids, **kwargs): memory_state, sink = self.set_memory(input_ids.shape) seg_kwargs = dict(**kwargs) inputs_embeds = kwargs.get('inputs_embeds') if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) if self.use_sink: inputs_embeds = torch.cat([sink, inputs_embeds, memory_state], dim=1) else: inputs_embeds = torch.cat([inputs_embeds, memory_state], dim=1) seg_kwargs['input_ids'] = None seg_kwargs['inputs_embeds'] = inputs_embeds if kwargs.get('attention_mask') is not None: seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], dtype=inputs_embeds.dtype) if kwargs.get('prev_attn_mask') is not None: prev_seg_attn_mask = self.pad_prev_seg_attn_mask(kwargs['prev_attn_mask'], dtype=inputs_embeds.dtype) seg_kwargs['attention_mask'] = torch.cat([prev_seg_attn_mask, seg_kwargs['attention_mask']], dim=-1) if 'prev_attn_mask' in seg_kwargs: seg_kwargs.pop('prev_attn_mask') seg_kwargs['output_hidden_states'] = True if self.wrap_pos: num_pos_embs = self.model.transformer.wpe.weight.shape[0] ordinary_pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device) write_pos = torch.arange(num_pos_embs - self.num_mem_tokens, num_pos_embs, dtype=torch.long, device=input_ids.device) seg_kwargs['position_ids'] = torch.cat([ ordinary_pos, write_pos ]).long().unsqueeze(0) return seg_kwargs def pad_attention_mask(self, attention_mask, dtype=float): if self.num_mem_tokens in {0, None}: return attention_mask else: shape = list(attention_mask.shape) if len(shape) == 4: shape[-1] += self.num_mem_tokens + self.use_sink shape[-2] += self.num_mem_tokens + self.use_sink mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask if self.use_sink: mask[..., 0, 1:] = 0 mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0 # mask = torch.tril(mask) if not os.environ.get("NOT_INVERT_ATTN_MASK"): mask = invert_attn_mask(mask, dtype) else: shape[-1] += self.num_mem_tokens + self.use_sink mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask return mask.to(dtype) def pad_prev_seg_attn_mask(self, prev_seg_attn_mask, dtype=float): if self.num_mem_tokens in {0, None}: return prev_seg_attn_mask else: shape = list(prev_seg_attn_mask.shape) if len(shape) == 4: shape[-2] += self.num_mem_tokens + self.use_sink mask = torch.ones(*shape, dtype=dtype).to(prev_seg_attn_mask.device) mask[..., int(self.use_sink):-self.num_mem_tokens, :] = prev_seg_attn_mask if self.use_sink: mask[..., 0, :] = 0 if not os.environ.get("NOT_INVERT_ATTN_MASK"): mask = invert_attn_mask(mask, dtype) else: mask = prev_seg_attn_mask return mask.to(dtype) def process_output(self, model_outputs, labels, labels_mask, **kwargs): if (self.num_mem_tokens not in {0, None}) and not self.RWKV_ARMT: out = CausalLMOutputWithCrossAttentions() out['logits'] = model_outputs.logits[:, int(self.use_sink):-self.num_mem_tokens] if kwargs.get('output_hidden_states'): out['hidden_states'] = [lh[:, int(self.use_sink):-self.num_mem_tokens] for lh in model_outputs.hidden_states] if kwargs.get('output_attentions'): out['attentions'] = model_outputs['attentions'] else: out = model_outputs if labels is not None: logits = out['logits'][..., :-1, :].contiguous() flat_logits = logits.view(-1, logits.size(-1)) labels = labels[..., 1:].contiguous() flat_labels = labels.view(-1) if labels_mask is not None: flat_mask = labels_mask[..., :-1].contiguous().view(-1) flat_logits = flat_logits[flat_mask] flat_labels = flat_labels[flat_mask] # Average by number of valid tokens ce_loss_fn = CrossEntropyLoss(reduction='sum') ce_loss = ce_loss_fn(flat_logits, flat_labels) if labels_mask is not None: denom = labels_mask[..., :-1].contiguous().view(-1).sum() else: denom = (flat_labels != -100).sum() denom = torch.clamp(denom, min=1) out['ce_loss'] = ce_loss / denom if kwargs.get('use_cache', False): out['past_key_values'] = model_outputs.past_key_values if self.act_on and self.act_type == 'model': out['remainders'] = model_outputs['remainders'] out['n_updates'] = model_outputs['n_updates'] return out def generate(self, input_ids, attention_mask, zero_mem=False, **generate_kwargs): if zero_mem: self.zero_mem() self.generate_mode(True) seg_kwargs = self.process_input(input_ids, attention_mask=attention_mask) out = self.model.generate( inputs_embeds=seg_kwargs['inputs_embeds'][:, :-self.num_mem_tokens], attention_mask=seg_kwargs['attention_mask'][:, :-self.num_mem_tokens], **generate_kwargs ) self.generate_mode(False) return out def update_past_key_values_sw(self, past_key_values, window_size): past_key_values = past_key_values.to_legacy_cache() past_key_values = [ [ k_or_v[..., -(window_size+self.use_sink):, :] for k_or_v in seg_kv ] for seg_kv in past_key_values ] past_key_values = DynamicCache.from_legacy_cache(past_key_values) return past_key_values def greedy_generate_sw(self, input_ids, attention_mask, prev_attn_mask, **generate_kwargs): self.generate_mode(True) window_size = generate_kwargs['window_size'] max_new_tokens = generate_kwargs['max_new_tokens'] past_key_values = self.update_past_key_values_sw(generate_kwargs['past_key_values'], window_size) eos_token_id = generate_kwargs['eos_token_id'] prev_attn_mask_2d = prev_attn_mask.clone() attention_mask_2d = attention_mask.clone() attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=attention_mask.size(-1)) prev_attn_mask = attn_mask_to_4d(prev_attn_mask, upper=True, query_len=attention_mask.size(-1)) seg_kwargs = self.process_input(input_ids=input_ids, attention_mask=attention_mask, prev_attn_mask=prev_attn_mask, past_key_values=past_key_values) seg_kwargs['inputs_embeds'] = seg_kwargs['inputs_embeds'][..., :-self.num_mem_tokens, :] seg_kwargs['attention_mask'] = seg_kwargs['attention_mask'][..., :-self.num_mem_tokens, :-self.num_mem_tokens] outputs = self.model(**seg_kwargs, use_cache=True) next_token_logits = outputs.logits[:, -1, :] past_key_values = outputs.past_key_values past_key_values = self.update_past_key_values_sw(past_key_values, window_size) generated_ids = None sw_attention_mask = torch.cat([prev_attn_mask_2d, torch.ones(attention_mask_2d.size(0), 1).to(prev_attn_mask_2d.device), attention_mask_2d], dim=-1) for i in range(max_new_tokens): # print(next_token_logits[..., :5]) next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) if generated_ids is not None: generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) else: generated_ids = next_token_id next_input = next_token_id sw_attention_mask = torch.cat([sw_attention_mask, torch.ones_like(next_token_id).to(sw_attention_mask.device)], dim=-1)[..., -window_size-1-self.use_sink:] with torch.no_grad(): outputs = self.model( input_ids=next_input, attention_mask=sw_attention_mask, past_key_values=past_key_values, use_cache=True, cache_position=torch.full((1,), window_size + i + input_ids.size(-1) + self.use_sink).to(input_ids.device) ) past_key_values = self.update_past_key_values_sw(outputs.past_key_values, window_size) next_token_logits = outputs.logits[:, -1, :] if (next_token_id[:, 0] == eos_token_id).all(): break self.generate_mode(False) return generated_ids def apply_layers(self, hidden_states, causal_mask, position_ids, cache_position, position_embeddings, update_mem=True): if not update_mem: tmp = [] for i in range(len(self.layers)): tmp.append(self.layers[i].forward) self.layers[i].forward = self.layers[i].forward_no_update for layer in self.get_layers(): hidden_states = layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, cache_position=cache_position, position_embeddings=position_embeddings, )[0] if not update_mem: for i, layer in enumerate(self.get_layers()): layer.forward = tmp[i] return hidden_states def gptneox_forward_act(self, inputs_embeds, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): drop = self.model.gpt_neox.emb_dropout hidden_states = drop(inputs_embeds) seq_length = hidden_states.shape[1] cache_position = torch.arange(0, seq_length, device=hidden_states.device) position_ids = cache_position.unsqueeze(0) position_embeddings = self.model.gpt_neox.rotary_emb(hidden_states, position_ids) causal_mask = self.model.gpt_neox._update_causal_mask( attention_mask, hidden_states, cache_position, None, False ) out, (remainders, n_updates) = self.act( state=hidden_states, inputs=hidden_states, fn_no_update=lambda *args, **kwargs: self.apply_layers(*args, **kwargs, update_mem=False), fn_update=self.apply_layers, time_enc=self.timing_signal, pos_enc=self.position_signal, max_hop=self.depth, causal_mask=causal_mask, position_ids=position_ids, cache_position=cache_position, position_embeddings=position_embeddings ) hidden_states = self.model.gpt_neox.final_layer_norm(out) lm_logits = self.model.embed_out(hidden_states) return Munch(logits=lm_logits, n_updates=n_updates, remainders=remainders) class AssociativeRecurrentWrapper(torch.nn.Module): def __init__(self, memory_cell, **rmt_kwargs): super().__init__() self.memory_cell = memory_cell self.rmt_config = rmt_kwargs self.last_state = None def gradient_checkpointing_enable(self, *args, **kwargs): self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) def process_segment(self, segment_kwargs, next_seg_len=None): sliding_window = self.rmt_config['sliding_window'] if 'sliding_window' in self.rmt_config else False attend_to_previous_input = self.rmt_config['attend_to_previous_input'] if 'attend_to_previous_input' in self.rmt_config else False attn_mask = segment_kwargs['attention_mask'] seg_len = segment_kwargs['input_ids'].size(-1) segment_kwargs['use_cache'] = sliding_window if segment_kwargs.get('past_key_values') is None: segment_kwargs['past_key_values'] = None if segment_kwargs.get('prev_attn_mask') is None: segment_kwargs['prev_attn_mask'] = None segment_kwargs['zero_mem'] = False if sliding_window or attend_to_previous_input: segment_kwargs['attention_mask'] = attn_mask_to_4d(attn_mask, upper=False, query_len=seg_len) if 'state' in segment_kwargs and segment_kwargs['state'] is None: segment_kwargs.pop('state') num_mem_tokens = self.memory_cell.num_mem_tokens cell_out = self.memory_cell(**segment_kwargs) state = cell_out.get('state') if (sliding_window or attend_to_previous_input) and next_seg_len is not None: prev_attn_mask = attn_mask_to_4d(attn_mask, upper=True, query_len=next_seg_len) else: prev_attn_mask = None if sliding_window: past_key_values = [ [ k_or_v[..., -(num_mem_tokens+seg_len):k_or_v.size(-2)-num_mem_tokens, :].detach() for k_or_v in seg_kv ] for seg_kv in cell_out['past_key_values'] ] if not isinstance(cell_out['past_key_values'], tuple) and not isinstance(cell_out['past_key_values'], list): past_key_values = cell_out['past_key_values'].from_legacy_cache(past_key_values) else: past_key_values = DynamicCache.from_legacy_cache(past_key_values) else: past_key_values = None next_segment_kwargs = dict() next_segment_kwargs['use_cache'] = sliding_window next_segment_kwargs['past_key_values'] = past_key_values next_segment_kwargs['prev_attn_mask'] = prev_attn_mask next_segment_kwargs['zero_mem'] = False if state is not None: next_segment_kwargs['state'] = state return cell_out, next_segment_kwargs def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, output_attentions=None, output_hidden_states=None, input_segmented=False, output_only_last_segment=False, use_previous_batch_state=torch.zeros(1), num_items_in_batch=None, # Added to handle HF Trainer compatibility **kwargs # Added to handle any other unexpected kwargs ): if input_segmented: n_segs = input_ids.shape[1] if not (input_ids is None) else inputs_embeds.shape[1] segmented = [dict( input_ids=input_ids[:, i] if not (input_ids is None) else None, inputs_embeds=inputs_embeds[:, i] if not (inputs_embeds is None) else None, attention_mask=attention_mask[:, i], labels=labels[:, i] if not (labels is None) else None, labels_mask=labels_mask[:, i] if not (labels_mask is None) else None, ) for i in range(n_segs)] labels = torch.cat([labels[:, i] for i in range(n_segs)], dim=1) if labels_mask is not None: labels_mask = torch.cat([labels_mask[:, i] for i in range(n_segs)], dim=1) else: segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, labels_mask=labels_mask) cell_outputs = [] if not use_previous_batch_state.all() or self.last_state is None: self.memory_cell.zero_mem() state = None else: self.memory_cell.detach_mem() state = self.last_state next_seg_kwargs = dict(state=state) for seg_num, segment in enumerate(segmented): if seg_num != len(segmented) - 1: next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) else: next_seg_len = None # Pass num_items_in_batch to segment processing segment_with_kwargs = dict(**segment, **next_seg_kwargs) if kwargs.get('num_items_in_batch') is not None: segment_with_kwargs['num_items_in_batch'] = kwargs['num_items_in_batch'] cell_out, next_seg_kwargs = self.process_segment(segment_with_kwargs, next_seg_len=next_seg_len) if (not output_only_last_segment) or (seg_num == len(segmented) - 1): cell_outputs.append(cell_out) out = self.process_outputs(cell_outputs, labels=labels, labels_mask=labels_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, num_items_in_batch=kwargs.get('num_items_in_batch')) if not self.training: self.memory_cell.zero_mem() self.last_state = None return out def segment(self, **kwargs): segments = [] for k, tensor in kwargs.items(): if tensor is not None: k_segments = self.split_tensor(tensor) for s, k_seg in enumerate(k_segments): if s < len(segments): segments[s][k] = k_seg else: segments.append({k: k_seg}) return segments def split_tensor(self, tensor): align = self.rmt_config.get('segment_alignment') segment_size = self.rmt_config.get('segment_size') if align in {'left', None}: split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]] segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] elif align in {'right', None}: split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1] segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] elif align == 'center': n_seg = math.ceil(tensor.shape[1] / segment_size) segments = torch.chunk(tensor, n_seg, dim=1) else: raise NotImplementedError return segments def process_outputs(self, cell_outputs, **kwargs): out = CausalLMOutputWithCrossAttentions() full_logits = torch.cat([o.logits for o in cell_outputs], dim=1) labels = kwargs.get('labels') if labels is not None: labels = labels[:, -full_logits.size(1):] shift_labels = labels[..., 1:].contiguous() shift_logits = full_logits[..., :-1, :].contiguous() flat_labels = shift_labels.view(-1) flat_logits = shift_logits.view(-1, shift_logits.size(-1)) labels_mask = kwargs.get('labels_mask') if labels_mask is not None: labels_mask = labels_mask[:, -full_logits.size(1):] shift_mask = labels_mask[..., :-1].contiguous() flat_labels = flat_labels[shift_mask.view(-1)] flat_logits = flat_logits[shift_mask.view(-1)] # Average by number of valid tokens loss_fct = CrossEntropyLoss(reduction='sum') loss = loss_fct(flat_logits, flat_labels) if labels_mask is not None: # Use the same mask used to filter flat logits/labels denom = labels_mask[..., :-1].contiguous().view(-1).sum() else: denom = (flat_labels != -100).sum() denom = torch.clamp(denom, min=1) out['loss'] = loss / denom else: out['loss'] = 0 if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: out['ce_loss'] = out['loss'] out['logits'] = full_logits segment_keys = ['loss', 'logits'] if kwargs.get('output_attentions'): segment_keys.append('attentions') if kwargs.get('output_hidden_states'): full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])]) segment_keys.append('hidden_states') out['hidden_states'] = full_hidden_states if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: for seg_num, o in enumerate(cell_outputs): for key, value in o.items(): if any([sk in key for sk in segment_keys]): out[f'{key}_{seg_num}'] = value remainders = [] n_updates = [] act_on = self.rmt_config['act_on'] if 'act_on' in self.rmt_config else False if act_on: if self.memory_cell.act_type != 'model': for layer in self.memory_cell.get_layers(): remainders.append(layer.remainders / layer.segments_passed) n_updates.append(layer.n_updates / layer.segments_passed) remainders = torch.mean(torch.stack(remainders, dim=0)) n_updates = torch.mean(torch.stack(n_updates, dim=0)) else: remainders = torch.mean(torch.stack([o['remainders'] for o in cell_outputs], dim=0)) n_updates = torch.mean(torch.stack([o['n_updates'] for o in cell_outputs], dim=0)) out['n_updates'] = n_updates.detach().cpu() out['remainders'] = remainders.detach().cpu() time_penalty = self.rmt_config['time_penalty'] out['loss'] = out['loss'] + time_penalty * remainders return out def generate(self, input_ids, attention_mask, **generate_kwargs): self.memory_cell.zero_mem() segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask) next_seg_kwargs = dict() for seg_num, segment in enumerate(segmented[:-1]): next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) _, next_seg_kwargs = self.process_segment(dict(**segment, **next_seg_kwargs), next_seg_len=next_seg_len) final_segment = segmented[-1] assert next_seg_kwargs.get('past_key_values') is None or isinstance(next_seg_kwargs.get('past_key_values'), Cache), "Sliding Window generation is not implemented for legacy cache" if next_seg_kwargs.get('past_key_values') is not None: prev_attn_mask = segmented[-2]['attention_mask'] legacy_cache = next_seg_kwargs['past_key_values'].to_legacy_cache() seg_len = segmented[-2]['input_ids'].size(-1) cache = DynamicCache().from_legacy_cache(legacy_cache) generate_kwargs['past_key_values'] = cache generate_kwargs['window_size'] = seg_len final_segment['prev_attn_mask'] = prev_attn_mask out = self.memory_cell.greedy_generate_sw(**final_segment, **generate_kwargs) return out else: out = self.memory_cell.generate(**final_segment, **generate_kwargs) return out # ---- model.py ---- import math import torch from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.cache_utils import Cache, DynamicCache from torch.nn.functional import relu as r import torch.nn.functional as F import os # inlined language_modeling: removed import AssociativeMemoryCell, AssociativeRecurrentWrapper, attn_mask_to_4d, invert_attn_mask class ARMTConfig(PretrainedConfig): model_type = "armt" def __init__(self, base_model_name=None, base_model_config=None, num_mem_tokens=16, d_mem=512, segment_size=512, segment_alignment="left", sliding_window=False, attend_to_previous_input=False, use_sink=False, layers_attr="model.layers", wrap_pos=False, correction=True, n_heads=1, use_denom=True, gating=False, freeze_mem=False, act_on=False, max_hop=4, act_type="associative", act_format="linear", noisy_halting=False, constant_depth=False, time_penalty=0.0, **kwargs): super().__init__(**kwargs) # Validate mutual exclusivity if (base_model_name is not None) and (base_model_config is not None): raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided. Set the other to None.") self.base_model_name = base_model_name # Optional alternative to base_model_name: a config (dict/PretrainedConfig/name-or-path) self.base_model_config = base_model_config self.num_mem_tokens = num_mem_tokens self.d_mem = d_mem self.segment_size = segment_size self.segment_alignment = segment_alignment self.sliding_window = sliding_window self.attend_to_previous_input = attend_to_previous_input self.use_sink = use_sink self.layers_attr = layers_attr self.wrap_pos = wrap_pos self.correction = correction self.n_heads = n_heads self.use_denom = use_denom self.gating = gating self.freeze_mem = freeze_mem self.act_on = act_on self.max_hop = max_hop self.act_type = act_type self.act_format = act_format self.noisy_halting = noisy_halting self.constant_depth = constant_depth self.time_penalty = time_penalty def get(self, attr: str, default=None): if hasattr(self, attr): return getattr(self, attr) else: return default class ARMTForCausalLM(PreTrainedModel): config_class = ARMTConfig def __init__(self, config: ARMTConfig, **kwargs): super().__init__(config, **kwargs) from transformers import AutoConfig, AutoModelForCausalLM # Build base model either from name (pretrained weights) or from provided config base_model = None if getattr(config, 'base_model_name', None) is not None and getattr(config, 'base_model_config', None) is not None: raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in ARMTConfig.") bm_cfg = getattr(config, 'base_model_config', None) if bm_cfg is not None: # Prefer explicit config when provided if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, 'model_type', None) != ARMTConfig.model_type: resolved_cfg = bm_cfg elif isinstance(bm_cfg, dict): if 'model_type' not in bm_cfg: raise ValueError("`base_model_config` dict must include a 'model_type' key (e.g., 'gpt_neox', 'llama').") config_cls_or_instance = AutoConfig.for_model(bm_cfg['model_type']) # If an instance was returned, update it; if a class was returned, construct from dict if isinstance(config_cls_or_instance, PretrainedConfig): resolved_cfg = config_cls_or_instance for k, v in bm_cfg.items(): setattr(resolved_cfg, k, v) else: resolved_cfg = config_cls_or_instance.from_dict(bm_cfg) elif isinstance(bm_cfg, str): # Treat as a name or path to load a config resolved_cfg = AutoConfig.from_pretrained(bm_cfg) else: raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str (name/path)") base_model = AutoModelForCausalLM.from_config(resolved_cfg) elif getattr(config, 'base_model_name', None): base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name) else: raise ValueError("ARMTForCausalLM requires either `base_model_config` or `base_model_name` in ARMTConfig.") self.armt_config = config # Create the associative memory cell memory_cell = AssociativeMemoryCell( base_model=base_model, num_mem_tokens=config.num_mem_tokens, d_mem=config.d_mem, layers_attr=config.layers_attr, wrap_pos=config.wrap_pos, correction=config.correction, n_heads=config.n_heads, use_denom=config.use_denom, gating=config.gating, freeze_mem=config.freeze_mem, act_on=config.act_on, max_hop=config.max_hop, act_type=config.act_type, # Optional extras constant_depth=config.get('constant_depth', False), act_format=config.get('act_format', 'linear'), noisy_halting=config.get('noisy_halting', False), attend_to_previous_input=config.attend_to_previous_input, use_sink=config.use_sink ) # Create the associative recurrent wrapper self.armt = AssociativeRecurrentWrapper( memory_cell, segment_size=config.segment_size, segment_alignment=config.segment_alignment, sliding_window=config.sliding_window, attend_to_previous_input=config.attend_to_previous_input, act_on=config.act_on, time_penalty=config.time_penalty ) def forward( self, input_ids=None, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, output_attentions=None, output_hidden_states=None, input_segmented=False, output_only_last_segment=False, num_items_in_batch=None, ): return self.armt( input_ids=input_ids, labels=labels, labels_mask=labels_mask, inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, input_segmented=input_segmented, output_only_last_segment=output_only_last_segment, num_items_in_batch=num_items_in_batch, ) def generate(self, *args, **kwargs): return self.armt.generate(*args, **kwargs) def load_state_dict(self, state_dict, strict=True, assign=False): try: return super().load_state_dict(state_dict, strict, assign) except RuntimeError: print("Failed to load state, retrying with ARMT loader.") self.armt.load_state_dict(state_dict, strict=True, assign=assign) print("Success!") @classmethod def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs): # Delegate to the base class to benefit from full shard/format support return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs) def gradient_checkpointing_enable(self, *args, **kwargs): self.armt.gradient_checkpointing_enable(*args, **kwargs)