OliverPerrin commited on
Commit
f9edbb4
·
1 Parent(s): 1bdd1c1

Updated Summarizer, Preprocessor to run on my custom transformer and added basic streamlit frontend demo

Browse files
requirements-dev.txt CHANGED
@@ -6,4 +6,5 @@ isort>=5.12.0
6
  flake8>=6.0.0
7
  mypy>=1.4.0
8
  jupyter>=1.0.0
9
- ipywidgets>=8.0.0
 
 
6
  flake8>=6.0.0
7
  mypy>=1.4.0
8
  jupyter>=1.0.0
9
+ ipywidgets>=8.0.0
10
+ pre-commit>=3.4.0
requirements.txt CHANGED
@@ -15,4 +15,8 @@ omegaconf>=2.3.0
15
  tensorboard>=2.13.0
16
  gradio>=3.35.0
17
  requests>=2.31.0
18
- kagglehub>=0.2.0
 
 
 
 
 
15
  tensorboard>=2.13.0
16
  gradio>=3.35.0
17
  requests>=2.31.0
18
+ kaggle>=1.5.12
19
+ streamlit>=1.25.0
20
+ plotly>=5.18.0
21
+ faiss-cpu==1.9.0; platform_system != "Windows"
22
+ faiss-cpu==1.9.0; platform_system == "Windows"
src/api/inference/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ API inference module for LexiMind.
3
+ """
4
+
5
+ from .inference import load_models, summarize_text, classify_emotion, topic_for_text
6
+
7
+ __all__ = ["load_models", "summarize_text", "classify_emotion", "topic_for_text"]
src/api/inference/inference.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal inference helpers that rely on the custom transformer stack."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+
10
+ from ...data.preprocessing import TextPreprocessor, TransformerTokenizer
11
+ from ...models.multitask import MultiTaskModel
12
+
13
+
14
+ def _load_tokenizer(tokenizer_path: Path) -> TransformerTokenizer:
15
+ if not tokenizer_path.exists():
16
+ raise FileNotFoundError(f"tokenizer file '{tokenizer_path}' not found")
17
+ return TransformerTokenizer.load(tokenizer_path)
18
+
19
+
20
+ def load_models(config: Dict[str, Any]) -> Dict[str, Any]:
21
+ """Load MultiTaskModel together with the tokenizer-driven preprocessor."""
22
+
23
+ device = torch.device(config.get("device", "cpu"))
24
+ tokenizer_path = config.get("tokenizer_path")
25
+ if tokenizer_path is None:
26
+ raise ValueError("'tokenizer_path' missing in config")
27
+
28
+ tokenizer = _load_tokenizer(Path(tokenizer_path))
29
+ preprocessor = TextPreprocessor(
30
+ max_length=int(config.get("max_length", 512)),
31
+ tokenizer=tokenizer,
32
+ min_freq=int(config.get("min_freq", 1)),
33
+ lowercase=bool(config.get("lowercase", True)),
34
+ )
35
+
36
+ encoder_kwargs = dict(config.get("encoder", {}))
37
+ decoder_kwargs = dict(config.get("decoder", {}))
38
+
39
+ encoder = preprocessor.build_encoder(**encoder_kwargs)
40
+ decoder = preprocessor.build_decoder(**decoder_kwargs)
41
+ model = MultiTaskModel(encoder=encoder, decoder=decoder)
42
+
43
+ checkpoint_path = config.get("checkpoint_path")
44
+ if checkpoint_path:
45
+ state = torch.load(checkpoint_path, map_location=device)
46
+ if isinstance(state, dict) and "state_dict" in state:
47
+ state = state["state_dict"]
48
+ model.load_state_dict(state, strict=False)
49
+
50
+ model.to(device)
51
+
52
+ return {
53
+ "loaded": True,
54
+ "device": device,
55
+ "mt": model,
56
+ "preprocessor": preprocessor,
57
+ }
58
+
59
+
60
+ def summarize_text(
61
+ text: str,
62
+ compression: float = 0.25,
63
+ collect_attn: bool = False,
64
+ models: Optional[Dict[str, Any]] = None,
65
+ ) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
66
+ if models is None or not models.get("loaded"):
67
+ raise RuntimeError("Models must be loaded via load_models before summarize_text is called")
68
+
69
+ model: MultiTaskModel = models["mt"]
70
+ preprocessor: TextPreprocessor = models["preprocessor"]
71
+ device: torch.device = models["device"]
72
+
73
+ batch = preprocessor.batch_encode([text])
74
+ tokenizer = preprocessor.tokenizer
75
+ encoder = model.encoder
76
+ decoder = model.decoder
77
+ if tokenizer is None or encoder is None or decoder is None:
78
+ raise RuntimeError("Encoder, decoder, and tokenizer must be configured before summarization")
79
+ input_ids = batch.input_ids.to(device)
80
+ memory = encoder(input_ids)
81
+ src_len = batch.lengths[0]
82
+ max_tgt = max(4, int(src_len * compression))
83
+ generated = decoder.greedy_decode(
84
+ memory,
85
+ max_len=min(preprocessor.max_length, max_tgt),
86
+ start_token_id=tokenizer.bos_id,
87
+ end_token_id=tokenizer.eos_id,
88
+ )
89
+ summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
90
+ return summary.strip(), None if not collect_attn else {}
91
+
92
+
93
+ def classify_emotion(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[List[float], List[str]]:
94
+ if models is None or not models.get("loaded"):
95
+ raise RuntimeError("Models must be loaded via load_models before classify_emotion is called")
96
+
97
+ model: MultiTaskModel = models["mt"]
98
+ preprocessor: TextPreprocessor = models["preprocessor"]
99
+ device: torch.device = models["device"]
100
+
101
+ batch = preprocessor.batch_encode([text])
102
+ input_ids = batch.input_ids.to(device)
103
+ result = model.forward("emotion", {"input_ids": input_ids})
104
+ logits = result[1] if isinstance(result, tuple) else result
105
+ scores = torch.sigmoid(logits).squeeze(0).detach().cpu().tolist()
106
+ labels = models.get("emotion_labels") or [
107
+ "joy",
108
+ "sadness",
109
+ "anger",
110
+ "fear",
111
+ "surprise",
112
+ "disgust",
113
+ ]
114
+ return scores, labels[: len(scores)]
115
+
116
+
117
+ def topic_for_text(text: str, models: Optional[Dict[str, Any]] = None) -> Tuple[int, List[str]]:
118
+ if models is None or not models.get("loaded"):
119
+ raise RuntimeError("Models must be loaded via load_models before topic_for_text is called")
120
+
121
+ model: MultiTaskModel = models["mt"]
122
+ preprocessor: TextPreprocessor = models["preprocessor"]
123
+ device: torch.device = models["device"]
124
+
125
+ batch = preprocessor.batch_encode([text])
126
+ input_ids = batch.input_ids.to(device)
127
+ encoder = model.encoder
128
+ if encoder is None:
129
+ raise RuntimeError("Encoder must be configured before topic_for_text is called")
130
+ memory = encoder(input_ids)
131
+ embedding = memory.mean(dim=1).detach().cpu()
132
+ _ = embedding # placeholder for downstream clustering hook
133
+ return 0, ["topic_stub"]
src/data/download.py CHANGED
@@ -1,57 +1,63 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import requests
3
- import kaggle
4
-
5
- def download_gutenberg():
6
- """Example: download Pride and Prejudice"""
7
- url = "https://www.gutenberg.org/files/1342/1342-0.txt"
8
- os.makedirs("data/raw/books", exist_ok=True)
9
- out_path = "data/raw/books/pride_and_prejudice.txt"
10
- if not os.path.exists(out_path):
11
- r = requests.get(url)
 
 
 
12
  with open(out_path, "wb") as f:
13
  f.write(r.content)
14
- print("Downloaded:", out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Kaggle dataset download helpers
17
  def download_emotion_dataset():
18
- """Download the emotions dataset from Kaggle."""
19
  target_dir = "data/raw/emotion"
20
- os.makedirs(target_dir, exist_ok=True)
21
- # Downloading using Kaggle Python API
22
- kaggle.api.authenticate()
23
- kaggle.api.dataset_download_files(
24
- 'praveengovi/emotions-dataset-for-nlp',
25
- path=target_dir,
26
- unzip=True
27
- )
28
- print("Downloaded Kaggle emotion dataset to", target_dir)
29
 
30
  def download_cnn_dailymail():
31
- """Download the CNN/DailyMail summarization dataset from Kaggle."""
32
  target_dir = "data/raw/summarization"
33
- os.makedirs(target_dir, exist_ok=True)
34
- # Downloading using Kaggle Python API
35
- kaggle.api.authenticate()
36
- kaggle.api.dataset_download_files(
37
- 'gowrishankarp/newspaper-text-summarization-cnn-dailymail',
38
- path=target_dir,
39
- unzip=True
40
- )
41
- print("Downloaded Kaggle CNN/DailyMail dataset to", target_dir)
42
 
43
  def download_ag_news():
44
- """Download the AG News dataset from Kaggle."""
45
  target_dir = "data/raw/topic"
46
- os.makedirs(target_dir, exist_ok=True)
47
- # Downloading using Kaggle Python API
48
- kaggle.api.authenticate()
49
- kaggle.api.dataset_download_files(
50
- 'amananandrai/ag-news-classification-dataset',
51
- path=target_dir,
52
- unzip=True
53
- )
54
- print("Downloaded Kaggle AG News dataset to", target_dir)
55
 
56
  if __name__ == "__main__":
57
  download_gutenberg()
 
1
+ """
2
+ Download helpers for datasets.
3
+
4
+ This version:
5
+ - Adds robust error handling when Kaggle API is not configured.
6
+ - Stores files under data/raw/ subfolders.
7
+ - Keeps the Gutenberg direct download example.
8
+
9
+ Make sure you have Kaggle credentials configured if you call Kaggle downloads.
10
+ """
11
  import os
12
  import requests
13
+
14
+ def download_gutenberg(out_dir="data/raw/books", gutenberg_id: int = 1342, filename: str = "pride_and_prejudice.txt"):
15
+ """Download a Gutenberg text file by direct URL template (best-effort)."""
16
+ url = f"https://www.gutenberg.org/files/{gutenberg_id}/{gutenberg_id}-0.txt"
17
+ os.makedirs(out_dir, exist_ok=True)
18
+ out_path = os.path.join(out_dir, filename)
19
+ if os.path.exists(out_path):
20
+ print("Already downloaded:", out_path)
21
+ return out_path
22
+ try:
23
+ r = requests.get(url, timeout=30)
24
+ r.raise_for_status()
25
  with open(out_path, "wb") as f:
26
  f.write(r.content)
27
+ print("Downloaded:", out_path)
28
+ return out_path
29
+ except Exception as e:
30
+ print("Failed to download Gutenberg file:", e)
31
+ return None
32
+
33
+ # Kaggle helpers: optional, wrapped to avoid hard failure when Kaggle isn't configured.
34
+ def _safe_kaggle_download(dataset: str, path: str):
35
+ try:
36
+ import kaggle
37
+ except Exception as e:
38
+ print("Kaggle package not available or not configured. Please install 'kaggle' and configure API token. Error:", e)
39
+ return False
40
+ try:
41
+ os.makedirs(path, exist_ok=True)
42
+ kaggle.api.authenticate()
43
+ kaggle.api.dataset_download_files(dataset, path=path, unzip=True)
44
+ print(f"Downloaded Kaggle dataset {dataset} to {path}")
45
+ return True
46
+ except Exception as e:
47
+ print("Failed to download Kaggle dataset:", e)
48
+ return False
49
 
 
50
  def download_emotion_dataset():
 
51
  target_dir = "data/raw/emotion"
52
+ return _safe_kaggle_download('praveengovi/emotions-dataset-for-nlp', target_dir)
 
 
 
 
 
 
 
 
53
 
54
  def download_cnn_dailymail():
 
55
  target_dir = "data/raw/summarization"
56
+ return _safe_kaggle_download('gowrishankarp/newspaper-text-summarization-cnn-dailymail', target_dir)
 
 
 
 
 
 
 
 
57
 
58
  def download_ag_news():
 
59
  target_dir = "data/raw/topic"
60
+ return _safe_kaggle_download('amananandrai/ag-news-classification-dataset', target_dir)
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
  download_gutenberg()
src/data/preprocessing.py CHANGED
@@ -1,263 +1,260 @@
1
- # src/preprocessing.py
2
- import re
3
- import os
 
 
 
4
  import json
5
- import tensorflow as tf
6
- import numpy as np
7
- import pandas as pd
8
- from sklearn.model_selection import train_test_split
9
- from transformers import AutoTokenizer
10
- import nltk
11
- from nltk.corpus import stopwords
12
- from nltk.tokenize import word_tokenize
13
-
14
-
15
- class textPreprocessor:
16
- def __init__(self, max_length=512, model_name='bert-base-uncased'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  self.max_length = max_length
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
-
 
 
20
  def clean_text(self, text: str) -> str:
21
- """Cleaning and Normalizing Text"""
22
- text = re.sub(r'\s+', ' ', text) # Getting rid of extra spaces
23
- text = re.sub(r'[^a-zA-Z0-9.,;:?!\'" ]', '', text) # Removing weird characters
24
- return text.strip()
25
-
26
- def tokenize_text(self, texts: list[str]):
27
- return self.tokenizer(
28
- texts,
29
- truncation=True,
30
- padding=True,
31
- max_length=self.max_length,
32
- return_tensors='tf'
33
  )
34
-
35
- def prepare_data(self, texts: list[str], labels=None):
36
- """Preparing Data for Training"""
37
- cleaned_texts = [self.clean_text(text) for text in texts]
38
- encoded = self.tokenize_text(cleaned_texts)
39
-
40
- if labels is not None:
41
- return encoded, tf.convert_to_tensor(labels)
42
- return encoded
43
-
44
- def load_books(self, folder_path="data/raw/books") -> list[str]:
45
- """Load books from text files in the specific folder"""
46
- texts = []
47
- for filename in os.listdir(folder_path):
48
- if filename.endswith(".txt"):
49
- file_path = os.path.join(folder_path, filename)
50
- with open(file_path, 'r', encoding='utf-8', errors ="ignore") as file:
51
- raw_text = file.read()
52
- cleaned = self.clean_text(raw_text)
53
- texts.append(cleaned)
54
- return texts
55
-
56
- def chunk_text(self, text: str, chunk_size=1000, overlap=100) -> list[str]:
57
- """Splits long texts into smaller segments or chunks"""
58
- words = text.split()
59
- chunks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  start = 0
61
  while start < len(words):
62
- end = start + chunk_size
63
- chunk = " ".join(words[start:end])
64
- chunks.append(chunk)
65
  start += chunk_size - overlap
66
  return chunks
67
 
68
- def save_preprocessed_books(self, data, input_folder="data/raw/books", output_folder="data/processed/books", chunk_size=1000, overlap=100):
69
- os.makedirs(output_folder, exist_ok=True)
70
- for filename in os.listdir(input_folder):
71
- if filename.endswith(".txt"):
72
- file_path = os.path.join(input_folder, filename)
73
- with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
74
- raw_text = f.read()
75
- cleaned = self.clean_text(raw_text)
76
- chunks = self.chunk_text(cleaned, chunk_size, overlap)
77
-
78
- # Saving as JSON, one file for each book
79
- out_file = os.path.join(output_folder, filename.replace(".txt", ".json"))
80
- with open(out_file, "w", encoding="utf-8") as out:
81
- json.dump(chunks, out, ensure_ascii=False, indent=2)
82
-
83
- print(f"Processed and saved {filename} → {out_file}")
84
-
85
-
86
- # ----- Dataset-specific processing methods ------
87
-
88
- def process_summarization_dataset(self):
89
- """Process summarization dataset: clean, split, and save."""
90
- input_folder = "data/raw/summarization/cnn_dailymail"
91
- output_folder = "data/processed/summarization"
92
- os.makedirs(output_folder, exist_ok=True)
93
-
94
- # Process each CSV file separately (train.csv, validation.csv, test.csv)
95
- file_mapping = {
96
- 'train.csv': 'train',
97
- 'validation.csv': 'val',
98
- 'test.csv': 'test'
99
- }
100
-
101
- for csv_file, split_name in file_mapping.items():
102
- file_path = os.path.join(input_folder, csv_file)
103
- if not os.path.exists(file_path):
104
- print(f"Missing file: {file_path}")
105
- continue
106
-
107
- print(f"Processing {csv_file}...")
108
- df = pd.read_csv(file_path)
109
-
110
- # Check for required columns (article and highlights)
111
- if 'article' not in df.columns or 'highlights' not in df.columns:
112
- print(f"CSV {csv_file} must have 'article' and 'highlights' columns.")
113
- continue
114
-
115
- # Clean the text data
116
- df['article'] = df['article'].astype(str).apply(self.clean_text)
117
- df['summary'] = df['highlights'].astype(str).apply(self.clean_text) # rename highlights to summary
118
-
119
- # Convert to records format
120
- records = df[['article', 'summary']].to_dict(orient='records')
121
-
122
- # Save as JSON
123
- output_file = os.path.join(output_folder, f"{split_name}.json")
124
- with open(output_file, "w", encoding="utf-8") as f:
125
- json.dump(records, f, ensure_ascii=False, indent=2)
126
- print(f"Processed {csv_file}: {len(records)} samples saved to {split_name}.json")
127
-
128
- print("Summarization dataset processed and saved.")
129
-
130
- def process_emotion_dataset(self):
131
- """Process emotion dataset: clean, split, and save."""
132
- input_folder = "data/raw/emotion"
133
- output_folder = "data/processed/emotion"
134
- os.makedirs(output_folder, exist_ok=True)
135
-
136
- # Process each txt file (train.txt, val.txt, test.txt)
137
- for split_file in ['train.txt', 'val.txt', 'test.txt']:
138
- file_path = os.path.join(input_folder, split_file)
139
- if not os.path.exists(file_path):
140
- print(f"Missing file: {file_path}")
141
- continue
142
-
143
- records = []
144
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
145
- for line in f:
146
- line = line.strip()
147
- if line and ';' in line:
148
- # Split on the last semicolon to handle semicolons in text
149
- text, label = line.rsplit(';', 1)
150
- records.append({
151
- 'text': self.clean_text(text),
152
- 'label': label.strip()
153
- })
154
-
155
- # Save as JSON
156
- split_name = split_file.replace('.txt', '')
157
- output_file = os.path.join(output_folder, f"{split_name}.json")
158
- with open(output_file, "w", encoding="utf-8") as f:
159
- json.dump(records, f, ensure_ascii=False, indent=2)
160
- print(f"Processed {split_file}: {len(records)} samples saved to {split_name}.json")
161
-
162
- print("Emotion dataset processed and saved.")
163
-
164
- def process_topic_dataset(self):
165
- """Process topic dataset: clean, split, and save."""
166
- input_folder = "data/raw/topic"
167
- output_folder = "data/processed/topic"
168
- os.makedirs(output_folder, exist_ok=True)
169
-
170
- # Process each CSV file separately (train.csv, test.csv)
171
- file_mapping = {
172
- 'train.csv': 'train',
173
- 'test.csv': 'test'
174
- }
175
-
176
- # Class index to topic name mapping for AG News dataset
177
- class_map = {
178
- 1: 'World',
179
- 2: 'Sports',
180
- 3: 'Business',
181
- 4: 'Science/Technology'
182
- }
183
-
184
- for csv_file, split_name in file_mapping.items():
185
- file_path = os.path.join(input_folder, csv_file)
186
- if not os.path.exists(file_path):
187
- print(f"Missing file: {file_path}")
188
- continue
189
-
190
- print(f"Processing {csv_file}...")
191
- df = pd.read_csv(file_path)
192
-
193
- # Check for required columns
194
- if 'Class Index' not in df.columns:
195
- print(f"CSV {csv_file} must have 'Class Index' column.")
196
- continue
197
-
198
- # Concatenate title and description
199
- if 'Title' in df.columns and 'Description' in df.columns:
200
- text = df['Title'].astype(str) + ". " + df['Description'].astype(str)
201
- elif 'Title' in df.columns:
202
- text = df['Title'].astype(str)
203
- elif 'Description' in df.columns:
204
- text = df['Description'].astype(str)
205
- else:
206
- print("CSV must have 'Title' or 'Description' columns.")
207
- continue
208
-
209
- df['text'] = text.apply(self.clean_text)
210
-
211
- # Map numeric labels to category names
212
- df['label'] = df['Class Index'].map(class_map)
213
-
214
- # Convert to records format
215
- records = df[['text', 'label']].to_dict(orient='records')
216
-
217
- # Save as JSON
218
- output_file = os.path.join(output_folder, f"{split_name}.json")
219
- with open(output_file, "w", encoding="utf-8") as f:
220
- json.dump(records, f, ensure_ascii=False, indent=2)
221
- print(f"Processed {csv_file}: {len(records)} samples saved to {split_name}.json")
222
-
223
- # Create validation split from training data
224
- if os.path.exists(os.path.join(output_folder, "train.json")):
225
- print("Creating validation split from training data...")
226
- with open(os.path.join(output_folder, "train.json"), "r", encoding="utf-8") as f:
227
- train_data = json.load(f)
228
-
229
- # Split training data into train and validation
230
- train_records, val_records = train_test_split(train_data, test_size=0.2, random_state=42)
231
-
232
- # Save updated train and new validation files
233
- with open(os.path.join(output_folder, "train.json"), "w", encoding="utf-8") as f:
234
- json.dump(train_records, f, ensure_ascii=False, indent=2)
235
-
236
- with open(os.path.join(output_folder, "val.json"), "w", encoding="utf-8") as f:
237
- json.dump(val_records, f, ensure_ascii=False, indent=2)
238
-
239
- print(f"Updated train.json: {len(train_records)} samples")
240
- print(f"Created val.json: {len(val_records)} samples")
241
-
242
- print("Topic dataset processed and saved.")
243
-
244
-
245
- # ----- Main function for quick testing ------
246
-
247
- if __name__ == "__main__":
248
- preprocessor = textPreprocessor(max_length=128)
249
-
250
- # Process and save all books
251
- preprocessor.save_preprocessed_books(data=None)
252
-
253
- # Load a processed book back
254
- import json
255
- with open("data/processed/books/pride_and_prejudice.json", "r") as f:
256
- chunks = json.load(f)
257
- print(f"Loaded {len(chunks)} chunks from Pride and Prejudice")
258
- print(chunks[0][:200]) # printing first 200 chars of chunk
259
-
260
- # Process new datasets
261
- preprocessor.process_summarization_dataset()
262
- preprocessor.process_emotion_dataset()
263
- preprocessor.process_topic_dataset()
 
1
+ """Lightweight preprocessing utilities built around the in-repo transformer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import Counter
6
+ from dataclasses import dataclass
7
  import json
8
+ from pathlib import Path
9
+ import re
10
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
11
+
12
+ import torch
13
+
14
+ from ..models.decoder import TransformerDecoder
15
+ from ..models.encoder import TransformerEncoder
16
+
17
+ SPECIAL_TOKENS: Tuple[str, str, str, str] = ("<pad>", "<bos>", "<eos>", "<unk>")
18
+
19
+
20
+ def _normalize(text: str, lowercase: bool) -> str:
21
+ text = text.strip()
22
+ text = re.sub(r"\s+", " ", text)
23
+ if lowercase:
24
+ text = text.lower()
25
+ return text
26
+
27
+
28
+ def _basic_tokenize(text: str) -> List[str]:
29
+ return re.findall(r"\b\w+\b|[.,;:?!]", text)
30
+
31
+
32
+ class TransformerTokenizer:
33
+ """Minimal tokenizer that keeps vocabulary aligned with the custom transformer."""
34
+
35
+ def __init__(
36
+ self,
37
+ stoi: Dict[str, int],
38
+ itos: List[str],
39
+ specials: Sequence[str] = SPECIAL_TOKENS,
40
+ lowercase: bool = True,
41
+ ) -> None:
42
+ self.stoi = stoi
43
+ self.itos = itos
44
+ self.specials = tuple(specials)
45
+ self.lowercase = lowercase
46
+ self.pad_id = self._lookup(self.specials[0])
47
+ self.bos_id = self._lookup(self.specials[1])
48
+ self.eos_id = self._lookup(self.specials[2])
49
+ self.unk_id = self._lookup(self.specials[3])
50
+
51
+ @classmethod
52
+ def build(
53
+ cls,
54
+ texts: Iterable[str],
55
+ min_freq: int = 1,
56
+ lowercase: bool = True,
57
+ specials: Sequence[str] = SPECIAL_TOKENS,
58
+ ) -> "TransformerTokenizer":
59
+ counter: Counter[str] = Counter()
60
+ for text in texts:
61
+ normalized = _normalize(text, lowercase)
62
+ counter.update(_basic_tokenize(normalized))
63
+
64
+ ordered_specials = list(dict.fromkeys(specials))
65
+ itos: List[str] = ordered_specials.copy()
66
+ for token, freq in counter.most_common():
67
+ if freq < min_freq:
68
+ continue
69
+ if token in itos:
70
+ continue
71
+ itos.append(token)
72
+
73
+ stoi = {token: idx for idx, token in enumerate(itos)}
74
+ return cls(stoi=stoi, itos=itos, specials=ordered_specials, lowercase=lowercase)
75
+
76
+ @property
77
+ def vocab_size(self) -> int:
78
+ return len(self.itos)
79
+
80
+ def tokenize(self, text: str) -> List[str]:
81
+ normalized = _normalize(text, self.lowercase)
82
+ return _basic_tokenize(normalized)
83
+
84
+ def encode(
85
+ self,
86
+ text: str,
87
+ add_special_tokens: bool = True,
88
+ max_length: Optional[int] = None,
89
+ ) -> List[int]:
90
+ tokens = self.tokenize(text)
91
+ pieces = [self.stoi.get(tok, self.unk_id) for tok in tokens]
92
+ if add_special_tokens:
93
+ pieces = [self.bos_id] + pieces + [self.eos_id]
94
+
95
+ if max_length is not None and len(pieces) > max_length:
96
+ if add_special_tokens and max_length >= 2:
97
+ inner_max = max_length - 2
98
+ trimmed = pieces[1:-1][:inner_max]
99
+ pieces = [self.bos_id] + trimmed + [self.eos_id]
100
+ else:
101
+ pieces = pieces[:max_length]
102
+ return pieces
103
+
104
+ def decode(self, ids: Sequence[int], skip_special_tokens: bool = True) -> str:
105
+ tokens: List[str] = []
106
+ for idx in ids:
107
+ if idx < 0 or idx >= len(self.itos):
108
+ continue
109
+ token = self.itos[idx]
110
+ if skip_special_tokens and token in self.specials:
111
+ continue
112
+ tokens.append(token)
113
+ return " ".join(tokens).strip()
114
+
115
+ def pad_batch(
116
+ self,
117
+ sequences: Sequence[Sequence[int]],
118
+ pad_to_length: Optional[int] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ if not sequences:
121
+ raise ValueError("pad_batch requires at least one sequence")
122
+ if pad_to_length is None:
123
+ pad_to_length = max(len(seq) for seq in sequences)
124
+ padded: List[List[int]] = []
125
+ mask: List[List[int]] = []
126
+ for seq in sequences:
127
+ trimmed = list(seq[:pad_to_length])
128
+ pad_len = pad_to_length - len(trimmed)
129
+ padded.append(trimmed + [self.pad_id] * pad_len)
130
+ mask.append([1] * len(trimmed) + [0] * pad_len)
131
+ return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)
132
+
133
+ def save(self, path: Path) -> None:
134
+ payload = {
135
+ "itos": self.itos,
136
+ "specials": list(self.specials),
137
+ "lowercase": self.lowercase,
138
+ }
139
+ path.parent.mkdir(parents=True, exist_ok=True)
140
+ path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
141
+
142
+ @classmethod
143
+ def load(cls, path: Path) -> "TransformerTokenizer":
144
+ data = json.loads(path.read_text(encoding="utf-8"))
145
+ itos = list(data["itos"])
146
+ stoi = {token: idx for idx, token in enumerate(itos)}
147
+ specials = data.get("specials", list(SPECIAL_TOKENS))
148
+ lowercase = bool(data.get("lowercase", True))
149
+ return cls(stoi=stoi, itos=itos, specials=specials, lowercase=lowercase)
150
+
151
+ def _lookup(self, token: str) -> int:
152
+ if token not in self.stoi:
153
+ raise ValueError(f"token '{token}' missing from vocabulary")
154
+ return self.stoi[token]
155
+
156
+
157
+ @dataclass
158
+ class Batch:
159
+ input_ids: torch.Tensor
160
+ attention_mask: torch.Tensor
161
+ lengths: List[int]
162
+
163
+
164
+ class TextPreprocessor:
165
+ """Prepares text so it can flow directly into the custom transformer stack."""
166
+
167
+ def __init__(
168
+ self,
169
+ max_length: int = 512,
170
+ tokenizer: Optional[TransformerTokenizer] = None,
171
+ *,
172
+ min_freq: int = 1,
173
+ lowercase: bool = True,
174
+ ) -> None:
175
  self.max_length = max_length
176
+ self.min_freq = min_freq
177
+ self.lowercase = lowercase
178
+ self.tokenizer = tokenizer
179
+
180
  def clean_text(self, text: str) -> str:
181
+ return _normalize(text, self.lowercase)
182
+
183
+ def fit_tokenizer(self, texts: Iterable[str]) -> TransformerTokenizer:
184
+ cleaned = [self.clean_text(text) for text in texts]
185
+ self.tokenizer = TransformerTokenizer.build(
186
+ cleaned,
187
+ min_freq=self.min_freq,
188
+ lowercase=False,
 
 
 
 
189
  )
190
+ return self.tokenizer
191
+
192
+ def encode(self, text: str, *, add_special_tokens: bool = True) -> List[int]:
193
+ if self.tokenizer is None:
194
+ raise RuntimeError("Tokenizer not fitted")
195
+ cleaned = self.clean_text(text)
196
+ return self.tokenizer.encode(cleaned, add_special_tokens=add_special_tokens, max_length=self.max_length)
197
+
198
+ def batch_encode(self, texts: Sequence[str]) -> Batch:
199
+ if self.tokenizer is None:
200
+ raise RuntimeError("Tokenizer not fitted")
201
+ sequences = [self.encode(text) for text in texts]
202
+ lengths = [len(seq) for seq in sequences]
203
+ input_ids, attention_mask = self.tokenizer.pad_batch(sequences, pad_to_length=self.max_length)
204
+ return Batch(input_ids=input_ids, attention_mask=attention_mask, lengths=lengths)
205
+
206
+ def build_encoder(self, **encoder_kwargs) -> TransformerEncoder:
207
+ if self.tokenizer is None:
208
+ raise RuntimeError("Tokenizer not fitted")
209
+ return TransformerEncoder(
210
+ vocab_size=self.tokenizer.vocab_size,
211
+ max_len=self.max_length,
212
+ pad_token_id=self.tokenizer.pad_id,
213
+ **encoder_kwargs,
214
+ )
215
+
216
+ def build_decoder(self, **decoder_kwargs) -> TransformerDecoder:
217
+ if self.tokenizer is None:
218
+ raise RuntimeError("Tokenizer not fitted")
219
+ return TransformerDecoder(
220
+ vocab_size=self.tokenizer.vocab_size,
221
+ max_len=self.max_length,
222
+ pad_token_id=self.tokenizer.pad_id,
223
+ **decoder_kwargs,
224
+ )
225
+
226
+ def save_tokenizer(self, path: Path) -> None:
227
+ if self.tokenizer is None:
228
+ raise RuntimeError("Tokenizer not fitted")
229
+ self.tokenizer.save(path)
230
+
231
+ def load_tokenizer(self, path: Path) -> TransformerTokenizer:
232
+ self.tokenizer = TransformerTokenizer.load(path)
233
+ return self.tokenizer
234
+
235
+ def chunk_text(self, text: str, *, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
236
+ if chunk_size <= overlap:
237
+ raise ValueError("chunk_size must be larger than overlap")
238
+ words = self.clean_text(text).split()
239
+ chunks: List[str] = []
240
  start = 0
241
  while start < len(words):
242
+ end = min(start + chunk_size, len(words))
243
+ chunks.append(" ".join(words[start:end]))
 
244
  start += chunk_size - overlap
245
  return chunks
246
 
247
+ def save_book_chunks(
248
+ self,
249
+ input_path: Path,
250
+ out_dir: Path,
251
+ *,
252
+ chunk_size: int = 1000,
253
+ overlap: int = 100,
254
+ ) -> Path:
255
+ out_dir.mkdir(parents=True, exist_ok=True)
256
+ raw_text = input_path.read_text(encoding="utf-8", errors="ignore")
257
+ chunks = self.chunk_text(raw_text, chunk_size=chunk_size, overlap=overlap)
258
+ out_file = out_dir / f"{input_path.stem}.json"
259
+ out_file.write_text(json.dumps(chunks, ensure_ascii=False, indent=2), encoding="utf-8")
260
+ return out_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference/__init__.py CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference utilities for LexiMind.
3
+ """
4
+
5
+ from .baseline_summarizer import Summarizer, TransformerSummarizer
6
+
7
+ __all__ = ["Summarizer", "TransformerSummarizer"]
src/inference/baseline_summarizer.py CHANGED
@@ -1,222 +1,41 @@
1
- import os
2
- import json
3
- from typing import Any, List, Dict, Optional
4
- import torch
5
- from torch.utils.data import Dataset, DataLoader
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
-
8
- class Summarizer:
9
- def __init__(self, model_name: str = "t5-small", max_input: int = 512, max_output: int = 128, device: Optional[str] = None):
10
- self.model_name = model_name
11
- self.max_input = max_input
12
- self.max_output = max_output
13
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
- self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- self.model.to(self.device)
17
-
18
- def load_data(self, split: str = "train", limit: Optional[int] = None) -> List[Dict[str, str]]:
19
- """
20
- Load processed summarization data from JSON files.
21
-
22
- Args:
23
- split (str): Data split to load ('train', 'val', 'test')
24
- limit (int): Maximum number of samples to load (None for all)
25
-
26
- Returns:
27
- list: List of dictionaries with 'article' and 'summary' keys
28
- """
29
- # Resolve to project root regardless of current working directory
30
- root = os.path.dirname(os.path.dirname(__file__))
31
- file_path = os.path.join(root, "data", "processed", "summarization", f"{split}.json")
32
-
33
- if not os.path.exists(file_path):
34
- raise FileNotFoundError(f"Data file not found: {file_path}")
35
-
36
- with open(file_path, "r", encoding="utf-8") as f:
37
- data = json.load(f)
38
-
39
- if limit:
40
- data = data[:limit]
41
- return data
42
-
43
- def encode(self, articles: List[str] | str, summaries: Optional[List[str] | str] = None):
44
- if isinstance(articles, str):
45
- articles = [articles]
46
- if summaries is not None and isinstance(summaries, str):
47
- summaries = [summaries]
48
-
49
- inputs = self.tokenizer(
50
- [f"summarize: {a}" for a in articles],
51
- max_length=self.max_input,
52
- truncation=True,
53
- padding="max_length",
54
- return_tensors="pt"
55
- )
56
-
57
- result = {
58
- "input_ids": inputs.input_ids.to(self.device),
59
- "attention_mask": inputs.attention_mask.to(self.device)
60
- }
61
-
62
- if summaries is not None:
63
- labels = self.tokenizer(
64
- summaries,
65
- max_length=self.max_output,
66
- truncation=True,
67
- padding="max_length",
68
- return_tensors="pt"
69
- ).input_ids
70
- # Mask pad tokens in labels with -100 for loss
71
- labels[labels == self.tokenizer.pad_token_id] = -100
72
- result["labels"] = labels.to(self.device)
73
- return result
74
-
75
- def train(self, epochs: int = 3, batch_size: int = 4, train_limit: int = 2000, val_limit: int = 500, learning_rate: float = 5e-5):
76
- train_data = self.load_data("train", limit=train_limit)
77
- val_data = self.load_data("val", limit=val_limit)
78
-
79
- train_ds = _SummarizationDataset(train_data, self.tokenizer, self.max_input, self.max_output)
80
- val_ds = _SummarizationDataset(val_data, self.tokenizer, self.max_input, self.max_output) if val_data else None
81
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
82
- val_loader = DataLoader(val_ds, batch_size=batch_size) if val_ds else None
83
-
84
- optim = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
85
 
86
- self.model.train()
87
- for epoch in range(epochs):
88
- print(f"Epoch {epoch+1}/{epochs} - {len(train_loader)} batches", flush=True)
89
- for i, batch in enumerate(train_loader, start=1):
90
- batch = {k: v.to(self.device) for k, v in batch.items()}
91
- outputs = self.model(**batch)
92
- loss = outputs.loss
93
- loss.backward()
94
- optim.step()
95
- optim.zero_grad()
96
- if (i % max(1, len(train_loader)//5 or 1)) == 0:
97
- print(f" step {i}/{len(train_loader)} - loss {float(loss):.4f}", flush=True)
98
-
99
- if val_loader:
100
- _ = self.evaluate(val_data[: min(100, len(val_data))])
101
- print("Training complete.", flush=True)
102
-
103
- def evaluate(self, val_data: List[Dict[str, str]]) -> float:
104
- if not val_data:
105
- return 0.0
106
-
107
- ds = _SummarizationDataset(val_data, self.tokenizer, self.max_input, self.max_output)
108
- loader = DataLoader(ds, batch_size=4)
109
- self.model.eval()
110
- total = 0.0
111
- count = 0
112
- with torch.no_grad():
113
- for batch in loader:
114
- batch = {k: v.to(self.device) for k, v in batch.items()}
115
- outputs = self.model(**batch)
116
- total += float(outputs.loss) * batch["input_ids"].size(0)
117
- count += batch["input_ids"].size(0)
118
- self.model.train()
119
- return total / max(count, 1)
120
-
121
- def summarize(self, text: str, max_length: Optional[int] = None, num_beams: int = 4) -> str:
122
- if not text.strip():
123
- return ""
124
- inputs = self.tokenizer(
125
- f"summarize: {text}",
126
- return_tensors="pt",
127
- max_length=self.max_input,
128
- truncation=True,
129
- padding=True
130
- )
131
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
- with torch.no_grad():
133
- summary_ids = self.model.generate(
134
- inputs["input_ids"],
135
- attention_mask=inputs.get("attention_mask"),
136
- max_length=max_length or self.max_output,
137
- num_beams=num_beams,
138
- length_penalty=2.0,
139
- early_stopping=True
140
- )
141
- return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
142
-
143
- def save(self, path: str = "models/summarizer"):
144
- """
145
- Save the trained model and tokenizer.
146
-
147
- Args:
148
- path (str): Directory path to save the model
149
- """
150
- os.makedirs(path, exist_ok=True)
151
- self.model.save_pretrained(path)
152
- self.tokenizer.save_pretrained(path)
153
-
154
- @classmethod
155
- def load(cls, path: str = "models/summarizer"):
156
- """
157
- Load a pre-trained model from disk.
158
-
159
- Args:
160
- path (str): Directory path containing the saved model
161
-
162
- Returns:
163
- Summarizer: Loaded summarizer instance
164
- """
165
- obj = cls.__new__(cls)
166
- obj.model_name = path
167
- obj.max_input = 512
168
- obj.max_output = 128
169
- obj.tokenizer = AutoTokenizer.from_pretrained(path)
170
- obj.model = AutoModelForSeq2SeqLM.from_pretrained(path)
171
- obj.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
- obj.model.to(obj.device)
173
- return obj
174
-
175
- class _SummarizationDataset(Dataset):
176
- def __init__(self, data: List[Dict[str, str]], tokenizer: Any, max_input: int, max_output: int):
177
- self.data = data
178
- self.tokenizer = tokenizer
179
- self.max_input = max_input
180
- self.max_output = max_output
181
-
182
- def __len__(self):
183
- return len(self.data)
184
-
185
- def __getitem__(self, idx: int):
186
- item = self.data[idx]
187
- inputs = self.tokenizer(
188
- f"summarize: {item['article']}",
189
- max_length=self.max_input,
190
- truncation=True,
191
- padding="max_length",
192
- return_tensors="pt"
193
  )
194
- labels = self.tokenizer(
195
- item["summary"],
196
- max_length=self.max_output,
197
- truncation=True,
198
- padding="max_length",
199
- return_tensors="pt"
200
- ).input_ids
201
- labels[labels == self.tokenizer.pad_token_id] = -100
202
- return {
203
- "input_ids": inputs.input_ids.squeeze(0),
204
- "attention_mask": inputs.attention_mask.squeeze(0),
205
- "labels": labels.squeeze(0),
206
- }
207
-
208
- if __name__ == "__main__":
209
- print("Initializing summarizer...", flush=True)
210
- summarizer = Summarizer(model_name="t5-small")
211
- print("Starting a short training run...", flush=True)
212
- summarizer.train(epochs=3, batch_size=2, train_limit=100, val_limit=50)
213
- test_text = (
214
- "The quick brown fox jumps over the lazy dog. This is a common "
215
- "pangram used in typography and printing. It contains every letter of the "
216
- "alphabet at least once, making it useful for testing fonts and keyboards."
217
- )
218
- print("Generating summary...", flush=True)
219
- summary = summarizer.summarize(test_text)
220
- print(f"\nOriginal text: {test_text}")
221
- print(f"Summary: {summary}")
222
- summarizer.save()
 
1
+ """Thin wrapper around the custom transformer summarizer."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from __future__ import annotations
4
+ from typing import Any, Dict, Optional, Tuple
5
+ import torch
6
+ from ..api.inference import load_models
7
+
8
+
9
+ class TransformerSummarizer:
10
+ def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
11
+ models = load_models(config or {})
12
+ if not models.get("loaded"):
13
+ raise RuntimeError("load_models returned an unloaded model; check configuration")
14
+ self.model = models["mt"]
15
+ self.preprocessor = models["preprocessor"]
16
+ self.device = models["device"]
17
+
18
+ def summarize(
19
+ self,
20
+ text: str,
21
+ compression: float = 0.25,
22
+ collect_attn: bool = False,
23
+ ) -> Tuple[str, Optional[Dict[str, torch.Tensor]]]:
24
+ batch = self.preprocessor.batch_encode([text])
25
+ tokenizer = self.preprocessor.tokenizer
26
+ encoder = self.model.encoder
27
+ decoder = self.model.decoder
28
+ if tokenizer is None or encoder is None or decoder is None:
29
+ raise RuntimeError("Model components are missing; ensure encoder, decoder, and tokenizer are set")
30
+ input_ids = batch.input_ids.to(self.device)
31
+ memory = encoder(input_ids)
32
+ src_len = batch.lengths[0]
33
+ target_len = max(4, int(src_len * compression))
34
+ generated = decoder.greedy_decode(
35
+ memory,
36
+ max_len=min(self.preprocessor.max_length, target_len),
37
+ start_token_id=tokenizer.bos_id,
38
+ end_token_id=tokenizer.eos_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ summary = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
41
+ return summary.strip(), None if not collect_attn else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/__init__.py CHANGED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LexiMind custom transformer models.
3
+
4
+ This package provides a from-scratch transformer implementation with:
5
+ - TransformerEncoder/TransformerDecoder
6
+ - MultiHeadAttention, FeedForward, PositionalEncoding
7
+ - Task heads: ClassificationHead, TokenClassificationHead, LMHead
8
+ - MultiTaskModel: composable wrapper for encoder/decoder + task heads
9
+ """
10
+
11
+ from .encoder import TransformerEncoder, TransformerEncoderLayer
12
+ from .decoder import TransformerDecoder, TransformerDecoderLayer, create_causal_mask
13
+ from .attention import MultiHeadAttention
14
+ from .feedforward import FeedForward
15
+ from .positional_encoding import PositionalEncoding
16
+ from .heads import ClassificationHead, TokenClassificationHead, LMHead, ProjectionHead
17
+ from .multitask import MultiTaskModel
18
+
19
+ __all__ = [
20
+ "TransformerEncoder",
21
+ "TransformerEncoderLayer",
22
+ "TransformerDecoder",
23
+ "TransformerDecoderLayer",
24
+ "create_causal_mask",
25
+ "MultiHeadAttention",
26
+ "FeedForward",
27
+ "PositionalEncoding",
28
+ "ClassificationHead",
29
+ "TokenClassificationHead",
30
+ "LMHead",
31
+ "ProjectionHead",
32
+ "MultiTaskModel",
33
+ ]
src/ui/streamlit_app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit prototype for LexiMind (summarization, emotion, topic).
3
+ Run from repo root: streamlit run streamlit_app.py
4
+ """
5
+ import streamlit as st
6
+ import numpy as np
7
+ import pandas as pd
8
+ import plotly.express as px
9
+ import plotly.figure_factory as ff
10
+
11
+ # Stable absolute import; ensure repo root is on PYTHONPATH (running from repo root is standard)
12
+ try:
13
+ from ..api.inference import load_models, summarize_text, classify_emotion, topic_for_text
14
+ except Exception as e:
15
+ st.error(f"Failed to import inference helpers: {e}")
16
+ raise
17
+
18
+ st.set_page_config(page_title="LexiMind demo", layout="wide")
19
+
20
+ MODEL_CONFIG = {
21
+ "checkpoint_path": "checkpoints/best.pt", # change to your trained checkpoint
22
+ "tokenizer_path": "artifacts/tokenizer.json", # JSON produced by TextPreprocessor.save_tokenizer
23
+ "device": "cpu",
24
+ }
25
+ try:
26
+ models = load_models(MODEL_CONFIG)
27
+ except Exception as exc:
28
+ st.error(f"Failed to load models: {exc}")
29
+ st.stop()
30
+
31
+ st.sidebar.title("LexiMind")
32
+ task = st.sidebar.selectbox("Task", ["Summarize", "Emotion", "Topic", "Search demo"])
33
+ compression = st.sidebar.slider("Compression (summary length)", 0.1, 1.0, 0.25)
34
+ show_attn = st.sidebar.checkbox("Show attention heatmap (collect_attn)", value=False)
35
+
36
+ st.sidebar.markdown("Demo controls")
37
+ sample_choice = st.sidebar.selectbox("Use sample text", ["None", "Gutenberg sample", "News sample"])
38
+
39
+ SAMPLES = {
40
+ "Gutenberg sample": (
41
+ "It was the best of times, it was the worst of times, it was the age of wisdom, "
42
+ "it was the age of foolishness..."
43
+ ),
44
+ "News sample": (
45
+ "Markets rallied today as tech stocks posted gains amid broad optimism over earnings..."
46
+ ),
47
+ }
48
+
49
+ st.title("LexiMind — Summarization, Emotion, Topic (Prototype)")
50
+
51
+ if sample_choice != "None":
52
+ input_text = st.text_area("Input text", value=SAMPLES[sample_choice], height=280)
53
+ else:
54
+ input_text = st.text_area("Input text", value="", height=280)
55
+
56
+ col1, col2 = st.columns([2, 1])
57
+
58
+ with col1:
59
+ st.subheader("Output")
60
+ if st.button("Run"):
61
+ if not input_text.strip():
62
+ st.warning("Enter some text or select a sample to run the model.")
63
+ else:
64
+ if task == "Summarize":
65
+ summary, attn_data = summarize_text(input_text, compression=compression, collect_attn=show_attn, models=models)
66
+ st.markdown("**Summary**")
67
+ st.write(summary)
68
+ if show_attn and attn_data is not None:
69
+ st.markdown("**Attention heatmap (averaged heads)**")
70
+ src_tokens = attn_data.get("src_tokens", None)
71
+ tgt_tokens = attn_data.get("tgt_tokens", None)
72
+ weights = attn_data.get("weights", None)
73
+ if weights is not None:
74
+ arr = np.array(weights)
75
+ if arr.ndim == 4:
76
+ arr = arr.mean(axis=(0,1))
77
+ elif arr.ndim == 3:
78
+ arr = arr.mean(axis=0)
79
+ fig = ff.create_annotated_heatmap(
80
+ z=arr.tolist(),
81
+ x=src_tokens if src_tokens else [f"tok{i}" for i in range(arr.shape[1])],
82
+ y=tgt_tokens if tgt_tokens else [f"tok{i}" for i in range(arr.shape[0])],
83
+ colorscale="Viridis",
84
+ )
85
+ st.plotly_chart(fig, use_container_width=True)
86
+ else:
87
+ st.info("Attention data not available from the model.")
88
+ elif task == "Emotion":
89
+ probs, labels = classify_emotion(input_text, models=models)
90
+ st.markdown("**Emotion predictions (multi-label probabilities)**")
91
+ df = pd.DataFrame({"emotion": labels, "prob": probs})
92
+ fig = px.bar(df, x="emotion", y="prob", color="prob", range_y=[0,1])
93
+ st.plotly_chart(fig, use_container_width=True)
94
+ elif task == "Topic":
95
+ topic_id, topic_terms = topic_for_text(input_text, models=models)
96
+ st.markdown("**Topic cluster**")
97
+ st.write(f"Cluster ID: {topic_id}")
98
+ st.write("Top terms:", ", ".join(topic_terms))
99
+ elif task == "Search demo":
100
+ st.info("Search demo will be available when ingestion is run (see scripts).")
101
+
102
+ with col2:
103
+ st.subheader("Model & Info")
104
+ st.markdown(f"*Model loaded:* {'yes' if models.get('loaded', False) else 'no'}")
105
+ st.markdown(f"*Device:* {models.get('device', MODEL_CONFIG['device'])}")
106
+ st.markdown("**Notes**")
107
+ st.markdown("- Attention visualization depends on model support to return attention.")
108
+ st.markdown("- For long inputs the UI truncates tokens for heatmap clarity.")