""" Inference pipeline for LexiMind. Unified interface for summarization, emotion detection, and topic classification with batched processing and device management. Author: Oliver Perrin Date: December 2025 """ from __future__ import annotations from dataclasses import dataclass, fields, replace from typing import Any, Dict, List, Sequence, cast import torch import torch.nn.functional as F from ..data.preprocessing import Batch, TextPreprocessor from ..data.tokenization import Tokenizer # --------------- Configuration --------------- @dataclass class InferenceConfig: """Pipeline settings.""" summary_max_length: int = 128 emotion_threshold: float = 0.5 device: str | None = None @dataclass class EmotionPrediction: labels: List[str] scores: List[float] @dataclass class TopicPrediction: label: str confidence: float # --------------- Pipeline --------------- class InferencePipeline: """Multi-task inference with batched processing.""" def __init__( self, model: torch.nn.Module, tokenizer: Tokenizer, *, preprocessor: TextPreprocessor | None = None, emotion_labels: Sequence[str] | None = None, topic_labels: Sequence[str] | None = None, config: InferenceConfig | None = None, device: torch.device | str | None = None, ) -> None: self.model = model self.tokenizer = tokenizer self.config = config or InferenceConfig() # Resolve device chosen = device or self.config.device if chosen is None: param = next(model.parameters(), None) chosen = param.device if param else "cpu" self.device = torch.device(chosen) self.model.to(self.device) self.model.eval() self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer) self.emotion_labels = list(emotion_labels) if emotion_labels else None self.topic_labels = list(topic_labels) if topic_labels else None # --------------- Summarization --------------- def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]: """Generate summaries for input texts.""" if not texts: return [] batch = self._to_device(self.preprocessor.batch_encode(texts)) src_ids = batch.input_ids src_mask = batch.attention_mask max_len = max_length or self.config.summary_max_length model = cast(Any, self.model) if not hasattr(model, "encoder") or not hasattr(model, "decoder"): raise RuntimeError("Model must have encoder and decoder for summarization") with torch.inference_mode(): # Encode enc_mask = ( src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None ) memory = model.encoder(src_ids, mask=enc_mask) # Decode with constraints to improve quality ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id] unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None) if isinstance(unk, int): ban_ids.append(unk) generated = model.decoder.greedy_decode( memory=memory, max_len=max_len, start_token_id=self.tokenizer.bos_token_id, end_token_id=self.tokenizer.eos_token_id, device=self.device, min_len=10, ban_token_ids=[i for i in ban_ids if i is not None], no_repeat_ngram_size=3, memory_mask=src_mask, ) return self.tokenizer.decode_batch(generated.tolist()) # --------------- Emotion --------------- def predict_emotions( self, texts: Sequence[str], *, threshold: float | None = None, ) -> List[EmotionPrediction]: """Predict emotions for input texts.""" if not texts: return [] if not self.emotion_labels: raise RuntimeError("emotion_labels required for emotion prediction") batch = self._to_device(self.preprocessor.batch_encode(texts)) inputs = self._model_inputs(batch) thresh = threshold or self.config.emotion_threshold with torch.inference_mode(): logits = self.model.forward("emotion", inputs) probs = torch.sigmoid(logits) results = [] for row in probs.cpu(): pairs = [ (label, score) for label, score in zip(self.emotion_labels, row.tolist(), strict=False) if score >= thresh ] results.append( EmotionPrediction( labels=[label for label, _ in pairs], scores=[score for _, score in pairs], ) ) return results # --------------- Topic --------------- def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]: """Predict topic for input texts.""" if not texts: return [] if not self.topic_labels: raise RuntimeError("topic_labels required for topic prediction") batch = self._to_device(self.preprocessor.batch_encode(texts)) inputs = self._model_inputs(batch) with torch.inference_mode(): logits = self.model.forward("topic", inputs) probs = F.softmax(logits, dim=-1) results = [] for row in probs.cpu(): idx = int(row.argmax().item()) results.append( TopicPrediction( label=self.topic_labels[idx], confidence=row[idx].item(), ) ) return results # --------------- Batch Prediction --------------- def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]: """Run all three tasks on input texts.""" if not self.emotion_labels or not self.topic_labels: raise RuntimeError("Both emotion_labels and topic_labels required") text_list = list(texts) return { "summaries": self.summarize(text_list), "emotion": self.predict_emotions(text_list), "topic": self.predict_topics(text_list), } # --------------- Helpers --------------- def _to_device(self, batch: Batch) -> Batch: """Move batch tensors to device with non_blocking for speed.""" updates = {} for f in fields(batch): val = getattr(batch, f.name) if torch.is_tensor(val): updates[f.name] = val.to(self.device, non_blocking=True) return replace(batch, **updates) if updates else batch @staticmethod def _model_inputs(batch: Batch) -> Dict[str, torch.Tensor]: """Extract model inputs from batch.""" inputs = {"input_ids": batch.input_ids} if batch.attention_mask is not None: inputs["attention_mask"] = batch.attention_mask return inputs