|
|
|
|
|
""" |
|
|
Test NVFP4 model forward pass. |
|
|
|
|
|
This tests that a single forward pass through the model works correctly. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import json |
|
|
import torch |
|
|
|
|
|
from model import Transformer, ModelArgs |
|
|
from generate import load_sharded_model, link_fp8_scales |
|
|
|
|
|
|
|
|
def test_forward_pass(): |
|
|
"""Test single forward pass through the model.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("NVFP4 Forward Pass Test") |
|
|
print("=" * 70) |
|
|
print("Testing single forward pass with dummy input") |
|
|
print("Expected runtime: 2-10 minutes") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
|
|
|
print("Loading config...") |
|
|
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" |
|
|
with open(config_path) as f: |
|
|
args = ModelArgs(**json.load(f)) |
|
|
print(f" PASS: Config loaded: {args.n_layers} layers, dtype={args.dtype}\n") |
|
|
|
|
|
|
|
|
print("Creating model...") |
|
|
torch.set_default_dtype(torch.bfloat16) |
|
|
with torch.device("cpu"): |
|
|
model = Transformer(args) |
|
|
print(f" PASS: Model created\n") |
|
|
|
|
|
|
|
|
print("Loading weights (this will take several minutes)...") |
|
|
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" |
|
|
load_sharded_model(model, ckpt_path) |
|
|
print(f" PASS: Weights loaded\n") |
|
|
|
|
|
|
|
|
batch_size = 1 |
|
|
seq_len = 4 |
|
|
tokens = torch.randint(0, args.vocab_size, (batch_size, seq_len), device="cpu") |
|
|
|
|
|
print(f"Running forward pass...") |
|
|
print(f" Input: {tokens.shape}, tokens={tokens[0].tolist()}") |
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
logits = model(tokens, start_pos=0) |
|
|
|
|
|
print(f" Output: {logits.shape}, dtype={logits.dtype}") |
|
|
|
|
|
|
|
|
expected_shape = (batch_size, args.vocab_size) |
|
|
assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}" |
|
|
print(f" PASS: Output shape correct: {logits.shape}") |
|
|
|
|
|
|
|
|
has_nan = torch.isnan(logits).any().item() |
|
|
has_inf = torch.isinf(logits).any().item() |
|
|
|
|
|
if has_nan: |
|
|
print(f" FAIL: ERROR: Output contains NaN!") |
|
|
return 1 |
|
|
if has_inf: |
|
|
print(f" FAIL: ERROR: Output contains Inf!") |
|
|
return 1 |
|
|
|
|
|
print(f" PASS: No NaN/Inf in output") |
|
|
|
|
|
|
|
|
logits_min = logits.min().item() |
|
|
logits_max = logits.max().item() |
|
|
logits_mean = logits.mean().item() |
|
|
|
|
|
print(f"\n Output statistics:") |
|
|
print(f" Min: {logits_min:.3f}") |
|
|
print(f" Max: {logits_max:.3f}") |
|
|
print(f" Mean: {logits_mean:.3f}") |
|
|
|
|
|
|
|
|
if abs(logits_mean) > 50 or logits_max > 100 or logits_min < -100: |
|
|
print(f" WARN: WARNING: Logits have unusual range (possible issue)") |
|
|
else: |
|
|
print(f" PASS: Logits in reasonable range") |
|
|
|
|
|
|
|
|
top_k = 5 |
|
|
top_logits, top_indices = logits[0].topk(top_k) |
|
|
print(f"\n Top {top_k} predictions:") |
|
|
for i, (logit, idx) in enumerate(zip(top_logits, top_indices)): |
|
|
print(f" {i+1}. Token {idx.item()}: logit={logit.item():.3f}") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("PASS: FORWARD PASS TEST PASSED") |
|
|
print("=" * 70) |
|
|
print("Forward pass completed successfully!") |
|
|
print("Model is producing valid outputs.") |
|
|
print("=" * 70) |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nFAIL: FORWARD PASS FAILED: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(test_forward_pass()) |
|
|
|