LexiMind / scripts /demo_gradio.py
OliverPerrin's picture
fix: initialize gradio demo for spaces
bc94c66
raw
history blame
15.3 kB
"""
Gradio Demo interface for LexiMind NLP pipeline.
Showcases summarization, emotion detection, and topic prediction.
"""
import json
import sys
from io import StringIO
from pathlib import Path
from typing import Iterable, Sequence
import gradio as gr
from gradio.themes import Soft
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from matplotlib.figure import Figure
# Add project root to the path, going up two folder levels from this file
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.inference.factory import create_inference_pipeline
from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
from src.utils.logging import configure_logging, get_logger
configure_logging()
logger = get_logger(__name__)
_pipeline: InferencePipeline | None = None # Global pipeline instance
_label_metadata = None # Cached label metadata
def get_pipeline() -> InferencePipeline:
"""Lazy Loading and Caching the inference pipeline"""
global _pipeline, _label_metadata
if _pipeline is None:
try:
logger.info("Loading inference pipeline...")
pipeline, label_metadata = create_inference_pipeline(
tokenizer_dir="data/tokenization",
checkpoint_path="checkpoints/best.pt",
labels_path="data/labels.json",
)
_pipeline = pipeline
_label_metadata = label_metadata
logger.info("Pipeline loaded successfully")
except Exception as e:
logger.error(f"Failed to load pipeline: {e}")
raise RuntimeError("Could not initialize inference pipeline. Check logs for details.")
return _pipeline
def count_tokens(text: str) -> str:
"""Count tokens in the input text."""
if not text:
return "Tokens: 0"
try:
pipeline = get_pipeline()
token_count = len(pipeline.tokenizer.encode(text))
return f"Tokens: {token_count}"
except Exception as e:
logger.error(f"Token counting error: {e}")
return "Token count unavailable"
def map_compression_to_length(compression: int, max_model_length: int = 512):
"""
Map Compression slider (20-80%) to max summary length.
Higher compression = shorter summary output.
"""
# Invert, 20% compression = 80% of max length
ratio = (100 - compression) / 100
return int(ratio * max_model_length)
def predict(text: str, compression: int):
"""
Main predcition function for the Gradio interface.
Args:
text: Text to process
compression: Compression percentage (20-80)
Returns:
Tuple of (summary_html, emotion_plot, topic_output, attention_fig, download_data)
"""
if not text or not text.strip():
return ("Please enter some text to analyze.",
None,
"No topic prediction available",
None,
None)
try:
pipeline = get_pipeline()
max_len = map_compression_to_length(compression)
logger.info(f"Generating summary with max length of {max_len}")
# Get the predictions
summary = pipeline.summarize([text], max_length=max_len)[0]
emotions = pipeline.predict_emotions([text])[0]
topic = pipeline.predict_topics([text])[0]
summary_html = format_summary(text, summary)
emotion_plot = create_emotion_plot(emotions)
topic_output = format_topic(topic)
attention_fig = create_attention_heatmap(text, summary, pipeline)
download_data = prepare_download(text, summary, emotions, topic)
return summary_html, emotion_plot, topic_output, attention_fig, gr.update(
value=download_data,
visible=True,
)
except Exception as e:
logger.error(f"Prediction error: {e}", exc_info=True)
error_msg = "Prediction failed. Check logs for details."
return error_msg, None, "Error", None, None
def format_summary(original: str, summary:str) ->str:
"""Format original and summary text for display"""
html = f"""
<div style="padding: 10px; border-radius: 5px;">
<h3>Original Text</h3>
<p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
{original}
</p>
<h3>Summary</h3>
<p style="background-color: #e6f3ff; padding: 10px; border-radius: 3px;">
{summary}
</p>
</div>
"""
return html
def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]) -> Figure | None:
"""
Create bar plot for emotion predictions.
Args:
emotions: Dict with 'labels' and 'scores' keys
"""
if isinstance(emotions, EmotionPrediction):
labels = emotions.labels
scores = emotions.scores
else:
labels = list(emotions.get("labels", []))
scores = list(emotions.get("scores", []))
if not labels or not scores:
return None
df = pd.DataFrame({
"Emotion": labels,
"Probability": scores,
})
fig, ax = plt.subplots(figsize=(8, 5))
colors = sns.color_palette("Set2", len(labels))
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
ax.set_xlabel("Probability", fontsize=12)
ax.set_ylabel("Emotion", fontsize=12)
ax.set_title("Emotion Detection Results", fontsize=14, fontweight="bold")
ax.set_xlim(0, 1)
for bar in bars:
width = bar.get_width()
ax.text(
width,
bar.get_y() + bar.get_height() / 2,
f"{width:.2%}",
ha="left",
va="center",
fontsize=10,
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
)
plt.tight_layout()
return fig
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
"""
Format topic prediction output.
Args:
topic: Dict with 'label' and 'score' keys
"""
if isinstance(topic, TopicPrediction):
label = topic.label
score = topic.confidence
else:
label = str(topic.get("label", "Unknown"))
score = float(topic.get("score", 0.0))
output = f"""
### Predicted Topic
**{label}**
Confidence: {score:.2%}
"""
return output
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
cleaned: list[str] = []
for token in tokens:
item = token.replace("Ġ", " ").replace("▁", " ")
cleaned.append(item.strip() if item.strip() else token)
return cleaned
def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
"""Generate a seaborn heatmap of decoder cross-attention averaged over heads."""
if not summary:
return None
try:
batch = pipeline.preprocessor.batch_encode([text])
batch = pipeline._batch_to_device(batch)
src_ids = batch.input_ids
src_mask = batch.attention_mask
encoder_mask = None
if src_mask is not None:
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
with torch.inference_mode():
memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
target_enc = pipeline.tokenizer.batch_encode([summary])
target_ids = target_enc["input_ids"].to(pipeline.device)
target_mask = target_enc["attention_mask"].to(pipeline.device)
target_len = int(target_mask.sum().item())
decoder_inputs = pipeline.tokenizer.prepare_decoder_inputs(target_ids)
decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
target_ids = target_ids[:, :target_len]
memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
_, attn_list = pipeline.model.decoder(
decoder_inputs,
memory,
memory_mask=memory_mask,
collect_attn=True,
)
if not attn_list:
return None
cross_attn = attn_list[-1]["cross"] # (B, heads, T, S)
attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
source_len = batch.lengths[0]
attn_matrix = attn_matrix[:target_len, :source_len]
source_ids = src_ids[0, :source_len].tolist()
target_id_list = target_ids[0].tolist()
special_ids = {
pipeline.tokenizer.pad_token_id,
pipeline.tokenizer.bos_token_id,
pipeline.tokenizer.eos_token_id,
}
keep_indices = [index for index, token_id in enumerate(target_id_list) if token_id not in special_ids]
if not keep_indices:
return None
pruned_matrix = attn_matrix[keep_indices, :]
tokenizer_impl = pipeline.tokenizer.tokenizer
convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
if convert_tokens is None:
logger.warning("Tokenizer does not expose convert_ids_to_tokens; skipping attention heatmap.")
return None
summary_tokens_raw = convert_tokens([target_id_list[index] for index in keep_indices])
source_tokens_raw = convert_tokens(source_ids)
summary_tokens = _clean_tokens(summary_tokens_raw)
source_tokens = _clean_tokens(source_tokens_raw)
height = max(4.0, 0.4 * len(summary_tokens))
width = max(6.0, 0.4 * len(source_tokens))
fig, ax = plt.subplots(figsize=(width, height))
sns.heatmap(
pruned_matrix,
cmap="mako",
xticklabels=source_tokens,
yticklabels=summary_tokens,
ax=ax,
cbar_kws={"label": "Attention"},
)
ax.set_xlabel("Input Tokens")
ax.set_ylabel("Summary Tokens")
ax.set_title("Cross-Attention (decoder last layer)")
ax.tick_params(axis="x", rotation=90)
ax.tick_params(axis="y", rotation=0)
fig.tight_layout()
return fig
except Exception as exc:
logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
return None
def prepare_download(
text: str,
summary: str,
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
topic: TopicPrediction | dict[str, float | str],
) -> str:
"""Prepare JSON data for download."""
if isinstance(emotions, EmotionPrediction):
emotion_payload = {
"labels": list(emotions.labels),
"scores": list(emotions.scores),
}
else:
emotion_payload = emotions
if isinstance(topic, TopicPrediction):
topic_payload = {
"label": topic.label,
"confidence": topic.confidence,
}
else:
topic_payload = topic
data = {
"original_text": text,
"summary": summary,
"emotions": emotion_payload,
"topic": topic_payload,
}
return json.dumps(data, indent=2)
# Sample data for the demo
SAMPLE_TEXT = """
Artificial intelligence is rapidly transforming the technology landscape.
Machine learning algorithms are now capable of processing vast amounts of data,
identifying patterns, and making predictions with unprecedented accuracy.
From healthcare diagnostics to financial forecasting, AI applications are
revolutionizing industries worldwide. However, ethical considerations around
privacy, bias, and transparency remain critical challenges that must be addressed
as these technologies continue to evolve.
"""
def create_interface() -> gr.Blocks:
with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
gr.Markdown("""
# LexiMind NLP Pipeline Demo
**Full pipleine for text summarization, emotion detection, and topic prediction.**
Enter text below and adjust compressoin to see the results.
""")
with gr.Row():
# Left column - Input
with gr.Column(scale=1):
gr.Markdown("### Input")
input_text = gr.Textbox(
label="Enter text",
placeholder="Paste or type your text here...",
lines=10,
value=SAMPLE_TEXT
)
token_count = gr.Textbox(
label="Token Count",
value="Tokens: 0",
interactive=False
)
compression = gr.Slider(
minimum=20,
maximum=80,
value=50,
step=5,
label="Compression %",
info="Higher = shorter summary"
)
predict_btn = gr.Button("🚀 Analyze", variant="primary", size="lg")
# Right column - Outputs
with gr.Column(scale=2):
gr.Markdown("### Result")
with gr.Tabs():
with gr.TabItem("Summary"):
summary_output = gr.HTML(label="Summary")
with gr.TabItem("Emotions"):
emotion_output = gr.Plot(label="Emotion Analysis")
with gr.TabItem("Topic"):
topic_output = gr.Markdown(label="Topic Prediction")
with gr.TabItem("Attention Heatmap"):
attention_output = gr.Plot(label="Attention Weights")
gr.Markdown("*Visualizes which parts of the input the model focused on.*")
# Download section
gr.Markdown("### Export Results")
download_btn = gr.DownloadButton(
"Download Results (JSON)",
visible=False,
)
# Event Handlers
input_text.change(
fn=count_tokens,
inputs=[input_text],
outputs=[token_count]
)
predict_btn.click(
fn=predict,
inputs=[input_text, compression],
outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
)
# Examples
gr.Examples(
examples=[
[SAMPLE_TEXT, 50],
[
"Climate change poses significant risks to global ecosystems. Rising temperatures, melting ice caps, and extreme weather events are becoming more frequent. Scientists urge immediate action to reduce carbon emissions and transition to renewable energy sources.",
40,
],
],
inputs=[input_text, compression],
label="Try these examples:",
)
return demo
demo = create_interface()
app = demo
if __name__ == "__main__":
try:
get_pipeline()
demo.queue().launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)
except Exception as e:
logger.error("Failed to launch demo: %s", e, exc_info=True)
print(f"Error: {e}")
sys.exit(1)