import torch from torch import nn from transformers import PreTrainedModel, AutoModel from .configuration_semcsemulti import SemCSEMultiConfig class SemCSEMulti(PreTrainedModel): config_class = SemCSEMultiConfig def __init__(self, config): super().__init__(config) self.encoder = AutoModel.from_pretrained(config.encoder_checkpoint) self.hidden_size = self.encoder.config.hidden_size if not config.encoder_hidden_dim else config.encoder_hidden_dim self.embedding_dim = config.embedding_dim self.aspect_identifiers = config.aspect_identifiers self.prompt_projections = nn.ModuleDict({ p: nn.Linear(self.hidden_size, self.embedding_dim, bias=False) for p in self.aspect_identifiers }) for module in self.prompt_projections.values(): nn.init.normal_(module.weight, mean=0.0, std=1e-2) def forward(self, input_ids, attention_mask, **kwargs): base_embedding = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ).hidden_states[-1][:, 0] embeddings = {} for p in self.aspect_identifiers: embeddings[p] = self.prompt_projections[p](base_embedding) return embeddings