Spaces:
Sleeping
Sleeping
Commit
·
7aaf14d
1
Parent(s):
10efa63
Summarization fix
Browse files- scripts/demo_gradio.py +9 -2
- src/inference/pipeline.py +13 -6
scripts/demo_gradio.py
CHANGED
|
@@ -126,7 +126,8 @@ def predict(text: str, compression: int):
|
|
| 126 |
logger.info("Generating summary with max length %s", max_len)
|
| 127 |
|
| 128 |
summary = pipeline.summarize([text], max_length=max_len)[0].strip()
|
| 129 |
-
|
|
|
|
| 130 |
topic = pipeline.predict_topics([text])[0]
|
| 131 |
|
| 132 |
fallback_summary = None
|
|
@@ -451,10 +452,16 @@ def load_rouge_metrics():
|
|
| 451 |
)
|
| 452 |
|
| 453 |
table = pd.DataFrame(rows, columns=columns) if rows else empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
metadata = {
|
| 455 |
"num_examples": report.get("num_examples"),
|
| 456 |
"config": report.get("config"),
|
| 457 |
-
"report_path":
|
| 458 |
"last_updated": datetime.fromtimestamp(ROUGE_REPORT_PATH.stat().st_mtime).isoformat(),
|
| 459 |
}
|
| 460 |
return table, metadata
|
|
|
|
| 126 |
logger.info("Generating summary with max length %s", max_len)
|
| 127 |
|
| 128 |
summary = pipeline.summarize([text], max_length=max_len)[0].strip()
|
| 129 |
+
# Use a higher threshold to filter out weak/wrong predictions on out-of-domain text
|
| 130 |
+
emotions = pipeline.predict_emotions([text], threshold=0.6)[0]
|
| 131 |
topic = pipeline.predict_topics([text])[0]
|
| 132 |
|
| 133 |
fallback_summary = None
|
|
|
|
| 452 |
)
|
| 453 |
|
| 454 |
table = pd.DataFrame(rows, columns=columns) if rows else empty
|
| 455 |
+
|
| 456 |
+
# Clean up path for display
|
| 457 |
+
display_path = str(ROUGE_REPORT_PATH)
|
| 458 |
+
if "/app/" in display_path:
|
| 459 |
+
display_path = display_path.replace("/app/", "/LexiMind/")
|
| 460 |
+
|
| 461 |
metadata = {
|
| 462 |
"num_examples": report.get("num_examples"),
|
| 463 |
"config": report.get("config"),
|
| 464 |
+
"report_path": display_path,
|
| 465 |
"last_updated": datetime.fromtimestamp(ROUGE_REPORT_PATH.stat().st_mtime).isoformat(),
|
| 466 |
}
|
| 467 |
return table, metadata
|
src/inference/pipeline.py
CHANGED
|
@@ -75,8 +75,8 @@ class InferencePipeline:
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
-
#
|
| 79 |
-
min_len =
|
| 80 |
generated = self.model.decoder.greedy_decode(
|
| 81 |
memory=memory,
|
| 82 |
max_len=max_len,
|
|
@@ -86,11 +86,18 @@ class InferencePipeline:
|
|
| 86 |
min_len=min_len,
|
| 87 |
)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
return
|
| 94 |
|
| 95 |
def predict_emotions(
|
| 96 |
self,
|
|
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
+
# Force a minimum length to prevent immediate EOS
|
| 79 |
+
min_len = 10
|
| 80 |
generated = self.model.decoder.greedy_decode(
|
| 81 |
memory=memory,
|
| 82 |
max_len=max_len,
|
|
|
|
| 86 |
min_len=min_len,
|
| 87 |
)
|
| 88 |
|
| 89 |
+
# Post-process to remove repetition if detected
|
| 90 |
+
decoded_list = self.tokenizer.decode_batch(generated.tolist())
|
| 91 |
+
final_summaries = []
|
| 92 |
+
for summary in decoded_list:
|
| 93 |
+
# Simple repetition check: if the string starts with a repeated pattern
|
| 94 |
+
# "TextText" -> "Text" == "Text"
|
| 95 |
+
if len(summary) > 20 and summary[:4] == summary[4:8]:
|
| 96 |
+
final_summaries.append("") # Fallback to empty if garbage
|
| 97 |
+
else:
|
| 98 |
+
final_summaries.append(summary)
|
| 99 |
|
| 100 |
+
return final_summaries
|
| 101 |
|
| 102 |
def predict_emotions(
|
| 103 |
self,
|