OliverPerrin commited on
Commit
8951fba
·
1 Parent(s): 3318356

Updated Gradio Demo to include model visualizations

Browse files
Files changed (2) hide show
  1. scripts/demo_gradio.py +168 -417
  2. src/inference/pipeline.py +9 -63
scripts/demo_gradio.py CHANGED
@@ -1,15 +1,14 @@
1
  """
2
- Gradio Demo interface for LexiMind NLP pipeline.
3
- Showcases summarization, emotion detection, and topic prediction.
4
  """
 
 
5
  import json
6
- import re
7
  import sys
8
  from pathlib import Path
9
  from tempfile import NamedTemporaryFile
10
  from typing import Iterable, Sequence
11
- from textwrap import dedent
12
- from collections import Counter
13
 
14
  import gradio as gr
15
  from gradio.themes import Soft
@@ -19,9 +18,10 @@ import seaborn as sns
19
  import torch
20
  from matplotlib.figure import Figure
21
 
22
- # Add project root to the path, going up two folder levels from this file
23
- project_root = Path(__file__).parent.parent
24
- sys.path.insert(0, str(project_root))
 
25
 
26
  from src.inference.factory import create_inference_pipeline
27
  from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
@@ -30,275 +30,140 @@ from src.utils.logging import configure_logging, get_logger
30
  configure_logging()
31
  logger = get_logger(__name__)
32
 
33
- _pipeline: InferencePipeline | None = None # Global pipeline instance
34
- _label_metadata = None # Cached label metadata
35
-
36
- STOPWORDS = {
37
- "the",
38
- "is",
39
- "a",
40
- "an",
41
- "to",
42
- "of",
43
- "and",
44
- "in",
45
- "it",
46
- "that",
47
- "for",
48
- "on",
49
- "with",
50
- "as",
51
- "by",
52
- "be",
53
- "are",
54
- "was",
55
- "were",
56
- "this",
57
- "which",
58
- "at",
59
- "or",
60
- "from",
61
- "but",
62
- "has",
63
- "have",
64
- "had",
65
- "can",
66
- "will",
67
- "would",
68
- }
69
-
70
- EMOTION_THRESHOLDS = {
71
- "anger": 0.6,
72
- "fear": 0.85,
73
- "joy": 0.6,
74
- "love": 0.25,
75
- "sadness": 0.3,
76
- "surprise": 0.55,
77
- }
78
-
79
- EMOTION_KEYWORDS = {
80
- "love": {
81
- "love",
82
- "loved",
83
- "loving",
84
- "beloved",
85
- "romance",
86
- "romantic",
87
- "affection",
88
- "passion",
89
- "sweetheart",
90
- "valentine",
91
- "dear",
92
- "cherish",
93
- "ador",
94
- "marriage",
95
- "wedding",
96
- }
97
- }
98
 
99
  def get_pipeline() -> InferencePipeline:
100
- """Lazy Loading and Caching the inference pipeline"""
101
- global _pipeline, _label_metadata
102
  if _pipeline is None:
103
- try:
104
- logger.info("Loading inference pipeline...")
105
- pipeline, label_metadata = create_inference_pipeline(
106
- tokenizer_dir="artifacts/hf_tokenizer/",
107
- checkpoint_path="checkpoints/best.pt",
108
- labels_path="artifacts/labels.json",
109
- )
110
- _pipeline = pipeline
111
- _label_metadata = label_metadata
112
- logger.info("Pipeline loaded successfully")
113
- except Exception as e:
114
- logger.error(f"Failed to load pipeline: {e}")
115
- raise RuntimeError("Could not initialize inference pipeline. Check logs for details.")
116
  return _pipeline
117
 
 
 
 
 
 
 
118
  def count_tokens(text: str) -> str:
119
- """Count tokens in the input text."""
120
  if not text:
121
  return "Tokens: 0"
122
- try:
123
  pipeline = get_pipeline()
124
- token_count = len(pipeline.tokenizer.encode(text))
125
- return f"Tokens: {token_count}"
126
- except Exception as e:
127
- logger.error(f"Token counting error: {e}")
128
  return "Token count unavailable"
129
-
130
- def map_compression_to_length(compression: int, max_model_length: int = 512):
131
- """
132
- Map Compression slider (20-80%) to max summary length.
133
- Higher compression = shorter summary output.
134
- """
135
- # Invert, 20% compression = 80% of max length
136
- ratio = (100 - compression) / 100
137
- return int(ratio * max_model_length)
138
 
139
  def predict(text: str, compression: int):
140
- """Run the full pipeline and prepare Gradio outputs."""
141
  hidden_download = gr.update(value=None, visible=False)
142
  if not text or not text.strip():
143
  return (
144
- "Please enter some text to analyze.",
145
  None,
146
- "No topic prediction available",
147
  None,
148
  hidden_download,
149
  )
 
150
  try:
151
  pipeline = get_pipeline()
152
  max_len = map_compression_to_length(compression)
153
- logger.info("Generating summary with max length of %s", max_len)
154
 
155
- summary = pipeline.summarize([text], max_length=max_len)[0]
156
- raw_emotion_pairs = extract_emotion_pairs(pipeline.predict_emotions([text], threshold=0.0)[0])
157
- filtered_emotions = filter_emotions(raw_emotion_pairs, text)
158
  topic = pipeline.predict_topics([text])[0]
159
 
160
- clean_summary = summary.strip()
161
- summary_notice = ""
162
- fallback_summary: str | None = None
163
-
164
- if clean_summary and summary_is_plausible(clean_summary, text):
165
- summary_source = clean_summary
166
  else:
167
- fallback_summary = generate_fallback_summary(text)
168
- summary_source = fallback_summary
169
- if clean_summary:
170
- logger.info("Neural summary flagged as low-overlap; showing extractive fallback instead")
171
- summary_notice = dedent(
172
- f"""
173
- <p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> The neural summary looked off-topic, so an extractive fallback is shown above.</p>
174
- <details style=\"margin-top: 4px;\">
175
- <summary style=\"color: #b45309; cursor: pointer;\">View the original neural summary</summary>
176
- <p style=\"margin-top: 8px; background-color: #fff7ed; padding: 10px; border-radius: 4px; color: #7c2d12; white-space: pre-wrap;\">
177
- {clean_summary}
178
- </p>
179
- </details>
180
- """
181
- ).strip()
182
- else:
183
- summary_notice = (
184
- "<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> "
185
- "The model did not produce a summary, so an extractive fallback is shown instead.</p>"
186
- )
187
 
188
- summary_html = format_summary(text, summary_source, notice=summary_notice)
189
- emotion_plot = create_emotion_plot(filtered_emotions)
190
- if emotion_plot is None:
191
- emotion_plot = render_unavailable_message(
192
- "No emotion met the confidence threshold."
193
- )
194
- topic_output = format_topic(topic)
195
- if clean_summary and fallback_summary is None:
196
- attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
197
- else:
198
- attention_fig = render_unavailable_message(
199
- "Attention heatmap unavailable because the neural summary was empty or flagged as unreliable."
200
- )
201
- download_path = prepare_download(
202
- text,
203
- summary_source,
204
- filtered_emotions,
205
- topic,
206
- neural_summary=clean_summary or None,
207
- fallback_summary=fallback_summary,
208
- raw_emotions=raw_emotion_pairs,
209
- )
210
  download_update = gr.update(value=download_path, visible=True)
211
 
212
- return summary_html, emotion_plot, topic_output, attention_fig, download_update
213
 
214
  except Exception as exc: # pragma: no cover - surfaced in UI
215
  logger.error("Prediction error: %s", exc, exc_info=True)
216
- error_msg = "Prediction failed. Check logs for details."
217
- return error_msg, None, "Error", None, hidden_download
218
-
219
- def format_summary(original: str, summary: str, *, notice: str = "") -> str:
220
- """Format original and summary text for display."""
221
- html = f"""
222
- <div style="padding: 10px; border-radius: 5px; color: #111;">
223
- <h3 style="color: #111;">Original Text</h3>
224
- <p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px; color: #111; white-space: pre-wrap;">
 
 
225
  {original}
226
  </p>
227
- <h3 style="color: #111;">Summary</h3>
228
- <p style="background-color: #e6f3ff; padding: 10px; border-radius: 3px; color: #111; white-space: pre-wrap;">
229
  {summary}
230
  </p>
231
- {notice}
 
 
232
  </div>
233
- """
234
- return dedent(html).strip()
235
 
236
- def create_emotion_plot(
237
- emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
238
- ) -> Figure | None:
239
- """Create a horizontal bar chart for emotion predictions."""
240
- if isinstance(emotions, EmotionPrediction):
241
- raw_pairs = list(zip(emotions.labels, emotions.scores))
242
- else:
243
- raw_pairs = list(zip(emotions.get("labels", []), emotions.get("scores", [])))
244
-
245
- pairs: list[tuple[str, float]] = []
246
- for label, score in raw_pairs:
247
- try:
248
- score_val = float(score)
249
- except (TypeError, ValueError):
250
- continue
251
- pairs.append((str(label), score_val))
252
-
253
- if not pairs:
254
- return None
255
-
256
- pairs = sorted(pairs, key=lambda item: item[1], reverse=True)
257
- filtered = [item for item in pairs if item[1] >= 0.2]
258
- if not filtered:
259
- filtered = pairs[:3]
260
-
261
- labels = [label for label, _ in filtered[:5]]
262
- scores = [score for _, score in filtered[:5]]
263
-
264
- df = pd.DataFrame({"Emotion": labels, "Probability": scores})
265
- fig, ax = plt.subplots(figsize=(8, 5))
266
- colors = sns.color_palette("Set2", len(labels))
267
  bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
268
- ax.set_xlabel("Probability", fontsize=12)
269
- ax.set_ylabel("Emotion", fontsize=12)
270
- ax.set_title("Emotion Detection Results", fontsize=14, fontweight="bold")
271
  ax.set_xlim(0, 1)
272
  for bar in bars:
273
  width = bar.get_width()
274
- ax.text(
275
- width,
276
- bar.get_y() + bar.get_height() / 2,
277
- f"{width:.2%}",
278
- ha="left",
279
- va="center",
280
- fontsize=10,
281
- bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
282
- )
283
  plt.tight_layout()
284
  return fig
285
 
 
286
  def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
287
- """Format topic prediction output as markdown."""
288
  if isinstance(topic, TopicPrediction):
289
  label = topic.label
290
- score = topic.confidence
291
  else:
292
  label = str(topic.get("label", "Unknown"))
293
- score = float(topic.get("score", 0.0))
294
-
295
  return f"""
296
  ### Predicted Topic
297
-
298
  **{label}**
299
-
300
- Confidence: {score:.2%}
301
- """
 
302
 
303
  def _clean_tokens(tokens: Iterable[str]) -> list[str]:
304
  cleaned: list[str] = []
@@ -307,105 +172,14 @@ def _clean_tokens(tokens: Iterable[str]) -> list[str]:
307
  cleaned.append(item.strip() if item.strip() else token)
308
  return cleaned
309
 
310
- def extract_emotion_pairs(
311
- emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
312
- ) -> list[tuple[str, float]]:
313
- if isinstance(emotions, EmotionPrediction):
314
- return list(zip(map(str, emotions.labels), map(float, emotions.scores)))
315
- labels = emotions.get("labels", [])
316
- scores = emotions.get("scores", [])
317
- return [(str(label), float(score)) for label, score in zip(labels, scores)]
318
-
319
- def filter_emotions(pairs: list[tuple[str, float]], text: str) -> EmotionPrediction:
320
- filtered: list[tuple[str, float]] = []
321
- lowered_text = text.lower()
322
-
323
- for label, score in pairs:
324
- threshold = EMOTION_THRESHOLDS.get(label, 0.5)
325
- if score < threshold:
326
- continue
327
-
328
- if label == "love":
329
- keywords = EMOTION_KEYWORDS.get("love", set())
330
- if score < 0.6 and not any(keyword in lowered_text for keyword in keywords):
331
- continue
332
-
333
- filtered.append((label, score))
334
-
335
- if filtered:
336
- labels, scores = zip(*filtered)
337
- return EmotionPrediction(labels=list(labels), scores=list(scores))
338
-
339
- return EmotionPrediction(labels=[], scores=[])
340
-
341
- def summary_is_plausible(
342
- summary: str,
343
- original: str,
344
- *,
345
- min_overlap: float = 0.2,
346
- min_unique_ratio: float = 0.3,
347
- max_repeat_ratio: float = 0.6,
348
- ) -> bool:
349
- """Heuristic filter to catch off-topic or repetitive neural summaries."""
350
-
351
- summary_tokens = re.findall(r"\w+", summary.lower())
352
- if not summary_tokens:
353
- return False
354
-
355
- summary_content = [token for token in summary_tokens if token not in STOPWORDS]
356
- if not summary_content:
357
- return False
358
-
359
- original_vocab = {token for token in re.findall(r"\w+", original.lower()) if token not in STOPWORDS}
360
- overlap = sum(1 for token in summary_content if token in original_vocab)
361
- overlap_ratio = overlap / max(1, len(summary_content))
362
- if overlap_ratio < min_overlap:
363
- return False
364
-
365
- token_counts = Counter(summary_content)
366
- most_common_ratio = token_counts.most_common(1)[0][1] / len(summary_content)
367
- unique_ratio = len(token_counts) / len(summary_content)
368
- if unique_ratio < min_unique_ratio:
369
- return False
370
- if most_common_ratio > max_repeat_ratio:
371
- return False
372
- return True
373
-
374
- def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
375
- """Build a lightweight extractive summary when the model generates nothing."""
376
- if not text.strip():
377
- return ""
378
-
379
- sections = [segment.strip() for segment in text.replace("\n", " ").split(". ") if segment.strip()]
380
- if not sections:
381
- return text.strip()[:max_chars]
382
-
383
- summary_fragments: list[str] = []
384
- chars_used = 0
385
- for segment in sections:
386
- candidate = segment if segment.endswith(".") else f"{segment}."
387
- if chars_used + len(candidate) > max_chars and summary_fragments:
388
- break
389
- summary_fragments.append(candidate)
390
- chars_used += len(candidate)
391
-
392
- fallback = " ".join(summary_fragments)
393
- if not fallback:
394
- fallback = text.strip()[:max_chars]
395
- return fallback
396
 
397
  def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
398
- """Generate a seaborn heatmap of decoder cross-attention averaged over heads."""
399
- if not summary:
400
- return None
401
  try:
402
  batch = pipeline.preprocessor.batch_encode([text])
403
  batch = pipeline._batch_to_device(batch)
404
  src_ids = batch.input_ids
405
  src_mask = batch.attention_mask
406
- encoder_mask = None
407
- if src_mask is not None:
408
- encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
409
 
410
  with torch.inference_mode():
411
  memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
@@ -423,11 +197,11 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
423
  memory_mask=memory_mask,
424
  collect_attn=True,
425
  )
 
426
  if not attn_list:
427
  return None
428
- cross_attn = attn_list[-1]["cross"] # (B, heads, T, S)
429
  attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
430
-
431
  source_len = batch.lengths[0]
432
  attn_matrix = attn_matrix[:target_len, :source_len]
433
 
@@ -439,7 +213,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
439
  pipeline.tokenizer.bos_token_id,
440
  pipeline.tokenizer.eos_token_id,
441
  }
442
- keep_indices = [index for index, token_id in enumerate(target_id_list) if token_id not in special_ids]
443
  if not keep_indices:
444
  return None
445
 
@@ -447,10 +221,9 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
447
  tokenizer_impl = pipeline.tokenizer.tokenizer
448
  convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
449
  if convert_tokens is None:
450
- logger.warning("Tokenizer does not expose convert_ids_to_tokens; skipping attention heatmap.")
451
  return None
452
 
453
- summary_tokens_raw = convert_tokens([target_id_list[index] for index in keep_indices])
454
  source_tokens_raw = convert_tokens(source_ids)
455
 
456
  summary_tokens = _clean_tokens(summary_tokens_raw)
@@ -477,14 +250,13 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
477
 
478
  except Exception as exc:
479
  logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
480
- return None
481
 
482
 
483
- def render_unavailable_message(message: str) -> Figure:
484
- """Render a simple Matplotlib figure containing an informational message."""
485
  fig, ax = plt.subplots(figsize=(6, 2))
486
  ax.axis("off")
487
- ax.text(0.5, 0.5, message, ha="center", va="center", fontsize=11, wrap=True)
488
  fig.tight_layout()
489
  return fig
490
 
@@ -494,12 +266,7 @@ def prepare_download(
494
  summary: str,
495
  emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
496
  topic: TopicPrediction | dict[str, float | str],
497
- *,
498
- neural_summary: str | None = None,
499
- fallback_summary: str | None = None,
500
- raw_emotions: Sequence[tuple[str, float]] | None = None,
501
  ) -> str:
502
- """Persist JSON payload to a temporary file and return its path for download."""
503
  if isinstance(emotions, EmotionPrediction):
504
  emotion_payload = {
505
  "labels": list(emotions.labels),
@@ -512,10 +279,7 @@ def prepare_download(
512
  }
513
 
514
  if isinstance(topic, TopicPrediction):
515
- topic_payload = {
516
- "label": topic.label,
517
- "confidence": topic.confidence,
518
- }
519
  else:
520
  topic_payload = {
521
  "label": str(topic.get("label", topic.get("topic", "Unknown"))),
@@ -525,106 +289,99 @@ def prepare_download(
525
  payload = {
526
  "original_text": text,
527
  "summary": summary,
528
- "neural_summary": neural_summary,
529
- "fallback_summary": fallback_summary,
530
  "emotions": emotion_payload,
531
  "topic": topic_payload,
532
  }
533
- if raw_emotions is not None:
534
- payload["raw_emotions"] = [
535
- {"label": label, "score": float(score)} for label, score in raw_emotions
536
- ]
537
  with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
538
  json.dump(payload, handle, ensure_ascii=False, indent=2)
539
- temp_path = handle.name
540
- return temp_path
541
-
542
- # Sample data for the demo
543
- SAMPLE_TEXT = """
544
- Artificial intelligence is rapidly transforming the technology landscape.
545
- Machine learning algorithms are now capable of processing vast amounts of data,
546
- identifying patterns, and making predictions with unprecedented accuracy.
547
- From healthcare diagnostics to financial forecasting, AI applications are
548
- revolutionizing industries worldwide. However, ethical considerations around
549
- privacy, bias, and transparency remain critical challenges that must be addressed
550
- as these technologies continue to evolve.
551
- """
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  def create_interface() -> gr.Blocks:
554
  with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
555
- gr.Markdown("""
556
- # LexiMind NLP Pipeline Demo
557
-
558
- **Full pipleine for text summarization, emotion detection, and topic prediction.**
559
-
560
- Enter text below and adjust compressoin to see the results.
561
- """)
 
 
562
  with gr.Row():
563
- # Left column - Input
564
  with gr.Column(scale=1):
565
- gr.Markdown("### Input")
566
  input_text = gr.Textbox(
567
- label="Enter text",
568
- placeholder="Paste or type your text here...",
569
  lines=10,
570
- value=SAMPLE_TEXT
571
- )
572
- token_count = gr.Textbox(
573
- label="Token Count",
574
- value="Tokens: 0",
575
- interactive=False
576
  )
 
577
  compression = gr.Slider(
578
  minimum=20,
579
  maximum=80,
580
  value=50,
581
  step=5,
582
  label="Compression %",
583
- info="Higher = shorter summary"
584
  )
585
- predict_btn = gr.Button("🚀 Analyze", variant="primary", size="lg")
586
- # Right column - Outputs
587
  with gr.Column(scale=2):
588
- gr.Markdown("### Result")
589
  with gr.Tabs():
590
  with gr.TabItem("Summary"):
591
  summary_output = gr.HTML(label="Summary")
592
  with gr.TabItem("Emotions"):
593
- emotion_output = gr.Plot(label="Emotion Analysis")
594
  with gr.TabItem("Topic"):
595
  topic_output = gr.Markdown(label="Topic Prediction")
596
- with gr.TabItem("Attention Heatmap"):
597
- attention_output = gr.Plot(label="Attention Weights")
598
- gr.Markdown("*Visualizes which parts of the input the model focused on.*")
599
- # Download section
600
- gr.Markdown("### Export Results")
601
- download_btn = gr.DownloadButton(
602
- "Download Results (JSON)",
603
- visible=False,
604
- )
605
- # Event Handlers
606
- input_text.change(
607
- fn=count_tokens,
608
- inputs=[input_text],
609
- outputs=[token_count]
610
- )
611
- predict_btn.click(
612
- fn=predict,
613
- inputs=[input_text, compression],
614
- outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
615
- )
616
- # Examples
617
- gr.Examples(
618
- examples=[
619
- [SAMPLE_TEXT, 50],
620
- [
621
- "Climate change poses significant risks to global ecosystems. Rising temperatures, melting ice caps, and extreme weather events are becoming more frequent. Scientists urge immediate action to reduce carbon emissions and transition to renewable energy sources.",
622
- 40,
623
- ],
624
- ],
625
- inputs=[input_text, compression],
626
- label="Try these examples:",
627
- )
628
  return demo
629
 
630
 
@@ -635,14 +392,8 @@ app = demo
635
  if __name__ == "__main__":
636
  try:
637
  get_pipeline()
638
- demo.queue().launch(
639
- share=True,
640
- server_name="0.0.0.0",
641
- server_port=7860,
642
- show_error=True,
643
- )
644
- except Exception as e:
645
- logger.error("Failed to launch demo: %s", e, exc_info=True)
646
- print(f"Error: {e}")
647
- sys.exit(1)
648
 
 
1
  """
2
+ Minimal Gradio demo for the LexiMind multitask model.
3
+ Shows raw model outputs without any post-processing tricks.
4
  """
5
+ from __future__ import annotations
6
+
7
  import json
 
8
  import sys
9
  from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
  from typing import Iterable, Sequence
 
 
12
 
13
  import gradio as gr
14
  from gradio.themes import Soft
 
18
  import torch
19
  from matplotlib.figure import Figure
20
 
21
+ # Make local packages importable when running the script directly
22
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
23
+ if str(PROJECT_ROOT) not in sys.path:
24
+ sys.path.insert(0, str(PROJECT_ROOT))
25
 
26
  from src.inference.factory import create_inference_pipeline
27
  from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
 
30
  configure_logging()
31
  logger = get_logger(__name__)
32
 
33
+ _pipeline: InferencePipeline | None = None
34
+
35
+ VISUALIZATION_DIR = PROJECT_ROOT / "outputs"
36
+ VISUALIZATION_ASSETS: list[tuple[str, str]] = [
37
+ ("attention_visualization.png", "Attention weights (single head)"),
38
+ ("multihead_attention_visualization.png", "Multi-head attention comparison"),
39
+ ("single_vs_multihead.png", "Single vs multi-head attention"),
40
+ ("positional_encoding_heatmap.png", "Positional encoding heatmap"),
41
+ ]
42
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def get_pipeline() -> InferencePipeline:
45
+ global _pipeline
 
46
  if _pipeline is None:
47
+ logger.info("Loading inference pipeline ...")
48
+ _pipeline, _ = create_inference_pipeline(
49
+ tokenizer_dir="artifacts/hf_tokenizer/",
50
+ checkpoint_path="checkpoints/best.pt",
51
+ labels_path="artifacts/labels.json",
52
+ )
53
+ logger.info("Pipeline loaded")
 
 
 
 
 
 
54
  return _pipeline
55
 
56
+
57
+ def map_compression_to_length(compression: int, max_model_length: int = 512) -> int:
58
+ ratio = (100 - compression) / 100
59
+ return max(16, int(ratio * max_model_length))
60
+
61
+
62
  def count_tokens(text: str) -> str:
 
63
  if not text:
64
  return "Tokens: 0"
65
+ try:
66
  pipeline = get_pipeline()
67
+ return f"Tokens: {len(pipeline.tokenizer.encode(text))}"
68
+ except Exception as exc: # pragma: no cover - surfaced in UI
69
+ logger.error("Token counting failed: %s", exc, exc_info=True)
 
70
  return "Token count unavailable"
71
+
 
 
 
 
 
 
 
 
72
 
73
  def predict(text: str, compression: int):
 
74
  hidden_download = gr.update(value=None, visible=False)
75
  if not text or not text.strip():
76
  return (
77
+ "Please enter text to analyze.",
78
  None,
79
+ "No topic prediction available.",
80
  None,
81
  hidden_download,
82
  )
83
+
84
  try:
85
  pipeline = get_pipeline()
86
  max_len = map_compression_to_length(compression)
87
+ logger.info("Generating summary with max length %s", max_len)
88
 
89
+ summary = pipeline.summarize([text], max_length=max_len)[0].strip()
90
+ emotions = pipeline.predict_emotions([text])[0]
 
91
  topic = pipeline.predict_topics([text])[0]
92
 
93
+ summary_html = format_summary(text, summary)
94
+ emotion_plot = create_emotion_plot(emotions)
95
+ topic_markdown = format_topic(topic)
96
+ if summary:
97
+ attention_fig = create_attention_heatmap(text, summary, pipeline)
 
98
  else:
99
+ attention_fig = render_message_figure("Attention heatmap unavailable: summary was empty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ download_path = prepare_download(text, summary, emotions, topic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  download_update = gr.update(value=download_path, visible=True)
103
 
104
+ return summary_html, emotion_plot, topic_markdown, attention_fig, download_update
105
 
106
  except Exception as exc: # pragma: no cover - surfaced in UI
107
  logger.error("Prediction error: %s", exc, exc_info=True)
108
+ return "Prediction failed. Check logs for details.", None, "Error", None, hidden_download
109
+
110
+
111
+ def format_summary(original: str, summary: str) -> str:
112
+ if not summary:
113
+ summary = "(Model returned an empty summary. Consider retraining the summarization head.)"
114
+
115
+ return f"""
116
+ <div style="padding: 12px; border-radius: 6px; background-color: #fafafa; color: #222;">
117
+ <h3 style="margin-top: 0; color: #222;">Original Text</h3>
118
+ <p style="background-color: #f0f0f0; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #222;">
119
  {original}
120
  </p>
121
+ <h3 style="color: #222;">Model Summary</h3>
122
+ <p style="background-color: #e6f3ff; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #111;">
123
  {summary}
124
  </p>
125
+ <p style="margin-top: 12px; color: #6b7280; font-size: 0.9rem;">
126
+ Outputs are shown exactly as produced by the checkpoint.
127
+ </p>
128
  </div>
129
+ """.strip()
 
130
 
131
+
132
+ def create_emotion_plot(emotions: EmotionPrediction) -> Figure | None:
133
+ if not emotions.labels:
134
+ return render_message_figure("No emotions cleared the model threshold.")
135
+
136
+ df = pd.DataFrame({"Emotion": emotions.labels, "Probability": emotions.scores}).sort_values(
137
+ "Probability", ascending=True
138
+ )
139
+ fig, ax = plt.subplots(figsize=(6, 4))
140
+ colors = sns.color_palette("crest", len(df))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
142
+ ax.set_xlabel("Probability")
143
+ ax.set_title("Emotion Scores")
 
144
  ax.set_xlim(0, 1)
145
  for bar in bars:
146
  width = bar.get_width()
147
+ ax.text(width + 0.02, bar.get_y() + bar.get_height() / 2, f"{width:.2%}", va="center")
 
 
 
 
 
 
 
 
148
  plt.tight_layout()
149
  return fig
150
 
151
+
152
  def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
 
153
  if isinstance(topic, TopicPrediction):
154
  label = topic.label
155
+ confidence = topic.confidence
156
  else:
157
  label = str(topic.get("label", "Unknown"))
158
+ confidence = float(topic.get("score", 0.0))
 
159
  return f"""
160
  ### Predicted Topic
161
+
162
  **{label}**
163
+
164
+ Confidence: {confidence:.2%}
165
+ """.strip()
166
+
167
 
168
  def _clean_tokens(tokens: Iterable[str]) -> list[str]:
169
  cleaned: list[str] = []
 
172
  cleaned.append(item.strip() if item.strip() else token)
173
  return cleaned
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
 
 
 
177
  try:
178
  batch = pipeline.preprocessor.batch_encode([text])
179
  batch = pipeline._batch_to_device(batch)
180
  src_ids = batch.input_ids
181
  src_mask = batch.attention_mask
182
+ encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
 
 
183
 
184
  with torch.inference_mode():
185
  memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
 
197
  memory_mask=memory_mask,
198
  collect_attn=True,
199
  )
200
+
201
  if not attn_list:
202
  return None
203
+ cross_attn = attn_list[-1]["cross"]
204
  attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
 
205
  source_len = batch.lengths[0]
206
  attn_matrix = attn_matrix[:target_len, :source_len]
207
 
 
213
  pipeline.tokenizer.bos_token_id,
214
  pipeline.tokenizer.eos_token_id,
215
  }
216
+ keep_indices = [idx for idx, token_id in enumerate(target_id_list) if token_id not in special_ids]
217
  if not keep_indices:
218
  return None
219
 
 
221
  tokenizer_impl = pipeline.tokenizer.tokenizer
222
  convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
223
  if convert_tokens is None:
 
224
  return None
225
 
226
+ summary_tokens_raw = convert_tokens([target_id_list[idx] for idx in keep_indices])
227
  source_tokens_raw = convert_tokens(source_ids)
228
 
229
  summary_tokens = _clean_tokens(summary_tokens_raw)
 
250
 
251
  except Exception as exc:
252
  logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
253
+ return render_message_figure("Unable to render attention heatmap for this example.")
254
 
255
 
256
+ def render_message_figure(message: str) -> Figure:
 
257
  fig, ax = plt.subplots(figsize=(6, 2))
258
  ax.axis("off")
259
+ ax.text(0.5, 0.5, message, ha="center", va="center", wrap=True)
260
  fig.tight_layout()
261
  return fig
262
 
 
266
  summary: str,
267
  emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
268
  topic: TopicPrediction | dict[str, float | str],
 
 
 
 
269
  ) -> str:
 
270
  if isinstance(emotions, EmotionPrediction):
271
  emotion_payload = {
272
  "labels": list(emotions.labels),
 
279
  }
280
 
281
  if isinstance(topic, TopicPrediction):
282
+ topic_payload = {"label": topic.label, "confidence": topic.confidence}
 
 
 
283
  else:
284
  topic_payload = {
285
  "label": str(topic.get("label", topic.get("topic", "Unknown"))),
 
289
  payload = {
290
  "original_text": text,
291
  "summary": summary,
 
 
292
  "emotions": emotion_payload,
293
  "topic": topic_payload,
294
  }
295
+
 
 
 
296
  with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
297
  json.dump(payload, handle, ensure_ascii=False, indent=2)
298
+ return handle.name
299
+
300
+
301
+ def load_visualization_gallery() -> list[tuple[str, str]]:
302
+ """Collect visualization images produced by model tests."""
303
+ items: list[tuple[str, str]] = []
304
+ for filename, label in VISUALIZATION_ASSETS:
305
+ path = VISUALIZATION_DIR / filename
306
+ if path.exists():
307
+ items.append((str(path), label))
308
+ else:
309
+ logger.debug("Visualization asset missing: %s", path)
310
+ return items
311
+
312
+
313
+ SAMPLE_TEXT = (
314
+ "Artificial intelligence is rapidly transforming the technology landscape. "
315
+ "Machine learning algorithms are now capable of processing vast amounts of data, "
316
+ "identifying patterns, and making predictions with unprecedented accuracy. "
317
+ "From healthcare diagnostics to financial forecasting, AI applications are "
318
+ "revolutionizing industries worldwide. However, ethical considerations around "
319
+ "privacy, bias, and transparency remain critical challenges that must be addressed "
320
+ "as these technologies continue to evolve."
321
+ )
322
+
323
 
324
  def create_interface() -> gr.Blocks:
325
  with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
326
+ gr.Markdown(
327
+ """
328
+ # LexiMind NLP Demo
329
+
330
+ This demo streams the raw outputs from the saved LexiMind checkpoint.
331
+ Results may be noisy; retraining is recommended for production use.
332
+ """
333
+ )
334
+
335
  with gr.Row():
 
336
  with gr.Column(scale=1):
 
337
  input_text = gr.Textbox(
338
+ label="Input Text",
 
339
  lines=10,
340
+ value=SAMPLE_TEXT,
341
+ placeholder="Paste or type your text here...",
 
 
 
 
342
  )
343
+ token_box = gr.Textbox(label="Token Count", value="Tokens: 0", interactive=False)
344
  compression = gr.Slider(
345
  minimum=20,
346
  maximum=80,
347
  value=50,
348
  step=5,
349
  label="Compression %",
350
+ info="Higher values request shorter summaries.",
351
  )
352
+ analyze_btn = gr.Button("Run Analysis", variant="primary")
353
+
354
  with gr.Column(scale=2):
 
355
  with gr.Tabs():
356
  with gr.TabItem("Summary"):
357
  summary_output = gr.HTML(label="Summary")
358
  with gr.TabItem("Emotions"):
359
+ emotion_output = gr.Plot(label="Emotion Probabilities")
360
  with gr.TabItem("Topic"):
361
  topic_output = gr.Markdown(label="Topic Prediction")
362
+ with gr.TabItem("Attention"):
363
+ attention_output = gr.Plot(label="Attention Heatmap")
364
+ gr.Markdown("*Shows decoder attention if a summary is available.*")
365
+ with gr.TabItem("Model Visuals"):
366
+ visuals = gr.Gallery(
367
+ label="Test Visualizations",
368
+ value=load_visualization_gallery(),
369
+ columns=2,
370
+ height=400,
371
+ )
372
+ gr.Markdown(
373
+ "These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
374
+ )
375
+ gr.Markdown("### Download Results")
376
+ download_btn = gr.DownloadButton("Download JSON", visible=False)
377
+
378
+ input_text.change(fn=count_tokens, inputs=[input_text], outputs=[token_box])
379
+ analyze_btn.click(
380
+ fn=predict,
381
+ inputs=[input_text, compression],
382
+ outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
383
+ )
384
+
 
 
 
 
 
 
 
 
 
385
  return demo
386
 
387
 
 
392
  if __name__ == "__main__":
393
  try:
394
  get_pipeline()
395
+ demo.queue().launch(share=False)
396
+ except Exception as exc: # pragma: no cover - surfaced in console
397
+ logger.error("Failed to launch demo: %s", exc, exc_info=True)
398
+ raise
 
 
 
 
 
 
399
 
src/inference/pipeline.py CHANGED
@@ -75,13 +75,15 @@ class InferencePipeline:
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
- generated = self._constrained_greedy_decode(memory, max_len, memory_mask=src_mask)
 
 
 
 
 
 
79
 
80
- trimmed_sequences: List[List[int]] = []
81
- for row in generated.cpu().tolist():
82
- trimmed_sequences.append(self._trim_special_tokens(row))
83
-
84
- return self.tokenizer.decode_batch(trimmed_sequences)
85
 
86
  def predict_emotions(
87
  self,
@@ -96,7 +98,7 @@ class InferencePipeline:
96
 
97
  batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
98
  model_inputs = self._batch_to_model_inputs(batch)
99
- decision_threshold = self.config.emotion_threshold if threshold is None else float(threshold)
100
 
101
  with torch.inference_mode():
102
  logits = self.model.forward("emotion", model_inputs)
@@ -146,62 +148,6 @@ class InferencePipeline:
146
  "topic": self.predict_topics(text_list),
147
  }
148
 
149
- def _constrained_greedy_decode(
150
- self,
151
- memory: torch.Tensor,
152
- max_len: int,
153
- *,
154
- memory_mask: torch.Tensor | None = None,
155
- ) -> torch.Tensor:
156
- """Run greedy decoding while banning BOS/PAD tokens from the generated sequence."""
157
-
158
- device = memory.device
159
- batch_size = memory.size(0)
160
- bos = self.tokenizer.bos_token_id
161
- pad = getattr(self.tokenizer, "pad_token_id", None)
162
- eos = getattr(self.tokenizer, "eos_token_id", None)
163
-
164
- generated = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)
165
- expanded_memory_mask = None
166
- if memory_mask is not None:
167
- expanded_memory_mask = memory_mask.to(device=device, dtype=torch.bool)
168
-
169
- for _ in range(max(1, max_len) - 1):
170
- decoder_out = self.model.decoder(generated, memory, memory_mask=expanded_memory_mask)
171
- logits = decoder_out if isinstance(decoder_out, torch.Tensor) else decoder_out[0]
172
-
173
- step_logits = logits[:, -1, :].clone()
174
- if bos is not None and bos < step_logits.size(-1):
175
- step_logits[:, bos] = float("-inf")
176
- if pad is not None and pad < step_logits.size(-1):
177
- step_logits[:, pad] = float("-inf")
178
-
179
- next_token = step_logits.argmax(dim=-1, keepdim=True)
180
- generated = torch.cat([generated, next_token], dim=1)
181
-
182
- if eos is not None and torch.all(next_token.squeeze(-1) == eos):
183
- break
184
-
185
- return generated
186
-
187
- def _trim_special_tokens(self, sequence: Sequence[int]) -> List[int]:
188
- """Remove leading BOS and trailing PAD/EOS tokens from a generated sequence."""
189
-
190
- bos = self.tokenizer.bos_token_id
191
- pad = getattr(self.tokenizer, "pad_token_id", None)
192
- eos = getattr(self.tokenizer, "eos_token_id", None)
193
-
194
- trimmed: List[int] = []
195
- for idx, token in enumerate(sequence):
196
- if idx == 0 and bos is not None and token == bos:
197
- continue
198
- if pad is not None and token == pad:
199
- continue
200
- if eos is not None and token == eos:
201
- break
202
- trimmed.append(int(token))
203
- return trimmed
204
-
205
  def _batch_to_device(self, batch: Batch) -> Batch:
206
  tensor_updates: dict[str, torch.Tensor] = {}
207
  for item in fields(batch):
 
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
+ generated = self.model.decoder.greedy_decode(
79
+ memory=memory,
80
+ max_len=max_len,
81
+ start_token_id=self.tokenizer.bos_token_id,
82
+ end_token_id=self.tokenizer.eos_token_id,
83
+ device=self.device,
84
+ )
85
 
86
+ return self.tokenizer.decode_batch(generated.tolist())
 
 
 
 
87
 
88
  def predict_emotions(
89
  self,
 
98
 
99
  batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
100
  model_inputs = self._batch_to_model_inputs(batch)
101
+ decision_threshold = threshold or self.config.emotion_threshold
102
 
103
  with torch.inference_mode():
104
  logits = self.model.forward("emotion", model_inputs)
 
148
  "topic": self.predict_topics(text_list),
149
  }
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def _batch_to_device(self, batch: Batch) -> Batch:
152
  tensor_updates: dict[str, torch.Tensor] = {}
153
  for item in fields(batch):