OliverPerrin commited on
Commit
7aaf14d
·
1 Parent(s): 10efa63

Summarization fix

Browse files
scripts/demo_gradio.py CHANGED
@@ -126,7 +126,8 @@ def predict(text: str, compression: int):
126
  logger.info("Generating summary with max length %s", max_len)
127
 
128
  summary = pipeline.summarize([text], max_length=max_len)[0].strip()
129
- emotions = pipeline.predict_emotions([text])[0]
 
130
  topic = pipeline.predict_topics([text])[0]
131
 
132
  fallback_summary = None
@@ -451,10 +452,16 @@ def load_rouge_metrics():
451
  )
452
 
453
  table = pd.DataFrame(rows, columns=columns) if rows else empty
 
 
 
 
 
 
454
  metadata = {
455
  "num_examples": report.get("num_examples"),
456
  "config": report.get("config"),
457
- "report_path": str(ROUGE_REPORT_PATH),
458
  "last_updated": datetime.fromtimestamp(ROUGE_REPORT_PATH.stat().st_mtime).isoformat(),
459
  }
460
  return table, metadata
 
126
  logger.info("Generating summary with max length %s", max_len)
127
 
128
  summary = pipeline.summarize([text], max_length=max_len)[0].strip()
129
+ # Use a higher threshold to filter out weak/wrong predictions on out-of-domain text
130
+ emotions = pipeline.predict_emotions([text], threshold=0.6)[0]
131
  topic = pipeline.predict_topics([text])[0]
132
 
133
  fallback_summary = None
 
452
  )
453
 
454
  table = pd.DataFrame(rows, columns=columns) if rows else empty
455
+
456
+ # Clean up path for display
457
+ display_path = str(ROUGE_REPORT_PATH)
458
+ if "/app/" in display_path:
459
+ display_path = display_path.replace("/app/", "/LexiMind/")
460
+
461
  metadata = {
462
  "num_examples": report.get("num_examples"),
463
  "config": report.get("config"),
464
+ "report_path": display_path,
465
  "last_updated": datetime.fromtimestamp(ROUGE_REPORT_PATH.stat().st_mtime).isoformat(),
466
  }
467
  return table, metadata
src/inference/pipeline.py CHANGED
@@ -75,8 +75,8 @@ class InferencePipeline:
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
- # Relax min_len to avoid forcing repetition if the model wants to stop
79
- min_len = 0
80
  generated = self.model.decoder.greedy_decode(
81
  memory=memory,
82
  max_len=max_len,
@@ -86,11 +86,18 @@ class InferencePipeline:
86
  min_len=min_len,
87
  )
88
 
89
- # If the first token is EOS, it means empty generation.
90
- # Try forcing a different start token if that happens, or just accept it.
91
- # For now, we just decode.
 
 
 
 
 
 
 
92
 
93
- return self.tokenizer.decode_batch(generated.tolist())
94
 
95
  def predict_emotions(
96
  self,
 
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
+ # Force a minimum length to prevent immediate EOS
79
+ min_len = 10
80
  generated = self.model.decoder.greedy_decode(
81
  memory=memory,
82
  max_len=max_len,
 
86
  min_len=min_len,
87
  )
88
 
89
+ # Post-process to remove repetition if detected
90
+ decoded_list = self.tokenizer.decode_batch(generated.tolist())
91
+ final_summaries = []
92
+ for summary in decoded_list:
93
+ # Simple repetition check: if the string starts with a repeated pattern
94
+ # "TextText" -> "Text" == "Text"
95
+ if len(summary) > 20 and summary[:4] == summary[4:8]:
96
+ final_summaries.append("") # Fallback to empty if garbage
97
+ else:
98
+ final_summaries.append(summary)
99
 
100
+ return final_summaries
101
 
102
  def predict_emotions(
103
  self,