SemCSE-Multi-Medical / modeling_semcsemulti.py
marcfelix12's picture
Upload SemCSEMulti
2fba39f verified
import torch
from torch import nn
from transformers import PreTrainedModel, AutoModel
from .configuration_semcsemulti import SemCSEMultiConfig
class SemCSEMulti(PreTrainedModel):
config_class = SemCSEMultiConfig
def __init__(self, config):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config.encoder_checkpoint)
self.hidden_size = self.encoder.config.hidden_size if not config.encoder_hidden_dim else config.encoder_hidden_dim
self.embedding_dim = config.embedding_dim
self.aspect_identifiers = config.aspect_identifiers
self.prompt_projections = nn.ModuleDict({
p: nn.Linear(self.hidden_size, self.embedding_dim, bias=False)
for p in self.aspect_identifiers
})
for module in self.prompt_projections.values():
nn.init.normal_(module.weight, mean=0.0, std=1e-2)
def forward(self, input_ids, attention_mask, **kwargs):
base_embedding = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
).hidden_states[-1][:, 0]
embeddings = {}
for p in self.aspect_identifiers:
embeddings[p] = self.prompt_projections[p](base_embedding)
return embeddings