Spaces:
Running
Running
OliverPerrin
commited on
Commit
·
374a07d
1
Parent(s):
b43ba56
Training run: dev config 1 epoch results
Browse filesResults:
- Topic accuracy: 85.76% (Sports 96%, World 87%, Business 81%, Sci/Tech 78%)
- Summarization ROUGE-like: 0.343, BLEU: 0.088
- Emotion F1 macro: 0.356
Updated inference pipeline, demo script, and evaluation outputs
- .gitignore +1 -0
- outputs/evaluation_report.json +19 -19
- outputs/training_history.json +12 -12
- scripts/demo_gradio.py +2 -2
- src/inference/pipeline.py +6 -3
.gitignore
CHANGED
|
@@ -11,6 +11,7 @@ build/
|
|
| 11 |
|
| 12 |
# Virtual environments
|
| 13 |
venv/
|
|
|
|
| 14 |
env/
|
| 15 |
ENV/
|
| 16 |
|
|
|
|
| 11 |
|
| 12 |
# Virtual environments
|
| 13 |
venv/
|
| 14 |
+
.venv/
|
| 15 |
env/
|
| 16 |
ENV/
|
| 17 |
|
outputs/evaluation_report.json
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
{
|
| 2 |
"split": "test",
|
| 3 |
"summarization": {
|
| 4 |
-
"rouge_like": 0.
|
| 5 |
-
"bleu": 0.
|
| 6 |
},
|
| 7 |
"emotion": {
|
| 8 |
-
"f1_macro": 0.
|
| 9 |
},
|
| 10 |
"topic": {
|
| 11 |
-
"accuracy": 0.
|
| 12 |
"classification_report": {
|
| 13 |
"Business": {
|
| 14 |
-
"precision": 0.
|
| 15 |
-
"recall": 0.
|
| 16 |
-
"f1-score": 0.
|
| 17 |
"support": 1900
|
| 18 |
},
|
| 19 |
"Sci/Tech": {
|
| 20 |
-
"precision": 0.
|
| 21 |
-
"recall": 0.
|
| 22 |
-
"f1-score": 0.
|
| 23 |
"support": 1900
|
| 24 |
},
|
| 25 |
"Sports": {
|
| 26 |
-
"precision": 0.
|
| 27 |
-
"recall": 0.
|
| 28 |
-
"f1-score": 0.
|
| 29 |
"support": 1900
|
| 30 |
},
|
| 31 |
"World": {
|
| 32 |
-
"precision": 0.
|
| 33 |
-
"recall": 0.
|
| 34 |
-
"f1-score": 0.
|
| 35 |
"support": 1900
|
| 36 |
},
|
| 37 |
"macro avg": {
|
| 38 |
-
"precision": 0.
|
| 39 |
-
"recall": 0.
|
| 40 |
-
"f1-score": 0.
|
| 41 |
"support": 7600
|
| 42 |
}
|
| 43 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"split": "test",
|
| 3 |
"summarization": {
|
| 4 |
+
"rouge_like": 0.3430426484440944,
|
| 5 |
+
"bleu": 0.0879515124653127
|
| 6 |
},
|
| 7 |
"emotion": {
|
| 8 |
+
"f1_macro": 0.3558666706085205
|
| 9 |
},
|
| 10 |
"topic": {
|
| 11 |
+
"accuracy": 0.8576315789473684,
|
| 12 |
"classification_report": {
|
| 13 |
"Business": {
|
| 14 |
+
"precision": 0.7614165890027959,
|
| 15 |
+
"recall": 0.86,
|
| 16 |
+
"f1-score": 0.8077113198220465,
|
| 17 |
"support": 1900
|
| 18 |
},
|
| 19 |
"Sci/Tech": {
|
| 20 |
+
"precision": 0.8759791122715405,
|
| 21 |
+
"recall": 0.7063157894736842,
|
| 22 |
+
"f1-score": 0.782051282051282,
|
| 23 |
"support": 1900
|
| 24 |
},
|
| 25 |
"Sports": {
|
| 26 |
+
"precision": 0.9454638124362895,
|
| 27 |
+
"recall": 0.9763157894736842,
|
| 28 |
+
"f1-score": 0.9606421543241843,
|
| 29 |
"support": 1900
|
| 30 |
},
|
| 31 |
"World": {
|
| 32 |
+
"precision": 0.8607142857142858,
|
| 33 |
+
"recall": 0.8878947368421053,
|
| 34 |
+
"f1-score": 0.8740932642487047,
|
| 35 |
"support": 1900
|
| 36 |
},
|
| 37 |
"macro avg": {
|
| 38 |
+
"precision": 0.860893449856228,
|
| 39 |
+
"recall": 0.8576315789473684,
|
| 40 |
+
"f1-score": 0.8561245051115545,
|
| 41 |
"support": 7600
|
| 42 |
}
|
| 43 |
}
|
outputs/training_history.json
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss": 3.
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"emotion_loss": 0.
|
| 6 |
-
"emotion_f1": 0.
|
| 7 |
-
"topic_loss": 1.
|
| 8 |
-
"topic_accuracy": 0.
|
| 9 |
-
"total_loss": 5.
|
| 10 |
"epoch": 1.0
|
| 11 |
},
|
| 12 |
"val_epoch_1": {
|
| 13 |
-
"summarization_loss": 3.
|
| 14 |
-
"summarization_rouge_like": 0.
|
| 15 |
-
"emotion_loss": 0.
|
| 16 |
"emotion_f1": 0.0,
|
| 17 |
-
"topic_loss": 0.
|
| 18 |
-
"topic_accuracy": 0.
|
| 19 |
"epoch": 1.0
|
| 20 |
}
|
| 21 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 3.67411927986145,
|
| 4 |
+
"summarization_rouge_like": 0.39456057390021504,
|
| 5 |
+
"emotion_loss": 0.5643834336996079,
|
| 6 |
+
"emotion_f1": 0.023809524163603782,
|
| 7 |
+
"topic_loss": 1.2467568359375,
|
| 8 |
+
"topic_accuracy": 0.587,
|
| 9 |
+
"total_loss": 5.485259549498558,
|
| 10 |
"epoch": 1.0
|
| 11 |
},
|
| 12 |
"val_epoch_1": {
|
| 13 |
+
"summarization_loss": 3.2498003482818603,
|
| 14 |
+
"summarization_rouge_like": 0.44230111155579444,
|
| 15 |
+
"emotion_loss": 0.4288424849510193,
|
| 16 |
"emotion_f1": 0.0,
|
| 17 |
+
"topic_loss": 0.807373046875,
|
| 18 |
+
"topic_accuracy": 0.85,
|
| 19 |
"epoch": 1.0
|
| 20 |
}
|
| 21 |
}
|
scripts/demo_gradio.py
CHANGED
|
@@ -262,7 +262,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 262 |
)
|
| 263 |
|
| 264 |
with torch.inference_mode():
|
| 265 |
-
memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
|
| 266 |
target_enc = pipeline.tokenizer.batch_encode([summary])
|
| 267 |
target_ids = target_enc["input_ids"].to(pipeline.device)
|
| 268 |
target_mask = target_enc["attention_mask"].to(pipeline.device)
|
|
@@ -271,7 +271,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
|
|
| 271 |
decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
|
| 272 |
target_ids = target_ids[:, :target_len]
|
| 273 |
memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
|
| 274 |
-
_, attn_list = pipeline.model.decoder(
|
| 275 |
decoder_inputs,
|
| 276 |
memory,
|
| 277 |
memory_mask=memory_mask,
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
with torch.inference_mode():
|
| 265 |
+
memory = pipeline.model.encoder(src_ids, mask=encoder_mask) # type: ignore
|
| 266 |
target_enc = pipeline.tokenizer.batch_encode([summary])
|
| 267 |
target_ids = target_enc["input_ids"].to(pipeline.device)
|
| 268 |
target_mask = target_enc["attention_mask"].to(pipeline.device)
|
|
|
|
| 271 |
decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
|
| 272 |
target_ids = target_ids[:, :target_len]
|
| 273 |
memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
|
| 274 |
+
_, attn_list = pipeline.model.decoder( # type: ignore
|
| 275 |
decoder_inputs,
|
| 276 |
memory,
|
| 277 |
memory_mask=memory_mask,
|
src/inference/pipeline.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass, fields, replace
|
| 6 |
-
from typing import Iterable, List, Sequence
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
|
@@ -75,11 +75,14 @@ class InferencePipeline:
|
|
| 75 |
"Model must expose encoder and decoder attributes for summarization."
|
| 76 |
)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
with torch.inference_mode():
|
| 79 |
encoder_mask = (
|
| 80 |
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 81 |
)
|
| 82 |
-
memory =
|
| 83 |
min_len = 10
|
| 84 |
|
| 85 |
# Ban BOS, PAD, UNK from being generated
|
|
@@ -92,7 +95,7 @@ class InferencePipeline:
|
|
| 92 |
ban_token_ids.append(unk_id)
|
| 93 |
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 94 |
|
| 95 |
-
generated =
|
| 96 |
memory=memory,
|
| 97 |
max_len=max_len,
|
| 98 |
start_token_id=self.tokenizer.bos_token_id,
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass, fields, replace
|
| 6 |
+
from typing import Any, Iterable, List, Sequence, cast
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 75 |
"Model must expose encoder and decoder attributes for summarization."
|
| 76 |
)
|
| 77 |
|
| 78 |
+
# Cast to Any to allow access to dynamic attributes encoder and decoder
|
| 79 |
+
model = cast(Any, self.model)
|
| 80 |
+
|
| 81 |
with torch.inference_mode():
|
| 82 |
encoder_mask = (
|
| 83 |
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 84 |
)
|
| 85 |
+
memory = model.encoder(src_ids, mask=encoder_mask)
|
| 86 |
min_len = 10
|
| 87 |
|
| 88 |
# Ban BOS, PAD, UNK from being generated
|
|
|
|
| 95 |
ban_token_ids.append(unk_id)
|
| 96 |
ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
|
| 97 |
|
| 98 |
+
generated = model.decoder.greedy_decode(
|
| 99 |
memory=memory,
|
| 100 |
max_len=max_len,
|
| 101 |
start_token_id=self.tokenizer.bos_token_id,
|