| import torch | |
| from torch import nn | |
| from transformers import PreTrainedModel, AutoModel | |
| from .configuration_semcsemulti import SemCSEMultiConfig | |
| class SemCSEMulti(PreTrainedModel): | |
| config_class = SemCSEMultiConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = AutoModel.from_pretrained(config.encoder_checkpoint) | |
| self.hidden_size = self.encoder.config.hidden_size if not config.encoder_hidden_dim else config.encoder_hidden_dim | |
| self.embedding_dim = config.embedding_dim | |
| self.aspect_identifiers = config.aspect_identifiers | |
| self.prompt_projections = nn.ModuleDict({ | |
| p: nn.Linear(self.hidden_size, self.embedding_dim, bias=False) | |
| for p in self.aspect_identifiers | |
| }) | |
| for module in self.prompt_projections.values(): | |
| nn.init.normal_(module.weight, mean=0.0, std=1e-2) | |
| def forward(self, input_ids, attention_mask, **kwargs): | |
| base_embedding = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ).hidden_states[-1][:, 0] | |
| embeddings = {} | |
| for p in self.aspect_identifiers: | |
| embeddings[p] = self.prompt_projections[p](base_embedding) | |
| return embeddings | |