Spaces:
Running
Running
Commit
·
8951fba
1
Parent(s):
3318356
Updated Gradio Demo to include model visualizations
Browse files- scripts/demo_gradio.py +168 -417
- src/inference/pipeline.py +9 -63
scripts/demo_gradio.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
"""
|
| 2 |
-
Gradio
|
| 3 |
-
|
| 4 |
"""
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
-
import re
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
from tempfile import NamedTemporaryFile
|
| 10 |
from typing import Iterable, Sequence
|
| 11 |
-
from textwrap import dedent
|
| 12 |
-
from collections import Counter
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
from gradio.themes import Soft
|
|
@@ -19,9 +18,10 @@ import seaborn as sns
|
|
| 19 |
import torch
|
| 20 |
from matplotlib.figure import Figure
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
from src.inference.factory import create_inference_pipeline
|
| 27 |
from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
|
|
@@ -30,275 +30,140 @@ from src.utils.logging import configure_logging, get_logger
|
|
| 30 |
configure_logging()
|
| 31 |
logger = get_logger(__name__)
|
| 32 |
|
| 33 |
-
_pipeline: InferencePipeline | None = None
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
"and",
|
| 44 |
-
"in",
|
| 45 |
-
"it",
|
| 46 |
-
"that",
|
| 47 |
-
"for",
|
| 48 |
-
"on",
|
| 49 |
-
"with",
|
| 50 |
-
"as",
|
| 51 |
-
"by",
|
| 52 |
-
"be",
|
| 53 |
-
"are",
|
| 54 |
-
"was",
|
| 55 |
-
"were",
|
| 56 |
-
"this",
|
| 57 |
-
"which",
|
| 58 |
-
"at",
|
| 59 |
-
"or",
|
| 60 |
-
"from",
|
| 61 |
-
"but",
|
| 62 |
-
"has",
|
| 63 |
-
"have",
|
| 64 |
-
"had",
|
| 65 |
-
"can",
|
| 66 |
-
"will",
|
| 67 |
-
"would",
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
EMOTION_THRESHOLDS = {
|
| 71 |
-
"anger": 0.6,
|
| 72 |
-
"fear": 0.85,
|
| 73 |
-
"joy": 0.6,
|
| 74 |
-
"love": 0.25,
|
| 75 |
-
"sadness": 0.3,
|
| 76 |
-
"surprise": 0.55,
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
EMOTION_KEYWORDS = {
|
| 80 |
-
"love": {
|
| 81 |
-
"love",
|
| 82 |
-
"loved",
|
| 83 |
-
"loving",
|
| 84 |
-
"beloved",
|
| 85 |
-
"romance",
|
| 86 |
-
"romantic",
|
| 87 |
-
"affection",
|
| 88 |
-
"passion",
|
| 89 |
-
"sweetheart",
|
| 90 |
-
"valentine",
|
| 91 |
-
"dear",
|
| 92 |
-
"cherish",
|
| 93 |
-
"ador",
|
| 94 |
-
"marriage",
|
| 95 |
-
"wedding",
|
| 96 |
-
}
|
| 97 |
-
}
|
| 98 |
|
| 99 |
def get_pipeline() -> InferencePipeline:
|
| 100 |
-
|
| 101 |
-
global _pipeline, _label_metadata
|
| 102 |
if _pipeline is None:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
_pipeline = pipeline
|
| 111 |
-
_label_metadata = label_metadata
|
| 112 |
-
logger.info("Pipeline loaded successfully")
|
| 113 |
-
except Exception as e:
|
| 114 |
-
logger.error(f"Failed to load pipeline: {e}")
|
| 115 |
-
raise RuntimeError("Could not initialize inference pipeline. Check logs for details.")
|
| 116 |
return _pipeline
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def count_tokens(text: str) -> str:
|
| 119 |
-
"""Count tokens in the input text."""
|
| 120 |
if not text:
|
| 121 |
return "Tokens: 0"
|
| 122 |
-
try:
|
| 123 |
pipeline = get_pipeline()
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
logger.error(f"Token counting error: {e}")
|
| 128 |
return "Token count unavailable"
|
| 129 |
-
|
| 130 |
-
def map_compression_to_length(compression: int, max_model_length: int = 512):
|
| 131 |
-
"""
|
| 132 |
-
Map Compression slider (20-80%) to max summary length.
|
| 133 |
-
Higher compression = shorter summary output.
|
| 134 |
-
"""
|
| 135 |
-
# Invert, 20% compression = 80% of max length
|
| 136 |
-
ratio = (100 - compression) / 100
|
| 137 |
-
return int(ratio * max_model_length)
|
| 138 |
|
| 139 |
def predict(text: str, compression: int):
|
| 140 |
-
"""Run the full pipeline and prepare Gradio outputs."""
|
| 141 |
hidden_download = gr.update(value=None, visible=False)
|
| 142 |
if not text or not text.strip():
|
| 143 |
return (
|
| 144 |
-
"Please enter
|
| 145 |
None,
|
| 146 |
-
"No topic prediction available",
|
| 147 |
None,
|
| 148 |
hidden_download,
|
| 149 |
)
|
|
|
|
| 150 |
try:
|
| 151 |
pipeline = get_pipeline()
|
| 152 |
max_len = map_compression_to_length(compression)
|
| 153 |
-
logger.info("Generating summary with max length
|
| 154 |
|
| 155 |
-
summary = pipeline.summarize([text], max_length=max_len)[0]
|
| 156 |
-
|
| 157 |
-
filtered_emotions = filter_emotions(raw_emotion_pairs, text)
|
| 158 |
topic = pipeline.predict_topics([text])[0]
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
summary_source = clean_summary
|
| 166 |
else:
|
| 167 |
-
|
| 168 |
-
summary_source = fallback_summary
|
| 169 |
-
if clean_summary:
|
| 170 |
-
logger.info("Neural summary flagged as low-overlap; showing extractive fallback instead")
|
| 171 |
-
summary_notice = dedent(
|
| 172 |
-
f"""
|
| 173 |
-
<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> The neural summary looked off-topic, so an extractive fallback is shown above.</p>
|
| 174 |
-
<details style=\"margin-top: 4px;\">
|
| 175 |
-
<summary style=\"color: #b45309; cursor: pointer;\">View the original neural summary</summary>
|
| 176 |
-
<p style=\"margin-top: 8px; background-color: #fff7ed; padding: 10px; border-radius: 4px; color: #7c2d12; white-space: pre-wrap;\">
|
| 177 |
-
{clean_summary}
|
| 178 |
-
</p>
|
| 179 |
-
</details>
|
| 180 |
-
"""
|
| 181 |
-
).strip()
|
| 182 |
-
else:
|
| 183 |
-
summary_notice = (
|
| 184 |
-
"<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> "
|
| 185 |
-
"The model did not produce a summary, so an extractive fallback is shown instead.</p>"
|
| 186 |
-
)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
emotion_plot = create_emotion_plot(filtered_emotions)
|
| 190 |
-
if emotion_plot is None:
|
| 191 |
-
emotion_plot = render_unavailable_message(
|
| 192 |
-
"No emotion met the confidence threshold."
|
| 193 |
-
)
|
| 194 |
-
topic_output = format_topic(topic)
|
| 195 |
-
if clean_summary and fallback_summary is None:
|
| 196 |
-
attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
|
| 197 |
-
else:
|
| 198 |
-
attention_fig = render_unavailable_message(
|
| 199 |
-
"Attention heatmap unavailable because the neural summary was empty or flagged as unreliable."
|
| 200 |
-
)
|
| 201 |
-
download_path = prepare_download(
|
| 202 |
-
text,
|
| 203 |
-
summary_source,
|
| 204 |
-
filtered_emotions,
|
| 205 |
-
topic,
|
| 206 |
-
neural_summary=clean_summary or None,
|
| 207 |
-
fallback_summary=fallback_summary,
|
| 208 |
-
raw_emotions=raw_emotion_pairs,
|
| 209 |
-
)
|
| 210 |
download_update = gr.update(value=download_path, visible=True)
|
| 211 |
|
| 212 |
-
return summary_html, emotion_plot,
|
| 213 |
|
| 214 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 215 |
logger.error("Prediction error: %s", exc, exc_info=True)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def format_summary(original: str, summary: str
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
{original}
|
| 226 |
</p>
|
| 227 |
-
<h3 style="color: #
|
| 228 |
-
<p style="background-color: #e6f3ff; padding: 10px; border-radius:
|
| 229 |
{summary}
|
| 230 |
</p>
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
</div>
|
| 233 |
-
"""
|
| 234 |
-
return dedent(html).strip()
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
for label, score in raw_pairs:
|
| 247 |
-
try:
|
| 248 |
-
score_val = float(score)
|
| 249 |
-
except (TypeError, ValueError):
|
| 250 |
-
continue
|
| 251 |
-
pairs.append((str(label), score_val))
|
| 252 |
-
|
| 253 |
-
if not pairs:
|
| 254 |
-
return None
|
| 255 |
-
|
| 256 |
-
pairs = sorted(pairs, key=lambda item: item[1], reverse=True)
|
| 257 |
-
filtered = [item for item in pairs if item[1] >= 0.2]
|
| 258 |
-
if not filtered:
|
| 259 |
-
filtered = pairs[:3]
|
| 260 |
-
|
| 261 |
-
labels = [label for label, _ in filtered[:5]]
|
| 262 |
-
scores = [score for _, score in filtered[:5]]
|
| 263 |
-
|
| 264 |
-
df = pd.DataFrame({"Emotion": labels, "Probability": scores})
|
| 265 |
-
fig, ax = plt.subplots(figsize=(8, 5))
|
| 266 |
-
colors = sns.color_palette("Set2", len(labels))
|
| 267 |
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
|
| 268 |
-
ax.set_xlabel("Probability"
|
| 269 |
-
ax.
|
| 270 |
-
ax.set_title("Emotion Detection Results", fontsize=14, fontweight="bold")
|
| 271 |
ax.set_xlim(0, 1)
|
| 272 |
for bar in bars:
|
| 273 |
width = bar.get_width()
|
| 274 |
-
ax.text(
|
| 275 |
-
width,
|
| 276 |
-
bar.get_y() + bar.get_height() / 2,
|
| 277 |
-
f"{width:.2%}",
|
| 278 |
-
ha="left",
|
| 279 |
-
va="center",
|
| 280 |
-
fontsize=10,
|
| 281 |
-
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
| 282 |
-
)
|
| 283 |
plt.tight_layout()
|
| 284 |
return fig
|
| 285 |
|
|
|
|
| 286 |
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
|
| 287 |
-
"""Format topic prediction output as markdown."""
|
| 288 |
if isinstance(topic, TopicPrediction):
|
| 289 |
label = topic.label
|
| 290 |
-
|
| 291 |
else:
|
| 292 |
label = str(topic.get("label", "Unknown"))
|
| 293 |
-
|
| 294 |
-
|
| 295 |
return f"""
|
| 296 |
### Predicted Topic
|
| 297 |
-
|
| 298 |
**{label}**
|
| 299 |
-
|
| 300 |
-
Confidence: {
|
| 301 |
-
"""
|
|
|
|
| 302 |
|
| 303 |
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
| 304 |
cleaned: list[str] = []
|
|
@@ -307,105 +172,14 @@ def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
|
| 307 |
cleaned.append(item.strip() if item.strip() else token)
|
| 308 |
return cleaned
|
| 309 |
|
| 310 |
-
def extract_emotion_pairs(
|
| 311 |
-
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
|
| 312 |
-
) -> list[tuple[str, float]]:
|
| 313 |
-
if isinstance(emotions, EmotionPrediction):
|
| 314 |
-
return list(zip(map(str, emotions.labels), map(float, emotions.scores)))
|
| 315 |
-
labels = emotions.get("labels", [])
|
| 316 |
-
scores = emotions.get("scores", [])
|
| 317 |
-
return [(str(label), float(score)) for label, score in zip(labels, scores)]
|
| 318 |
-
|
| 319 |
-
def filter_emotions(pairs: list[tuple[str, float]], text: str) -> EmotionPrediction:
|
| 320 |
-
filtered: list[tuple[str, float]] = []
|
| 321 |
-
lowered_text = text.lower()
|
| 322 |
-
|
| 323 |
-
for label, score in pairs:
|
| 324 |
-
threshold = EMOTION_THRESHOLDS.get(label, 0.5)
|
| 325 |
-
if score < threshold:
|
| 326 |
-
continue
|
| 327 |
-
|
| 328 |
-
if label == "love":
|
| 329 |
-
keywords = EMOTION_KEYWORDS.get("love", set())
|
| 330 |
-
if score < 0.6 and not any(keyword in lowered_text for keyword in keywords):
|
| 331 |
-
continue
|
| 332 |
-
|
| 333 |
-
filtered.append((label, score))
|
| 334 |
-
|
| 335 |
-
if filtered:
|
| 336 |
-
labels, scores = zip(*filtered)
|
| 337 |
-
return EmotionPrediction(labels=list(labels), scores=list(scores))
|
| 338 |
-
|
| 339 |
-
return EmotionPrediction(labels=[], scores=[])
|
| 340 |
-
|
| 341 |
-
def summary_is_plausible(
|
| 342 |
-
summary: str,
|
| 343 |
-
original: str,
|
| 344 |
-
*,
|
| 345 |
-
min_overlap: float = 0.2,
|
| 346 |
-
min_unique_ratio: float = 0.3,
|
| 347 |
-
max_repeat_ratio: float = 0.6,
|
| 348 |
-
) -> bool:
|
| 349 |
-
"""Heuristic filter to catch off-topic or repetitive neural summaries."""
|
| 350 |
-
|
| 351 |
-
summary_tokens = re.findall(r"\w+", summary.lower())
|
| 352 |
-
if not summary_tokens:
|
| 353 |
-
return False
|
| 354 |
-
|
| 355 |
-
summary_content = [token for token in summary_tokens if token not in STOPWORDS]
|
| 356 |
-
if not summary_content:
|
| 357 |
-
return False
|
| 358 |
-
|
| 359 |
-
original_vocab = {token for token in re.findall(r"\w+", original.lower()) if token not in STOPWORDS}
|
| 360 |
-
overlap = sum(1 for token in summary_content if token in original_vocab)
|
| 361 |
-
overlap_ratio = overlap / max(1, len(summary_content))
|
| 362 |
-
if overlap_ratio < min_overlap:
|
| 363 |
-
return False
|
| 364 |
-
|
| 365 |
-
token_counts = Counter(summary_content)
|
| 366 |
-
most_common_ratio = token_counts.most_common(1)[0][1] / len(summary_content)
|
| 367 |
-
unique_ratio = len(token_counts) / len(summary_content)
|
| 368 |
-
if unique_ratio < min_unique_ratio:
|
| 369 |
-
return False
|
| 370 |
-
if most_common_ratio > max_repeat_ratio:
|
| 371 |
-
return False
|
| 372 |
-
return True
|
| 373 |
-
|
| 374 |
-
def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
|
| 375 |
-
"""Build a lightweight extractive summary when the model generates nothing."""
|
| 376 |
-
if not text.strip():
|
| 377 |
-
return ""
|
| 378 |
-
|
| 379 |
-
sections = [segment.strip() for segment in text.replace("\n", " ").split(". ") if segment.strip()]
|
| 380 |
-
if not sections:
|
| 381 |
-
return text.strip()[:max_chars]
|
| 382 |
-
|
| 383 |
-
summary_fragments: list[str] = []
|
| 384 |
-
chars_used = 0
|
| 385 |
-
for segment in sections:
|
| 386 |
-
candidate = segment if segment.endswith(".") else f"{segment}."
|
| 387 |
-
if chars_used + len(candidate) > max_chars and summary_fragments:
|
| 388 |
-
break
|
| 389 |
-
summary_fragments.append(candidate)
|
| 390 |
-
chars_used += len(candidate)
|
| 391 |
-
|
| 392 |
-
fallback = " ".join(summary_fragments)
|
| 393 |
-
if not fallback:
|
| 394 |
-
fallback = text.strip()[:max_chars]
|
| 395 |
-
return fallback
|
| 396 |
|
| 397 |
def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
|
| 398 |
-
"""Generate a seaborn heatmap of decoder cross-attention averaged over heads."""
|
| 399 |
-
if not summary:
|
| 400 |
-
return None
|
| 401 |
try:
|
| 402 |
batch = pipeline.preprocessor.batch_encode([text])
|
| 403 |
batch = pipeline._batch_to_device(batch)
|
| 404 |
src_ids = batch.input_ids
|
| 405 |
src_mask = batch.attention_mask
|
| 406 |
-
encoder_mask = None
|
| 407 |
-
if src_mask is not None:
|
| 408 |
-
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
|
| 409 |
|
| 410 |
with torch.inference_mode():
|
| 411 |
memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
|
|
@@ -423,11 +197,11 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 423 |
memory_mask=memory_mask,
|
| 424 |
collect_attn=True,
|
| 425 |
)
|
|
|
|
| 426 |
if not attn_list:
|
| 427 |
return None
|
| 428 |
-
cross_attn = attn_list[-1]["cross"]
|
| 429 |
attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
|
| 430 |
-
|
| 431 |
source_len = batch.lengths[0]
|
| 432 |
attn_matrix = attn_matrix[:target_len, :source_len]
|
| 433 |
|
|
@@ -439,7 +213,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 439 |
pipeline.tokenizer.bos_token_id,
|
| 440 |
pipeline.tokenizer.eos_token_id,
|
| 441 |
}
|
| 442 |
-
keep_indices = [
|
| 443 |
if not keep_indices:
|
| 444 |
return None
|
| 445 |
|
|
@@ -447,10 +221,9 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 447 |
tokenizer_impl = pipeline.tokenizer.tokenizer
|
| 448 |
convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
|
| 449 |
if convert_tokens is None:
|
| 450 |
-
logger.warning("Tokenizer does not expose convert_ids_to_tokens; skipping attention heatmap.")
|
| 451 |
return None
|
| 452 |
|
| 453 |
-
summary_tokens_raw = convert_tokens([target_id_list[
|
| 454 |
source_tokens_raw = convert_tokens(source_ids)
|
| 455 |
|
| 456 |
summary_tokens = _clean_tokens(summary_tokens_raw)
|
|
@@ -477,14 +250,13 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 477 |
|
| 478 |
except Exception as exc:
|
| 479 |
logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
|
| 480 |
-
return
|
| 481 |
|
| 482 |
|
| 483 |
-
def
|
| 484 |
-
"""Render a simple Matplotlib figure containing an informational message."""
|
| 485 |
fig, ax = plt.subplots(figsize=(6, 2))
|
| 486 |
ax.axis("off")
|
| 487 |
-
ax.text(0.5, 0.5, message, ha="center", va="center",
|
| 488 |
fig.tight_layout()
|
| 489 |
return fig
|
| 490 |
|
|
@@ -494,12 +266,7 @@ def prepare_download(
|
|
| 494 |
summary: str,
|
| 495 |
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
|
| 496 |
topic: TopicPrediction | dict[str, float | str],
|
| 497 |
-
*,
|
| 498 |
-
neural_summary: str | None = None,
|
| 499 |
-
fallback_summary: str | None = None,
|
| 500 |
-
raw_emotions: Sequence[tuple[str, float]] | None = None,
|
| 501 |
) -> str:
|
| 502 |
-
"""Persist JSON payload to a temporary file and return its path for download."""
|
| 503 |
if isinstance(emotions, EmotionPrediction):
|
| 504 |
emotion_payload = {
|
| 505 |
"labels": list(emotions.labels),
|
|
@@ -512,10 +279,7 @@ def prepare_download(
|
|
| 512 |
}
|
| 513 |
|
| 514 |
if isinstance(topic, TopicPrediction):
|
| 515 |
-
topic_payload = {
|
| 516 |
-
"label": topic.label,
|
| 517 |
-
"confidence": topic.confidence,
|
| 518 |
-
}
|
| 519 |
else:
|
| 520 |
topic_payload = {
|
| 521 |
"label": str(topic.get("label", topic.get("topic", "Unknown"))),
|
|
@@ -525,106 +289,99 @@ def prepare_download(
|
|
| 525 |
payload = {
|
| 526 |
"original_text": text,
|
| 527 |
"summary": summary,
|
| 528 |
-
"neural_summary": neural_summary,
|
| 529 |
-
"fallback_summary": fallback_summary,
|
| 530 |
"emotions": emotion_payload,
|
| 531 |
"topic": topic_payload,
|
| 532 |
}
|
| 533 |
-
|
| 534 |
-
payload["raw_emotions"] = [
|
| 535 |
-
{"label": label, "score": float(score)} for label, score in raw_emotions
|
| 536 |
-
]
|
| 537 |
with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
|
| 538 |
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
|
| 553 |
def create_interface() -> gr.Blocks:
|
| 554 |
with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
|
| 555 |
-
gr.Markdown(
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
with gr.Row():
|
| 563 |
-
# Left column - Input
|
| 564 |
with gr.Column(scale=1):
|
| 565 |
-
gr.Markdown("### Input")
|
| 566 |
input_text = gr.Textbox(
|
| 567 |
-
label="
|
| 568 |
-
placeholder="Paste or type your text here...",
|
| 569 |
lines=10,
|
| 570 |
-
value=SAMPLE_TEXT
|
| 571 |
-
|
| 572 |
-
token_count = gr.Textbox(
|
| 573 |
-
label="Token Count",
|
| 574 |
-
value="Tokens: 0",
|
| 575 |
-
interactive=False
|
| 576 |
)
|
|
|
|
| 577 |
compression = gr.Slider(
|
| 578 |
minimum=20,
|
| 579 |
maximum=80,
|
| 580 |
value=50,
|
| 581 |
step=5,
|
| 582 |
label="Compression %",
|
| 583 |
-
info="Higher
|
| 584 |
)
|
| 585 |
-
|
| 586 |
-
|
| 587 |
with gr.Column(scale=2):
|
| 588 |
-
gr.Markdown("### Result")
|
| 589 |
with gr.Tabs():
|
| 590 |
with gr.TabItem("Summary"):
|
| 591 |
summary_output = gr.HTML(label="Summary")
|
| 592 |
with gr.TabItem("Emotions"):
|
| 593 |
-
emotion_output = gr.Plot(label="Emotion
|
| 594 |
with gr.TabItem("Topic"):
|
| 595 |
topic_output = gr.Markdown(label="Topic Prediction")
|
| 596 |
-
with gr.TabItem("Attention
|
| 597 |
-
attention_output = gr.Plot(label="Attention
|
| 598 |
-
gr.Markdown("*
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
[SAMPLE_TEXT, 50],
|
| 620 |
-
[
|
| 621 |
-
"Climate change poses significant risks to global ecosystems. Rising temperatures, melting ice caps, and extreme weather events are becoming more frequent. Scientists urge immediate action to reduce carbon emissions and transition to renewable energy sources.",
|
| 622 |
-
40,
|
| 623 |
-
],
|
| 624 |
-
],
|
| 625 |
-
inputs=[input_text, compression],
|
| 626 |
-
label="Try these examples:",
|
| 627 |
-
)
|
| 628 |
return demo
|
| 629 |
|
| 630 |
|
|
@@ -635,14 +392,8 @@ app = demo
|
|
| 635 |
if __name__ == "__main__":
|
| 636 |
try:
|
| 637 |
get_pipeline()
|
| 638 |
-
demo.queue().launch(
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
show_error=True,
|
| 643 |
-
)
|
| 644 |
-
except Exception as e:
|
| 645 |
-
logger.error("Failed to launch demo: %s", e, exc_info=True)
|
| 646 |
-
print(f"Error: {e}")
|
| 647 |
-
sys.exit(1)
|
| 648 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal Gradio demo for the LexiMind multitask model.
|
| 3 |
+
Shows raw model outputs without any post-processing tricks.
|
| 4 |
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
import json
|
|
|
|
| 8 |
import sys
|
| 9 |
from pathlib import Path
|
| 10 |
from tempfile import NamedTemporaryFile
|
| 11 |
from typing import Iterable, Sequence
|
|
|
|
|
|
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
from gradio.themes import Soft
|
|
|
|
| 18 |
import torch
|
| 19 |
from matplotlib.figure import Figure
|
| 20 |
|
| 21 |
+
# Make local packages importable when running the script directly
|
| 22 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 23 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 24 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 25 |
|
| 26 |
from src.inference.factory import create_inference_pipeline
|
| 27 |
from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
|
|
|
|
| 30 |
configure_logging()
|
| 31 |
logger = get_logger(__name__)
|
| 32 |
|
| 33 |
+
_pipeline: InferencePipeline | None = None
|
| 34 |
+
|
| 35 |
+
VISUALIZATION_DIR = PROJECT_ROOT / "outputs"
|
| 36 |
+
VISUALIZATION_ASSETS: list[tuple[str, str]] = [
|
| 37 |
+
("attention_visualization.png", "Attention weights (single head)"),
|
| 38 |
+
("multihead_attention_visualization.png", "Multi-head attention comparison"),
|
| 39 |
+
("single_vs_multihead.png", "Single vs multi-head attention"),
|
| 40 |
+
("positional_encoding_heatmap.png", "Positional encoding heatmap"),
|
| 41 |
+
]
|
| 42 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def get_pipeline() -> InferencePipeline:
|
| 45 |
+
global _pipeline
|
|
|
|
| 46 |
if _pipeline is None:
|
| 47 |
+
logger.info("Loading inference pipeline ...")
|
| 48 |
+
_pipeline, _ = create_inference_pipeline(
|
| 49 |
+
tokenizer_dir="artifacts/hf_tokenizer/",
|
| 50 |
+
checkpoint_path="checkpoints/best.pt",
|
| 51 |
+
labels_path="artifacts/labels.json",
|
| 52 |
+
)
|
| 53 |
+
logger.info("Pipeline loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return _pipeline
|
| 55 |
|
| 56 |
+
|
| 57 |
+
def map_compression_to_length(compression: int, max_model_length: int = 512) -> int:
|
| 58 |
+
ratio = (100 - compression) / 100
|
| 59 |
+
return max(16, int(ratio * max_model_length))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def count_tokens(text: str) -> str:
|
|
|
|
| 63 |
if not text:
|
| 64 |
return "Tokens: 0"
|
| 65 |
+
try:
|
| 66 |
pipeline = get_pipeline()
|
| 67 |
+
return f"Tokens: {len(pipeline.tokenizer.encode(text))}"
|
| 68 |
+
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 69 |
+
logger.error("Token counting failed: %s", exc, exc_info=True)
|
|
|
|
| 70 |
return "Token count unavailable"
|
| 71 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
def predict(text: str, compression: int):
|
|
|
|
| 74 |
hidden_download = gr.update(value=None, visible=False)
|
| 75 |
if not text or not text.strip():
|
| 76 |
return (
|
| 77 |
+
"Please enter text to analyze.",
|
| 78 |
None,
|
| 79 |
+
"No topic prediction available.",
|
| 80 |
None,
|
| 81 |
hidden_download,
|
| 82 |
)
|
| 83 |
+
|
| 84 |
try:
|
| 85 |
pipeline = get_pipeline()
|
| 86 |
max_len = map_compression_to_length(compression)
|
| 87 |
+
logger.info("Generating summary with max length %s", max_len)
|
| 88 |
|
| 89 |
+
summary = pipeline.summarize([text], max_length=max_len)[0].strip()
|
| 90 |
+
emotions = pipeline.predict_emotions([text])[0]
|
|
|
|
| 91 |
topic = pipeline.predict_topics([text])[0]
|
| 92 |
|
| 93 |
+
summary_html = format_summary(text, summary)
|
| 94 |
+
emotion_plot = create_emotion_plot(emotions)
|
| 95 |
+
topic_markdown = format_topic(topic)
|
| 96 |
+
if summary:
|
| 97 |
+
attention_fig = create_attention_heatmap(text, summary, pipeline)
|
|
|
|
| 98 |
else:
|
| 99 |
+
attention_fig = render_message_figure("Attention heatmap unavailable: summary was empty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
download_path = prepare_download(text, summary, emotions, topic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
download_update = gr.update(value=download_path, visible=True)
|
| 103 |
|
| 104 |
+
return summary_html, emotion_plot, topic_markdown, attention_fig, download_update
|
| 105 |
|
| 106 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 107 |
logger.error("Prediction error: %s", exc, exc_info=True)
|
| 108 |
+
return "Prediction failed. Check logs for details.", None, "Error", None, hidden_download
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def format_summary(original: str, summary: str) -> str:
|
| 112 |
+
if not summary:
|
| 113 |
+
summary = "(Model returned an empty summary. Consider retraining the summarization head.)"
|
| 114 |
+
|
| 115 |
+
return f"""
|
| 116 |
+
<div style="padding: 12px; border-radius: 6px; background-color: #fafafa; color: #222;">
|
| 117 |
+
<h3 style="margin-top: 0; color: #222;">Original Text</h3>
|
| 118 |
+
<p style="background-color: #f0f0f0; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #222;">
|
| 119 |
{original}
|
| 120 |
</p>
|
| 121 |
+
<h3 style="color: #222;">Model Summary</h3>
|
| 122 |
+
<p style="background-color: #e6f3ff; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #111;">
|
| 123 |
{summary}
|
| 124 |
</p>
|
| 125 |
+
<p style="margin-top: 12px; color: #6b7280; font-size: 0.9rem;">
|
| 126 |
+
Outputs are shown exactly as produced by the checkpoint.
|
| 127 |
+
</p>
|
| 128 |
</div>
|
| 129 |
+
""".strip()
|
|
|
|
| 130 |
|
| 131 |
+
|
| 132 |
+
def create_emotion_plot(emotions: EmotionPrediction) -> Figure | None:
|
| 133 |
+
if not emotions.labels:
|
| 134 |
+
return render_message_figure("No emotions cleared the model threshold.")
|
| 135 |
+
|
| 136 |
+
df = pd.DataFrame({"Emotion": emotions.labels, "Probability": emotions.scores}).sort_values(
|
| 137 |
+
"Probability", ascending=True
|
| 138 |
+
)
|
| 139 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 140 |
+
colors = sns.color_palette("crest", len(df))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
|
| 142 |
+
ax.set_xlabel("Probability")
|
| 143 |
+
ax.set_title("Emotion Scores")
|
|
|
|
| 144 |
ax.set_xlim(0, 1)
|
| 145 |
for bar in bars:
|
| 146 |
width = bar.get_width()
|
| 147 |
+
ax.text(width + 0.02, bar.get_y() + bar.get_height() / 2, f"{width:.2%}", va="center")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
plt.tight_layout()
|
| 149 |
return fig
|
| 150 |
|
| 151 |
+
|
| 152 |
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
|
|
|
|
| 153 |
if isinstance(topic, TopicPrediction):
|
| 154 |
label = topic.label
|
| 155 |
+
confidence = topic.confidence
|
| 156 |
else:
|
| 157 |
label = str(topic.get("label", "Unknown"))
|
| 158 |
+
confidence = float(topic.get("score", 0.0))
|
|
|
|
| 159 |
return f"""
|
| 160 |
### Predicted Topic
|
| 161 |
+
|
| 162 |
**{label}**
|
| 163 |
+
|
| 164 |
+
Confidence: {confidence:.2%}
|
| 165 |
+
""".strip()
|
| 166 |
+
|
| 167 |
|
| 168 |
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
| 169 |
cleaned: list[str] = []
|
|
|
|
| 172 |
cleaned.append(item.strip() if item.strip() else token)
|
| 173 |
return cleaned
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
|
|
|
|
|
|
|
|
|
|
| 177 |
try:
|
| 178 |
batch = pipeline.preprocessor.batch_encode([text])
|
| 179 |
batch = pipeline._batch_to_device(batch)
|
| 180 |
src_ids = batch.input_ids
|
| 181 |
src_mask = batch.attention_mask
|
| 182 |
+
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
|
|
|
|
|
|
| 183 |
|
| 184 |
with torch.inference_mode():
|
| 185 |
memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
|
|
|
|
| 197 |
memory_mask=memory_mask,
|
| 198 |
collect_attn=True,
|
| 199 |
)
|
| 200 |
+
|
| 201 |
if not attn_list:
|
| 202 |
return None
|
| 203 |
+
cross_attn = attn_list[-1]["cross"]
|
| 204 |
attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
|
|
|
|
| 205 |
source_len = batch.lengths[0]
|
| 206 |
attn_matrix = attn_matrix[:target_len, :source_len]
|
| 207 |
|
|
|
|
| 213 |
pipeline.tokenizer.bos_token_id,
|
| 214 |
pipeline.tokenizer.eos_token_id,
|
| 215 |
}
|
| 216 |
+
keep_indices = [idx for idx, token_id in enumerate(target_id_list) if token_id not in special_ids]
|
| 217 |
if not keep_indices:
|
| 218 |
return None
|
| 219 |
|
|
|
|
| 221 |
tokenizer_impl = pipeline.tokenizer.tokenizer
|
| 222 |
convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
|
| 223 |
if convert_tokens is None:
|
|
|
|
| 224 |
return None
|
| 225 |
|
| 226 |
+
summary_tokens_raw = convert_tokens([target_id_list[idx] for idx in keep_indices])
|
| 227 |
source_tokens_raw = convert_tokens(source_ids)
|
| 228 |
|
| 229 |
summary_tokens = _clean_tokens(summary_tokens_raw)
|
|
|
|
| 250 |
|
| 251 |
except Exception as exc:
|
| 252 |
logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
|
| 253 |
+
return render_message_figure("Unable to render attention heatmap for this example.")
|
| 254 |
|
| 255 |
|
| 256 |
+
def render_message_figure(message: str) -> Figure:
|
|
|
|
| 257 |
fig, ax = plt.subplots(figsize=(6, 2))
|
| 258 |
ax.axis("off")
|
| 259 |
+
ax.text(0.5, 0.5, message, ha="center", va="center", wrap=True)
|
| 260 |
fig.tight_layout()
|
| 261 |
return fig
|
| 262 |
|
|
|
|
| 266 |
summary: str,
|
| 267 |
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
|
| 268 |
topic: TopicPrediction | dict[str, float | str],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
) -> str:
|
|
|
|
| 270 |
if isinstance(emotions, EmotionPrediction):
|
| 271 |
emotion_payload = {
|
| 272 |
"labels": list(emotions.labels),
|
|
|
|
| 279 |
}
|
| 280 |
|
| 281 |
if isinstance(topic, TopicPrediction):
|
| 282 |
+
topic_payload = {"label": topic.label, "confidence": topic.confidence}
|
|
|
|
|
|
|
|
|
|
| 283 |
else:
|
| 284 |
topic_payload = {
|
| 285 |
"label": str(topic.get("label", topic.get("topic", "Unknown"))),
|
|
|
|
| 289 |
payload = {
|
| 290 |
"original_text": text,
|
| 291 |
"summary": summary,
|
|
|
|
|
|
|
| 292 |
"emotions": emotion_payload,
|
| 293 |
"topic": topic_payload,
|
| 294 |
}
|
| 295 |
+
|
|
|
|
|
|
|
|
|
|
| 296 |
with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
|
| 297 |
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
| 298 |
+
return handle.name
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def load_visualization_gallery() -> list[tuple[str, str]]:
|
| 302 |
+
"""Collect visualization images produced by model tests."""
|
| 303 |
+
items: list[tuple[str, str]] = []
|
| 304 |
+
for filename, label in VISUALIZATION_ASSETS:
|
| 305 |
+
path = VISUALIZATION_DIR / filename
|
| 306 |
+
if path.exists():
|
| 307 |
+
items.append((str(path), label))
|
| 308 |
+
else:
|
| 309 |
+
logger.debug("Visualization asset missing: %s", path)
|
| 310 |
+
return items
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
SAMPLE_TEXT = (
|
| 314 |
+
"Artificial intelligence is rapidly transforming the technology landscape. "
|
| 315 |
+
"Machine learning algorithms are now capable of processing vast amounts of data, "
|
| 316 |
+
"identifying patterns, and making predictions with unprecedented accuracy. "
|
| 317 |
+
"From healthcare diagnostics to financial forecasting, AI applications are "
|
| 318 |
+
"revolutionizing industries worldwide. However, ethical considerations around "
|
| 319 |
+
"privacy, bias, and transparency remain critical challenges that must be addressed "
|
| 320 |
+
"as these technologies continue to evolve."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
|
| 324 |
def create_interface() -> gr.Blocks:
|
| 325 |
with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
|
| 326 |
+
gr.Markdown(
|
| 327 |
+
"""
|
| 328 |
+
# LexiMind NLP Demo
|
| 329 |
+
|
| 330 |
+
This demo streams the raw outputs from the saved LexiMind checkpoint.
|
| 331 |
+
Results may be noisy; retraining is recommended for production use.
|
| 332 |
+
"""
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
with gr.Row():
|
|
|
|
| 336 |
with gr.Column(scale=1):
|
|
|
|
| 337 |
input_text = gr.Textbox(
|
| 338 |
+
label="Input Text",
|
|
|
|
| 339 |
lines=10,
|
| 340 |
+
value=SAMPLE_TEXT,
|
| 341 |
+
placeholder="Paste or type your text here...",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
)
|
| 343 |
+
token_box = gr.Textbox(label="Token Count", value="Tokens: 0", interactive=False)
|
| 344 |
compression = gr.Slider(
|
| 345 |
minimum=20,
|
| 346 |
maximum=80,
|
| 347 |
value=50,
|
| 348 |
step=5,
|
| 349 |
label="Compression %",
|
| 350 |
+
info="Higher values request shorter summaries.",
|
| 351 |
)
|
| 352 |
+
analyze_btn = gr.Button("Run Analysis", variant="primary")
|
| 353 |
+
|
| 354 |
with gr.Column(scale=2):
|
|
|
|
| 355 |
with gr.Tabs():
|
| 356 |
with gr.TabItem("Summary"):
|
| 357 |
summary_output = gr.HTML(label="Summary")
|
| 358 |
with gr.TabItem("Emotions"):
|
| 359 |
+
emotion_output = gr.Plot(label="Emotion Probabilities")
|
| 360 |
with gr.TabItem("Topic"):
|
| 361 |
topic_output = gr.Markdown(label="Topic Prediction")
|
| 362 |
+
with gr.TabItem("Attention"):
|
| 363 |
+
attention_output = gr.Plot(label="Attention Heatmap")
|
| 364 |
+
gr.Markdown("*Shows decoder attention if a summary is available.*")
|
| 365 |
+
with gr.TabItem("Model Visuals"):
|
| 366 |
+
visuals = gr.Gallery(
|
| 367 |
+
label="Test Visualizations",
|
| 368 |
+
value=load_visualization_gallery(),
|
| 369 |
+
columns=2,
|
| 370 |
+
height=400,
|
| 371 |
+
)
|
| 372 |
+
gr.Markdown(
|
| 373 |
+
"These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
|
| 374 |
+
)
|
| 375 |
+
gr.Markdown("### Download Results")
|
| 376 |
+
download_btn = gr.DownloadButton("Download JSON", visible=False)
|
| 377 |
+
|
| 378 |
+
input_text.change(fn=count_tokens, inputs=[input_text], outputs=[token_box])
|
| 379 |
+
analyze_btn.click(
|
| 380 |
+
fn=predict,
|
| 381 |
+
inputs=[input_text, compression],
|
| 382 |
+
outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
|
| 383 |
+
)
|
| 384 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
return demo
|
| 386 |
|
| 387 |
|
|
|
|
| 392 |
if __name__ == "__main__":
|
| 393 |
try:
|
| 394 |
get_pipeline()
|
| 395 |
+
demo.queue().launch(share=False)
|
| 396 |
+
except Exception as exc: # pragma: no cover - surfaced in console
|
| 397 |
+
logger.error("Failed to launch demo: %s", exc, exc_info=True)
|
| 398 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
src/inference/pipeline.py
CHANGED
|
@@ -75,13 +75,15 @@ class InferencePipeline:
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
-
generated = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
for row in generated.cpu().tolist():
|
| 82 |
-
trimmed_sequences.append(self._trim_special_tokens(row))
|
| 83 |
-
|
| 84 |
-
return self.tokenizer.decode_batch(trimmed_sequences)
|
| 85 |
|
| 86 |
def predict_emotions(
|
| 87 |
self,
|
|
@@ -96,7 +98,7 @@ class InferencePipeline:
|
|
| 96 |
|
| 97 |
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 98 |
model_inputs = self._batch_to_model_inputs(batch)
|
| 99 |
-
decision_threshold = self.config.emotion_threshold
|
| 100 |
|
| 101 |
with torch.inference_mode():
|
| 102 |
logits = self.model.forward("emotion", model_inputs)
|
|
@@ -146,62 +148,6 @@ class InferencePipeline:
|
|
| 146 |
"topic": self.predict_topics(text_list),
|
| 147 |
}
|
| 148 |
|
| 149 |
-
def _constrained_greedy_decode(
|
| 150 |
-
self,
|
| 151 |
-
memory: torch.Tensor,
|
| 152 |
-
max_len: int,
|
| 153 |
-
*,
|
| 154 |
-
memory_mask: torch.Tensor | None = None,
|
| 155 |
-
) -> torch.Tensor:
|
| 156 |
-
"""Run greedy decoding while banning BOS/PAD tokens from the generated sequence."""
|
| 157 |
-
|
| 158 |
-
device = memory.device
|
| 159 |
-
batch_size = memory.size(0)
|
| 160 |
-
bos = self.tokenizer.bos_token_id
|
| 161 |
-
pad = getattr(self.tokenizer, "pad_token_id", None)
|
| 162 |
-
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 163 |
-
|
| 164 |
-
generated = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)
|
| 165 |
-
expanded_memory_mask = None
|
| 166 |
-
if memory_mask is not None:
|
| 167 |
-
expanded_memory_mask = memory_mask.to(device=device, dtype=torch.bool)
|
| 168 |
-
|
| 169 |
-
for _ in range(max(1, max_len) - 1):
|
| 170 |
-
decoder_out = self.model.decoder(generated, memory, memory_mask=expanded_memory_mask)
|
| 171 |
-
logits = decoder_out if isinstance(decoder_out, torch.Tensor) else decoder_out[0]
|
| 172 |
-
|
| 173 |
-
step_logits = logits[:, -1, :].clone()
|
| 174 |
-
if bos is not None and bos < step_logits.size(-1):
|
| 175 |
-
step_logits[:, bos] = float("-inf")
|
| 176 |
-
if pad is not None and pad < step_logits.size(-1):
|
| 177 |
-
step_logits[:, pad] = float("-inf")
|
| 178 |
-
|
| 179 |
-
next_token = step_logits.argmax(dim=-1, keepdim=True)
|
| 180 |
-
generated = torch.cat([generated, next_token], dim=1)
|
| 181 |
-
|
| 182 |
-
if eos is not None and torch.all(next_token.squeeze(-1) == eos):
|
| 183 |
-
break
|
| 184 |
-
|
| 185 |
-
return generated
|
| 186 |
-
|
| 187 |
-
def _trim_special_tokens(self, sequence: Sequence[int]) -> List[int]:
|
| 188 |
-
"""Remove leading BOS and trailing PAD/EOS tokens from a generated sequence."""
|
| 189 |
-
|
| 190 |
-
bos = self.tokenizer.bos_token_id
|
| 191 |
-
pad = getattr(self.tokenizer, "pad_token_id", None)
|
| 192 |
-
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 193 |
-
|
| 194 |
-
trimmed: List[int] = []
|
| 195 |
-
for idx, token in enumerate(sequence):
|
| 196 |
-
if idx == 0 and bos is not None and token == bos:
|
| 197 |
-
continue
|
| 198 |
-
if pad is not None and token == pad:
|
| 199 |
-
continue
|
| 200 |
-
if eos is not None and token == eos:
|
| 201 |
-
break
|
| 202 |
-
trimmed.append(int(token))
|
| 203 |
-
return trimmed
|
| 204 |
-
|
| 205 |
def _batch_to_device(self, batch: Batch) -> Batch:
|
| 206 |
tensor_updates: dict[str, torch.Tensor] = {}
|
| 207 |
for item in fields(batch):
|
|
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
+
generated = self.model.decoder.greedy_decode(
|
| 79 |
+
memory=memory,
|
| 80 |
+
max_len=max_len,
|
| 81 |
+
start_token_id=self.tokenizer.bos_token_id,
|
| 82 |
+
end_token_id=self.tokenizer.eos_token_id,
|
| 83 |
+
device=self.device,
|
| 84 |
+
)
|
| 85 |
|
| 86 |
+
return self.tokenizer.decode_batch(generated.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def predict_emotions(
|
| 89 |
self,
|
|
|
|
| 98 |
|
| 99 |
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 100 |
model_inputs = self._batch_to_model_inputs(batch)
|
| 101 |
+
decision_threshold = threshold or self.config.emotion_threshold
|
| 102 |
|
| 103 |
with torch.inference_mode():
|
| 104 |
logits = self.model.forward("emotion", model_inputs)
|
|
|
|
| 148 |
"topic": self.predict_topics(text_list),
|
| 149 |
}
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def _batch_to_device(self, batch: Batch) -> Batch:
|
| 152 |
tensor_updates: dict[str, torch.Tensor] = {}
|
| 153 |
for item in fields(batch):
|