lime-1b-instruct / modeling_lime.py
anarlavrenov's picture
Handle token_type_ids in LIMEForCausalLM
830c0ea verified
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, Union
from ukraine.research.transformer.transformer import Transformer
from ukraine.research.transformer.layers import SiLUFeedForward
from ukraine.research.transformer.masking import generate_square_subsequent_mask
from .configuration_lime import LIMEConfig
def make_ff(config: LIMEConfig):
return SiLUFeedForward(
d_model=config.d_model,
dff=config.dff,
multiple_of=config.multiple_of
)
def make_norm(config: LIMEConfig):
return nn.RMSNorm(config.d_model)
class LIMEForCausalLM(PreTrainedModel, GenerationMixin):
config_class = LIMEConfig
base_model_prefix = "lime"
_tied_weights_keys = ["transformer.output_fc.weight"]
def __init__(self, config: LIMEConfig):
super().__init__(config)
self.config = config
self.transformer = Transformer(
num_encoder_layers=config.num_encoder_layers,
num_decoder_layers=config.num_decoder_layers,
d_model=config.d_model,
num_heads=config.num_heads,
input_vocab_size=config.vocab_size,
target_vocab_size=config.vocab_size,
dropout_rate=config.dropout_rate,
ff_factory=lambda: make_ff(config),
norm_factory=lambda: make_norm(config),
pad_token_id=config.pad_token_id,
use_encoder=config.use_encoder,
use_flash=config.use_flash
)
self.post_init()
# For transformers library
def get_input_embeddings(self):
return self.transformer.decoder.embedding
def set_input_embeddings(self, value):
self.transformer.decoder.embedding = value
def get_output_embeddings(self):
return self.transformer.output_fc
def set_output_embeddings(self, new_embeddings):
self.transformer.output_fc = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(
self.transformer.output_fc,
self.get_input_embeddings()
)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_len = input_ids.shape
device = input_ids.device
tgt_mask = generate_square_subsequent_mask(seq_len, device)
# If we are planning to train the model.
if labels is not None:
tgt_key_padding_mask = input_ids.eq(self.config.pad_token_id)
# For inference we do not need it.
else:
tgt_key_padding_mask = None
logits, _ = self.transformer(
src=input_ids,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask
)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
# This ignore index was used during SFT training.
criterion = nn.CrossEntropyLoss(ignore_index=-100)
loss = criterion(
shift_logits.reshape(-1, self.config.vocab_size),
shift_labels.reshape(-1)
)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None
)