DeepSeek-V3.2-NVFP4 / inference /test_forward_pass.py
eousphoros's picture
Upload inference/test_forward_pass.py with huggingface_hub
bc09fc1 verified
#!/usr/bin/env python3
"""
Test NVFP4 model forward pass.
This tests that a single forward pass through the model works correctly.
"""
import sys
import json
import torch
from model import Transformer, ModelArgs
from generate import load_sharded_model, link_fp8_scales
def test_forward_pass():
"""Test single forward pass through the model."""
print("\n" + "=" * 70)
print("NVFP4 Forward Pass Test")
print("=" * 70)
print("Testing single forward pass with dummy input")
print("Expected runtime: 2-10 minutes")
print("=" * 70 + "\n")
# Load config
print("Loading config...")
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json"
with open(config_path) as f:
args = ModelArgs(**json.load(f))
print(f" PASS: Config loaded: {args.n_layers} layers, dtype={args.dtype}\n")
# Create model
print("Creating model...")
torch.set_default_dtype(torch.bfloat16)
with torch.device("cpu"):
model = Transformer(args)
print(f" PASS: Model created\n")
# Load weights
print("Loading weights (this will take several minutes)...")
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4"
load_sharded_model(model, ckpt_path)
print(f" PASS: Weights loaded\n")
# Create dummy input
batch_size = 1
seq_len = 4
tokens = torch.randint(0, args.vocab_size, (batch_size, seq_len), device="cpu")
print(f"Running forward pass...")
print(f" Input: {tokens.shape}, tokens={tokens[0].tolist()}")
try:
with torch.inference_mode():
logits = model(tokens, start_pos=0)
print(f" Output: {logits.shape}, dtype={logits.dtype}")
# Verify output shape
expected_shape = (batch_size, args.vocab_size)
assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}"
print(f" PASS: Output shape correct: {logits.shape}")
# Check for NaN/Inf
has_nan = torch.isnan(logits).any().item()
has_inf = torch.isinf(logits).any().item()
if has_nan:
print(f" FAIL: ERROR: Output contains NaN!")
return 1
if has_inf:
print(f" FAIL: ERROR: Output contains Inf!")
return 1
print(f" PASS: No NaN/Inf in output")
# Check output statistics
logits_min = logits.min().item()
logits_max = logits.max().item()
logits_mean = logits.mean().item()
print(f"\n Output statistics:")
print(f" Min: {logits_min:.3f}")
print(f" Max: {logits_max:.3f}")
print(f" Mean: {logits_mean:.3f}")
# Sanity check: logits should be roughly in [-20, 20] range
if abs(logits_mean) > 50 or logits_max > 100 or logits_min < -100:
print(f" WARN: WARNING: Logits have unusual range (possible issue)")
else:
print(f" PASS: Logits in reasonable range")
# Test top predictions
top_k = 5
top_logits, top_indices = logits[0].topk(top_k)
print(f"\n Top {top_k} predictions:")
for i, (logit, idx) in enumerate(zip(top_logits, top_indices)):
print(f" {i+1}. Token {idx.item()}: logit={logit.item():.3f}")
print("\n" + "=" * 70)
print("PASS: FORWARD PASS TEST PASSED")
print("=" * 70)
print("Forward pass completed successfully!")
print("Model is producing valid outputs.")
print("=" * 70)
return 0
except Exception as e:
print(f"\nFAIL: FORWARD PASS FAILED: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(test_forward_pass())