#!/usr/bin/env python3 """ DualHeadModel: Inference-ready dual-head classifier for onderwerp + beleving. Minimal PyTorch nn.Module that loads from HF-compatible checkpoint. """ import os import json import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class DualHeadModel(nn.Module): """ Dual-head multi-label classifier on top of a HuggingFace encoder. Two classification heads: onderwerp (topic) and beleving (experience). """ def __init__(self, encoder, num_onderwerp, num_beleving, dropout=0.1): super().__init__() self.encoder = encoder hidden_size = encoder.config.hidden_size self.onderwerp_head = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.Dropout(dropout), nn.ReLU(), nn.Linear(hidden_size, num_onderwerp) ) self.beleving_head = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.Dropout(dropout), nn.ReLU(), nn.Linear(hidden_size, num_beleving) ) def forward(self, input_ids, attention_mask): """Forward pass: encoder + dual heads""" outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0, :] onderwerp_logits = self.onderwerp_head(pooled) beleving_logits = self.beleving_head(pooled) return onderwerp_logits, beleving_logits @classmethod def from_pretrained(cls, model_dir, device='cpu'): """ Load model from HF-compatible checkpoint directory. Args: model_dir: Path to directory containing encoder, tokenizer, and dual_head_state.pt device: torch.device or string ('cpu', 'cuda', 'mps') Returns: model: DualHeadModel ready for inference tokenizer: Loaded tokenizer config: Dict with metadata (max_length, label_names, etc.) """ # Load encoder and tokenizer encoder = AutoModel.from_pretrained(model_dir).to(device).eval() tokenizer = AutoTokenizer.from_pretrained(model_dir) # Load head states and metadata state_path = os.path.join(model_dir, "dual_head_state.pt") state = torch.load(state_path, map_location="cpu") # Load label names labels_path = os.path.join(model_dir, "label_names.json") with open(labels_path) as f: labels = json.load(f) # Extract config num_onderwerp = int(state["num_onderwerp"]) num_beleving = int(state["num_beleving"]) dropout = float(state.get("dropout", 0.1)) max_length = int(state.get("max_length", 512)) # Build model (dropout automatically disabled in eval mode) model = cls(encoder, num_onderwerp, num_beleving, dropout) # Load head weights model.onderwerp_head.load_state_dict(state["onderwerp_head_state"], strict=True) model.beleving_head.load_state_dict(state["beleving_head_state"], strict=True) model = model.to(device).eval() # Return model, tokenizer, and config config = { "max_length": max_length, "labels": labels, "num_onderwerp": num_onderwerp, "num_beleving": num_beleving } return model, tokenizer, config @torch.inference_mode() def predict(self, input_ids, attention_mask): """Inference with sigmoid activation""" onderwerp_logits, beleving_logits = self.forward(input_ids, attention_mask) onderwerp_probs = torch.sigmoid(onderwerp_logits) beleving_probs = torch.sigmoid(beleving_logits) return onderwerp_probs, beleving_probs