eousphoros commited on
Commit
bc09fc1
·
verified ·
1 Parent(s): 2720e07

Upload inference/test_forward_pass.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/test_forward_pass.py +117 -0
inference/test_forward_pass.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test NVFP4 model forward pass.
4
+
5
+ This tests that a single forward pass through the model works correctly.
6
+ """
7
+
8
+ import sys
9
+ import json
10
+ import torch
11
+
12
+ from model import Transformer, ModelArgs
13
+ from generate import load_sharded_model, link_fp8_scales
14
+
15
+
16
+ def test_forward_pass():
17
+ """Test single forward pass through the model."""
18
+ print("\n" + "=" * 70)
19
+ print("NVFP4 Forward Pass Test")
20
+ print("=" * 70)
21
+ print("Testing single forward pass with dummy input")
22
+ print("Expected runtime: 2-10 minutes")
23
+ print("=" * 70 + "\n")
24
+
25
+ # Load config
26
+ print("Loading config...")
27
+ config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json"
28
+ with open(config_path) as f:
29
+ args = ModelArgs(**json.load(f))
30
+ print(f" PASS: Config loaded: {args.n_layers} layers, dtype={args.dtype}\n")
31
+
32
+ # Create model
33
+ print("Creating model...")
34
+ torch.set_default_dtype(torch.bfloat16)
35
+ with torch.device("cpu"):
36
+ model = Transformer(args)
37
+ print(f" PASS: Model created\n")
38
+
39
+ # Load weights
40
+ print("Loading weights (this will take several minutes)...")
41
+ ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4"
42
+ load_sharded_model(model, ckpt_path)
43
+ print(f" PASS: Weights loaded\n")
44
+
45
+ # Create dummy input
46
+ batch_size = 1
47
+ seq_len = 4
48
+ tokens = torch.randint(0, args.vocab_size, (batch_size, seq_len), device="cpu")
49
+
50
+ print(f"Running forward pass...")
51
+ print(f" Input: {tokens.shape}, tokens={tokens[0].tolist()}")
52
+
53
+ try:
54
+ with torch.inference_mode():
55
+ logits = model(tokens, start_pos=0)
56
+
57
+ print(f" Output: {logits.shape}, dtype={logits.dtype}")
58
+
59
+ # Verify output shape
60
+ expected_shape = (batch_size, args.vocab_size)
61
+ assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}"
62
+ print(f" PASS: Output shape correct: {logits.shape}")
63
+
64
+ # Check for NaN/Inf
65
+ has_nan = torch.isnan(logits).any().item()
66
+ has_inf = torch.isinf(logits).any().item()
67
+
68
+ if has_nan:
69
+ print(f" FAIL: ERROR: Output contains NaN!")
70
+ return 1
71
+ if has_inf:
72
+ print(f" FAIL: ERROR: Output contains Inf!")
73
+ return 1
74
+
75
+ print(f" PASS: No NaN/Inf in output")
76
+
77
+ # Check output statistics
78
+ logits_min = logits.min().item()
79
+ logits_max = logits.max().item()
80
+ logits_mean = logits.mean().item()
81
+
82
+ print(f"\n Output statistics:")
83
+ print(f" Min: {logits_min:.3f}")
84
+ print(f" Max: {logits_max:.3f}")
85
+ print(f" Mean: {logits_mean:.3f}")
86
+
87
+ # Sanity check: logits should be roughly in [-20, 20] range
88
+ if abs(logits_mean) > 50 or logits_max > 100 or logits_min < -100:
89
+ print(f" WARN: WARNING: Logits have unusual range (possible issue)")
90
+ else:
91
+ print(f" PASS: Logits in reasonable range")
92
+
93
+ # Test top predictions
94
+ top_k = 5
95
+ top_logits, top_indices = logits[0].topk(top_k)
96
+ print(f"\n Top {top_k} predictions:")
97
+ for i, (logit, idx) in enumerate(zip(top_logits, top_indices)):
98
+ print(f" {i+1}. Token {idx.item()}: logit={logit.item():.3f}")
99
+
100
+ print("\n" + "=" * 70)
101
+ print("PASS: FORWARD PASS TEST PASSED")
102
+ print("=" * 70)
103
+ print("Forward pass completed successfully!")
104
+ print("Model is producing valid outputs.")
105
+ print("=" * 70)
106
+
107
+ return 0
108
+
109
+ except Exception as e:
110
+ print(f"\nFAIL: FORWARD PASS FAILED: {e}")
111
+ import traceback
112
+ traceback.print_exc()
113
+ return 1
114
+
115
+
116
+ if __name__ == "__main__":
117
+ sys.exit(test_forward_pass())