OliverPerrin commited on
Commit
ea3248a
·
1 Parent(s): 7aaf14d

Summarization fix

Browse files
debug_heads.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ from pathlib import Path
4
+ import torch
5
+ import logging
6
+
7
+ # Add project root to path
8
+ PROJECT_ROOT = Path(__file__).resolve().parent
9
+ sys.path.insert(0, str(PROJECT_ROOT))
10
+
11
+ from src.models.factory import ModelConfig
12
+ from src.data.tokenization import Tokenizer, TokenizerConfig
13
+ from src.models.factory import build_multitask_model
14
+ from src.utils.io import load_state
15
+ from src.utils.labels import load_label_metadata
16
+ from src.inference.pipeline import InferencePipeline, InferenceConfig
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def debug_pipeline():
22
+ labels = load_label_metadata("artifacts/labels.json")
23
+ tokenizer = Tokenizer(TokenizerConfig(pretrained_model_name="artifacts/hf_tokenizer"))
24
+
25
+ for heads in [4, 8, 16]:
26
+ print(f"\n============================================")
27
+ print(f"Testing num_heads={heads}")
28
+ print(f"============================================")
29
+ try:
30
+ cfg = ModelConfig(num_attention_heads=heads)
31
+ model = build_multitask_model(
32
+ tokenizer,
33
+ num_emotions=labels.emotion_size,
34
+ num_topics=labels.topic_size,
35
+ config=cfg,
36
+ )
37
+ load_state(model, "checkpoints/best.pt")
38
+
39
+ # Tie weights (as per my previous fix)
40
+ if hasattr(model.decoder, "output_projection") and hasattr(model.decoder, "embedding"):
41
+ model.decoder.output_projection.weight = model.decoder.embedding.weight
42
+
43
+ pipeline = InferencePipeline(
44
+ model=model,
45
+ tokenizer=tokenizer,
46
+ config=InferenceConfig(device="cpu"),
47
+ emotion_labels=labels.emotion,
48
+ topic_labels=labels.topic,
49
+ device="cpu"
50
+ )
51
+
52
+ text = "Artificial intelligence is rapidly transforming the technology landscape."
53
+ summary = pipeline.summarize([text], max_length=20)
54
+ print(f"Summary: '{summary[0]}'")
55
+
56
+ except Exception as e:
57
+ print(f"Error: {e}")
58
+
59
+ if __name__ == "__main__":
60
+ debug_pipeline()
debug_summarization.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ from pathlib import Path
4
+ import torch
5
+ import logging
6
+
7
+ # Add project root to path
8
+ PROJECT_ROOT = Path(__file__).resolve().parent
9
+ sys.path.insert(0, str(PROJECT_ROOT))
10
+
11
+ from src.inference.factory import create_inference_pipeline
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def debug_pipeline():
17
+ print("Loading pipeline...")
18
+ pipeline, _ = create_inference_pipeline(
19
+ tokenizer_dir="artifacts/hf_tokenizer/",
20
+ checkpoint_path="checkpoints/best.pt",
21
+ labels_path="artifacts/labels.json",
22
+ )
23
+
24
+ tokenizer = pipeline.tokenizer
25
+ print(f"BOS ID: {tokenizer.bos_token_id}")
26
+ print(f"EOS ID: {tokenizer.eos_token_id}")
27
+ print(f"PAD ID: {tokenizer.pad_token_id}")
28
+
29
+ text = "Artificial intelligence is rapidly transforming the technology landscape."
30
+
31
+ print("\n--- Input Analysis ---")
32
+ encoded = tokenizer.encode(text)
33
+ print(f"Encoded input: {encoded}")
34
+ print(f"Decoded input: {tokenizer.decode(encoded)}")
35
+
36
+ print("\n--- Model Generation Debug ---")
37
+ # Manually run the summarization steps
38
+ batch = pipeline.preprocessor.batch_encode([text])
39
+ batch = pipeline._batch_to_device(batch)
40
+
41
+ src_ids = batch.input_ids
42
+ src_mask = batch.attention_mask
43
+
44
+ print(f"Source IDs shape: {src_ids.shape}")
45
+ print(f"Source IDs: {src_ids}")
46
+
47
+ with torch.inference_mode():
48
+ encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
49
+ memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
50
+
51
+ # Try decoding with BOS as start
52
+ print("\n--- Decoding with BOS start ---")
53
+ generated_bos = pipeline.model.decoder.greedy_decode(
54
+ memory=memory,
55
+ max_len=20,
56
+ start_token_id=tokenizer.bos_token_id,
57
+ end_token_id=tokenizer.eos_token_id,
58
+ device=pipeline.device,
59
+ min_len=0
60
+ )
61
+ print(f"Generated IDs (BOS start): {generated_bos.tolist()}")
62
+ print(f"Decoded (BOS start): {tokenizer.decode_batch(generated_bos.tolist())}")
63
+
64
+ # Try decoding with [BOS, FirstContentToken] start
65
+ print("\n--- Decoding with [BOS, FirstContentToken] start ---")
66
+ bos_id = tokenizer.bos_token_id
67
+ first_content_id = src_ids[0, 1] # Skip BOS in input
68
+ print(f"First content token ID: {first_content_id} ({tokenizer.decode([first_content_id])})")
69
+
70
+ generated = torch.tensor([[bos_id, first_content_id]], dtype=torch.long, device=pipeline.device)
71
+
72
+ for _ in range(20):
73
+ logits = pipeline.model.decoder.forward(generated, memory, collect_attn=False)
74
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
75
+ generated = torch.cat([generated, next_token], dim=1)
76
+ if next_token.item() == tokenizer.eos_token_id:
77
+ break
78
+
79
+ print(f"Generated IDs ([BOS, Content] start): {generated.tolist()}")
80
+ print(f"Decoded ([BOS, Content] start): {tokenizer.decode_batch(generated.tolist())}")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ debug_pipeline()
inspect_checkpoint.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ def inspect_checkpoint():
7
+ path = "checkpoints/best.pt"
8
+ print(f"Loading {path}...")
9
+ try:
10
+ state_dict = torch.load(path, map_location="cpu", weights_only=True)
11
+ print(f"Keys found: {len(state_dict)}")
12
+
13
+ print("\n--- Head Keys ---")
14
+ head_keys = [k for k in state_dict.keys() if "head" in k]
15
+ for k in sorted(head_keys):
16
+ print(k)
17
+
18
+ print("\n--- Decoder Keys (Sample) ---")
19
+ decoder_keys = [k for k in state_dict.keys() if "decoder" in k][:10]
20
+ for k in sorted(decoder_keys):
21
+ print(k)
22
+
23
+ print("\n--- Checking for Cross Attention ---")
24
+ if "decoder.layers.0.cross_attn.W_Q.weight" in state_dict:
25
+ print("Found decoder.layers.0.cross_attn.W_Q.weight")
26
+ else:
27
+ print("MISSING decoder.layers.0.cross_attn.W_Q.weight")
28
+
29
+ except Exception as e:
30
+ print(f"Failed to load: {e}")
31
+
32
+ if __name__ == "__main__":
33
+ inspect_checkpoint()
src/inference/pipeline.py CHANGED
@@ -77,6 +77,16 @@ class InferencePipeline:
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
  # Force a minimum length to prevent immediate EOS
79
  min_len = 10
 
 
 
 
 
 
 
 
 
 
80
  generated = self.model.decoder.greedy_decode(
81
  memory=memory,
82
  max_len=max_len,
@@ -84,6 +94,8 @@ class InferencePipeline:
84
  end_token_id=self.tokenizer.eos_token_id,
85
  device=self.device,
86
  min_len=min_len,
 
 
87
  )
88
 
89
  # Post-process to remove repetition if detected
 
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
  # Force a minimum length to prevent immediate EOS
79
  min_len = 10
80
+
81
+ # Ban BOS, PAD, UNK from being generated
82
+ ban_token_ids = [
83
+ self.tokenizer.bos_token_id,
84
+ self.tokenizer.pad_token_id,
85
+ self.tokenizer.tokenizer.unk_token_id
86
+ ]
87
+ # Filter out None values just in case
88
+ ban_token_ids = [tid for tid in ban_token_ids if tid is not None]
89
+
90
  generated = self.model.decoder.greedy_decode(
91
  memory=memory,
92
  max_len=max_len,
 
94
  end_token_id=self.tokenizer.eos_token_id,
95
  device=self.device,
96
  min_len=min_len,
97
+ ban_token_ids=ban_token_ids,
98
+ no_repeat_ngram_size=3,
99
  )
100
 
101
  # Post-process to remove repetition if detected
src/models/decoder.py CHANGED
@@ -221,6 +221,8 @@ class TransformerDecoder(nn.Module):
221
  device: Optional[torch.device] = None,
222
  *,
223
  min_len: Optional[int] = None,
 
 
224
  ) -> torch.Tensor:
225
  """
226
  Naive greedy decoding: repeatedly run the decoder on the growing prefix.
@@ -237,9 +239,52 @@ class TransformerDecoder(nn.Module):
237
  logits = self.forward(generated, memory, collect_attn=False) # (B, L, V)
238
  assert isinstance(logits, torch.Tensor) # type narrowing
239
  next_step_logits = logits[:, -1, :]
 
 
 
240
  if end_token_id is not None and generated.size(1) < max(1, min_len):
 
 
 
 
 
 
 
 
 
 
241
  next_step_logits = next_step_logits.clone()
 
 
242
  next_step_logits[:, end_token_id] = float("-inf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
244
  generated = torch.cat([generated, next_token], dim=1)
245
 
 
221
  device: Optional[torch.device] = None,
222
  *,
223
  min_len: Optional[int] = None,
224
+ ban_token_ids: Optional[List[int]] = None,
225
+ no_repeat_ngram_size: int = 0,
226
  ) -> torch.Tensor:
227
  """
228
  Naive greedy decoding: repeatedly run the decoder on the growing prefix.
 
239
  logits = self.forward(generated, memory, collect_attn=False) # (B, L, V)
240
  assert isinstance(logits, torch.Tensor) # type narrowing
241
  next_step_logits = logits[:, -1, :]
242
+
243
+ # Apply constraints (min_len or ban_token_ids)
244
+ should_clone = False
245
  if end_token_id is not None and generated.size(1) < max(1, min_len):
246
+ should_clone = True
247
+ if ban_token_ids:
248
+ should_clone = True
249
+
250
+ # Check for n-gram repetition
251
+ if no_repeat_ngram_size > 0:
252
+ # We might need to clone if we find something to ban
253
+ pass
254
+
255
+ if should_clone:
256
  next_step_logits = next_step_logits.clone()
257
+
258
+ if end_token_id is not None and generated.size(1) < max(1, min_len):
259
  next_step_logits[:, end_token_id] = float("-inf")
260
+
261
+ if ban_token_ids:
262
+ next_step_logits[:, ban_token_ids] = float("-inf")
263
+
264
+ if no_repeat_ngram_size > 0:
265
+ # Calculate banned tokens based on n-grams
266
+ for b in range(B):
267
+ gen_seq = generated[b].tolist()
268
+ if len(gen_seq) < no_repeat_ngram_size - 1:
269
+ continue
270
+
271
+ prefix = tuple(gen_seq[-(no_repeat_ngram_size - 1):])
272
+ banned_for_this_batch = set()
273
+
274
+ # Scan history for prefix
275
+ for i in range(len(gen_seq) - no_repeat_ngram_size + 1):
276
+ window = tuple(gen_seq[i : i + no_repeat_ngram_size - 1])
277
+ if window == prefix:
278
+ # The token that followed this instance of prefix
279
+ if i + no_repeat_ngram_size - 1 < len(gen_seq):
280
+ banned_for_this_batch.add(gen_seq[i + no_repeat_ngram_size - 1])
281
+
282
+ if banned_for_this_batch:
283
+ if not should_clone:
284
+ next_step_logits = next_step_logits.clone()
285
+ should_clone = True
286
+ next_step_logits[b, list(banned_for_this_batch)] = float("-inf")
287
+
288
  next_token = next_step_logits.argmax(dim=-1, keepdim=True) # (B, 1)
289
  generated = torch.cat([generated, next_token], dim=1)
290