DeepSeek-V3.2-NVFP4 / inference /test_nvfp4_kernel.py
eousphoros's picture
Upload inference/test_nvfp4_kernel.py with huggingface_hub
e82122f verified
#!/usr/bin/env python3
"""
Unit tests for NVFP4 kernel functions.
This tests dequantization and GEMM operations in isolation before
attempting full model inference.
"""
import sys
import torch
import torch.nn.functional as F
# Import from local inference directory
from nvfp4_kernel import (
dequantize_nvfp4,
nvfp4_gemm_dequant,
NVFP4_LUT,
NVFP4_BLOCK_SIZE
)
# Constants from quantization script
FP4_MAX = 6.0
FP8_E4M3_MAX = 448.0
E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32)
def compute_nvfp4_scales(fp32_weight, block_size=16):
"""
Compute two-level NVFP4 scaling factors.
Simplified version for testing.
"""
# Global scale
global_amax = fp32_weight.abs().max()
weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX)
if weight_scale_2.abs() < 1e-10:
weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device)
# Per-block scale
M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
N = fp32_weight.shape[-1]
# Pad if needed
N_padded = ((N + block_size - 1) // block_size) * block_size
if N_padded != N:
if fp32_weight.dim() == 1:
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[:N] = fp32_weight
fp32_weight = padded
else:
padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[:, :N] = fp32_weight
fp32_weight = padded
# Reshape to blocks
if fp32_weight.dim() == 1:
weight_blocks = fp32_weight.view(-1, block_size)
else:
weight_blocks = fp32_weight.view(M, -1, block_size)
# Compute per-block amax
per_block_amax = weight_blocks.abs().amax(dim=-1)
per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
per_block_scale = per_block_scale.clamp(min=1e-8)
# Convert to FP8 E4M3
try:
weight_scale = per_block_scale.to(torch.float8_e4m3fn)
except (RuntimeError, TypeError):
weight_scale = per_block_scale.to(torch.float32)
return weight_scale, weight_scale_2
def quantize_to_nvfp4_packed(fp32_weight, weight_scale, weight_scale_2, block_size=16):
"""
Quantize FP32 weight to NVFP4 packed uint8 format.
Simplified version for testing.
"""
device = fp32_weight.device
M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
N = fp32_weight.shape[-1]
# Pad if needed
N_padded = ((N + block_size - 1) // block_size) * block_size
if N_padded != N:
if fp32_weight.dim() == 1:
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device)
padded[:N] = fp32_weight
fp32_weight = padded
else:
padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=device)
padded[:, :N] = fp32_weight
fp32_weight = padded
# Reshape to blocks
if fp32_weight.dim() == 1:
weight_blocks = fp32_weight.view(-1, block_size)
else:
weight_blocks = fp32_weight.view(M, -1, block_size)
# Apply scaling
combined_scale = weight_scale.to(torch.float32) * weight_scale_2
scaled_weight = weight_blocks / combined_scale.unsqueeze(-1)
# Flatten
if fp32_weight.dim() == 1:
scaled_weight = scaled_weight.view(-1)
else:
scaled_weight = scaled_weight.view(M, -1)
# Get E2M1 bounds
e2m1_bounds = E2M1_BOUNDS.to(device)
# Extract sign and absolute values
sign_bit = (scaled_weight < 0).to(torch.uint8)
weight_abs = scaled_weight.abs()
# Quantize to E2M1 magnitude codes [0-7]
magnitude_code = torch.searchsorted(e2m1_bounds, weight_abs)
# Combine sign bit and magnitude
code = (sign_bit << 3) | magnitude_code.to(torch.uint8)
# Pack two 4-bit values per byte
N_current = code.shape[-1]
if N_current % 2 != 0:
# Pad to even
if code.dim() == 1:
padded = torch.zeros(N_current + 1, dtype=torch.uint8, device=device)
padded[:N_current] = code
code = padded
else:
padded = torch.zeros(M, N_current + 1, dtype=torch.uint8, device=device)
padded[:, :N_current] = code
code = padded
# Pack
if code.dim() == 1:
packed = (code[1::2] << 4) | code[0::2]
else:
packed = (code[:, 1::2] << 4) | code[:, 0::2]
return packed
def test_dequant_lookup_table():
"""Test 1: Verify NVFP4 lookup table values are correct."""
print("\n" + "=" * 70)
print("Test 1: NVFP4 Lookup Table")
print("=" * 70)
expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]
assert len(NVFP4_LUT) == 16, f"LUT should have 16 entries, got {len(NVFP4_LUT)}"
for i, (actual, expected_val) in enumerate(zip(NVFP4_LUT, expected)):
assert abs(actual - expected_val) < 1e-6, f"LUT[{i}] = {actual}, expected {expected_val}"
print(f" PASS: Lookup table correct: {NVFP4_LUT.tolist()[:8]}")
print(f" {NVFP4_LUT.tolist()[8:]}")
print(" PASS: Test 1 PASSED\n")
def test_dequant_simple():
"""Test 2: Simple dequantization with known values."""
print("=" * 70)
print("Test 2: Simple Dequantization")
print("=" * 70)
# Create simple test case: packed values representing [0, 1.0, 2.0, 3.0, ...]
# Codes: 0=0.0, 2=1.0, 4=2.0, 5=3.0, 6=4.0, 7=6.0
# Pack: (high << 4) | low
packed = torch.tensor([
[0x20, 0x54, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00], # [0,2,4,5,6,7,0,0] -> [0,1,2,3,4,6,0,0]
], dtype=torch.uint8)
# Uniform scales for simplicity
scale = torch.ones(1, 1, dtype=torch.float8_e4m3fn)
scale_2 = torch.tensor([1.0], dtype=torch.float32)
result = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32)
print(f" Packed: {packed[0].tolist()}")
print(f" Scales: scale={scale.shape}, scale_2={scale_2.item()}")
print(f" Result shape: {result.shape}")
print(f" Result values: {result[0].tolist()}")
# Expected: [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
expected_values = [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
for i, (val, expected) in enumerate(zip(result[0].tolist(), expected_values)):
assert abs(val - expected) < 0.01, f"Position {i}: got {val}, expected {expected}"
print(" PASS: Dequantization correct")
print(" PASS: Test 2 PASSED\n")
def test_quantize_dequantize_roundtrip():
"""Test 3: Quantize then dequantize, check error is acceptable."""
print("=" * 70)
print("Test 3: Quantization-Dequantization Roundtrip")
print("=" * 70)
# Create test tensor with values in representable range
M, N = 64, 256
torch.manual_seed(42)
fp32_weight = torch.randn(M, N, dtype=torch.float32) * 2.0 # Scale to ~[-6, 6]
print(f" Input shape: {fp32_weight.shape}")
print(f" Input range: [{fp32_weight.min():.3f}, {fp32_weight.max():.3f}]")
# Compute scales
scale, scale_2 = compute_nvfp4_scales(fp32_weight, block_size=16)
print(f" Scale shape: {scale.shape}, scale_2: {scale_2.item():.6e}")
# Quantize
packed = quantize_to_nvfp4_packed(fp32_weight, scale, scale_2, block_size=16)
print(f" Packed shape: {packed.shape} (expected [{M}, {N//2}])")
assert packed.shape == (M, N // 2), f"Packed shape mismatch"
# Dequantize
dequantized = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32)
print(f" Dequantized shape: {dequantized.shape}")
assert dequantized.shape == (M, N), f"Dequantized shape mismatch"
# Compute error
error = (fp32_weight - dequantized).abs()
mean_error = error.mean().item()
max_error = error.max().item()
relative_error = (error / (fp32_weight.abs() + 1e-8)).mean().item()
print(f" Mean absolute error: {mean_error:.6f}")
print(f" Max absolute error: {max_error:.6f}")
print(f" Mean relative error: {relative_error:.6f}")
# For 4-bit quantization, we expect some error but should be reasonable
assert mean_error < 1.0, f"Mean error too high: {mean_error}"
assert relative_error < 0.5, f"Relative error too high: {relative_error}"
print(" PASS: Roundtrip error acceptable for 4-bit quantization")
print(" PASS: Test 3 PASSED\n")
def test_gemm_shapes():
"""Test 4: NVFP4 GEMM with various shapes."""
print("=" * 70)
print("Test 4: NVFP4 GEMM Shape Tests")
print("=" * 70)
test_cases = [
(32, 64, 128), # Small
(128, 256, 512), # Medium
(64, 512, 256), # Asymmetric
]
for M, N, K in test_cases:
print(f"\n Testing GEMM: [{M}, {K}] @ [{N}, {K}].T = [{M}, {N}]")
# Create input activation
x = torch.randn(M, K, dtype=torch.bfloat16)
# Create quantized weight
weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 2.0
scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16)
packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16)
print(f" Input: {x.shape}, Weight: {packed_weight.shape}")
print(f" Scales: {scale.shape}, {scale_2.shape}")
# Run NVFP4 GEMM
result = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2)
print(f" Output: {result.shape}")
assert result.shape == (M, N), f"Output shape mismatch: {result.shape} != ({M}, {N})"
# Verify no NaN/Inf
assert not torch.isnan(result).any(), "Output contains NaN"
assert not torch.isinf(result).any(), "Output contains Inf"
print(f" PASS: Shape correct, no NaN/Inf")
print("\n PASS: All GEMM shape tests passed")
print(" PASS: Test 4 PASSED\n")
def test_gemm_correctness():
"""Test 5: Verify NVFP4 GEMM output is close to reference."""
print("=" * 70)
print("Test 5: NVFP4 GEMM Correctness")
print("=" * 70)
M, N, K = 64, 128, 256
# Create test tensors
x = torch.randn(M, K, dtype=torch.bfloat16)
weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 1.5
# Quantize weight
scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16)
packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16)
# Run NVFP4 GEMM
result_nvfp4 = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2)
# Run reference GEMM with FP32
result_reference = F.linear(x, weight_fp32.to(torch.bfloat16))
print(f" NVFP4 GEMM output: {result_nvfp4.shape}, dtype={result_nvfp4.dtype}")
print(f" Reference output: {result_reference.shape}, dtype={result_reference.dtype}")
# Compute error
error = (result_nvfp4.float() - result_reference.float()).abs()
mean_error = error.mean().item()
max_error = error.max().item()
relative_error = (error / (result_reference.float().abs() + 1e-8)).mean().item()
print(f" Mean absolute error: {mean_error:.6f}")
print(f" Max absolute error: {max_error:.6f}")
print(f" Mean relative error: {relative_error:.6f}")
# Due to 4-bit quantization, expect significant error but not catastrophic
assert mean_error < 5.0, f"Mean error too high: {mean_error}"
assert relative_error < 1.0, f"Relative error too high: {relative_error}"
print(" PASS: NVFP4 GEMM output reasonably close to reference")
print(" PASS: Test 5 PASSED\n")
def main():
"""Run all NVFP4 kernel unit tests."""
print("\n" + "=" * 70)
print("NVFP4 Kernel Unit Tests")
print("=" * 70)
print("Testing NVFP4 quantization/dequantization and GEMM operations")
print("Expected runtime: < 30 seconds")
print("=" * 70)
try:
# Run all tests
test_dequant_lookup_table()
test_dequant_simple()
test_quantize_dequantize_roundtrip()
test_gemm_shapes()
test_gemm_correctness()
# Summary
print("=" * 70)
print("PASS: ALL TESTS PASSED")
print("=" * 70)
print("NVFP4 kernel functions are working correctly!")
print("Ready to proceed with full model testing.")
print("=" * 70)
return 0
except AssertionError as e:
print(f"\nFAIL: TEST FAILED: {e}")
import traceback
traceback.print_exc()
return 1
except Exception as e:
print(f"\nFAIL: UNEXPECTED ERROR: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())