eousphoros commited on
Commit
93931fb
·
verified ·
1 Parent(s): 17a0e58

Upload inference/nvfp4_kernel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/nvfp4_kernel.py +17 -3
inference/nvfp4_kernel.py CHANGED
@@ -16,6 +16,7 @@ import triton
16
  import triton.language as tl
17
  from triton.tools.tensor_descriptor import TensorDescriptor
18
  from typing import Tuple, Optional
 
19
 
20
 
21
  # NVFP4 E2M1 lookup table for dequantization
@@ -24,6 +25,19 @@ NVFP4_LUT = torch.tensor([
24
  -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
25
  ], dtype=torch.float32)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Block size for NVFP4 (16 elements per scale)
28
  NVFP4_BLOCK_SIZE = 16
29
 
@@ -109,8 +123,8 @@ def dequantize_nvfp4(
109
  high = (packed >> 4) & 0x0F
110
  fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
111
 
112
- # Lookup table dequantization
113
- lut = NVFP4_LUT.to(device=packed.device)
114
  tensor = lut[fp4_tensor.long()]
115
 
116
  # Apply dual-level scales
@@ -412,7 +426,7 @@ def test_nvfp4_gemm():
412
 
413
  # Compare
414
  error = (ref - out_deq).abs().mean()
415
- print(f" NVFP4 GEMM dequant test: mean abs error = {error:.6f}")
416
 
417
  return True
418
 
 
16
  import triton.language as tl
17
  from triton.tools.tensor_descriptor import TensorDescriptor
18
  from typing import Tuple, Optional
19
+ import functools
20
 
21
 
22
  # NVFP4 E2M1 lookup table for dequantization
 
25
  -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
26
  ], dtype=torch.float32)
27
 
28
+
29
+ @functools.lru_cache(maxsize=8)
30
+ def _get_nvfp4_lut(device_str: str) -> torch.Tensor:
31
+ """Get NVFP4 lookup table on specified device (cached).
32
+
33
+ Args:
34
+ device_str: Device string (e.g., 'cpu', 'cuda:0')
35
+
36
+ Returns:
37
+ NVFP4 lookup table on the specified device
38
+ """
39
+ return NVFP4_LUT.to(device=device_str)
40
+
41
  # Block size for NVFP4 (16 elements per scale)
42
  NVFP4_BLOCK_SIZE = 16
43
 
 
123
  high = (packed >> 4) & 0x0F
124
  fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
125
 
126
+ # Lookup table dequantization (use cached LUT for efficiency)
127
+ lut = _get_nvfp4_lut(str(packed.device))
128
  tensor = lut[fp4_tensor.long()]
129
 
130
  # Apply dual-level scales
 
426
 
427
  # Compare
428
  error = (ref - out_deq).abs().mean()
429
+ print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}")
430
 
431
  return True
432