Spaces:
Sleeping
Sleeping
Commit
·
f9edbb4
1
Parent(s):
1bdd1c1
Updated Summarizer, Preprocessor to run on my custom transformer and added basic streamlit frontend demo
Browse files- requirements-dev.txt +2 -1
- requirements.txt +5 -1
- src/api/inference/__init__.py +7 -0
- src/api/inference/inference.py +133 -0
- src/data/download.py +47 -41
- src/data/preprocessing.py +251 -254
- src/inference/__init__.py +7 -0
- src/inference/baseline_summarizer.py +39 -220
- src/models/__init__.py +33 -0
- src/ui/streamlit_app.py +108 -0
requirements-dev.txt
CHANGED
|
@@ -6,4 +6,5 @@ isort>=5.12.0
|
|
| 6 |
flake8>=6.0.0
|
| 7 |
mypy>=1.4.0
|
| 8 |
jupyter>=1.0.0
|
| 9 |
-
ipywidgets>=8.0.0
|
|
|
|
|
|
| 6 |
flake8>=6.0.0
|
| 7 |
mypy>=1.4.0
|
| 8 |
jupyter>=1.0.0
|
| 9 |
+
ipywidgets>=8.0.0
|
| 10 |
+
pre-commit>=3.4.0
|
requirements.txt
CHANGED
|
@@ -15,4 +15,8 @@ omegaconf>=2.3.0
|
|
| 15 |
tensorboard>=2.13.0
|
| 16 |
gradio>=3.35.0
|
| 17 |
requests>=2.31.0
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
tensorboard>=2.13.0
|
| 16 |
gradio>=3.35.0
|
| 17 |
requests>=2.31.0
|
| 18 |
+
kaggle>=1.5.12
|
| 19 |
+
streamlit>=1.25.0
|
| 20 |
+
plotly>=5.18.0
|
| 21 |
+
faiss-cpu==1.9.0; platform_system != "Windows"
|
| 22 |
+
faiss-cpu==1.9.0; platform_system == "Windows"
|
src/api/inference/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API inference module for LexiMind.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .inference import load_models, summarize_text, classify_emotion, topic_for_text
|
| 6 |
+
|
| 7 |
+
__all__ = ["load_models", "summarize_text", "classify_emotion", "topic_for_text"]
|
src/api/inference/inference.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal inference helpers that rely on the custom transformer stack."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from ...data.preprocessing import TextPreprocessor, TransformerTokenizer
|
| 11 |
+
from ...models.multitask import MultiTaskModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _load_tokenizer(tokenizer_path: Path) -> TransformerTokenizer:
|
| 15 |
+
if not tokenizer_path.exists():
|
| 16 |
+
raise FileNotFoundError(f"tokenizer file '{tokenizer_path}' not found")
|
| 17 |
+
return TransformerTokenizer.load(tokenizer_path)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_models(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 21 |
+
"""Load MultiTaskModel together with the tokenizer-driven preprocessor."""
|
| 22 |
+
|
| 23 |
+
device = torch.device(config.get("device", "cpu"))
|
| 24 |
+
tokenizer_path = config.get("tokenizer_path")
|
| 25 |
+
if tokenizer_path is None:
|
| 26 |
+
raise ValueError("'tokenizer_path' missing in config")
|
| 27 |
+
|
| 28 |
+
tokenizer = _load_tokenizer(Path(tokenizer_path))
|
| 29 |
+
preprocessor = TextPreprocessor(
|
| 30 |
+
max_length=int(config.get("max_length", 512)),
|
| 31 |
+
tokenizer=tokenizer,
|
| 32 |
+
min_freq=int(config.get("min_freq", 1)),
|
| 33 |
+
lowercase=bool(config.get("lowercase", True)),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
encoder_kwargs = dict(config.get("encoder", {}))
|
| 37 |
+
decoder_kwargs = dict(config.get("decoder", {}))
|
| 38 |
+
|
| 39 |
+
encoder = preprocessor.build_encoder(**encoder_kwargs)
|
| 40 |
+
decoder = preprocessor.build_decoder(**decoder_kwargs)
|
| 41 |
+
model = MultiTaskModel(encoder=encoder, decoder=decoder)
|
| 42 |
+
|
| 43 |
+
checkpoint_path = config.get("checkpoint_path")
|
| 44 |
+
if checkpoint_path:
|
| 45 |
+
state = torch.load(checkpoint_path, map_location=device)
|
| 46 |
+
if isinstance(state, dict) and "state_dict" in state:
|
| 47 |
+
state = state["state_dict"]
|
| 48 |
+
model.load_state_dict(state, strict=False)
|
| 49 |
+
|
| 50 |
+
model.to(device)
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"loaded": True,
|
| 54 |
+
"device": device,
|
| 55 |
+
"mt": model,
|
| 56 |
+
"preprocessor": preprocessor,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def summarize_text(
|
| 61 |
+
text: str,
|
| 62 |
+
compression: float = 0.25,
|
| 63 |
+
collect_attn: bool = False,
|
| 64 |
+
models: Optional[Dict[str, Any]] = None,
|
| 65 |
+
) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
|
| 66 |
+
if models is None or not models.get("loaded"):
|
| 67 |
+
raise RuntimeError("Models must be loaded via load_models before summarize_text is called")
|
| 68 |
+
|
| 69 |
+
model: MultiTaskModel = models["mt"]
|
| 70 |
+
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 71 |
+
device: torch.device = models["device"]
|
| 72 |
+
|
| 73 |
+
batch = preprocessor.batch_encode([text])
|
| 74 |
+
tokenizer = preprocessor.tokenizer
|
| 75 |
+
encoder = model.encoder
|
| 76 |
+
decoder = model.decoder
|
| 77 |
+
if tokenizer is None or encoder is None or decoder is None:
|
| 78 |
+
raise RuntimeError("Encoder, decoder, and tokenizer must be configured before summarization")
|
| 79 |
+
input_ids = batch.input_ids.to(device)
|
| 80 |
+
memory = encoder(input_ids)
|
| 81 |
+
src_len = batch.lengths[0]
|
| 82 |
+
max_tgt = max(4, int(src_len * compression))
|
| 83 |
+
generated = decoder.greedy_decode(
|
| 84 |
+
memory,
|
| 85 |
+
max_len=min(preprocessor.max_length, max_tgt),
|
| 86 |
+
start_token_id=tokenizer.bos_id,
|
| 87 |
+
end_token_id=tokenizer.eos_id,
|
| 88 |
+
)
|
| 89 |
+
summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
|
| 90 |
+
return summary.strip(), None if not collect_attn else {}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def classify_emotion(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[List[float], List[str]]:
|
| 94 |
+
if models is None or not models.get("loaded"):
|
| 95 |
+
raise RuntimeError("Models must be loaded via load_models before classify_emotion is called")
|
| 96 |
+
|
| 97 |
+
model: MultiTaskModel = models["mt"]
|
| 98 |
+
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 99 |
+
device: torch.device = models["device"]
|
| 100 |
+
|
| 101 |
+
batch = preprocessor.batch_encode([text])
|
| 102 |
+
input_ids = batch.input_ids.to(device)
|
| 103 |
+
result = model.forward("emotion", {"input_ids": input_ids})
|
| 104 |
+
logits = result[1] if isinstance(result, tuple) else result
|
| 105 |
+
scores = torch.sigmoid(logits).squeeze(0).detach().cpu().tolist()
|
| 106 |
+
labels = models.get("emotion_labels") or [
|
| 107 |
+
"joy",
|
| 108 |
+
"sadness",
|
| 109 |
+
"anger",
|
| 110 |
+
"fear",
|
| 111 |
+
"surprise",
|
| 112 |
+
"disgust",
|
| 113 |
+
]
|
| 114 |
+
return scores, labels[: len(scores)]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def topic_for_text(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[int, List[str]]:
|
| 118 |
+
if models is None or not models.get("loaded"):
|
| 119 |
+
raise RuntimeError("Models must be loaded via load_models before topic_for_text is called")
|
| 120 |
+
|
| 121 |
+
model: MultiTaskModel = models["mt"]
|
| 122 |
+
preprocessor: TextPreprocessor = models["preprocessor"]
|
| 123 |
+
device: torch.device = models["device"]
|
| 124 |
+
|
| 125 |
+
batch = preprocessor.batch_encode([text])
|
| 126 |
+
input_ids = batch.input_ids.to(device)
|
| 127 |
+
encoder = model.encoder
|
| 128 |
+
if encoder is None:
|
| 129 |
+
raise RuntimeError("Encoder must be configured before topic_for_text is called")
|
| 130 |
+
memory = encoder(input_ids)
|
| 131 |
+
embedding = memory.mean(dim=1).detach().cpu()
|
| 132 |
+
_ = embedding # placeholder for downstream clustering hook
|
| 133 |
+
return 0, ["topic_stub"]
|
src/data/download.py
CHANGED
|
@@ -1,57 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import requests
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
os.
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
with open(out_path, "wb") as f:
|
| 13 |
f.write(r.content)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
# Kaggle dataset download helpers
|
| 17 |
def download_emotion_dataset():
|
| 18 |
-
"""Download the emotions dataset from Kaggle."""
|
| 19 |
target_dir = "data/raw/emotion"
|
| 20 |
-
|
| 21 |
-
# Downloading using Kaggle Python API
|
| 22 |
-
kaggle.api.authenticate()
|
| 23 |
-
kaggle.api.dataset_download_files(
|
| 24 |
-
'praveengovi/emotions-dataset-for-nlp',
|
| 25 |
-
path=target_dir,
|
| 26 |
-
unzip=True
|
| 27 |
-
)
|
| 28 |
-
print("Downloaded Kaggle emotion dataset to", target_dir)
|
| 29 |
|
| 30 |
def download_cnn_dailymail():
|
| 31 |
-
"""Download the CNN/DailyMail summarization dataset from Kaggle."""
|
| 32 |
target_dir = "data/raw/summarization"
|
| 33 |
-
|
| 34 |
-
# Downloading using Kaggle Python API
|
| 35 |
-
kaggle.api.authenticate()
|
| 36 |
-
kaggle.api.dataset_download_files(
|
| 37 |
-
'gowrishankarp/newspaper-text-summarization-cnn-dailymail',
|
| 38 |
-
path=target_dir,
|
| 39 |
-
unzip=True
|
| 40 |
-
)
|
| 41 |
-
print("Downloaded Kaggle CNN/DailyMail dataset to", target_dir)
|
| 42 |
|
| 43 |
def download_ag_news():
|
| 44 |
-
"""Download the AG News dataset from Kaggle."""
|
| 45 |
target_dir = "data/raw/topic"
|
| 46 |
-
|
| 47 |
-
# Downloading using Kaggle Python API
|
| 48 |
-
kaggle.api.authenticate()
|
| 49 |
-
kaggle.api.dataset_download_files(
|
| 50 |
-
'amananandrai/ag-news-classification-dataset',
|
| 51 |
-
path=target_dir,
|
| 52 |
-
unzip=True
|
| 53 |
-
)
|
| 54 |
-
print("Downloaded Kaggle AG News dataset to", target_dir)
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
download_gutenberg()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Download helpers for datasets.
|
| 3 |
+
|
| 4 |
+
This version:
|
| 5 |
+
- Adds robust error handling when Kaggle API is not configured.
|
| 6 |
+
- Stores files under data/raw/ subfolders.
|
| 7 |
+
- Keeps the Gutenberg direct download example.
|
| 8 |
+
|
| 9 |
+
Make sure you have Kaggle credentials configured if you call Kaggle downloads.
|
| 10 |
+
"""
|
| 11 |
import os
|
| 12 |
import requests
|
| 13 |
+
|
| 14 |
+
def download_gutenberg(out_dir="data/raw/books", gutenberg_id: int = 1342, filename: str = "pride_and_prejudice.txt"):
|
| 15 |
+
"""Download a Gutenberg text file by direct URL template (best-effort)."""
|
| 16 |
+
url = f"https://www.gutenberg.org/files/{gutenberg_id}/{gutenberg_id}-0.txt"
|
| 17 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 18 |
+
out_path = os.path.join(out_dir, filename)
|
| 19 |
+
if os.path.exists(out_path):
|
| 20 |
+
print("Already downloaded:", out_path)
|
| 21 |
+
return out_path
|
| 22 |
+
try:
|
| 23 |
+
r = requests.get(url, timeout=30)
|
| 24 |
+
r.raise_for_status()
|
| 25 |
with open(out_path, "wb") as f:
|
| 26 |
f.write(r.content)
|
| 27 |
+
print("Downloaded:", out_path)
|
| 28 |
+
return out_path
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print("Failed to download Gutenberg file:", e)
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
# Kaggle helpers: optional, wrapped to avoid hard failure when Kaggle isn't configured.
|
| 34 |
+
def _safe_kaggle_download(dataset: str, path: str):
|
| 35 |
+
try:
|
| 36 |
+
import kaggle
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print("Kaggle package not available or not configured. Please install 'kaggle' and configure API token. Error:", e)
|
| 39 |
+
return False
|
| 40 |
+
try:
|
| 41 |
+
os.makedirs(path, exist_ok=True)
|
| 42 |
+
kaggle.api.authenticate()
|
| 43 |
+
kaggle.api.dataset_download_files(dataset, path=path, unzip=True)
|
| 44 |
+
print(f"Downloaded Kaggle dataset {dataset} to {path}")
|
| 45 |
+
return True
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print("Failed to download Kaggle dataset:", e)
|
| 48 |
+
return False
|
| 49 |
|
|
|
|
| 50 |
def download_emotion_dataset():
|
|
|
|
| 51 |
target_dir = "data/raw/emotion"
|
| 52 |
+
return _safe_kaggle_download('praveengovi/emotions-dataset-for-nlp', target_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def download_cnn_dailymail():
|
|
|
|
| 55 |
target_dir = "data/raw/summarization"
|
| 56 |
+
return _safe_kaggle_download('gowrishankarp/newspaper-text-summarization-cnn-dailymail', target_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def download_ag_news():
|
|
|
|
| 59 |
target_dir = "data/raw/topic"
|
| 60 |
+
return _safe_kaggle_download('amananandrai/ag-news-classification-dataset', target_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
if __name__ == "__main__":
|
| 63 |
download_gutenberg()
|
src/data/preprocessing.py
CHANGED
|
@@ -1,263 +1,260 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import
|
|
|
|
|
|
|
|
|
|
| 4 |
import json
|
| 5 |
-
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
self.max_length = max_length
|
| 18 |
-
self.
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
def clean_text(self, text: str) -> str:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
truncation=True,
|
| 30 |
-
padding=True,
|
| 31 |
-
max_length=self.max_length,
|
| 32 |
-
return_tensors='tf'
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
start = 0
|
| 61 |
while start < len(words):
|
| 62 |
-
end = start + chunk_size
|
| 63 |
-
|
| 64 |
-
chunks.append(chunk)
|
| 65 |
start += chunk_size - overlap
|
| 66 |
return chunks
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
print(f"Processed and saved {filename} → {out_file}")
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ----- Dataset-specific processing methods ------
|
| 87 |
-
|
| 88 |
-
def process_summarization_dataset(self):
|
| 89 |
-
"""Process summarization dataset: clean, split, and save."""
|
| 90 |
-
input_folder = "data/raw/summarization/cnn_dailymail"
|
| 91 |
-
output_folder = "data/processed/summarization"
|
| 92 |
-
os.makedirs(output_folder, exist_ok=True)
|
| 93 |
-
|
| 94 |
-
# Process each CSV file separately (train.csv, validation.csv, test.csv)
|
| 95 |
-
file_mapping = {
|
| 96 |
-
'train.csv': 'train',
|
| 97 |
-
'validation.csv': 'val',
|
| 98 |
-
'test.csv': 'test'
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
for csv_file, split_name in file_mapping.items():
|
| 102 |
-
file_path = os.path.join(input_folder, csv_file)
|
| 103 |
-
if not os.path.exists(file_path):
|
| 104 |
-
print(f"Missing file: {file_path}")
|
| 105 |
-
continue
|
| 106 |
-
|
| 107 |
-
print(f"Processing {csv_file}...")
|
| 108 |
-
df = pd.read_csv(file_path)
|
| 109 |
-
|
| 110 |
-
# Check for required columns (article and highlights)
|
| 111 |
-
if 'article' not in df.columns or 'highlights' not in df.columns:
|
| 112 |
-
print(f"CSV {csv_file} must have 'article' and 'highlights' columns.")
|
| 113 |
-
continue
|
| 114 |
-
|
| 115 |
-
# Clean the text data
|
| 116 |
-
df['article'] = df['article'].astype(str).apply(self.clean_text)
|
| 117 |
-
df['summary'] = df['highlights'].astype(str).apply(self.clean_text) # rename highlights to summary
|
| 118 |
-
|
| 119 |
-
# Convert to records format
|
| 120 |
-
records = df[['article', 'summary']].to_dict(orient='records')
|
| 121 |
-
|
| 122 |
-
# Save as JSON
|
| 123 |
-
output_file = os.path.join(output_folder, f"{split_name}.json")
|
| 124 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
| 125 |
-
json.dump(records, f, ensure_ascii=False, indent=2)
|
| 126 |
-
print(f"Processed {csv_file}: {len(records)} samples saved to {split_name}.json")
|
| 127 |
-
|
| 128 |
-
print("Summarization dataset processed and saved.")
|
| 129 |
-
|
| 130 |
-
def process_emotion_dataset(self):
|
| 131 |
-
"""Process emotion dataset: clean, split, and save."""
|
| 132 |
-
input_folder = "data/raw/emotion"
|
| 133 |
-
output_folder = "data/processed/emotion"
|
| 134 |
-
os.makedirs(output_folder, exist_ok=True)
|
| 135 |
-
|
| 136 |
-
# Process each txt file (train.txt, val.txt, test.txt)
|
| 137 |
-
for split_file in ['train.txt', 'val.txt', 'test.txt']:
|
| 138 |
-
file_path = os.path.join(input_folder, split_file)
|
| 139 |
-
if not os.path.exists(file_path):
|
| 140 |
-
print(f"Missing file: {file_path}")
|
| 141 |
-
continue
|
| 142 |
-
|
| 143 |
-
records = []
|
| 144 |
-
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 145 |
-
for line in f:
|
| 146 |
-
line = line.strip()
|
| 147 |
-
if line and ';' in line:
|
| 148 |
-
# Split on the last semicolon to handle semicolons in text
|
| 149 |
-
text, label = line.rsplit(';', 1)
|
| 150 |
-
records.append({
|
| 151 |
-
'text': self.clean_text(text),
|
| 152 |
-
'label': label.strip()
|
| 153 |
-
})
|
| 154 |
-
|
| 155 |
-
# Save as JSON
|
| 156 |
-
split_name = split_file.replace('.txt', '')
|
| 157 |
-
output_file = os.path.join(output_folder, f"{split_name}.json")
|
| 158 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
| 159 |
-
json.dump(records, f, ensure_ascii=False, indent=2)
|
| 160 |
-
print(f"Processed {split_file}: {len(records)} samples saved to {split_name}.json")
|
| 161 |
-
|
| 162 |
-
print("Emotion dataset processed and saved.")
|
| 163 |
-
|
| 164 |
-
def process_topic_dataset(self):
|
| 165 |
-
"""Process topic dataset: clean, split, and save."""
|
| 166 |
-
input_folder = "data/raw/topic"
|
| 167 |
-
output_folder = "data/processed/topic"
|
| 168 |
-
os.makedirs(output_folder, exist_ok=True)
|
| 169 |
-
|
| 170 |
-
# Process each CSV file separately (train.csv, test.csv)
|
| 171 |
-
file_mapping = {
|
| 172 |
-
'train.csv': 'train',
|
| 173 |
-
'test.csv': 'test'
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
# Class index to topic name mapping for AG News dataset
|
| 177 |
-
class_map = {
|
| 178 |
-
1: 'World',
|
| 179 |
-
2: 'Sports',
|
| 180 |
-
3: 'Business',
|
| 181 |
-
4: 'Science/Technology'
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
for csv_file, split_name in file_mapping.items():
|
| 185 |
-
file_path = os.path.join(input_folder, csv_file)
|
| 186 |
-
if not os.path.exists(file_path):
|
| 187 |
-
print(f"Missing file: {file_path}")
|
| 188 |
-
continue
|
| 189 |
-
|
| 190 |
-
print(f"Processing {csv_file}...")
|
| 191 |
-
df = pd.read_csv(file_path)
|
| 192 |
-
|
| 193 |
-
# Check for required columns
|
| 194 |
-
if 'Class Index' not in df.columns:
|
| 195 |
-
print(f"CSV {csv_file} must have 'Class Index' column.")
|
| 196 |
-
continue
|
| 197 |
-
|
| 198 |
-
# Concatenate title and description
|
| 199 |
-
if 'Title' in df.columns and 'Description' in df.columns:
|
| 200 |
-
text = df['Title'].astype(str) + ". " + df['Description'].astype(str)
|
| 201 |
-
elif 'Title' in df.columns:
|
| 202 |
-
text = df['Title'].astype(str)
|
| 203 |
-
elif 'Description' in df.columns:
|
| 204 |
-
text = df['Description'].astype(str)
|
| 205 |
-
else:
|
| 206 |
-
print("CSV must have 'Title' or 'Description' columns.")
|
| 207 |
-
continue
|
| 208 |
-
|
| 209 |
-
df['text'] = text.apply(self.clean_text)
|
| 210 |
-
|
| 211 |
-
# Map numeric labels to category names
|
| 212 |
-
df['label'] = df['Class Index'].map(class_map)
|
| 213 |
-
|
| 214 |
-
# Convert to records format
|
| 215 |
-
records = df[['text', 'label']].to_dict(orient='records')
|
| 216 |
-
|
| 217 |
-
# Save as JSON
|
| 218 |
-
output_file = os.path.join(output_folder, f"{split_name}.json")
|
| 219 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
| 220 |
-
json.dump(records, f, ensure_ascii=False, indent=2)
|
| 221 |
-
print(f"Processed {csv_file}: {len(records)} samples saved to {split_name}.json")
|
| 222 |
-
|
| 223 |
-
# Create validation split from training data
|
| 224 |
-
if os.path.exists(os.path.join(output_folder, "train.json")):
|
| 225 |
-
print("Creating validation split from training data...")
|
| 226 |
-
with open(os.path.join(output_folder, "train.json"), "r", encoding="utf-8") as f:
|
| 227 |
-
train_data = json.load(f)
|
| 228 |
-
|
| 229 |
-
# Split training data into train and validation
|
| 230 |
-
train_records, val_records = train_test_split(train_data, test_size=0.2, random_state=42)
|
| 231 |
-
|
| 232 |
-
# Save updated train and new validation files
|
| 233 |
-
with open(os.path.join(output_folder, "train.json"), "w", encoding="utf-8") as f:
|
| 234 |
-
json.dump(train_records, f, ensure_ascii=False, indent=2)
|
| 235 |
-
|
| 236 |
-
with open(os.path.join(output_folder, "val.json"), "w", encoding="utf-8") as f:
|
| 237 |
-
json.dump(val_records, f, ensure_ascii=False, indent=2)
|
| 238 |
-
|
| 239 |
-
print(f"Updated train.json: {len(train_records)} samples")
|
| 240 |
-
print(f"Created val.json: {len(val_records)} samples")
|
| 241 |
-
|
| 242 |
-
print("Topic dataset processed and saved.")
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
# ----- Main function for quick testing ------
|
| 246 |
-
|
| 247 |
-
if __name__ == "__main__":
|
| 248 |
-
preprocessor = textPreprocessor(max_length=128)
|
| 249 |
-
|
| 250 |
-
# Process and save all books
|
| 251 |
-
preprocessor.save_preprocessed_books(data=None)
|
| 252 |
-
|
| 253 |
-
# Load a processed book back
|
| 254 |
-
import json
|
| 255 |
-
with open("data/processed/books/pride_and_prejudice.json", "r") as f:
|
| 256 |
-
chunks = json.load(f)
|
| 257 |
-
print(f"Loaded {len(chunks)} chunks from Pride and Prejudice")
|
| 258 |
-
print(chunks[0][:200]) # printing first 200 chars of chunk
|
| 259 |
-
|
| 260 |
-
# Process new datasets
|
| 261 |
-
preprocessor.process_summarization_dataset()
|
| 262 |
-
preprocessor.process_emotion_dataset()
|
| 263 |
-
preprocessor.process_topic_dataset()
|
|
|
|
| 1 |
+
"""Lightweight preprocessing utilities built around the in-repo transformer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from ..models.decoder import TransformerDecoder
|
| 15 |
+
from ..models.encoder import TransformerEncoder
|
| 16 |
+
|
| 17 |
+
SPECIAL_TOKENS: Tuple[str, str, str, str] = ("<pad>", "<bos>", "<eos>", "<unk>")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _normalize(text: str, lowercase: bool) -> str:
|
| 21 |
+
text = text.strip()
|
| 22 |
+
text = re.sub(r"\s+", " ", text)
|
| 23 |
+
if lowercase:
|
| 24 |
+
text = text.lower()
|
| 25 |
+
return text
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _basic_tokenize(text: str) -> List[str]:
|
| 29 |
+
return re.findall(r"\b\w+\b|[.,;:?!]", text)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TransformerTokenizer:
|
| 33 |
+
"""Minimal tokenizer that keeps vocabulary aligned with the custom transformer."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
stoi: Dict[str, int],
|
| 38 |
+
itos: List[str],
|
| 39 |
+
specials: Sequence[str] = SPECIAL_TOKENS,
|
| 40 |
+
lowercase: bool = True,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.stoi = stoi
|
| 43 |
+
self.itos = itos
|
| 44 |
+
self.specials = tuple(specials)
|
| 45 |
+
self.lowercase = lowercase
|
| 46 |
+
self.pad_id = self._lookup(self.specials[0])
|
| 47 |
+
self.bos_id = self._lookup(self.specials[1])
|
| 48 |
+
self.eos_id = self._lookup(self.specials[2])
|
| 49 |
+
self.unk_id = self._lookup(self.specials[3])
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def build(
|
| 53 |
+
cls,
|
| 54 |
+
texts: Iterable[str],
|
| 55 |
+
min_freq: int = 1,
|
| 56 |
+
lowercase: bool = True,
|
| 57 |
+
specials: Sequence[str] = SPECIAL_TOKENS,
|
| 58 |
+
) -> "TransformerTokenizer":
|
| 59 |
+
counter: Counter[str] = Counter()
|
| 60 |
+
for text in texts:
|
| 61 |
+
normalized = _normalize(text, lowercase)
|
| 62 |
+
counter.update(_basic_tokenize(normalized))
|
| 63 |
+
|
| 64 |
+
ordered_specials = list(dict.fromkeys(specials))
|
| 65 |
+
itos: List[str] = ordered_specials.copy()
|
| 66 |
+
for token, freq in counter.most_common():
|
| 67 |
+
if freq < min_freq:
|
| 68 |
+
continue
|
| 69 |
+
if token in itos:
|
| 70 |
+
continue
|
| 71 |
+
itos.append(token)
|
| 72 |
+
|
| 73 |
+
stoi = {token: idx for idx, token in enumerate(itos)}
|
| 74 |
+
return cls(stoi=stoi, itos=itos, specials=ordered_specials, lowercase=lowercase)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def vocab_size(self) -> int:
|
| 78 |
+
return len(self.itos)
|
| 79 |
+
|
| 80 |
+
def tokenize(self, text: str) -> List[str]:
|
| 81 |
+
normalized = _normalize(text, self.lowercase)
|
| 82 |
+
return _basic_tokenize(normalized)
|
| 83 |
+
|
| 84 |
+
def encode(
|
| 85 |
+
self,
|
| 86 |
+
text: str,
|
| 87 |
+
add_special_tokens: bool = True,
|
| 88 |
+
max_length: Optional[int] = None,
|
| 89 |
+
) -> List[int]:
|
| 90 |
+
tokens = self.tokenize(text)
|
| 91 |
+
pieces = [self.stoi.get(tok, self.unk_id) for tok in tokens]
|
| 92 |
+
if add_special_tokens:
|
| 93 |
+
pieces = [self.bos_id] + pieces + [self.eos_id]
|
| 94 |
+
|
| 95 |
+
if max_length is not None and len(pieces) > max_length:
|
| 96 |
+
if add_special_tokens and max_length >= 2:
|
| 97 |
+
inner_max = max_length - 2
|
| 98 |
+
trimmed = pieces[1:-1][:inner_max]
|
| 99 |
+
pieces = [self.bos_id] + trimmed + [self.eos_id]
|
| 100 |
+
else:
|
| 101 |
+
pieces = pieces[:max_length]
|
| 102 |
+
return pieces
|
| 103 |
+
|
| 104 |
+
def decode(self, ids: Sequence[int], skip_special_tokens: bool = True) -> str:
|
| 105 |
+
tokens: List[str] = []
|
| 106 |
+
for idx in ids:
|
| 107 |
+
if idx < 0 or idx >= len(self.itos):
|
| 108 |
+
continue
|
| 109 |
+
token = self.itos[idx]
|
| 110 |
+
if skip_special_tokens and token in self.specials:
|
| 111 |
+
continue
|
| 112 |
+
tokens.append(token)
|
| 113 |
+
return " ".join(tokens).strip()
|
| 114 |
+
|
| 115 |
+
def pad_batch(
|
| 116 |
+
self,
|
| 117 |
+
sequences: Sequence[Sequence[int]],
|
| 118 |
+
pad_to_length: Optional[int] = None,
|
| 119 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
+
if not sequences:
|
| 121 |
+
raise ValueError("pad_batch requires at least one sequence")
|
| 122 |
+
if pad_to_length is None:
|
| 123 |
+
pad_to_length = max(len(seq) for seq in sequences)
|
| 124 |
+
padded: List[List[int]] = []
|
| 125 |
+
mask: List[List[int]] = []
|
| 126 |
+
for seq in sequences:
|
| 127 |
+
trimmed = list(seq[:pad_to_length])
|
| 128 |
+
pad_len = pad_to_length - len(trimmed)
|
| 129 |
+
padded.append(trimmed + [self.pad_id] * pad_len)
|
| 130 |
+
mask.append([1] * len(trimmed) + [0] * pad_len)
|
| 131 |
+
return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)
|
| 132 |
+
|
| 133 |
+
def save(self, path: Path) -> None:
|
| 134 |
+
payload = {
|
| 135 |
+
"itos": self.itos,
|
| 136 |
+
"specials": list(self.specials),
|
| 137 |
+
"lowercase": self.lowercase,
|
| 138 |
+
}
|
| 139 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def load(cls, path: Path) -> "TransformerTokenizer":
|
| 144 |
+
data = json.loads(path.read_text(encoding="utf-8"))
|
| 145 |
+
itos = list(data["itos"])
|
| 146 |
+
stoi = {token: idx for idx, token in enumerate(itos)}
|
| 147 |
+
specials = data.get("specials", list(SPECIAL_TOKENS))
|
| 148 |
+
lowercase = bool(data.get("lowercase", True))
|
| 149 |
+
return cls(stoi=stoi, itos=itos, specials=specials, lowercase=lowercase)
|
| 150 |
+
|
| 151 |
+
def _lookup(self, token: str) -> int:
|
| 152 |
+
if token not in self.stoi:
|
| 153 |
+
raise ValueError(f"token '{token}' missing from vocabulary")
|
| 154 |
+
return self.stoi[token]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class Batch:
|
| 159 |
+
input_ids: torch.Tensor
|
| 160 |
+
attention_mask: torch.Tensor
|
| 161 |
+
lengths: List[int]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class TextPreprocessor:
|
| 165 |
+
"""Prepares text so it can flow directly into the custom transformer stack."""
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
max_length: int = 512,
|
| 170 |
+
tokenizer: Optional[TransformerTokenizer] = None,
|
| 171 |
+
*,
|
| 172 |
+
min_freq: int = 1,
|
| 173 |
+
lowercase: bool = True,
|
| 174 |
+
) -> None:
|
| 175 |
self.max_length = max_length
|
| 176 |
+
self.min_freq = min_freq
|
| 177 |
+
self.lowercase = lowercase
|
| 178 |
+
self.tokenizer = tokenizer
|
| 179 |
+
|
| 180 |
def clean_text(self, text: str) -> str:
|
| 181 |
+
return _normalize(text, self.lowercase)
|
| 182 |
+
|
| 183 |
+
def fit_tokenizer(self, texts: Iterable[str]) -> TransformerTokenizer:
|
| 184 |
+
cleaned = [self.clean_text(text) for text in texts]
|
| 185 |
+
self.tokenizer = TransformerTokenizer.build(
|
| 186 |
+
cleaned,
|
| 187 |
+
min_freq=self.min_freq,
|
| 188 |
+
lowercase=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
+
return self.tokenizer
|
| 191 |
+
|
| 192 |
+
def encode(self, text: str, *, add_special_tokens: bool = True) -> List[int]:
|
| 193 |
+
if self.tokenizer is None:
|
| 194 |
+
raise RuntimeError("Tokenizer not fitted")
|
| 195 |
+
cleaned = self.clean_text(text)
|
| 196 |
+
return self.tokenizer.encode(cleaned, add_special_tokens=add_special_tokens, max_length=self.max_length)
|
| 197 |
+
|
| 198 |
+
def batch_encode(self, texts: Sequence[str]) -> Batch:
|
| 199 |
+
if self.tokenizer is None:
|
| 200 |
+
raise RuntimeError("Tokenizer not fitted")
|
| 201 |
+
sequences = [self.encode(text) for text in texts]
|
| 202 |
+
lengths = [len(seq) for seq in sequences]
|
| 203 |
+
input_ids, attention_mask = self.tokenizer.pad_batch(sequences, pad_to_length=self.max_length)
|
| 204 |
+
return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
|
| 205 |
+
|
| 206 |
+
def build_encoder(self, **encoder_kwargs) -> TransformerEncoder:
|
| 207 |
+
if self.tokenizer is None:
|
| 208 |
+
raise RuntimeError("Tokenizer not fitted")
|
| 209 |
+
return TransformerEncoder(
|
| 210 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 211 |
+
max_len=self.max_length,
|
| 212 |
+
pad_token_id=self.tokenizer.pad_id,
|
| 213 |
+
**encoder_kwargs,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def build_decoder(self, **decoder_kwargs) -> TransformerDecoder:
|
| 217 |
+
if self.tokenizer is None:
|
| 218 |
+
raise RuntimeError("Tokenizer not fitted")
|
| 219 |
+
return TransformerDecoder(
|
| 220 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 221 |
+
max_len=self.max_length,
|
| 222 |
+
pad_token_id=self.tokenizer.pad_id,
|
| 223 |
+
**decoder_kwargs,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def save_tokenizer(self, path: Path) -> None:
|
| 227 |
+
if self.tokenizer is None:
|
| 228 |
+
raise RuntimeError("Tokenizer not fitted")
|
| 229 |
+
self.tokenizer.save(path)
|
| 230 |
+
|
| 231 |
+
def load_tokenizer(self, path: Path) -> TransformerTokenizer:
|
| 232 |
+
self.tokenizer = TransformerTokenizer.load(path)
|
| 233 |
+
return self.tokenizer
|
| 234 |
+
|
| 235 |
+
def chunk_text(self, text: str, *, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
|
| 236 |
+
if chunk_size <= overlap:
|
| 237 |
+
raise ValueError("chunk_size must be larger than overlap")
|
| 238 |
+
words = self.clean_text(text).split()
|
| 239 |
+
chunks: List[str] = []
|
| 240 |
start = 0
|
| 241 |
while start < len(words):
|
| 242 |
+
end = min(start + chunk_size, len(words))
|
| 243 |
+
chunks.append(" ".join(words[start:end]))
|
|
|
|
| 244 |
start += chunk_size - overlap
|
| 245 |
return chunks
|
| 246 |
|
| 247 |
+
def save_book_chunks(
|
| 248 |
+
self,
|
| 249 |
+
input_path: Path,
|
| 250 |
+
out_dir: Path,
|
| 251 |
+
*,
|
| 252 |
+
chunk_size: int = 1000,
|
| 253 |
+
overlap: int = 100,
|
| 254 |
+
) -> Path:
|
| 255 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
raw_text = input_path.read_text(encoding="utf-8", errors="ignore")
|
| 257 |
+
chunks = self.chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap)
|
| 258 |
+
out_file = out_dir / f"{input_path.stem}.json"
|
| 259 |
+
out_file.write_text(json.dumps(chunks, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 260 |
+
return out_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/inference/__init__.py
CHANGED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference utilities for LexiMind.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .baseline_summarizer import Summarizer, TransformerSummarizer
|
| 6 |
+
|
| 7 |
+
__all__ = ["Summarizer", "TransformerSummarizer"]
|
src/inference/baseline_summarizer.py
CHANGED
|
@@ -1,222 +1,41 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
-
from typing import Any, List, Dict, Optional
|
| 4 |
-
import torch
|
| 5 |
-
from torch.utils.data import Dataset, DataLoader
|
| 6 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 7 |
-
|
| 8 |
-
class Summarizer:
|
| 9 |
-
def __init__(self, model_name: str = "t5-small", max_input: int = 512, max_output: int = 128, device: Optional[str] = None):
|
| 10 |
-
self.model_name = model_name
|
| 11 |
-
self.max_input = max_input
|
| 12 |
-
self.max_output = max_output
|
| 13 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 14 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 15 |
-
self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
-
self.model.to(self.device)
|
| 17 |
-
|
| 18 |
-
def load_data(self, split: str = "train", limit: Optional[int] = None) -> List[Dict[str, str]]:
|
| 19 |
-
"""
|
| 20 |
-
Load processed summarization data from JSON files.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
split (str): Data split to load ('train', 'val', 'test')
|
| 24 |
-
limit (int): Maximum number of samples to load (None for all)
|
| 25 |
-
|
| 26 |
-
Returns:
|
| 27 |
-
list: List of dictionaries with 'article' and 'summary' keys
|
| 28 |
-
"""
|
| 29 |
-
# Resolve to project root regardless of current working directory
|
| 30 |
-
root = os.path.dirname(os.path.dirname(__file__))
|
| 31 |
-
file_path = os.path.join(root, "data", "processed", "summarization", f"{split}.json")
|
| 32 |
-
|
| 33 |
-
if not os.path.exists(file_path):
|
| 34 |
-
raise FileNotFoundError(f"Data file not found: {file_path}")
|
| 35 |
-
|
| 36 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
| 37 |
-
data = json.load(f)
|
| 38 |
-
|
| 39 |
-
if limit:
|
| 40 |
-
data = data[:limit]
|
| 41 |
-
return data
|
| 42 |
-
|
| 43 |
-
def encode(self, articles: List[str] | str, summaries: Optional[List[str] | str] = None):
|
| 44 |
-
if isinstance(articles, str):
|
| 45 |
-
articles = [articles]
|
| 46 |
-
if summaries is not None and isinstance(summaries, str):
|
| 47 |
-
summaries = [summaries]
|
| 48 |
-
|
| 49 |
-
inputs = self.tokenizer(
|
| 50 |
-
[f"summarize: {a}" for a in articles],
|
| 51 |
-
max_length=self.max_input,
|
| 52 |
-
truncation=True,
|
| 53 |
-
padding="max_length",
|
| 54 |
-
return_tensors="pt"
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
result = {
|
| 58 |
-
"input_ids": inputs.input_ids.to(self.device),
|
| 59 |
-
"attention_mask": inputs.attention_mask.to(self.device)
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
if summaries is not None:
|
| 63 |
-
labels = self.tokenizer(
|
| 64 |
-
summaries,
|
| 65 |
-
max_length=self.max_output,
|
| 66 |
-
truncation=True,
|
| 67 |
-
padding="max_length",
|
| 68 |
-
return_tensors="pt"
|
| 69 |
-
).input_ids
|
| 70 |
-
# Mask pad tokens in labels with -100 for loss
|
| 71 |
-
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 72 |
-
result["labels"] = labels.to(self.device)
|
| 73 |
-
return result
|
| 74 |
-
|
| 75 |
-
def train(self, epochs: int = 3, batch_size: int = 4, train_limit: int = 2000, val_limit: int = 500, learning_rate: float = 5e-5):
|
| 76 |
-
train_data = self.load_data("train", limit=train_limit)
|
| 77 |
-
val_data = self.load_data("val", limit=val_limit)
|
| 78 |
-
|
| 79 |
-
train_ds = _SummarizationDataset(train_data, self.tokenizer, self.max_input, self.max_output)
|
| 80 |
-
val_ds = _SummarizationDataset(val_data, self.tokenizer, self.max_input, self.max_output) if val_data else None
|
| 81 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
| 82 |
-
val_loader = DataLoader(val_ds, batch_size=batch_size) if val_ds else None
|
| 83 |
-
|
| 84 |
-
optim = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
self.model.
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
if not text.strip():
|
| 123 |
-
return ""
|
| 124 |
-
inputs = self.tokenizer(
|
| 125 |
-
f"summarize: {text}",
|
| 126 |
-
return_tensors="pt",
|
| 127 |
-
max_length=self.max_input,
|
| 128 |
-
truncation=True,
|
| 129 |
-
padding=True
|
| 130 |
-
)
|
| 131 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 132 |
-
with torch.no_grad():
|
| 133 |
-
summary_ids = self.model.generate(
|
| 134 |
-
inputs["input_ids"],
|
| 135 |
-
attention_mask=inputs.get("attention_mask"),
|
| 136 |
-
max_length=max_length or self.max_output,
|
| 137 |
-
num_beams=num_beams,
|
| 138 |
-
length_penalty=2.0,
|
| 139 |
-
early_stopping=True
|
| 140 |
-
)
|
| 141 |
-
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
|
| 142 |
-
|
| 143 |
-
def save(self, path: str = "models/summarizer"):
|
| 144 |
-
"""
|
| 145 |
-
Save the trained model and tokenizer.
|
| 146 |
-
|
| 147 |
-
Args:
|
| 148 |
-
path (str): Directory path to save the model
|
| 149 |
-
"""
|
| 150 |
-
os.makedirs(path, exist_ok=True)
|
| 151 |
-
self.model.save_pretrained(path)
|
| 152 |
-
self.tokenizer.save_pretrained(path)
|
| 153 |
-
|
| 154 |
-
@classmethod
|
| 155 |
-
def load(cls, path: str = "models/summarizer"):
|
| 156 |
-
"""
|
| 157 |
-
Load a pre-trained model from disk.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
path (str): Directory path containing the saved model
|
| 161 |
-
|
| 162 |
-
Returns:
|
| 163 |
-
Summarizer: Loaded summarizer instance
|
| 164 |
-
"""
|
| 165 |
-
obj = cls.__new__(cls)
|
| 166 |
-
obj.model_name = path
|
| 167 |
-
obj.max_input = 512
|
| 168 |
-
obj.max_output = 128
|
| 169 |
-
obj.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 170 |
-
obj.model = AutoModelForSeq2SeqLM.from_pretrained(path)
|
| 171 |
-
obj.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 172 |
-
obj.model.to(obj.device)
|
| 173 |
-
return obj
|
| 174 |
-
|
| 175 |
-
class _SummarizationDataset(Dataset):
|
| 176 |
-
def __init__(self, data: List[Dict[str, str]], tokenizer: Any, max_input: int, max_output: int):
|
| 177 |
-
self.data = data
|
| 178 |
-
self.tokenizer = tokenizer
|
| 179 |
-
self.max_input = max_input
|
| 180 |
-
self.max_output = max_output
|
| 181 |
-
|
| 182 |
-
def __len__(self):
|
| 183 |
-
return len(self.data)
|
| 184 |
-
|
| 185 |
-
def __getitem__(self, idx: int):
|
| 186 |
-
item = self.data[idx]
|
| 187 |
-
inputs = self.tokenizer(
|
| 188 |
-
f"summarize: {item['article']}",
|
| 189 |
-
max_length=self.max_input,
|
| 190 |
-
truncation=True,
|
| 191 |
-
padding="max_length",
|
| 192 |
-
return_tensors="pt"
|
| 193 |
)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
max_length=self.max_output,
|
| 197 |
-
truncation=True,
|
| 198 |
-
padding="max_length",
|
| 199 |
-
return_tensors="pt"
|
| 200 |
-
).input_ids
|
| 201 |
-
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 202 |
-
return {
|
| 203 |
-
"input_ids": inputs.input_ids.squeeze(0),
|
| 204 |
-
"attention_mask": inputs.attention_mask.squeeze(0),
|
| 205 |
-
"labels": labels.squeeze(0),
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
if __name__ == "__main__":
|
| 209 |
-
print("Initializing summarizer...", flush=True)
|
| 210 |
-
summarizer = Summarizer(model_name="t5-small")
|
| 211 |
-
print("Starting a short training run...", flush=True)
|
| 212 |
-
summarizer.train(epochs=3, batch_size=2, train_limit=100, val_limit=50)
|
| 213 |
-
test_text = (
|
| 214 |
-
"The quick brown fox jumps over the lazy dog. This is a common "
|
| 215 |
-
"pangram used in typography and printing. It contains every letter of the "
|
| 216 |
-
"alphabet at least once, making it useful for testing fonts and keyboards."
|
| 217 |
-
)
|
| 218 |
-
print("Generating summary...", flush=True)
|
| 219 |
-
summary = summarizer.summarize(test_text)
|
| 220 |
-
print(f"\nOriginal text: {test_text}")
|
| 221 |
-
print(f"Summary: {summary}")
|
| 222 |
-
summarizer.save()
|
|
|
|
| 1 |
+
"""Thin wrapper around the custom transformer summarizer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
from typing import Any, Dict, Optional, Tuple
|
| 5 |
+
import torch
|
| 6 |
+
from ..api.inference import load_models
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TransformerSummarizer:
|
| 10 |
+
def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
|
| 11 |
+
models = load_models(config or {})
|
| 12 |
+
if not models.get("loaded"):
|
| 13 |
+
raise RuntimeError("load_models returned an unloaded model; check configuration")
|
| 14 |
+
self.model = models["mt"]
|
| 15 |
+
self.preprocessor = models["preprocessor"]
|
| 16 |
+
self.device = models["device"]
|
| 17 |
+
|
| 18 |
+
def summarize(
|
| 19 |
+
self,
|
| 20 |
+
text: str,
|
| 21 |
+
compression: float = 0.25,
|
| 22 |
+
collect_attn: bool = False,
|
| 23 |
+
) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
|
| 24 |
+
batch = self.preprocessor.batch_encode([text])
|
| 25 |
+
tokenizer = self.preprocessor.tokenizer
|
| 26 |
+
encoder = self.model.encoder
|
| 27 |
+
decoder = self.model.decoder
|
| 28 |
+
if tokenizer is None or encoder is None or decoder is None:
|
| 29 |
+
raise RuntimeError("Model components are missing; ensure encoder, decoder, and tokenizer are set")
|
| 30 |
+
input_ids = batch.input_ids.to(self.device)
|
| 31 |
+
memory = encoder(input_ids)
|
| 32 |
+
src_len = batch.lengths[0]
|
| 33 |
+
target_len = max(4, int(src_len * compression))
|
| 34 |
+
generated = decoder.greedy_decode(
|
| 35 |
+
memory,
|
| 36 |
+
max_len=min(self.preprocessor.max_length, target_len),
|
| 37 |
+
start_token_id=tokenizer.bos_id,
|
| 38 |
+
end_token_id=tokenizer.eos_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
+
summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
|
| 41 |
+
return summary.strip(), None if not collect_attn else {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/__init__.py
CHANGED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LexiMind custom transformer models.
|
| 3 |
+
|
| 4 |
+
This package provides a from-scratch transformer implementation with:
|
| 5 |
+
- TransformerEncoder/TransformerDecoder
|
| 6 |
+
- MultiHeadAttention, FeedForward, PositionalEncoding
|
| 7 |
+
- Task heads: ClassificationHead, TokenClassificationHead, LMHead
|
| 8 |
+
- MultiTaskModel: composable wrapper for encoder/decoder + task heads
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .encoder import TransformerEncoder, TransformerEncoderLayer
|
| 12 |
+
from .decoder import TransformerDecoder, TransformerDecoderLayer, create_causal_mask
|
| 13 |
+
from .attention import MultiHeadAttention
|
| 14 |
+
from .feedforward import FeedForward
|
| 15 |
+
from .positional_encoding import PositionalEncoding
|
| 16 |
+
from .heads import ClassificationHead, TokenClassificationHead, LMHead, ProjectionHead
|
| 17 |
+
from .multitask import MultiTaskModel
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"TransformerEncoder",
|
| 21 |
+
"TransformerEncoderLayer",
|
| 22 |
+
"TransformerDecoder",
|
| 23 |
+
"TransformerDecoderLayer",
|
| 24 |
+
"create_causal_mask",
|
| 25 |
+
"MultiHeadAttention",
|
| 26 |
+
"FeedForward",
|
| 27 |
+
"PositionalEncoding",
|
| 28 |
+
"ClassificationHead",
|
| 29 |
+
"TokenClassificationHead",
|
| 30 |
+
"LMHead",
|
| 31 |
+
"ProjectionHead",
|
| 32 |
+
"MultiTaskModel",
|
| 33 |
+
]
|
src/ui/streamlit_app.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit prototype for LexiMind (summarization, emotion, topic).
|
| 3 |
+
Run from repo root: streamlit run streamlit_app.py
|
| 4 |
+
"""
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import plotly.express as px
|
| 9 |
+
import plotly.figure_factory as ff
|
| 10 |
+
|
| 11 |
+
# Stable absolute import; ensure repo root is on PYTHONPATH (running from repo root is standard)
|
| 12 |
+
try:
|
| 13 |
+
from ..api.inference import load_models, summarize_text, classify_emotion, topic_for_text
|
| 14 |
+
except Exception as e:
|
| 15 |
+
st.error(f"Failed to import inference helpers: {e}")
|
| 16 |
+
raise
|
| 17 |
+
|
| 18 |
+
st.set_page_config(page_title="LexiMind demo", layout="wide")
|
| 19 |
+
|
| 20 |
+
MODEL_CONFIG = {
|
| 21 |
+
"checkpoint_path": "checkpoints/best.pt", # change to your trained checkpoint
|
| 22 |
+
"tokenizer_path": "artifacts/tokenizer.json", # JSON produced by TextPreprocessor.save_tokenizer
|
| 23 |
+
"device": "cpu",
|
| 24 |
+
}
|
| 25 |
+
try:
|
| 26 |
+
models = load_models(MODEL_CONFIG)
|
| 27 |
+
except Exception as exc:
|
| 28 |
+
st.error(f"Failed to load models: {exc}")
|
| 29 |
+
st.stop()
|
| 30 |
+
|
| 31 |
+
st.sidebar.title("LexiMind")
|
| 32 |
+
task = st.sidebar.selectbox("Task", ["Summarize", "Emotion", "Topic", "Search demo"])
|
| 33 |
+
compression = st.sidebar.slider("Compression (summary length)", 0.1, 1.0, 0.25)
|
| 34 |
+
show_attn = st.sidebar.checkbox("Show attention heatmap (collect_attn)", value=False)
|
| 35 |
+
|
| 36 |
+
st.sidebar.markdown("Demo controls")
|
| 37 |
+
sample_choice = st.sidebar.selectbox("Use sample text", ["None", "Gutenberg sample", "News sample"])
|
| 38 |
+
|
| 39 |
+
SAMPLES = {
|
| 40 |
+
"Gutenberg sample": (
|
| 41 |
+
"It was the best of times, it was the worst of times, it was the age of wisdom, "
|
| 42 |
+
"it was the age of foolishness..."
|
| 43 |
+
),
|
| 44 |
+
"News sample": (
|
| 45 |
+
"Markets rallied today as tech stocks posted gains amid broad optimism over earnings..."
|
| 46 |
+
),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
st.title("LexiMind — Summarization, Emotion, Topic (Prototype)")
|
| 50 |
+
|
| 51 |
+
if sample_choice != "None":
|
| 52 |
+
input_text = st.text_area("Input text", value=SAMPLES[sample_choice], height=280)
|
| 53 |
+
else:
|
| 54 |
+
input_text = st.text_area("Input text", value="", height=280)
|
| 55 |
+
|
| 56 |
+
col1, col2 = st.columns([2, 1])
|
| 57 |
+
|
| 58 |
+
with col1:
|
| 59 |
+
st.subheader("Output")
|
| 60 |
+
if st.button("Run"):
|
| 61 |
+
if not input_text.strip():
|
| 62 |
+
st.warning("Enter some text or select a sample to run the model.")
|
| 63 |
+
else:
|
| 64 |
+
if task == "Summarize":
|
| 65 |
+
summary, attn_data = summarize_text(input_text, compression=compression, collect_attn=show_attn, models=models)
|
| 66 |
+
st.markdown("**Summary**")
|
| 67 |
+
st.write(summary)
|
| 68 |
+
if show_attn and attn_data is not None:
|
| 69 |
+
st.markdown("**Attention heatmap (averaged heads)**")
|
| 70 |
+
src_tokens = attn_data.get("src_tokens", None)
|
| 71 |
+
tgt_tokens = attn_data.get("tgt_tokens", None)
|
| 72 |
+
weights = attn_data.get("weights", None)
|
| 73 |
+
if weights is not None:
|
| 74 |
+
arr = np.array(weights)
|
| 75 |
+
if arr.ndim == 4:
|
| 76 |
+
arr = arr.mean(axis=(0,1))
|
| 77 |
+
elif arr.ndim == 3:
|
| 78 |
+
arr = arr.mean(axis=0)
|
| 79 |
+
fig = ff.create_annotated_heatmap(
|
| 80 |
+
z=arr.tolist(),
|
| 81 |
+
x=src_tokens if src_tokens else [f"tok{i}" for i in range(arr.shape[1])],
|
| 82 |
+
y=tgt_tokens if tgt_tokens else [f"tok{i}" for i in range(arr.shape[0])],
|
| 83 |
+
colorscale="Viridis",
|
| 84 |
+
)
|
| 85 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 86 |
+
else:
|
| 87 |
+
st.info("Attention data not available from the model.")
|
| 88 |
+
elif task == "Emotion":
|
| 89 |
+
probs, labels = classify_emotion(input_text, models=models)
|
| 90 |
+
st.markdown("**Emotion predictions (multi-label probabilities)**")
|
| 91 |
+
df = pd.DataFrame({"emotion": labels, "prob": probs})
|
| 92 |
+
fig = px.bar(df, x="emotion", y="prob", color="prob", range_y=[0,1])
|
| 93 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 94 |
+
elif task == "Topic":
|
| 95 |
+
topic_id, topic_terms = topic_for_text(input_text, models=models)
|
| 96 |
+
st.markdown("**Topic cluster**")
|
| 97 |
+
st.write(f"Cluster ID: {topic_id}")
|
| 98 |
+
st.write("Top terms:", ", ".join(topic_terms))
|
| 99 |
+
elif task == "Search demo":
|
| 100 |
+
st.info("Search demo will be available when ingestion is run (see scripts).")
|
| 101 |
+
|
| 102 |
+
with col2:
|
| 103 |
+
st.subheader("Model & Info")
|
| 104 |
+
st.markdown(f"*Model loaded:* {'yes' if models.get('loaded', False) else 'no'}")
|
| 105 |
+
st.markdown(f"*Device:* {models.get('device', MODEL_CONFIG['device'])}")
|
| 106 |
+
st.markdown("**Notes**")
|
| 107 |
+
st.markdown("- Attention visualization depends on model support to return attention.")
|
| 108 |
+
st.markdown("- For long inputs the UI truncates tokens for heatmap clarity.")
|