OliverPerrin commited on
Commit
374a07d
·
1 Parent(s): b43ba56

Training run: dev config 1 epoch results

Browse files

Results:
- Topic accuracy: 85.76% (Sports 96%, World 87%, Business 81%, Sci/Tech 78%)
- Summarization ROUGE-like: 0.343, BLEU: 0.088
- Emotion F1 macro: 0.356

Updated inference pipeline, demo script, and evaluation outputs

.gitignore CHANGED
@@ -11,6 +11,7 @@ build/
11
 
12
  # Virtual environments
13
  venv/
 
14
  env/
15
  ENV/
16
 
 
11
 
12
  # Virtual environments
13
  venv/
14
+ .venv/
15
  env/
16
  ENV/
17
 
outputs/evaluation_report.json CHANGED
@@ -1,43 +1,43 @@
1
  {
2
  "split": "test",
3
  "summarization": {
4
- "rouge_like": 0.031742493938280825,
5
- "bleu": 0.0008530696741094626
6
  },
7
  "emotion": {
8
- "f1_macro": 0.42543327808380127
9
  },
10
  "topic": {
11
- "accuracy": 0.3325,
12
  "classification_report": {
13
  "Business": {
14
- "precision": 0.24772065955383124,
15
- "recall": 0.6721052631578948,
16
- "f1-score": 0.3620127569099929,
17
  "support": 1900
18
  },
19
  "Sci/Tech": {
20
- "precision": 0.4942170818505338,
21
- "recall": 0.5847368421052631,
22
- "f1-score": 0.5356798457087754,
23
  "support": 1900
24
  },
25
  "Sports": {
26
- "precision": 0.9473684210526315,
27
- "recall": 0.018947368421052633,
28
- "f1-score": 0.03715170278637771,
29
  "support": 1900
30
  },
31
  "World": {
32
- "precision": 0.6477987421383647,
33
- "recall": 0.05421052631578947,
34
- "f1-score": 0.10004856726566294,
35
  "support": 1900
36
  },
37
  "macro avg": {
38
- "precision": 0.5842762261488403,
39
- "recall": 0.3325,
40
- "f1-score": 0.2587232181677022,
41
  "support": 7600
42
  }
43
  }
 
1
  {
2
  "split": "test",
3
  "summarization": {
4
+ "rouge_like": 0.3430426484440944,
5
+ "bleu": 0.0879515124653127
6
  },
7
  "emotion": {
8
+ "f1_macro": 0.3558666706085205
9
  },
10
  "topic": {
11
+ "accuracy": 0.8576315789473684,
12
  "classification_report": {
13
  "Business": {
14
+ "precision": 0.7614165890027959,
15
+ "recall": 0.86,
16
+ "f1-score": 0.8077113198220465,
17
  "support": 1900
18
  },
19
  "Sci/Tech": {
20
+ "precision": 0.8759791122715405,
21
+ "recall": 0.7063157894736842,
22
+ "f1-score": 0.782051282051282,
23
  "support": 1900
24
  },
25
  "Sports": {
26
+ "precision": 0.9454638124362895,
27
+ "recall": 0.9763157894736842,
28
+ "f1-score": 0.9606421543241843,
29
  "support": 1900
30
  },
31
  "World": {
32
+ "precision": 0.8607142857142858,
33
+ "recall": 0.8878947368421053,
34
+ "f1-score": 0.8740932642487047,
35
  "support": 1900
36
  },
37
  "macro avg": {
38
+ "precision": 0.860893449856228,
39
+ "recall": 0.8576315789473684,
40
+ "f1-score": 0.8561245051115545,
41
  "support": 7600
42
  }
43
  }
outputs/training_history.json CHANGED
@@ -1,21 +1,21 @@
1
  {
2
  "train_epoch_1": {
3
- "summarization_loss": 3.6738915424346925,
4
- "summarization_rouge_like": 0.3936604625654161,
5
- "emotion_loss": 0.5655887125730514,
6
- "emotion_f1": 0.02088333384692669,
7
- "topic_loss": 1.2472841796875,
8
- "topic_accuracy": 0.5795,
9
- "total_loss": 5.486764434695244,
10
  "epoch": 1.0
11
  },
12
  "val_epoch_1": {
13
- "summarization_loss": 3.24564736366272,
14
- "summarization_rouge_like": 0.4398922732261946,
15
- "emotion_loss": 0.4284175229072571,
16
  "emotion_f1": 0.0,
17
- "topic_loss": 0.814755859375,
18
- "topic_accuracy": 0.835,
19
  "epoch": 1.0
20
  }
21
  }
 
1
  {
2
  "train_epoch_1": {
3
+ "summarization_loss": 3.67411927986145,
4
+ "summarization_rouge_like": 0.39456057390021504,
5
+ "emotion_loss": 0.5643834336996079,
6
+ "emotion_f1": 0.023809524163603782,
7
+ "topic_loss": 1.2467568359375,
8
+ "topic_accuracy": 0.587,
9
+ "total_loss": 5.485259549498558,
10
  "epoch": 1.0
11
  },
12
  "val_epoch_1": {
13
+ "summarization_loss": 3.2498003482818603,
14
+ "summarization_rouge_like": 0.44230111155579444,
15
+ "emotion_loss": 0.4288424849510193,
16
  "emotion_f1": 0.0,
17
+ "topic_loss": 0.807373046875,
18
+ "topic_accuracy": 0.85,
19
  "epoch": 1.0
20
  }
21
  }
scripts/demo_gradio.py CHANGED
@@ -262,7 +262,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
262
  )
263
 
264
  with torch.inference_mode():
265
- memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
266
  target_enc = pipeline.tokenizer.batch_encode([summary])
267
  target_ids = target_enc["input_ids"].to(pipeline.device)
268
  target_mask = target_enc["attention_mask"].to(pipeline.device)
@@ -271,7 +271,7 @@ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipelin
271
  decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
272
  target_ids = target_ids[:, :target_len]
273
  memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
274
- _, attn_list = pipeline.model.decoder(
275
  decoder_inputs,
276
  memory,
277
  memory_mask=memory_mask,
 
262
  )
263
 
264
  with torch.inference_mode():
265
+ memory = pipeline.model.encoder(src_ids, mask=encoder_mask) # type: ignore
266
  target_enc = pipeline.tokenizer.batch_encode([summary])
267
  target_ids = target_enc["input_ids"].to(pipeline.device)
268
  target_mask = target_enc["attention_mask"].to(pipeline.device)
 
271
  decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
272
  target_ids = target_ids[:, :target_len]
273
  memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
274
+ _, attn_list = pipeline.model.decoder( # type: ignore
275
  decoder_inputs,
276
  memory,
277
  memory_mask=memory_mask,
src/inference/pipeline.py CHANGED
@@ -3,7 +3,7 @@
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, fields, replace
6
- from typing import Iterable, List, Sequence
7
 
8
  import torch
9
  import torch.nn.functional as F
@@ -75,11 +75,14 @@ class InferencePipeline:
75
  "Model must expose encoder and decoder attributes for summarization."
76
  )
77
 
 
 
 
78
  with torch.inference_mode():
79
  encoder_mask = (
80
  src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
81
  )
82
- memory = self.model.encoder(src_ids, mask=encoder_mask)
83
  min_len = 10
84
 
85
  # Ban BOS, PAD, UNK from being generated
@@ -92,7 +95,7 @@ class InferencePipeline:
92
  ban_token_ids.append(unk_id)
93
  ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
94
 
95
- generated = self.model.decoder.greedy_decode(
96
  memory=memory,
97
  max_len=max_len,
98
  start_token_id=self.tokenizer.bos_token_id,
 
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, fields, replace
6
+ from typing import Any, Iterable, List, Sequence, cast
7
 
8
  import torch
9
  import torch.nn.functional as F
 
75
  "Model must expose encoder and decoder attributes for summarization."
76
  )
77
 
78
+ # Cast to Any to allow access to dynamic attributes encoder and decoder
79
+ model = cast(Any, self.model)
80
+
81
  with torch.inference_mode():
82
  encoder_mask = (
83
  src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
84
  )
85
+ memory = model.encoder(src_ids, mask=encoder_mask)
86
  min_len = 10
87
 
88
  # Ban BOS, PAD, UNK from being generated
 
95
  ban_token_ids.append(unk_id)
96
  ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
97
 
98
+ generated = model.decoder.greedy_decode(
99
  memory=memory,
100
  max_len=max_len,
101
  start_token_id=self.tokenizer.bos_token_id,