Upload inference/test_forward_pass.py with huggingface_hub
Browse files- inference/test_forward_pass.py +117 -0
inference/test_forward_pass.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test NVFP4 model forward pass.
|
| 4 |
+
|
| 5 |
+
This tests that a single forward pass through the model works correctly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from model import Transformer, ModelArgs
|
| 13 |
+
from generate import load_sharded_model, link_fp8_scales
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_forward_pass():
|
| 17 |
+
"""Test single forward pass through the model."""
|
| 18 |
+
print("\n" + "=" * 70)
|
| 19 |
+
print("NVFP4 Forward Pass Test")
|
| 20 |
+
print("=" * 70)
|
| 21 |
+
print("Testing single forward pass with dummy input")
|
| 22 |
+
print("Expected runtime: 2-10 minutes")
|
| 23 |
+
print("=" * 70 + "\n")
|
| 24 |
+
|
| 25 |
+
# Load config
|
| 26 |
+
print("Loading config...")
|
| 27 |
+
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json"
|
| 28 |
+
with open(config_path) as f:
|
| 29 |
+
args = ModelArgs(**json.load(f))
|
| 30 |
+
print(f" PASS: Config loaded: {args.n_layers} layers, dtype={args.dtype}\n")
|
| 31 |
+
|
| 32 |
+
# Create model
|
| 33 |
+
print("Creating model...")
|
| 34 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 35 |
+
with torch.device("cpu"):
|
| 36 |
+
model = Transformer(args)
|
| 37 |
+
print(f" PASS: Model created\n")
|
| 38 |
+
|
| 39 |
+
# Load weights
|
| 40 |
+
print("Loading weights (this will take several minutes)...")
|
| 41 |
+
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4"
|
| 42 |
+
load_sharded_model(model, ckpt_path)
|
| 43 |
+
print(f" PASS: Weights loaded\n")
|
| 44 |
+
|
| 45 |
+
# Create dummy input
|
| 46 |
+
batch_size = 1
|
| 47 |
+
seq_len = 4
|
| 48 |
+
tokens = torch.randint(0, args.vocab_size, (batch_size, seq_len), device="cpu")
|
| 49 |
+
|
| 50 |
+
print(f"Running forward pass...")
|
| 51 |
+
print(f" Input: {tokens.shape}, tokens={tokens[0].tolist()}")
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
with torch.inference_mode():
|
| 55 |
+
logits = model(tokens, start_pos=0)
|
| 56 |
+
|
| 57 |
+
print(f" Output: {logits.shape}, dtype={logits.dtype}")
|
| 58 |
+
|
| 59 |
+
# Verify output shape
|
| 60 |
+
expected_shape = (batch_size, args.vocab_size)
|
| 61 |
+
assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}"
|
| 62 |
+
print(f" PASS: Output shape correct: {logits.shape}")
|
| 63 |
+
|
| 64 |
+
# Check for NaN/Inf
|
| 65 |
+
has_nan = torch.isnan(logits).any().item()
|
| 66 |
+
has_inf = torch.isinf(logits).any().item()
|
| 67 |
+
|
| 68 |
+
if has_nan:
|
| 69 |
+
print(f" FAIL: ERROR: Output contains NaN!")
|
| 70 |
+
return 1
|
| 71 |
+
if has_inf:
|
| 72 |
+
print(f" FAIL: ERROR: Output contains Inf!")
|
| 73 |
+
return 1
|
| 74 |
+
|
| 75 |
+
print(f" PASS: No NaN/Inf in output")
|
| 76 |
+
|
| 77 |
+
# Check output statistics
|
| 78 |
+
logits_min = logits.min().item()
|
| 79 |
+
logits_max = logits.max().item()
|
| 80 |
+
logits_mean = logits.mean().item()
|
| 81 |
+
|
| 82 |
+
print(f"\n Output statistics:")
|
| 83 |
+
print(f" Min: {logits_min:.3f}")
|
| 84 |
+
print(f" Max: {logits_max:.3f}")
|
| 85 |
+
print(f" Mean: {logits_mean:.3f}")
|
| 86 |
+
|
| 87 |
+
# Sanity check: logits should be roughly in [-20, 20] range
|
| 88 |
+
if abs(logits_mean) > 50 or logits_max > 100 or logits_min < -100:
|
| 89 |
+
print(f" WARN: WARNING: Logits have unusual range (possible issue)")
|
| 90 |
+
else:
|
| 91 |
+
print(f" PASS: Logits in reasonable range")
|
| 92 |
+
|
| 93 |
+
# Test top predictions
|
| 94 |
+
top_k = 5
|
| 95 |
+
top_logits, top_indices = logits[0].topk(top_k)
|
| 96 |
+
print(f"\n Top {top_k} predictions:")
|
| 97 |
+
for i, (logit, idx) in enumerate(zip(top_logits, top_indices)):
|
| 98 |
+
print(f" {i+1}. Token {idx.item()}: logit={logit.item():.3f}")
|
| 99 |
+
|
| 100 |
+
print("\n" + "=" * 70)
|
| 101 |
+
print("PASS: FORWARD PASS TEST PASSED")
|
| 102 |
+
print("=" * 70)
|
| 103 |
+
print("Forward pass completed successfully!")
|
| 104 |
+
print("Model is producing valid outputs.")
|
| 105 |
+
print("=" * 70)
|
| 106 |
+
|
| 107 |
+
return 0
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"\nFAIL: FORWARD PASS FAILED: {e}")
|
| 111 |
+
import traceback
|
| 112 |
+
traceback.print_exc()
|
| 113 |
+
return 1
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
sys.exit(test_forward_pass())
|