|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from typing import Optional, Tuple, Union |
|
|
from ukraine.research.transformer.transformer import Transformer |
|
|
from ukraine.research.transformer.layers import SiLUFeedForward |
|
|
from ukraine.research.transformer.masking import generate_square_subsequent_mask |
|
|
from .configuration_lime import LIMEConfig |
|
|
|
|
|
|
|
|
def make_ff(config: LIMEConfig): |
|
|
return SiLUFeedForward( |
|
|
d_model=config.d_model, |
|
|
dff=config.dff, |
|
|
multiple_of=config.multiple_of |
|
|
) |
|
|
|
|
|
|
|
|
def make_norm(config: LIMEConfig): |
|
|
return nn.RMSNorm(config.d_model) |
|
|
|
|
|
|
|
|
class LIMEForCausalLM(PreTrainedModel, GenerationMixin): |
|
|
config_class = LIMEConfig |
|
|
base_model_prefix = "lime" |
|
|
_tied_weights_keys = ["transformer.output_fc.weight"] |
|
|
|
|
|
def __init__(self, config: LIMEConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.transformer = Transformer( |
|
|
num_encoder_layers=config.num_encoder_layers, |
|
|
num_decoder_layers=config.num_decoder_layers, |
|
|
d_model=config.d_model, |
|
|
num_heads=config.num_heads, |
|
|
input_vocab_size=config.vocab_size, |
|
|
target_vocab_size=config.vocab_size, |
|
|
dropout_rate=config.dropout_rate, |
|
|
ff_factory=lambda: make_ff(config), |
|
|
norm_factory=lambda: make_norm(config), |
|
|
pad_token_id=config.pad_token_id, |
|
|
use_encoder=config.use_encoder, |
|
|
use_flash=config.use_flash |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.transformer.decoder.embedding |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.transformer.decoder.embedding = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.transformer.output_fc |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.transformer.output_fc = new_embeddings |
|
|
|
|
|
def _tie_weights(self): |
|
|
if self.config.tie_word_embeddings: |
|
|
self._tie_or_clone_weights( |
|
|
self.transformer.output_fc, |
|
|
self.get_input_embeddings() |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
tgt_mask = generate_square_subsequent_mask(seq_len, device) |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
tgt_key_padding_mask = input_ids.eq(self.config.pad_token_id) |
|
|
|
|
|
else: |
|
|
tgt_key_padding_mask = None |
|
|
|
|
|
logits, _ = self.transformer( |
|
|
src=input_ids, |
|
|
tgt_mask=tgt_mask, |
|
|
tgt_key_padding_mask=tgt_key_padding_mask |
|
|
) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
loss = criterion( |
|
|
shift_logits.reshape(-1, self.config.vocab_size), |
|
|
shift_labels.reshape(-1) |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
hidden_states=None, |
|
|
attentions=None |
|
|
) |