LexiMind / src /inference /pipeline.py
OliverPerrin
Full training run, code cleanup, mypy/ruff fixes
590a604
raw
history blame
7.12 kB
"""
Inference pipeline for LexiMind.
Unified interface for summarization, emotion detection, and topic classification
with batched processing and device management.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from dataclasses import dataclass, fields, replace
from typing import Any, Dict, List, Sequence, cast
import torch
import torch.nn.functional as F
from ..data.preprocessing import Batch, TextPreprocessor
from ..data.tokenization import Tokenizer
# --------------- Configuration ---------------
@dataclass
class InferenceConfig:
"""Pipeline settings."""
summary_max_length: int = 128
emotion_threshold: float = 0.5
device: str | None = None
@dataclass
class EmotionPrediction:
labels: List[str]
scores: List[float]
@dataclass
class TopicPrediction:
label: str
confidence: float
# --------------- Pipeline ---------------
class InferencePipeline:
"""Multi-task inference with batched processing."""
def __init__(
self,
model: torch.nn.Module,
tokenizer: Tokenizer,
*,
preprocessor: TextPreprocessor | None = None,
emotion_labels: Sequence[str] | None = None,
topic_labels: Sequence[str] | None = None,
config: InferenceConfig | None = None,
device: torch.device | str | None = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.config = config or InferenceConfig()
# Resolve device
chosen = device or self.config.device
if chosen is None:
param = next(model.parameters(), None)
chosen = param.device if param else "cpu"
self.device = torch.device(chosen)
self.model.to(self.device)
self.model.eval()
self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
self.emotion_labels = list(emotion_labels) if emotion_labels else None
self.topic_labels = list(topic_labels) if topic_labels else None
# --------------- Summarization ---------------
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
"""Generate summaries for input texts."""
if not texts:
return []
batch = self._to_device(self.preprocessor.batch_encode(texts))
src_ids = batch.input_ids
src_mask = batch.attention_mask
max_len = max_length or self.config.summary_max_length
model = cast(Any, self.model)
if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
raise RuntimeError("Model must have encoder and decoder for summarization")
with torch.inference_mode():
# Encode
enc_mask = (
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
)
memory = model.encoder(src_ids, mask=enc_mask)
# Decode with constraints to improve quality
ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
if isinstance(unk, int):
ban_ids.append(unk)
generated = model.decoder.greedy_decode(
memory=memory,
max_len=max_len,
start_token_id=self.tokenizer.bos_token_id,
end_token_id=self.tokenizer.eos_token_id,
device=self.device,
min_len=10,
ban_token_ids=[i for i in ban_ids if i is not None],
no_repeat_ngram_size=3,
memory_mask=src_mask,
)
return self.tokenizer.decode_batch(generated.tolist())
# --------------- Emotion ---------------
def predict_emotions(
self,
texts: Sequence[str],
*,
threshold: float | None = None,
) -> List[EmotionPrediction]:
"""Predict emotions for input texts."""
if not texts:
return []
if not self.emotion_labels:
raise RuntimeError("emotion_labels required for emotion prediction")
batch = self._to_device(self.preprocessor.batch_encode(texts))
inputs = self._model_inputs(batch)
thresh = threshold or self.config.emotion_threshold
with torch.inference_mode():
logits = self.model.forward("emotion", inputs)
probs = torch.sigmoid(logits)
results = []
for row in probs.cpu():
pairs = [
(label, score)
for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
if score >= thresh
]
results.append(
EmotionPrediction(
labels=[label for label, _ in pairs],
scores=[score for _, score in pairs],
)
)
return results
# --------------- Topic ---------------
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
"""Predict topic for input texts."""
if not texts:
return []
if not self.topic_labels:
raise RuntimeError("topic_labels required for topic prediction")
batch = self._to_device(self.preprocessor.batch_encode(texts))
inputs = self._model_inputs(batch)
with torch.inference_mode():
logits = self.model.forward("topic", inputs)
probs = F.softmax(logits, dim=-1)
results = []
for row in probs.cpu():
idx = int(row.argmax().item())
results.append(
TopicPrediction(
label=self.topic_labels[idx],
confidence=row[idx].item(),
)
)
return results
# --------------- Batch Prediction ---------------
def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
"""Run all three tasks on input texts."""
if not self.emotion_labels or not self.topic_labels:
raise RuntimeError("Both emotion_labels and topic_labels required")
text_list = list(texts)
return {
"summaries": self.summarize(text_list),
"emotion": self.predict_emotions(text_list),
"topic": self.predict_topics(text_list),
}
# --------------- Helpers ---------------
def _to_device(self, batch: Batch) -> Batch:
"""Move batch tensors to device with non_blocking for speed."""
updates = {}
for f in fields(batch):
val = getattr(batch, f.name)
if torch.is_tensor(val):
updates[f.name] = val.to(self.device, non_blocking=True)
return replace(batch, **updates) if updates else batch
@staticmethod
def _model_inputs(batch: Batch) -> Dict[str, torch.Tensor]:
"""Extract model inputs from batch."""
inputs = {"input_ids": batch.input_ids}
if batch.attention_mask is not None:
inputs["attention_mask"] = batch.attention_mask
return inputs