Upload inference/nvfp4_kernel.py with huggingface_hub
Browse files- inference/nvfp4_kernel.py +17 -3
inference/nvfp4_kernel.py
CHANGED
|
@@ -16,6 +16,7 @@ import triton
|
|
| 16 |
import triton.language as tl
|
| 17 |
from triton.tools.tensor_descriptor import TensorDescriptor
|
| 18 |
from typing import Tuple, Optional
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
# NVFP4 E2M1 lookup table for dequantization
|
|
@@ -24,6 +25,19 @@ NVFP4_LUT = torch.tensor([
|
|
| 24 |
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
|
| 25 |
], dtype=torch.float32)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Block size for NVFP4 (16 elements per scale)
|
| 28 |
NVFP4_BLOCK_SIZE = 16
|
| 29 |
|
|
@@ -109,8 +123,8 @@ def dequantize_nvfp4(
|
|
| 109 |
high = (packed >> 4) & 0x0F
|
| 110 |
fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
|
| 111 |
|
| 112 |
-
# Lookup table dequantization
|
| 113 |
-
lut =
|
| 114 |
tensor = lut[fp4_tensor.long()]
|
| 115 |
|
| 116 |
# Apply dual-level scales
|
|
@@ -412,7 +426,7 @@ def test_nvfp4_gemm():
|
|
| 412 |
|
| 413 |
# Compare
|
| 414 |
error = (ref - out_deq).abs().mean()
|
| 415 |
-
print(f"
|
| 416 |
|
| 417 |
return True
|
| 418 |
|
|
|
|
| 16 |
import triton.language as tl
|
| 17 |
from triton.tools.tensor_descriptor import TensorDescriptor
|
| 18 |
from typing import Tuple, Optional
|
| 19 |
+
import functools
|
| 20 |
|
| 21 |
|
| 22 |
# NVFP4 E2M1 lookup table for dequantization
|
|
|
|
| 25 |
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
|
| 26 |
], dtype=torch.float32)
|
| 27 |
|
| 28 |
+
|
| 29 |
+
@functools.lru_cache(maxsize=8)
|
| 30 |
+
def _get_nvfp4_lut(device_str: str) -> torch.Tensor:
|
| 31 |
+
"""Get NVFP4 lookup table on specified device (cached).
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
device_str: Device string (e.g., 'cpu', 'cuda:0')
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
NVFP4 lookup table on the specified device
|
| 38 |
+
"""
|
| 39 |
+
return NVFP4_LUT.to(device=device_str)
|
| 40 |
+
|
| 41 |
# Block size for NVFP4 (16 elements per scale)
|
| 42 |
NVFP4_BLOCK_SIZE = 16
|
| 43 |
|
|
|
|
| 123 |
high = (packed >> 4) & 0x0F
|
| 124 |
fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
|
| 125 |
|
| 126 |
+
# Lookup table dequantization (use cached LUT for efficiency)
|
| 127 |
+
lut = _get_nvfp4_lut(str(packed.device))
|
| 128 |
tensor = lut[fp4_tensor.long()]
|
| 129 |
|
| 130 |
# Apply dual-level scales
|
|
|
|
| 426 |
|
| 427 |
# Compare
|
| 428 |
error = (ref - out_deq).abs().mean()
|
| 429 |
+
print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}")
|
| 430 |
|
| 431 |
return True
|
| 432 |
|