#!/usr/bin/env python3 """ Test NVFP4 model forward pass. This tests that a single forward pass through the model works correctly. """ import sys import json import torch from model import Transformer, ModelArgs from generate import load_sharded_model, link_fp8_scales def test_forward_pass(): """Test single forward pass through the model.""" print("\n" + "=" * 70) print("NVFP4 Forward Pass Test") print("=" * 70) print("Testing single forward pass with dummy input") print("Expected runtime: 2-10 minutes") print("=" * 70 + "\n") # Load config print("Loading config...") config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" with open(config_path) as f: args = ModelArgs(**json.load(f)) print(f" PASS: Config loaded: {args.n_layers} layers, dtype={args.dtype}\n") # Create model print("Creating model...") torch.set_default_dtype(torch.bfloat16) with torch.device("cpu"): model = Transformer(args) print(f" PASS: Model created\n") # Load weights print("Loading weights (this will take several minutes)...") ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" load_sharded_model(model, ckpt_path) print(f" PASS: Weights loaded\n") # Create dummy input batch_size = 1 seq_len = 4 tokens = torch.randint(0, args.vocab_size, (batch_size, seq_len), device="cpu") print(f"Running forward pass...") print(f" Input: {tokens.shape}, tokens={tokens[0].tolist()}") try: with torch.inference_mode(): logits = model(tokens, start_pos=0) print(f" Output: {logits.shape}, dtype={logits.dtype}") # Verify output shape expected_shape = (batch_size, args.vocab_size) assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}" print(f" PASS: Output shape correct: {logits.shape}") # Check for NaN/Inf has_nan = torch.isnan(logits).any().item() has_inf = torch.isinf(logits).any().item() if has_nan: print(f" FAIL: ERROR: Output contains NaN!") return 1 if has_inf: print(f" FAIL: ERROR: Output contains Inf!") return 1 print(f" PASS: No NaN/Inf in output") # Check output statistics logits_min = logits.min().item() logits_max = logits.max().item() logits_mean = logits.mean().item() print(f"\n Output statistics:") print(f" Min: {logits_min:.3f}") print(f" Max: {logits_max:.3f}") print(f" Mean: {logits_mean:.3f}") # Sanity check: logits should be roughly in [-20, 20] range if abs(logits_mean) > 50 or logits_max > 100 or logits_min < -100: print(f" WARN: WARNING: Logits have unusual range (possible issue)") else: print(f" PASS: Logits in reasonable range") # Test top predictions top_k = 5 top_logits, top_indices = logits[0].topk(top_k) print(f"\n Top {top_k} predictions:") for i, (logit, idx) in enumerate(zip(top_logits, top_indices)): print(f" {i+1}. Token {idx.item()}: logit={logit.item():.3f}") print("\n" + "=" * 70) print("PASS: FORWARD PASS TEST PASSED") print("=" * 70) print("Forward pass completed successfully!") print("Model is producing valid outputs.") print("=" * 70) return 0 except Exception as e: print(f"\nFAIL: FORWARD PASS FAILED: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(test_forward_pass())