|
|
|
|
|
""" |
|
|
Unit tests for NVFP4 kernel functions. |
|
|
|
|
|
This tests dequantization and GEMM operations in isolation before |
|
|
attempting full model inference. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from nvfp4_kernel import ( |
|
|
dequantize_nvfp4, |
|
|
nvfp4_gemm_dequant, |
|
|
NVFP4_LUT, |
|
|
NVFP4_BLOCK_SIZE |
|
|
) |
|
|
|
|
|
|
|
|
FP4_MAX = 6.0 |
|
|
FP8_E4M3_MAX = 448.0 |
|
|
E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
def compute_nvfp4_scales(fp32_weight, block_size=16): |
|
|
""" |
|
|
Compute two-level NVFP4 scaling factors. |
|
|
Simplified version for testing. |
|
|
""" |
|
|
|
|
|
global_amax = fp32_weight.abs().max() |
|
|
weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX) |
|
|
|
|
|
if weight_scale_2.abs() < 1e-10: |
|
|
weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device) |
|
|
|
|
|
|
|
|
M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 |
|
|
N = fp32_weight.shape[-1] |
|
|
|
|
|
|
|
|
N_padded = ((N + block_size - 1) // block_size) * block_size |
|
|
if N_padded != N: |
|
|
if fp32_weight.dim() == 1: |
|
|
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) |
|
|
padded[:N] = fp32_weight |
|
|
fp32_weight = padded |
|
|
else: |
|
|
padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) |
|
|
padded[:, :N] = fp32_weight |
|
|
fp32_weight = padded |
|
|
|
|
|
|
|
|
if fp32_weight.dim() == 1: |
|
|
weight_blocks = fp32_weight.view(-1, block_size) |
|
|
else: |
|
|
weight_blocks = fp32_weight.view(M, -1, block_size) |
|
|
|
|
|
|
|
|
per_block_amax = weight_blocks.abs().amax(dim=-1) |
|
|
per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2) |
|
|
per_block_scale = per_block_scale.clamp(min=1e-8) |
|
|
|
|
|
|
|
|
try: |
|
|
weight_scale = per_block_scale.to(torch.float8_e4m3fn) |
|
|
except (RuntimeError, TypeError): |
|
|
weight_scale = per_block_scale.to(torch.float32) |
|
|
|
|
|
return weight_scale, weight_scale_2 |
|
|
|
|
|
|
|
|
def quantize_to_nvfp4_packed(fp32_weight, weight_scale, weight_scale_2, block_size=16): |
|
|
""" |
|
|
Quantize FP32 weight to NVFP4 packed uint8 format. |
|
|
Simplified version for testing. |
|
|
""" |
|
|
device = fp32_weight.device |
|
|
M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 |
|
|
N = fp32_weight.shape[-1] |
|
|
|
|
|
|
|
|
N_padded = ((N + block_size - 1) // block_size) * block_size |
|
|
if N_padded != N: |
|
|
if fp32_weight.dim() == 1: |
|
|
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device) |
|
|
padded[:N] = fp32_weight |
|
|
fp32_weight = padded |
|
|
else: |
|
|
padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=device) |
|
|
padded[:, :N] = fp32_weight |
|
|
fp32_weight = padded |
|
|
|
|
|
|
|
|
if fp32_weight.dim() == 1: |
|
|
weight_blocks = fp32_weight.view(-1, block_size) |
|
|
else: |
|
|
weight_blocks = fp32_weight.view(M, -1, block_size) |
|
|
|
|
|
|
|
|
combined_scale = weight_scale.to(torch.float32) * weight_scale_2 |
|
|
scaled_weight = weight_blocks / combined_scale.unsqueeze(-1) |
|
|
|
|
|
|
|
|
if fp32_weight.dim() == 1: |
|
|
scaled_weight = scaled_weight.view(-1) |
|
|
else: |
|
|
scaled_weight = scaled_weight.view(M, -1) |
|
|
|
|
|
|
|
|
e2m1_bounds = E2M1_BOUNDS.to(device) |
|
|
|
|
|
|
|
|
sign_bit = (scaled_weight < 0).to(torch.uint8) |
|
|
weight_abs = scaled_weight.abs() |
|
|
|
|
|
|
|
|
magnitude_code = torch.searchsorted(e2m1_bounds, weight_abs) |
|
|
|
|
|
|
|
|
code = (sign_bit << 3) | magnitude_code.to(torch.uint8) |
|
|
|
|
|
|
|
|
N_current = code.shape[-1] |
|
|
if N_current % 2 != 0: |
|
|
|
|
|
if code.dim() == 1: |
|
|
padded = torch.zeros(N_current + 1, dtype=torch.uint8, device=device) |
|
|
padded[:N_current] = code |
|
|
code = padded |
|
|
else: |
|
|
padded = torch.zeros(M, N_current + 1, dtype=torch.uint8, device=device) |
|
|
padded[:, :N_current] = code |
|
|
code = padded |
|
|
|
|
|
|
|
|
if code.dim() == 1: |
|
|
packed = (code[1::2] << 4) | code[0::2] |
|
|
else: |
|
|
packed = (code[:, 1::2] << 4) | code[:, 0::2] |
|
|
|
|
|
return packed |
|
|
|
|
|
|
|
|
def test_dequant_lookup_table(): |
|
|
"""Test 1: Verify NVFP4 lookup table values are correct.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("Test 1: NVFP4 Lookup Table") |
|
|
print("=" * 70) |
|
|
|
|
|
expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, |
|
|
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] |
|
|
|
|
|
assert len(NVFP4_LUT) == 16, f"LUT should have 16 entries, got {len(NVFP4_LUT)}" |
|
|
|
|
|
for i, (actual, expected_val) in enumerate(zip(NVFP4_LUT, expected)): |
|
|
assert abs(actual - expected_val) < 1e-6, f"LUT[{i}] = {actual}, expected {expected_val}" |
|
|
|
|
|
print(f" PASS: Lookup table correct: {NVFP4_LUT.tolist()[:8]}") |
|
|
print(f" {NVFP4_LUT.tolist()[8:]}") |
|
|
print(" PASS: Test 1 PASSED\n") |
|
|
|
|
|
|
|
|
def test_dequant_simple(): |
|
|
"""Test 2: Simple dequantization with known values.""" |
|
|
print("=" * 70) |
|
|
print("Test 2: Simple Dequantization") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
packed = torch.tensor([ |
|
|
[0x20, 0x54, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00], |
|
|
], dtype=torch.uint8) |
|
|
|
|
|
|
|
|
scale = torch.ones(1, 1, dtype=torch.float8_e4m3fn) |
|
|
scale_2 = torch.tensor([1.0], dtype=torch.float32) |
|
|
|
|
|
result = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) |
|
|
|
|
|
print(f" Packed: {packed[0].tolist()}") |
|
|
print(f" Scales: scale={scale.shape}, scale_2={scale_2.item()}") |
|
|
print(f" Result shape: {result.shape}") |
|
|
print(f" Result values: {result[0].tolist()}") |
|
|
|
|
|
|
|
|
expected_values = [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
|
for i, (val, expected) in enumerate(zip(result[0].tolist(), expected_values)): |
|
|
assert abs(val - expected) < 0.01, f"Position {i}: got {val}, expected {expected}" |
|
|
|
|
|
print(" PASS: Dequantization correct") |
|
|
print(" PASS: Test 2 PASSED\n") |
|
|
|
|
|
|
|
|
def test_quantize_dequantize_roundtrip(): |
|
|
"""Test 3: Quantize then dequantize, check error is acceptable.""" |
|
|
print("=" * 70) |
|
|
print("Test 3: Quantization-Dequantization Roundtrip") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
M, N = 64, 256 |
|
|
torch.manual_seed(42) |
|
|
fp32_weight = torch.randn(M, N, dtype=torch.float32) * 2.0 |
|
|
|
|
|
print(f" Input shape: {fp32_weight.shape}") |
|
|
print(f" Input range: [{fp32_weight.min():.3f}, {fp32_weight.max():.3f}]") |
|
|
|
|
|
|
|
|
scale, scale_2 = compute_nvfp4_scales(fp32_weight, block_size=16) |
|
|
print(f" Scale shape: {scale.shape}, scale_2: {scale_2.item():.6e}") |
|
|
|
|
|
|
|
|
packed = quantize_to_nvfp4_packed(fp32_weight, scale, scale_2, block_size=16) |
|
|
print(f" Packed shape: {packed.shape} (expected [{M}, {N//2}])") |
|
|
assert packed.shape == (M, N // 2), f"Packed shape mismatch" |
|
|
|
|
|
|
|
|
dequantized = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) |
|
|
print(f" Dequantized shape: {dequantized.shape}") |
|
|
assert dequantized.shape == (M, N), f"Dequantized shape mismatch" |
|
|
|
|
|
|
|
|
error = (fp32_weight - dequantized).abs() |
|
|
mean_error = error.mean().item() |
|
|
max_error = error.max().item() |
|
|
relative_error = (error / (fp32_weight.abs() + 1e-8)).mean().item() |
|
|
|
|
|
print(f" Mean absolute error: {mean_error:.6f}") |
|
|
print(f" Max absolute error: {max_error:.6f}") |
|
|
print(f" Mean relative error: {relative_error:.6f}") |
|
|
|
|
|
|
|
|
assert mean_error < 1.0, f"Mean error too high: {mean_error}" |
|
|
assert relative_error < 0.5, f"Relative error too high: {relative_error}" |
|
|
|
|
|
print(" PASS: Roundtrip error acceptable for 4-bit quantization") |
|
|
print(" PASS: Test 3 PASSED\n") |
|
|
|
|
|
|
|
|
def test_gemm_shapes(): |
|
|
"""Test 4: NVFP4 GEMM with various shapes.""" |
|
|
print("=" * 70) |
|
|
print("Test 4: NVFP4 GEMM Shape Tests") |
|
|
print("=" * 70) |
|
|
|
|
|
test_cases = [ |
|
|
(32, 64, 128), |
|
|
(128, 256, 512), |
|
|
(64, 512, 256), |
|
|
] |
|
|
|
|
|
for M, N, K in test_cases: |
|
|
print(f"\n Testing GEMM: [{M}, {K}] @ [{N}, {K}].T = [{M}, {N}]") |
|
|
|
|
|
|
|
|
x = torch.randn(M, K, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 2.0 |
|
|
scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) |
|
|
packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) |
|
|
|
|
|
print(f" Input: {x.shape}, Weight: {packed_weight.shape}") |
|
|
print(f" Scales: {scale.shape}, {scale_2.shape}") |
|
|
|
|
|
|
|
|
result = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) |
|
|
|
|
|
print(f" Output: {result.shape}") |
|
|
assert result.shape == (M, N), f"Output shape mismatch: {result.shape} != ({M}, {N})" |
|
|
|
|
|
|
|
|
assert not torch.isnan(result).any(), "Output contains NaN" |
|
|
assert not torch.isinf(result).any(), "Output contains Inf" |
|
|
|
|
|
print(f" PASS: Shape correct, no NaN/Inf") |
|
|
|
|
|
print("\n PASS: All GEMM shape tests passed") |
|
|
print(" PASS: Test 4 PASSED\n") |
|
|
|
|
|
|
|
|
def test_gemm_correctness(): |
|
|
"""Test 5: Verify NVFP4 GEMM output is close to reference.""" |
|
|
print("=" * 70) |
|
|
print("Test 5: NVFP4 GEMM Correctness") |
|
|
print("=" * 70) |
|
|
|
|
|
M, N, K = 64, 128, 256 |
|
|
|
|
|
|
|
|
x = torch.randn(M, K, dtype=torch.bfloat16) |
|
|
weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 1.5 |
|
|
|
|
|
|
|
|
scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) |
|
|
packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) |
|
|
|
|
|
|
|
|
result_nvfp4 = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) |
|
|
|
|
|
|
|
|
result_reference = F.linear(x, weight_fp32.to(torch.bfloat16)) |
|
|
|
|
|
print(f" NVFP4 GEMM output: {result_nvfp4.shape}, dtype={result_nvfp4.dtype}") |
|
|
print(f" Reference output: {result_reference.shape}, dtype={result_reference.dtype}") |
|
|
|
|
|
|
|
|
error = (result_nvfp4.float() - result_reference.float()).abs() |
|
|
mean_error = error.mean().item() |
|
|
max_error = error.max().item() |
|
|
relative_error = (error / (result_reference.float().abs() + 1e-8)).mean().item() |
|
|
|
|
|
print(f" Mean absolute error: {mean_error:.6f}") |
|
|
print(f" Max absolute error: {max_error:.6f}") |
|
|
print(f" Mean relative error: {relative_error:.6f}") |
|
|
|
|
|
|
|
|
assert mean_error < 5.0, f"Mean error too high: {mean_error}" |
|
|
assert relative_error < 1.0, f"Relative error too high: {relative_error}" |
|
|
|
|
|
print(" PASS: NVFP4 GEMM output reasonably close to reference") |
|
|
print(" PASS: Test 5 PASSED\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run all NVFP4 kernel unit tests.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("NVFP4 Kernel Unit Tests") |
|
|
print("=" * 70) |
|
|
print("Testing NVFP4 quantization/dequantization and GEMM operations") |
|
|
print("Expected runtime: < 30 seconds") |
|
|
print("=" * 70) |
|
|
|
|
|
try: |
|
|
|
|
|
test_dequant_lookup_table() |
|
|
test_dequant_simple() |
|
|
test_quantize_dequantize_roundtrip() |
|
|
test_gemm_shapes() |
|
|
test_gemm_correctness() |
|
|
|
|
|
|
|
|
print("=" * 70) |
|
|
print("PASS: ALL TESTS PASSED") |
|
|
print("=" * 70) |
|
|
print("NVFP4 kernel functions are working correctly!") |
|
|
print("Ready to proceed with full model testing.") |
|
|
print("=" * 70) |
|
|
|
|
|
return 0 |
|
|
|
|
|
except AssertionError as e: |
|
|
print(f"\nFAIL: TEST FAILED: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
except Exception as e: |
|
|
print(f"\nFAIL: UNEXPECTED ERROR: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|