Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Demo interface for LexiMind NLP pipeline. | |
| Showcases summarization, emotion detection, and topic prediction. | |
| """ | |
| import json | |
| import sys | |
| from io import StringIO | |
| from pathlib import Path | |
| from typing import Iterable, Sequence | |
| import gradio as gr | |
| from gradio.themes import Soft | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| import torch | |
| from matplotlib.figure import Figure | |
| # Add project root to the path, going up two folder levels from this file | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from src.inference.factory import create_inference_pipeline | |
| from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction | |
| from src.utils.logging import configure_logging, get_logger | |
| configure_logging() | |
| logger = get_logger(__name__) | |
| _pipeline: InferencePipeline | None = None # Global pipeline instance | |
| _label_metadata = None # Cached label metadata | |
| def get_pipeline() -> InferencePipeline: | |
| """Lazy Loading and Caching the inference pipeline""" | |
| global _pipeline, _label_metadata | |
| if _pipeline is None: | |
| try: | |
| logger.info("Loading inference pipeline...") | |
| pipeline, label_metadata = create_inference_pipeline( | |
| tokenizer_dir="data/tokenization", | |
| checkpoint_path="checkpoints/best.pt", | |
| labels_path="data/labels.json", | |
| ) | |
| _pipeline = pipeline | |
| _label_metadata = label_metadata | |
| logger.info("Pipeline loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load pipeline: {e}") | |
| raise RuntimeError("Could not initialize inference pipeline. Check logs for details.") | |
| return _pipeline | |
| def count_tokens(text: str) -> str: | |
| """Count tokens in the input text.""" | |
| if not text: | |
| return "Tokens: 0" | |
| try: | |
| pipeline = get_pipeline() | |
| token_count = len(pipeline.tokenizer.encode(text)) | |
| return f"Tokens: {token_count}" | |
| except Exception as e: | |
| logger.error(f"Token counting error: {e}") | |
| return "Token count unavailable" | |
| def map_compression_to_length(compression: int, max_model_length: int = 512): | |
| """ | |
| Map Compression slider (20-80%) to max summary length. | |
| Higher compression = shorter summary output. | |
| """ | |
| # Invert, 20% compression = 80% of max length | |
| ratio = (100 - compression) / 100 | |
| return int(ratio * max_model_length) | |
| def predict(text: str, compression: int): | |
| """ | |
| Main predcition function for the Gradio interface. | |
| Args: | |
| text: Text to process | |
| compression: Compression percentage (20-80) | |
| Returns: | |
| Tuple of (summary_html, emotion_plot, topic_output, attention_fig, download_data) | |
| """ | |
| if not text or not text.strip(): | |
| return ("Please enter some text to analyze.", | |
| None, | |
| "No topic prediction available", | |
| None, | |
| None) | |
| try: | |
| pipeline = get_pipeline() | |
| max_len = map_compression_to_length(compression) | |
| logger.info(f"Generating summary with max length of {max_len}") | |
| # Get the predictions | |
| summary = pipeline.summarize([text], max_length=max_len)[0] | |
| emotions = pipeline.predict_emotions([text])[0] | |
| topic = pipeline.predict_topics([text])[0] | |
| summary_html = format_summary(text, summary) | |
| emotion_plot = create_emotion_plot(emotions) | |
| topic_output = format_topic(topic) | |
| attention_fig = create_attention_heatmap(text, summary, pipeline) | |
| download_data = prepare_download(text, summary, emotions, topic) | |
| return summary_html, emotion_plot, topic_output, attention_fig, gr.update( | |
| value=download_data, | |
| visible=True, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}", exc_info=True) | |
| error_msg = "Prediction failed. Check logs for details." | |
| return error_msg, None, "Error", None, None | |
| def format_summary(original: str, summary:str) ->str: | |
| """Format original and summary text for display""" | |
| html = f""" | |
| <div style="padding: 10px; border-radius: 5px;"> | |
| <h3>Original Text</h3> | |
| <p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;"> | |
| {original} | |
| </p> | |
| <h3>Summary</h3> | |
| <p style="background-color: #e6f3ff; padding: 10px; border-radius: 3px;"> | |
| {summary} | |
| </p> | |
| </div> | |
| """ | |
| return html | |
| def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]) -> Figure | None: | |
| """ | |
| Create bar plot for emotion predictions. | |
| Args: | |
| emotions: Dict with 'labels' and 'scores' keys | |
| """ | |
| if isinstance(emotions, EmotionPrediction): | |
| labels = emotions.labels | |
| scores = emotions.scores | |
| else: | |
| labels = list(emotions.get("labels", [])) | |
| scores = list(emotions.get("scores", [])) | |
| if not labels or not scores: | |
| return None | |
| df = pd.DataFrame({ | |
| "Emotion": labels, | |
| "Probability": scores, | |
| }) | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| colors = sns.color_palette("Set2", len(labels)) | |
| bars = ax.barh(df["Emotion"], df["Probability"], color=colors) | |
| ax.set_xlabel("Probability", fontsize=12) | |
| ax.set_ylabel("Emotion", fontsize=12) | |
| ax.set_title("Emotion Detection Results", fontsize=14, fontweight="bold") | |
| ax.set_xlim(0, 1) | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text( | |
| width, | |
| bar.get_y() + bar.get_height() / 2, | |
| f"{width:.2%}", | |
| ha="left", | |
| va="center", | |
| fontsize=10, | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str: | |
| """ | |
| Format topic prediction output. | |
| Args: | |
| topic: Dict with 'label' and 'score' keys | |
| """ | |
| if isinstance(topic, TopicPrediction): | |
| label = topic.label | |
| score = topic.confidence | |
| else: | |
| label = str(topic.get("label", "Unknown")) | |
| score = float(topic.get("score", 0.0)) | |
| output = f""" | |
| ### Predicted Topic | |
| **{label}** | |
| Confidence: {score:.2%} | |
| """ | |
| return output | |
| def _clean_tokens(tokens: Iterable[str]) -> list[str]: | |
| cleaned: list[str] = [] | |
| for token in tokens: | |
| item = token.replace("Ġ", " ").replace("▁", " ") | |
| cleaned.append(item.strip() if item.strip() else token) | |
| return cleaned | |
| def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None: | |
| """Generate a seaborn heatmap of decoder cross-attention averaged over heads.""" | |
| if not summary: | |
| return None | |
| try: | |
| batch = pipeline.preprocessor.batch_encode([text]) | |
| batch = pipeline._batch_to_device(batch) | |
| src_ids = batch.input_ids | |
| src_mask = batch.attention_mask | |
| encoder_mask = None | |
| if src_mask is not None: | |
| encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) | |
| with torch.inference_mode(): | |
| memory = pipeline.model.encoder(src_ids, mask=encoder_mask) | |
| target_enc = pipeline.tokenizer.batch_encode([summary]) | |
| target_ids = target_enc["input_ids"].to(pipeline.device) | |
| target_mask = target_enc["attention_mask"].to(pipeline.device) | |
| target_len = int(target_mask.sum().item()) | |
| decoder_inputs = pipeline.tokenizer.prepare_decoder_inputs(target_ids) | |
| decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device) | |
| target_ids = target_ids[:, :target_len] | |
| memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None | |
| _, attn_list = pipeline.model.decoder( | |
| decoder_inputs, | |
| memory, | |
| memory_mask=memory_mask, | |
| collect_attn=True, | |
| ) | |
| if not attn_list: | |
| return None | |
| cross_attn = attn_list[-1]["cross"] # (B, heads, T, S) | |
| attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy() | |
| source_len = batch.lengths[0] | |
| attn_matrix = attn_matrix[:target_len, :source_len] | |
| source_ids = src_ids[0, :source_len].tolist() | |
| target_id_list = target_ids[0].tolist() | |
| special_ids = { | |
| pipeline.tokenizer.pad_token_id, | |
| pipeline.tokenizer.bos_token_id, | |
| pipeline.tokenizer.eos_token_id, | |
| } | |
| keep_indices = [index for index, token_id in enumerate(target_id_list) if token_id not in special_ids] | |
| if not keep_indices: | |
| return None | |
| pruned_matrix = attn_matrix[keep_indices, :] | |
| tokenizer_impl = pipeline.tokenizer.tokenizer | |
| convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None) | |
| if convert_tokens is None: | |
| logger.warning("Tokenizer does not expose convert_ids_to_tokens; skipping attention heatmap.") | |
| return None | |
| summary_tokens_raw = convert_tokens([target_id_list[index] for index in keep_indices]) | |
| source_tokens_raw = convert_tokens(source_ids) | |
| summary_tokens = _clean_tokens(summary_tokens_raw) | |
| source_tokens = _clean_tokens(source_tokens_raw) | |
| height = max(4.0, 0.4 * len(summary_tokens)) | |
| width = max(6.0, 0.4 * len(source_tokens)) | |
| fig, ax = plt.subplots(figsize=(width, height)) | |
| sns.heatmap( | |
| pruned_matrix, | |
| cmap="mako", | |
| xticklabels=source_tokens, | |
| yticklabels=summary_tokens, | |
| ax=ax, | |
| cbar_kws={"label": "Attention"}, | |
| ) | |
| ax.set_xlabel("Input Tokens") | |
| ax.set_ylabel("Summary Tokens") | |
| ax.set_title("Cross-Attention (decoder last layer)") | |
| ax.tick_params(axis="x", rotation=90) | |
| ax.tick_params(axis="y", rotation=0) | |
| fig.tight_layout() | |
| return fig | |
| except Exception as exc: | |
| logger.error("Unable to build attention heatmap: %s", exc, exc_info=True) | |
| return None | |
| def prepare_download( | |
| text: str, | |
| summary: str, | |
| emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]], | |
| topic: TopicPrediction | dict[str, float | str], | |
| ) -> str: | |
| """Prepare JSON data for download.""" | |
| if isinstance(emotions, EmotionPrediction): | |
| emotion_payload = { | |
| "labels": list(emotions.labels), | |
| "scores": list(emotions.scores), | |
| } | |
| else: | |
| emotion_payload = emotions | |
| if isinstance(topic, TopicPrediction): | |
| topic_payload = { | |
| "label": topic.label, | |
| "confidence": topic.confidence, | |
| } | |
| else: | |
| topic_payload = topic | |
| data = { | |
| "original_text": text, | |
| "summary": summary, | |
| "emotions": emotion_payload, | |
| "topic": topic_payload, | |
| } | |
| return json.dumps(data, indent=2) | |
| # Sample data for the demo | |
| SAMPLE_TEXT = """ | |
| Artificial intelligence is rapidly transforming the technology landscape. | |
| Machine learning algorithms are now capable of processing vast amounts of data, | |
| identifying patterns, and making predictions with unprecedented accuracy. | |
| From healthcare diagnostics to financial forecasting, AI applications are | |
| revolutionizing industries worldwide. However, ethical considerations around | |
| privacy, bias, and transparency remain critical challenges that must be addressed | |
| as these technologies continue to evolve. | |
| """ | |
| def create_interface() -> gr.Blocks: | |
| with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo: | |
| gr.Markdown(""" | |
| # LexiMind NLP Pipeline Demo | |
| **Full pipleine for text summarization, emotion detection, and topic prediction.** | |
| Enter text below and adjust compressoin to see the results. | |
| """) | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_text = gr.Textbox( | |
| label="Enter text", | |
| placeholder="Paste or type your text here...", | |
| lines=10, | |
| value=SAMPLE_TEXT | |
| ) | |
| token_count = gr.Textbox( | |
| label="Token Count", | |
| value="Tokens: 0", | |
| interactive=False | |
| ) | |
| compression = gr.Slider( | |
| minimum=20, | |
| maximum=80, | |
| value=50, | |
| step=5, | |
| label="Compression %", | |
| info="Higher = shorter summary" | |
| ) | |
| predict_btn = gr.Button("🚀 Analyze", variant="primary", size="lg") | |
| # Right column - Outputs | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Result") | |
| with gr.Tabs(): | |
| with gr.TabItem("Summary"): | |
| summary_output = gr.HTML(label="Summary") | |
| with gr.TabItem("Emotions"): | |
| emotion_output = gr.Plot(label="Emotion Analysis") | |
| with gr.TabItem("Topic"): | |
| topic_output = gr.Markdown(label="Topic Prediction") | |
| with gr.TabItem("Attention Heatmap"): | |
| attention_output = gr.Plot(label="Attention Weights") | |
| gr.Markdown("*Visualizes which parts of the input the model focused on.*") | |
| # Download section | |
| gr.Markdown("### Export Results") | |
| download_btn = gr.DownloadButton( | |
| "Download Results (JSON)", | |
| visible=False, | |
| ) | |
| # Event Handlers | |
| input_text.change( | |
| fn=count_tokens, | |
| inputs=[input_text], | |
| outputs=[token_count] | |
| ) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_text, compression], | |
| outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn], | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| [SAMPLE_TEXT, 50], | |
| [ | |
| "Climate change poses significant risks to global ecosystems. Rising temperatures, melting ice caps, and extreme weather events are becoming more frequent. Scientists urge immediate action to reduce carbon emissions and transition to renewable energy sources.", | |
| 40, | |
| ], | |
| ], | |
| inputs=[input_text, compression], | |
| label="Try these examples:", | |
| ) | |
| return demo | |
| demo = create_interface() | |
| app = demo | |
| if __name__ == "__main__": | |
| try: | |
| get_pipeline() | |
| demo.queue().launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to launch demo: %s", e, exc_info=True) | |
| print(f"Error: {e}") | |
| sys.exit(1) | |