Spaces:
Running
Running
File size: 7,122 Bytes
590a604 ee1a8a3 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 2286a5e 1fbc47b 590a604 1fbc47b 2286a5e 1fbc47b 2286a5e 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 374a07d 590a604 374a07d 1fbc47b 590a604 d18b34d 590a604 d18b34d 590a604 ea3248a 374a07d 8951fba f0493d8 8951fba 590a604 ea3248a f0493d8 8951fba d18b34d 590a604 d18b34d 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 60f8a12 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 d18b34d 590a604 d18b34d 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
"""
Inference pipeline for LexiMind.
Unified interface for summarization, emotion detection, and topic classification
with batched processing and device management.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from dataclasses import dataclass, fields, replace
from typing import Any, Dict, List, Sequence, cast
import torch
import torch.nn.functional as F
from ..data.preprocessing import Batch, TextPreprocessor
from ..data.tokenization import Tokenizer
# --------------- Configuration ---------------
@dataclass
class InferenceConfig:
"""Pipeline settings."""
summary_max_length: int = 128
emotion_threshold: float = 0.5
device: str | None = None
@dataclass
class EmotionPrediction:
labels: List[str]
scores: List[float]
@dataclass
class TopicPrediction:
label: str
confidence: float
# --------------- Pipeline ---------------
class InferencePipeline:
"""Multi-task inference with batched processing."""
def __init__(
self,
model: torch.nn.Module,
tokenizer: Tokenizer,
*,
preprocessor: TextPreprocessor | None = None,
emotion_labels: Sequence[str] | None = None,
topic_labels: Sequence[str] | None = None,
config: InferenceConfig | None = None,
device: torch.device | str | None = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.config = config or InferenceConfig()
# Resolve device
chosen = device or self.config.device
if chosen is None:
param = next(model.parameters(), None)
chosen = param.device if param else "cpu"
self.device = torch.device(chosen)
self.model.to(self.device)
self.model.eval()
self.preprocessor = preprocessor or TextPreprocessor(tokenizer=tokenizer)
self.emotion_labels = list(emotion_labels) if emotion_labels else None
self.topic_labels = list(topic_labels) if topic_labels else None
# --------------- Summarization ---------------
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
"""Generate summaries for input texts."""
if not texts:
return []
batch = self._to_device(self.preprocessor.batch_encode(texts))
src_ids = batch.input_ids
src_mask = batch.attention_mask
max_len = max_length or self.config.summary_max_length
model = cast(Any, self.model)
if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
raise RuntimeError("Model must have encoder and decoder for summarization")
with torch.inference_mode():
# Encode
enc_mask = (
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
)
memory = model.encoder(src_ids, mask=enc_mask)
# Decode with constraints to improve quality
ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
if isinstance(unk, int):
ban_ids.append(unk)
generated = model.decoder.greedy_decode(
memory=memory,
max_len=max_len,
start_token_id=self.tokenizer.bos_token_id,
end_token_id=self.tokenizer.eos_token_id,
device=self.device,
min_len=10,
ban_token_ids=[i for i in ban_ids if i is not None],
no_repeat_ngram_size=3,
memory_mask=src_mask,
)
return self.tokenizer.decode_batch(generated.tolist())
# --------------- Emotion ---------------
def predict_emotions(
self,
texts: Sequence[str],
*,
threshold: float | None = None,
) -> List[EmotionPrediction]:
"""Predict emotions for input texts."""
if not texts:
return []
if not self.emotion_labels:
raise RuntimeError("emotion_labels required for emotion prediction")
batch = self._to_device(self.preprocessor.batch_encode(texts))
inputs = self._model_inputs(batch)
thresh = threshold or self.config.emotion_threshold
with torch.inference_mode():
logits = self.model.forward("emotion", inputs)
probs = torch.sigmoid(logits)
results = []
for row in probs.cpu():
pairs = [
(label, score)
for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
if score >= thresh
]
results.append(
EmotionPrediction(
labels=[label for label, _ in pairs],
scores=[score for _, score in pairs],
)
)
return results
# --------------- Topic ---------------
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
"""Predict topic for input texts."""
if not texts:
return []
if not self.topic_labels:
raise RuntimeError("topic_labels required for topic prediction")
batch = self._to_device(self.preprocessor.batch_encode(texts))
inputs = self._model_inputs(batch)
with torch.inference_mode():
logits = self.model.forward("topic", inputs)
probs = F.softmax(logits, dim=-1)
results = []
for row in probs.cpu():
idx = int(row.argmax().item())
results.append(
TopicPrediction(
label=self.topic_labels[idx],
confidence=row[idx].item(),
)
)
return results
# --------------- Batch Prediction ---------------
def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
"""Run all three tasks on input texts."""
if not self.emotion_labels or not self.topic_labels:
raise RuntimeError("Both emotion_labels and topic_labels required")
text_list = list(texts)
return {
"summaries": self.summarize(text_list),
"emotion": self.predict_emotions(text_list),
"topic": self.predict_topics(text_list),
}
# --------------- Helpers ---------------
def _to_device(self, batch: Batch) -> Batch:
"""Move batch tensors to device with non_blocking for speed."""
updates = {}
for f in fields(batch):
val = getattr(batch, f.name)
if torch.is_tensor(val):
updates[f.name] = val.to(self.device, non_blocking=True)
return replace(batch, **updates) if updates else batch
@staticmethod
def _model_inputs(batch: Batch) -> Dict[str, torch.Tensor]:
"""Extract model inputs from batch."""
inputs = {"input_ids": batch.input_ids}
if batch.attention_mask is not None:
inputs["attention_mask"] = batch.attention_mask
return inputs
|