eousphoros commited on
Commit
2ad76b4
·
verified ·
1 Parent(s): 2657783

Upload tools/fp8_to_nvfp4_streaming.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tools/fp8_to_nvfp4_streaming.py +1290 -0
tools/fp8_to_nvfp4_streaming.py ADDED
@@ -0,0 +1,1290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streaming FP8 to NVFP4 Conversion for DeepSeek V3.2
4
+
5
+ Converts FP8 e4m3 quantized weights (128x128 block scales) to NVFP4 e2m1 format
6
+ (16-element block scales) via FP32 intermediates.
7
+
8
+ Target: vLLM-compatible checkpoint with compressed-tensors format.
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import torch
14
+ import gc
15
+ import re
16
+ import shutil
17
+ import time
18
+ import logging
19
+ from typing import Dict, Any, Optional, Tuple, List, Set
20
+ from pathlib import Path
21
+ from dataclasses import dataclass, field
22
+ from safetensors.torch import save_file as st_save_file
23
+ from safetensors import safe_open
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ============================================================================
28
+ # NVFP4 E2M1 Constants (from TensorRT-Model-Optimizer nvfp4_tensor.py)
29
+ # ============================================================================
30
+
31
+ # E2M1 quantization boundaries for searchsorted
32
+ E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0])
33
+
34
+ # E2M1 representable values (index 0-7 = positive, 8-15 = negative with sign bit)
35
+ E2M1_VALUES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
36
+
37
+ # Maximum representable FP4 value
38
+ FP4_MAX = 6.0
39
+
40
+ # Maximum FP8 E4M3 value (for scale normalization)
41
+ FP8_E4M3_MAX = 448.0
42
+
43
+ # ============================================================================
44
+ # Tensor Classification Patterns
45
+ # ============================================================================
46
+
47
+ # Patterns for tensors that should NOT be quantized (preserve in original dtype)
48
+ PRESERVE_PATTERNS = [
49
+ r"embed_tokens", # Embeddings
50
+ r"lm_head", # Output head
51
+ r"\.mlp\.gate\.", # MoE router gate (NOT gate_proj) - note: .gate. not .gate$
52
+ r"shared_experts\.gate\.", # Shared expert routing
53
+ r"shared_expert_gate", # Alternative naming
54
+ r"layernorm", # LayerNorm weights
55
+ r"_norm\.", # RMSNorm weights (input_layernorm, etc.)
56
+ r"\.norm\.", # Norm weights
57
+ r"\.bias$", # Bias terms
58
+ # V3.2 DSA-specific (CRITICAL):
59
+ r"indexer\.weights_proj", # Sparse pattern selector - MUST preserve!
60
+ r"indexer\.k_norm", # Indexer normalization
61
+ # Scale tensors (handled separately)
62
+ r"_scale_inv$", # FP8 scale_inv tensors
63
+ r"_scale$", # Scale tensors
64
+ r"_scale_2$", # Global scale tensors
65
+ ]
66
+
67
+ # Compile patterns for efficiency
68
+ PRESERVE_PATTERNS_COMPILED = [re.compile(p) for p in PRESERVE_PATTERNS]
69
+
70
+
71
+ # ============================================================================
72
+ # ShardedSafeTensorWriter (adapted from fp8_fp4_llmcompressor_streaming.py)
73
+ # ============================================================================
74
+
75
+ class ShardedSafeTensorWriter:
76
+ """
77
+ Stream tensors into numbered .safetensors shards and build a HF-style index JSON.
78
+ """
79
+ def __init__(self, out_dir: str, max_shard_size: str = "5GB"):
80
+ self.out_dir = os.path.abspath(out_dir)
81
+ os.makedirs(self.out_dir, exist_ok=True)
82
+ self.max_bytes = self._parse_size_to_bytes(max_shard_size)
83
+ self.curr_tensors: Dict[str, torch.Tensor] = {}
84
+ self.curr_bytes = 0
85
+ self.shard_idx = 1
86
+ self.weight_map: Dict[str, str] = {}
87
+ self.total_bytes = 0
88
+
89
+ def _parse_size_to_bytes(self, size_str: str) -> int:
90
+ size_str = size_str.upper().strip()
91
+ if size_str.endswith('GB'):
92
+ return int(float(size_str[:-2]) * 1024 * 1024 * 1024)
93
+ elif size_str.endswith('MB'):
94
+ return int(float(size_str[:-2]) * 1024 * 1024)
95
+ elif size_str.endswith('KB'):
96
+ return int(float(size_str[:-2]) * 1024)
97
+ else:
98
+ return int(size_str)
99
+
100
+ def _next_shard_name(self) -> str:
101
+ return f"model-{self.shard_idx:05d}.safetensors"
102
+
103
+ def _flush(self):
104
+ if not self.curr_tensors:
105
+ return
106
+ fname = self._next_shard_name()
107
+ path = os.path.join(self.out_dir, fname)
108
+ st_save_file(self.curr_tensors, path, metadata={"format": "nvfp4"})
109
+ logger.info(f" Saved shard {fname}: {len(self.curr_tensors)} tensors, {self.curr_bytes / 1e9:.2f} GB")
110
+ for k in self.curr_tensors.keys():
111
+ self.weight_map[k] = fname
112
+ self.total_bytes += self.curr_bytes
113
+ self.curr_tensors.clear()
114
+ self.curr_bytes = 0
115
+ self.shard_idx += 1
116
+
117
+ def add_tensor(self, name: str, tensor: torch.Tensor):
118
+ if tensor.device.type != "cpu":
119
+ tensor = tensor.to("cpu")
120
+ if not tensor.is_contiguous():
121
+ tensor = tensor.contiguous()
122
+ tbytes = tensor.element_size() * tensor.numel()
123
+ if self.curr_bytes > 0 and self.curr_bytes + tbytes > self.max_bytes:
124
+ self._flush()
125
+ self.curr_tensors[name] = tensor
126
+ self.curr_bytes += tbytes
127
+
128
+ def finalize(self) -> int:
129
+ self._flush()
130
+ index_path = os.path.join(self.out_dir, "model.safetensors.index.json")
131
+ index = {"metadata": {"total_size": self.total_bytes}, "weight_map": self.weight_map}
132
+ with open(index_path, "w") as f:
133
+ json.dump(index, f, indent=2)
134
+ logger.info(f"Finalized: {self.shard_idx - 1} shards, {self.total_bytes / 1e9:.2f} GB total")
135
+ return self.shard_idx - 1
136
+
137
+
138
+ # ============================================================================
139
+ # Conversion Statistics
140
+ # ============================================================================
141
+
142
+ @dataclass
143
+ class ConversionStats:
144
+ """Track conversion statistics."""
145
+ total_tensors: int = 0
146
+ fp8_tensors: int = 0
147
+ # Primary conversions: FP8 tensors where we ran the full conversion logic
148
+ primary_conversions: int = 0
149
+ # MoE partner conversions: FP8 tensors converted as partners during joint scale computation
150
+ # These are cached during primary conversion and written when encountered in stream
151
+ moe_partner_conversions: int = 0
152
+ preserved_sensitive: int = 0
153
+ copied_unchanged: int = 0
154
+ total_params: int = 0
155
+ layers_processed: Set[str] = field(default_factory=set)
156
+ warnings: List[Dict] = field(default_factory=list)
157
+ errors: List[Dict] = field(default_factory=list)
158
+ start_time: float = 0
159
+ end_time: float = 0
160
+
161
+ @property
162
+ def total_nvfp4_tensors(self) -> int:
163
+ """Total FP8 tensors converted to NVFP4 (primary + partner)."""
164
+ return self.primary_conversions + self.moe_partner_conversions
165
+
166
+ def log_warning(self, key: str, reason: str):
167
+ self.warnings.append({"tensor": key, "reason": reason})
168
+
169
+ def log_error(self, key: str, error: str):
170
+ self.errors.append({"tensor": key, "error": error})
171
+
172
+
173
+ # ============================================================================
174
+ # FP8 Block Dequantization
175
+ # ============================================================================
176
+
177
+ def dequantize_fp8_block_to_fp32(
178
+ fp8_weight: torch.Tensor,
179
+ scale_inv: torch.Tensor,
180
+ block_size: int = 128,
181
+ device: Optional[torch.device] = None
182
+ ) -> torch.Tensor:
183
+ """
184
+ Dequantize FP8 e4m3 weight using block-wise scale_inv.
185
+
186
+ The DeepSeek FP8 format uses 128x128 blocks where each block
187
+ shares a single inverse scale factor.
188
+
189
+ Formula: fp32_weight = fp8_weight.to(float32) * scale_inv[block_i, block_j]
190
+
191
+ Reference: TensorRT-Model-Optimizer/examples/deepseek/ds_kernel.py:89-110
192
+
193
+ Args:
194
+ fp8_weight: FP8 e4m3 weight tensor [M, N]
195
+ scale_inv: Inverse scale tensor [M/block_size, N/block_size]
196
+ block_size: Block size (default 128)
197
+ device: Device to compute on (None = same as input)
198
+
199
+ Returns:
200
+ FP32 dequantized weight tensor [M, N]
201
+ """
202
+ if device is not None:
203
+ fp8_weight = fp8_weight.to(device)
204
+ scale_inv = scale_inv.to(device)
205
+
206
+ M, N = fp8_weight.shape
207
+
208
+ # Handle case where dimensions aren't divisible by block_size
209
+ M_blocks = (M + block_size - 1) // block_size
210
+ N_blocks = (N + block_size - 1) // block_size
211
+
212
+ # Validate scale_inv shape
213
+ expected_scale_shape = (M_blocks, N_blocks)
214
+ if scale_inv.shape != expected_scale_shape:
215
+ # Some weights have different scale shapes (e.g., per-row scaling)
216
+ if scale_inv.numel() == 1:
217
+ # Scalar scale
218
+ return fp8_weight.to(torch.float32) * scale_inv.item()
219
+ elif scale_inv.shape[0] == 1 or scale_inv.shape[1] == 1:
220
+ # Per-row or per-column scaling
221
+ return fp8_weight.to(torch.float32) * scale_inv.to(torch.float32)
222
+ else:
223
+ logger.warning(f"Unexpected scale_inv shape {scale_inv.shape} for weight {fp8_weight.shape}, expected {expected_scale_shape}")
224
+ # Try to broadcast
225
+ return fp8_weight.to(torch.float32) * scale_inv.to(torch.float32)
226
+
227
+ # Convert FP8 to FP32
228
+ fp32_weight = fp8_weight.to(torch.float32)
229
+
230
+ # If dimensions match exactly, use efficient block multiplication
231
+ if M % block_size == 0 and N % block_size == 0:
232
+ # Reshape to blocks: [M/bs, bs, N/bs, bs]
233
+ weight_blocks = fp32_weight.view(M_blocks, block_size, N_blocks, block_size)
234
+
235
+ # Apply scale: scale_inv[i, j] applies to weight_blocks[i, :, j, :]
236
+ # scale_inv shape: [M_blocks, N_blocks] -> [M_blocks, 1, N_blocks, 1]
237
+ scaled = weight_blocks * scale_inv[:, None, :, None].to(torch.float32)
238
+
239
+ # Reshape back
240
+ return scaled.view(M, N)
241
+ else:
242
+ # Handle non-divisible dimensions with padding
243
+ M_pad = M_blocks * block_size
244
+ N_pad = N_blocks * block_size
245
+
246
+ padded_weight = torch.zeros(M_pad, N_pad, dtype=torch.float32, device=fp32_weight.device)
247
+ padded_weight[:M, :N] = fp32_weight
248
+
249
+ weight_blocks = padded_weight.view(M_blocks, block_size, N_blocks, block_size)
250
+ scaled = weight_blocks * scale_inv[:, None, :, None].to(torch.float32)
251
+
252
+ return scaled.view(M_pad, N_pad)[:M, :N]
253
+
254
+
255
+ # ============================================================================
256
+ # NVFP4 Scale Computation
257
+ # ============================================================================
258
+
259
+ def compute_nvfp4_scales(
260
+ fp32_weight: torch.Tensor,
261
+ block_size: int = 16
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ """
264
+ Compute two-level NVFP4 scaling factors.
265
+
266
+ NVFP4 uses dual-level scaling:
267
+ 1. Per-tensor global scale (scale_2): amax / (6.0 * 448.0)
268
+ 2. Per-block scale: per_block_amax / (6.0 * scale_2)
269
+
270
+ Reference: TensorRT-Model-Optimizer nvfp4_tensor.py:94-97, 63-92
271
+
272
+ Args:
273
+ fp32_weight: FP32 weight tensor
274
+ block_size: Block size for per-block scaling (default 16)
275
+
276
+ Returns:
277
+ Tuple of:
278
+ - weight_scale: Per-block FP8 E4M3 scale [M, N/block_size]
279
+ - weight_scale_2: Per-tensor FP32 global scale (scalar tensor)
280
+ """
281
+ # Step 1: Compute per-tensor global scale (scale_2)
282
+ global_amax = fp32_weight.abs().max()
283
+ weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX)
284
+
285
+ # Ensure non-zero scale (use abs comparison to avoid float precision issues)
286
+ if weight_scale_2.abs() < 1e-10:
287
+ weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device)
288
+
289
+ # Step 2: Compute per-block scale
290
+ original_shape = fp32_weight.shape
291
+
292
+ # Handle N dimension for block quantization
293
+ M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
294
+ N = fp32_weight.shape[-1]
295
+
296
+ # Pad N if not divisible by block_size
297
+ N_padded = ((N + block_size - 1) // block_size) * block_size
298
+ if N_padded != N:
299
+ if fp32_weight.dim() == 1:
300
+ padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
301
+ padded[:N] = fp32_weight
302
+ fp32_weight = padded
303
+ else:
304
+ padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
305
+ padded[..., :N] = fp32_weight
306
+ fp32_weight = padded
307
+
308
+ # Reshape to blocks along last dimension
309
+ if fp32_weight.dim() == 1:
310
+ weight_blocks = fp32_weight.view(-1, block_size)
311
+ else:
312
+ weight_blocks = fp32_weight.view(*original_shape[:-1], -1, block_size)
313
+
314
+ # Compute per-block amax
315
+ per_block_amax = weight_blocks.abs().amax(dim=-1) # [..., N/block_size]
316
+
317
+ # Per-block scale = per_block_amax / (6.0 * scale_2)
318
+ per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
319
+
320
+ # Clamp to avoid division by zero, set zeros to 1.0
321
+ per_block_scale = per_block_scale.clamp(min=1e-8)
322
+ per_block_scale[per_block_scale < 1e-7] = 1.0
323
+
324
+ # Convert to FP8 E4M3 (if available, otherwise keep as float32)
325
+ try:
326
+ weight_scale = per_block_scale.to(torch.float8_e4m3fn)
327
+ except (RuntimeError, TypeError):
328
+ # FP8 not supported on this device/PyTorch version
329
+ weight_scale = per_block_scale.to(torch.float32)
330
+
331
+ return weight_scale, weight_scale_2
332
+
333
+
334
+ # ============================================================================
335
+ # NVFP4 Quantization and Packing
336
+ # ============================================================================
337
+
338
+ def quantize_to_nvfp4_packed(
339
+ fp32_weight: torch.Tensor,
340
+ weight_scale: torch.Tensor,
341
+ weight_scale_2: torch.Tensor,
342
+ block_size: int = 16
343
+ ) -> torch.Tensor:
344
+ """
345
+ Quantize FP32 weight to NVFP4 packed uint8 format.
346
+
347
+ E2M1 values: {0, 0.5, 1, 1.5, 2, 3, 4, 6} with sign (16 total values)
348
+ Packing: (code[..., 1::2] << 4) | code[..., 0::2]
349
+
350
+ Reference: TensorRT-Model-Optimizer nvfp4_tensor.py:119-140, 224-227
351
+
352
+ Args:
353
+ fp32_weight: FP32 weight tensor
354
+ weight_scale: Per-block FP8 E4M3 scale
355
+ weight_scale_2: Per-tensor FP32 global scale
356
+ block_size: Block size (default 16)
357
+
358
+ Returns:
359
+ Packed uint8 tensor [M, N/2]
360
+ """
361
+ device = fp32_weight.device
362
+ original_shape = fp32_weight.shape
363
+ N = original_shape[-1]
364
+
365
+ # Pad N if not divisible by block_size
366
+ N_padded = ((N + block_size - 1) // block_size) * block_size
367
+ if N_padded != N:
368
+ if fp32_weight.dim() == 1:
369
+ padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device)
370
+ padded[:N] = fp32_weight
371
+ fp32_weight = padded
372
+ else:
373
+ padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=device)
374
+ padded[..., :N] = fp32_weight
375
+ fp32_weight = padded
376
+
377
+ # Reshape for block-wise processing
378
+ if fp32_weight.dim() == 1:
379
+ weight_blocks = fp32_weight.view(-1, block_size)
380
+ else:
381
+ weight_blocks = fp32_weight.view(*original_shape[:-1], -1, block_size)
382
+
383
+ # Compute combined scale and apply
384
+ # scaled_weight = weight / (scale * scale_2)
385
+ combined_scale = weight_scale.to(torch.float32) * weight_scale_2
386
+ scaled_weight = weight_blocks / combined_scale.unsqueeze(-1)
387
+
388
+ # Flatten back to original shape (with padding)
389
+ if fp32_weight.dim() == 1:
390
+ scaled_weight = scaled_weight.view(-1)
391
+ else:
392
+ scaled_weight = scaled_weight.view(*original_shape[:-1], -1)
393
+
394
+ # Get E2M1 bounds on device
395
+ e2m1_bounds = E2M1_BOUNDS.to(device)
396
+
397
+ # Extract sign bit and compute absolute values
398
+ sign_bit = (scaled_weight < 0).to(torch.uint8)
399
+ weight_abs = scaled_weight.abs()
400
+
401
+ # Find nearest E2M1 magnitude index (0-7) using searchsorted
402
+ # searchsorted returns index where value should be inserted
403
+ ord_idx = torch.searchsorted(e2m1_bounds, weight_abs, out_int32=True).to(torch.uint8)
404
+
405
+ # Handle rounding at boundary values (odd indices need special treatment)
406
+ # For values exactly at odd boundaries [0.75, 1.75, 2.5], round up
407
+ odd_bounds = e2m1_bounds[[1, 3, 5]] # [0.75, 1.75, 2.5]
408
+ equals_odd = torch.any(weight_abs.unsqueeze(-1) == odd_bounds, dim=-1).to(torch.uint8)
409
+
410
+ # Combine sign and ordinal: code = (sign << 3) | (ord + round_adjust)
411
+ fp4_codes = (sign_bit << 3) | (ord_idx + equals_odd)
412
+
413
+ # Ensure codes are in valid range [0, 15]
414
+ fp4_codes = fp4_codes.clamp(0, 15)
415
+
416
+ # Pack pairs of FP4 values into uint8
417
+ # Even indices in low nibble, odd indices in high nibble
418
+ packed = (fp4_codes[..., 1::2] << 4) | fp4_codes[..., 0::2]
419
+ packed = packed.to(torch.uint8)
420
+
421
+ return packed
422
+
423
+
424
+ # ============================================================================
425
+ # Tensor Classification
426
+ # ============================================================================
427
+
428
+ def should_preserve_tensor(key: str) -> bool:
429
+ """
430
+ Check if a tensor should be preserved (not quantized).
431
+
432
+ Args:
433
+ key: Tensor name/key
434
+
435
+ Returns:
436
+ True if tensor should be preserved in original dtype
437
+ """
438
+ for pattern in PRESERVE_PATTERNS_COMPILED:
439
+ if pattern.search(key):
440
+ return True
441
+ return False
442
+
443
+
444
+ def is_fp8_weight(key: str, tensor: torch.Tensor) -> bool:
445
+ """
446
+ Check if a tensor is an FP8 quantized weight.
447
+
448
+ Args:
449
+ key: Tensor name
450
+ tensor: The tensor to check
451
+
452
+ Returns:
453
+ True if this is an FP8 weight that should be converted
454
+ """
455
+ # Check dtype
456
+ if tensor.dtype != torch.float8_e4m3fn:
457
+ return False
458
+
459
+ # Check it's a weight (not a scale or bias)
460
+ if not key.endswith('.weight'):
461
+ return False
462
+
463
+ # Check it's not a preserved tensor
464
+ if should_preserve_tensor(key):
465
+ return False
466
+
467
+ return True
468
+
469
+
470
+ # ============================================================================
471
+ # MoE Expert Pair Helper Functions
472
+ # ============================================================================
473
+
474
+ def get_moe_expert_pair_key(weight_key: str) -> Optional[str]:
475
+ """
476
+ Get the expert pair identifier for MoE gate_proj/up_proj weights.
477
+
478
+ For vLLM's fused MoE kernels, gate_proj (w1) and up_proj (w3) must share
479
+ the same weight_scale_2 because they're fused together.
480
+
481
+ Args:
482
+ weight_key: Tensor name (e.g., "model.layers.0.mlp.experts.5.gate_proj.weight")
483
+
484
+ Returns:
485
+ Expert pair key (e.g., "model.layers.0.mlp.experts.5") or None if not MoE weight
486
+ """
487
+ # Match MoE expert gate_proj or up_proj patterns
488
+ # Pattern: model.layers.{L}.mlp.experts.{E}.gate_proj.weight
489
+ # Pattern: model.layers.{L}.mlp.experts.{E}.up_proj.weight
490
+ moe_pattern = re.match(r'(model\.layers\.\d+\.mlp\.experts\.\d+)\.(gate_proj|up_proj)\.weight$', weight_key)
491
+ if moe_pattern:
492
+ return moe_pattern.group(1)
493
+
494
+ # Also match shared_experts pattern if present
495
+ shared_pattern = re.match(r'(model\.layers\.\d+\.mlp\.shared_experts)\.(gate_proj|up_proj)\.weight$', weight_key)
496
+ if shared_pattern:
497
+ return shared_pattern.group(1)
498
+
499
+ return None
500
+
501
+
502
+ # ============================================================================
503
+ # Main Converter Class
504
+ # ============================================================================
505
+
506
+ class FP8ToNVFP4StreamingConverter:
507
+ """
508
+ Streaming FP8 to NVFP4 converter for DeepSeek V3.2.
509
+
510
+ Processes safetensor shards sequentially with GPU acceleration,
511
+ converting FP8 e4m3 weights to NVFP4 e2m1 format.
512
+ """
513
+
514
+ def __init__(
515
+ self,
516
+ model_path: str,
517
+ output_dir: str,
518
+ device: str = "cuda",
519
+ max_shard_size: str = "5GB",
520
+ fp8_block_size: int = 128,
521
+ nvfp4_block_size: int = 16
522
+ ):
523
+ """
524
+ Initialize the converter.
525
+
526
+ Args:
527
+ model_path: Path to source FP8 model
528
+ output_dir: Output directory for NVFP4 model
529
+ device: Device for computation (cuda or cpu)
530
+ max_shard_size: Maximum output shard size
531
+ fp8_block_size: FP8 quantization block size (default 128)
532
+ nvfp4_block_size: NVFP4 quantization block size (default 16)
533
+ """
534
+ self.model_path = Path(model_path)
535
+ self.output_dir = Path(output_dir)
536
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
537
+ self.max_shard_size = max_shard_size
538
+ self.fp8_block_size = fp8_block_size
539
+ self.nvfp4_block_size = nvfp4_block_size
540
+
541
+ # Load model index
542
+ self.weight_map, self.shard_to_keys = self._load_index()
543
+
544
+ # Initialize statistics
545
+ self.stats = ConversionStats()
546
+
547
+ # Cache for cross-shard scale_inv tensors
548
+ self.scale_cache: Dict[str, torch.Tensor] = {}
549
+
550
+ # Cache for processed MoE weights (for streaming partner handling)
551
+ # When we process gate_proj, we also load up_proj, process both with joint scale,
552
+ # and cache up_proj's result here so we can skip it when we encounter it later
553
+ # Key: weight_key (e.g., "model.layers.0.mlp.experts.5.up_proj.weight")
554
+ # Value: Dict of converted tensors
555
+ self.moe_processed_cache: Dict[str, Dict[str, torch.Tensor]] = {}
556
+
557
+ # Build MoE pair mapping from index for efficient lookup
558
+ self.moe_pairs: Dict[str, Dict[str, str]] = self._build_moe_pair_map()
559
+
560
+ # Initialize writer
561
+ self.writer = ShardedSafeTensorWriter(str(self.output_dir), max_shard_size)
562
+
563
+ logger.info(f"Initialized FP8→NVFP4 converter")
564
+ logger.info(f" Source: {self.model_path}")
565
+ logger.info(f" Output: {self.output_dir}")
566
+ logger.info(f" Device: {self.device}")
567
+ logger.info(f" FP8 block size: {self.fp8_block_size}")
568
+ logger.info(f" NVFP4 block size: {self.nvfp4_block_size}")
569
+
570
+ def _load_index(self) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
571
+ """Load model index and build shard-to-keys mapping."""
572
+ index_path = self.model_path / "model.safetensors.index.json"
573
+
574
+ if not index_path.exists():
575
+ raise FileNotFoundError(f"Model index not found: {index_path}")
576
+
577
+ with open(index_path) as f:
578
+ index = json.load(f)
579
+
580
+ weight_map = index.get("weight_map", {})
581
+
582
+ # Build reverse mapping: shard -> list of keys
583
+ shard_to_keys: Dict[str, List[str]] = {}
584
+ for key, shard in weight_map.items():
585
+ if shard not in shard_to_keys:
586
+ shard_to_keys[shard] = []
587
+ shard_to_keys[shard].append(key)
588
+
589
+ logger.info(f"Loaded index: {len(weight_map)} tensors across {len(shard_to_keys)} shards")
590
+
591
+ return weight_map, shard_to_keys
592
+
593
+ def _build_moe_pair_map(self) -> Dict[str, Dict[str, str]]:
594
+ """
595
+ Build mapping of MoE gate_proj/up_proj pairs from the index file.
596
+
597
+ This is a lightweight operation that just scans tensor names without
598
+ loading any weights, enabling efficient streaming processing.
599
+
600
+ Returns:
601
+ Dict mapping pair_key -> {"gate_proj": full_key, "up_proj": full_key}
602
+ """
603
+ moe_pairs: Dict[str, Dict[str, str]] = {}
604
+
605
+ for weight_key in self.weight_map.keys():
606
+ pair_key = get_moe_expert_pair_key(weight_key)
607
+ if pair_key:
608
+ if pair_key not in moe_pairs:
609
+ moe_pairs[pair_key] = {}
610
+ if "gate_proj" in weight_key:
611
+ moe_pairs[pair_key]["gate_proj"] = weight_key
612
+ elif "up_proj" in weight_key:
613
+ moe_pairs[pair_key]["up_proj"] = weight_key
614
+
615
+ # Filter to complete pairs only
616
+ complete_pairs = {k: v for k, v in moe_pairs.items()
617
+ if "gate_proj" in v and "up_proj" in v}
618
+
619
+ logger.info(f"Found {len(complete_pairs)} MoE expert pairs (gate_proj + up_proj)")
620
+ return complete_pairs
621
+
622
+ def _load_weight_from_shard(
623
+ self,
624
+ weight_key: str
625
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
626
+ """
627
+ Load an FP8 weight and its scale_inv from the appropriate shard.
628
+
629
+ Uses the index to locate which shard contains the weight.
630
+
631
+ Args:
632
+ weight_key: Full tensor key (e.g., "model.layers.0.mlp.experts.5.up_proj.weight")
633
+
634
+ Returns:
635
+ Tuple of (fp8_weight, scale_inv) or None if not found
636
+ """
637
+ if weight_key not in self.weight_map:
638
+ return None
639
+
640
+ shard_name = self.weight_map[weight_key]
641
+ shard_path = self.model_path / shard_name
642
+
643
+ if not shard_path.exists():
644
+ logger.warning(f"Shard not found: {shard_path}")
645
+ return None
646
+
647
+ try:
648
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
649
+ shard_keys = list(f.keys())
650
+
651
+ if weight_key not in shard_keys:
652
+ return None
653
+
654
+ fp8_weight = f.get_tensor(weight_key)
655
+
656
+ # Get scale_inv (may be in this shard or another)
657
+ scale_inv = self._get_scale_inv(weight_key, shard_keys, f)
658
+ if scale_inv is None:
659
+ logger.warning(f"Missing scale_inv for {weight_key}")
660
+ return None
661
+
662
+ return fp8_weight, scale_inv
663
+ except Exception as e:
664
+ logger.warning(f"Failed to load {weight_key}: {e}")
665
+ return None
666
+
667
+ def _get_partner_key(self, weight_key: str) -> Optional[str]:
668
+ """
669
+ Get the partner key for an MoE gate_proj/up_proj weight.
670
+
671
+ Args:
672
+ weight_key: Full tensor key
673
+
674
+ Returns:
675
+ Partner weight key or None if not an MoE pair weight
676
+ """
677
+ pair_key = get_moe_expert_pair_key(weight_key)
678
+ if not pair_key or pair_key not in self.moe_pairs:
679
+ return None
680
+
681
+ pair = self.moe_pairs[pair_key]
682
+ if "gate_proj" in weight_key:
683
+ return pair.get("up_proj")
684
+ elif "up_proj" in weight_key:
685
+ return pair.get("gate_proj")
686
+ return None
687
+
688
+ def _get_scale_inv(
689
+ self,
690
+ weight_key: str,
691
+ current_shard_keys: List[str],
692
+ current_shard_file: Any # safetensors file handle from safe_open()
693
+ ) -> Optional[torch.Tensor]:
694
+ """
695
+ Get scale_inv tensor, loading from other shard if needed.
696
+
697
+ Uses the model index to find which shard contains the scale_inv
698
+ and loads it on demand. Caches loaded scales for efficiency.
699
+
700
+ Args:
701
+ weight_key: The weight tensor key (e.g., "model.layers.X.mlp.gate_proj.weight")
702
+ current_shard_keys: List of keys in the current shard
703
+ current_shard_file: Open safetensors file handle for current shard
704
+
705
+ Returns:
706
+ scale_inv tensor or None if not found
707
+ """
708
+ scale_key = weight_key.replace('.weight', '.weight_scale_inv')
709
+
710
+ # Fast path: check current shard first
711
+ if scale_key in current_shard_keys:
712
+ return current_shard_file.get_tensor(scale_key)
713
+
714
+ # Check cache
715
+ if scale_key in self.scale_cache:
716
+ return self.scale_cache[scale_key]
717
+
718
+ # Look up in index and load from correct shard
719
+ if scale_key in self.weight_map:
720
+ scale_shard = self.weight_map[scale_key]
721
+ scale_path = self.model_path / scale_shard
722
+
723
+ try:
724
+ with safe_open(scale_path, framework="pt", device="cpu") as f:
725
+ scale_inv = f.get_tensor(scale_key)
726
+ # Cache for future use (scales are small ~32KB each)
727
+ self.scale_cache[scale_key] = scale_inv
728
+ logger.debug(f"Loaded cross-shard scale_inv from {scale_shard}: {scale_key}")
729
+ return scale_inv
730
+ except Exception as e:
731
+ logger.warning(f"Failed to load scale_inv from {scale_shard}: {e}")
732
+ return None
733
+
734
+ return None
735
+
736
+ def _convert_fp8_to_nvfp4(
737
+ self,
738
+ key: str,
739
+ fp8_weight: torch.Tensor,
740
+ scale_inv: torch.Tensor
741
+ ) -> Dict[str, torch.Tensor]:
742
+ """
743
+ Convert a single FP8 weight to NVFP4 format.
744
+
745
+ For MoE gate_proj/up_proj weights, loads the partner weight on-demand
746
+ to compute a joint scale_2, ensuring vLLM's fused MoE kernels work correctly.
747
+ The partner's result is cached to avoid reprocessing.
748
+
749
+ Args:
750
+ key: Tensor name
751
+ fp8_weight: FP8 e4m3 weight tensor
752
+ scale_inv: FP8 inverse scale tensor
753
+
754
+ Returns:
755
+ Dict with converted tensors:
756
+ - key: packed NVFP4 weight
757
+ - key.replace('.weight', '.weight_scale'): per-block scale
758
+ - key.replace('.weight', '.weight_scale_2'): global scale
759
+ """
760
+ # Move to processing device
761
+ fp8_weight = fp8_weight.to(self.device)
762
+ scale_inv = scale_inv.to(self.device)
763
+
764
+ # Step 1: Dequantize FP8 to FP32
765
+ fp32_weight = dequantize_fp8_block_to_fp32(
766
+ fp8_weight, scale_inv, block_size=self.fp8_block_size
767
+ )
768
+
769
+ # Step 2: Compute NVFP4 scales
770
+ # Check if this is an MoE weight that needs shared scale_2 with partner
771
+ partner_key = self._get_partner_key(key)
772
+
773
+ if partner_key:
774
+ # MoE gate_proj/up_proj - need joint scale with partner
775
+ # Load partner weight on-demand
776
+ partner_data = self._load_weight_from_shard(partner_key)
777
+
778
+ if partner_data:
779
+ partner_fp8, partner_scale_inv = partner_data
780
+ partner_fp8 = partner_fp8.to(self.device)
781
+ partner_scale_inv = partner_scale_inv.to(self.device)
782
+
783
+ # Dequantize partner
784
+ partner_fp32 = dequantize_fp8_block_to_fp32(
785
+ partner_fp8, partner_scale_inv, block_size=self.fp8_block_size
786
+ )
787
+
788
+ # Compute joint amax and scale_2
789
+ my_amax = fp32_weight.abs().max()
790
+ partner_amax = partner_fp32.abs().max()
791
+ joint_amax = torch.max(my_amax, partner_amax)
792
+ joint_scale_2 = joint_amax / (FP4_MAX * FP8_E4M3_MAX)
793
+
794
+ # Ensure non-zero (use abs comparison to avoid float precision issues)
795
+ if joint_scale_2.abs() < 1e-10:
796
+ joint_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=self.device)
797
+
798
+ # Compute per-block scale for this weight using joint scale_2
799
+ weight_scale = self._compute_per_block_scale(fp32_weight, joint_scale_2)
800
+ weight_scale_2 = joint_scale_2
801
+
802
+ # Also convert partner and cache its result
803
+ partner_scale = self._compute_per_block_scale(partner_fp32, joint_scale_2)
804
+ partner_packed = quantize_to_nvfp4_packed(
805
+ partner_fp32, partner_scale, joint_scale_2, block_size=self.nvfp4_block_size
806
+ )
807
+
808
+ partner_base = partner_key.replace('.weight', '')
809
+ self.moe_processed_cache[partner_key] = {
810
+ f"{partner_base}.weight": partner_packed.cpu(),
811
+ f"{partner_base}.weight_scale": partner_scale.cpu(),
812
+ f"{partner_base}.weight_scale_2": joint_scale_2.cpu().view(1),
813
+ }
814
+
815
+ logger.debug(f"Computed joint scale_2 for {key} + {partner_key}: {joint_scale_2.item():.6e}")
816
+
817
+ # Cleanup partner tensors
818
+ del partner_fp32, partner_fp8, partner_scale_inv
819
+ else:
820
+ # Partner not found - use standard per-tensor scale
821
+ logger.warning(f"Partner {partner_key} not found for {key}, using independent scale")
822
+ weight_scale, weight_scale_2 = compute_nvfp4_scales(
823
+ fp32_weight, block_size=self.nvfp4_block_size
824
+ )
825
+ else:
826
+ # Non-MoE weight - standard per-tensor scale computation
827
+ weight_scale, weight_scale_2 = compute_nvfp4_scales(
828
+ fp32_weight, block_size=self.nvfp4_block_size
829
+ )
830
+
831
+ # Step 3: Quantize to NVFP4 packed format
832
+ packed_weight = quantize_to_nvfp4_packed(
833
+ fp32_weight, weight_scale, weight_scale_2, block_size=self.nvfp4_block_size
834
+ )
835
+
836
+ # Build output tensor names
837
+ base_name = key.replace('.weight', '')
838
+ result = {
839
+ f"{base_name}.weight": packed_weight.cpu(),
840
+ f"{base_name}.weight_scale": weight_scale.cpu(),
841
+ f"{base_name}.weight_scale_2": weight_scale_2.cpu().view(1),
842
+ }
843
+
844
+ # Update statistics - this is a "primary" conversion (not from MoE partner cache)
845
+ self.stats.primary_conversions += 1
846
+
847
+ # Free GPU memory
848
+ del fp32_weight
849
+ if torch.cuda.is_available():
850
+ torch.cuda.empty_cache()
851
+
852
+ return result
853
+
854
+ def _compute_per_block_scale(
855
+ self,
856
+ fp32_weight: torch.Tensor,
857
+ weight_scale_2: torch.Tensor
858
+ ) -> torch.Tensor:
859
+ """
860
+ Compute per-block scale given a fixed weight_scale_2.
861
+
862
+ Args:
863
+ fp32_weight: FP32 weight tensor
864
+ weight_scale_2: Global scale (FP32 scalar)
865
+
866
+ Returns:
867
+ Per-block FP8 E4M3 scale tensor
868
+ """
869
+ original_shape = fp32_weight.shape
870
+ N = fp32_weight.shape[-1]
871
+ block_size = self.nvfp4_block_size
872
+
873
+ # Pad N if not divisible by block_size
874
+ N_padded = ((N + block_size - 1) // block_size) * block_size
875
+ if N_padded != N:
876
+ if fp32_weight.dim() == 1:
877
+ padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
878
+ padded[:N] = fp32_weight
879
+ fp32_padded = padded
880
+ else:
881
+ padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
882
+ padded[..., :N] = fp32_weight
883
+ fp32_padded = padded
884
+ else:
885
+ fp32_padded = fp32_weight
886
+
887
+ # Reshape to blocks
888
+ if fp32_padded.dim() == 1:
889
+ weight_blocks = fp32_padded.view(-1, block_size)
890
+ else:
891
+ weight_blocks = fp32_padded.view(*original_shape[:-1], -1, block_size)
892
+
893
+ # Per-block amax
894
+ per_block_amax = weight_blocks.abs().amax(dim=-1)
895
+
896
+ # Per-block scale with the given scale_2
897
+ per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
898
+ per_block_scale = per_block_scale.clamp(min=1e-8)
899
+ per_block_scale[per_block_scale < 1e-7] = 1.0
900
+
901
+ # Convert to FP8 E4M3
902
+ try:
903
+ return per_block_scale.to(torch.float8_e4m3fn)
904
+ except (RuntimeError, TypeError):
905
+ return per_block_scale.to(torch.float32)
906
+
907
+ def process_shard(self, shard_name: str) -> int:
908
+ """
909
+ Process a single shard, converting FP8 weights to NVFP4.
910
+
911
+ Args:
912
+ shard_name: Name of the shard file
913
+
914
+ Returns:
915
+ Number of tensors processed
916
+ """
917
+ shard_path = self.model_path / shard_name
918
+
919
+ if not shard_path.exists():
920
+ logger.error(f"Shard not found: {shard_path}")
921
+ return 0
922
+
923
+ tensors_processed = 0
924
+
925
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
926
+ keys = list(f.keys())
927
+
928
+ # Process each tensor
929
+ for key in keys:
930
+ tensor = f.get_tensor(key)
931
+ self.stats.total_tensors += 1
932
+ self.stats.total_params += tensor.numel()
933
+
934
+ # Track layer (safely handle edge cases)
935
+ if '.layers.' in key:
936
+ parts = key.split('.layers.')
937
+ if len(parts) > 1 and '.' in parts[1]:
938
+ layer_num = parts[1].split('.')[0]
939
+ self.stats.layers_processed.add(layer_num)
940
+
941
+ # Skip scale_inv tensors (handled with weights)
942
+ if key.endswith('_scale_inv'):
943
+ continue
944
+
945
+ # Check if this is an FP8 weight to convert
946
+ if is_fp8_weight(key, tensor):
947
+ self.stats.fp8_tensors += 1
948
+
949
+ # Check if this weight was already processed as a partner
950
+ if key in self.moe_processed_cache:
951
+ # Use cached result from partner processing
952
+ # This tensor was converted when its MoE partner was processed
953
+ # (gate_proj and up_proj share weight_scale_2 for vLLM fused kernels)
954
+ cached = self.moe_processed_cache.pop(key) # Pop to free memory
955
+ for name, t in cached.items():
956
+ self.writer.add_tensor(name, t)
957
+ self.stats.moe_partner_conversions += 1
958
+ tensors_processed += 1
959
+ logger.debug(f"Using cached result for MoE partner: {key}")
960
+ continue
961
+
962
+ # Find corresponding scale_inv (with cross-shard lookup)
963
+ scale_inv = self._get_scale_inv(key, keys, f)
964
+
965
+ if scale_inv is not None:
966
+ try:
967
+ # Convert FP8 → NVFP4
968
+ converted = self._convert_fp8_to_nvfp4(key, tensor, scale_inv)
969
+
970
+ # Add to writer
971
+ for name, t in converted.items():
972
+ self.writer.add_tensor(name, t)
973
+
974
+ tensors_processed += 1
975
+
976
+ except Exception as e:
977
+ logger.error(f"Error converting {key}: {e}")
978
+ self.stats.log_error(key, str(e))
979
+ # Skip this tensor - preserving FP8 would create corrupt checkpoint
980
+ # vLLM expects NVFP4 format for all quantized weights
981
+ logger.warning(f"Skipping {key} due to conversion error - checkpoint may be incomplete")
982
+ else:
983
+ # Missing scale_inv - skip this tensor
984
+ # Preserving FP8 would create corrupt checkpoint
985
+ logger.warning(f"Missing scale_inv for {key} (not found in any shard) - skipping")
986
+ self.stats.log_warning(key, "missing_scale_inv")
987
+
988
+ elif should_preserve_tensor(key):
989
+ # Preserve sensitive tensors
990
+ self.writer.add_tensor(key, tensor)
991
+ self.stats.preserved_sensitive += 1
992
+ tensors_processed += 1
993
+
994
+ else:
995
+ # Copy other tensors unchanged (norms, biases, etc.)
996
+ self.writer.add_tensor(key, tensor)
997
+ self.stats.copied_unchanged += 1
998
+ tensors_processed += 1
999
+
1000
+ # Free memory
1001
+ del tensor
1002
+
1003
+ # Clear scale cache - scales from this shard won't be needed again
1004
+ # This prevents unbounded memory growth for large models
1005
+ self.scale_cache.clear()
1006
+
1007
+ # Garbage collection
1008
+ gc.collect()
1009
+ if torch.cuda.is_available():
1010
+ torch.cuda.empty_cache()
1011
+
1012
+ return tensors_processed
1013
+
1014
+ def generate_config(self) -> Dict[str, Any]:
1015
+ """Generate vLLM-compatible config.json with modelopt NVFP4 format."""
1016
+ # Load original config
1017
+ config_path = self.model_path / "config.json"
1018
+ with open(config_path) as f:
1019
+ config = json.load(f)
1020
+
1021
+ # Update quantization config for NVFP4 using modelopt format
1022
+ # This format is compatible with vLLM's modelopt_fp4 quantization handler
1023
+ # Reference: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-NVFP4/blob/main/config.json
1024
+ config["quantization_config"] = {
1025
+ "quant_method": "modelopt",
1026
+ "quant_algo": "NVFP4",
1027
+ "config_groups": {
1028
+ "group_0": {
1029
+ "targets": ["Linear"],
1030
+ "weights": {
1031
+ "num_bits": 4,
1032
+ "type": "float",
1033
+ "group_size": self.nvfp4_block_size,
1034
+ "dynamic": False
1035
+ },
1036
+ "input_activations": None
1037
+ }
1038
+ },
1039
+ "ignore": [
1040
+ "lm_head",
1041
+ "model.embed_tokens",
1042
+ "re:.*\\.mlp\\.gate$",
1043
+ "re:.*layernorm.*",
1044
+ "re:.*_norm.*",
1045
+ "re:.*indexer\\.weights_proj.*",
1046
+ "re:.*indexer\\.k_norm.*"
1047
+ ],
1048
+ "kv_cache_scheme": None,
1049
+ "original_format": {
1050
+ "quant_method": "fp8",
1051
+ "fmt": "e4m3",
1052
+ "scale_fmt": "ue8m0",
1053
+ "weight_block_size": [self.fp8_block_size, self.fp8_block_size]
1054
+ },
1055
+ "conversion_info": {
1056
+ "source": "fp8_e4m3",
1057
+ "target": "nvfp4_e2m1",
1058
+ "intermediate": "fp32",
1059
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
1060
+ }
1061
+ }
1062
+
1063
+ return config
1064
+
1065
+ def copy_auxiliary_files(self):
1066
+ """Copy tokenizer and other auxiliary files."""
1067
+ aux_files = [
1068
+ "tokenizer.json",
1069
+ "tokenizer_config.json",
1070
+ "special_tokens_map.json",
1071
+ "vocab.json",
1072
+ "merges.txt",
1073
+ "tokenizer.model",
1074
+ "generation_config.json"
1075
+ ]
1076
+
1077
+ for filename in aux_files:
1078
+ src = self.model_path / filename
1079
+ if src.exists():
1080
+ dst = self.output_dir / filename
1081
+ shutil.copy2(src, dst)
1082
+ logger.info(f"Copied {filename}")
1083
+
1084
+ # Copy encoding folder if exists (V3.2 specific)
1085
+ encoding_src = self.model_path / "encoding"
1086
+ if encoding_src.exists() and encoding_src.is_dir():
1087
+ encoding_dst = self.output_dir / "encoding"
1088
+ shutil.copytree(encoding_src, encoding_dst, dirs_exist_ok=True)
1089
+ logger.info("Copied encoding folder")
1090
+
1091
+ def generate_report(self) -> Dict[str, Any]:
1092
+ """Generate conversion report."""
1093
+ elapsed = self.stats.end_time - self.stats.start_time
1094
+
1095
+ report = {
1096
+ "conversion_summary": {
1097
+ "source_format": "FP8 E4M3 (DeepSeek block-quantized)",
1098
+ "target_format": "NVFP4 E2M1 (16-element blocks)",
1099
+ "intermediate_format": "FP32",
1100
+ "model": str(self.model_path),
1101
+ "output": str(self.output_dir),
1102
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
1103
+ "elapsed_seconds": round(elapsed, 2),
1104
+ "elapsed_minutes": round(elapsed / 60, 2)
1105
+ },
1106
+ "tensor_statistics": {
1107
+ "total_tensors": self.stats.total_tensors,
1108
+ "fp8_tensors_found": self.stats.fp8_tensors,
1109
+ "primary_conversions": self.stats.primary_conversions,
1110
+ "moe_partner_conversions": self.stats.moe_partner_conversions,
1111
+ "total_nvfp4_tensors": self.stats.total_nvfp4_tensors,
1112
+ "preserved_sensitive": self.stats.preserved_sensitive,
1113
+ "copied_unchanged": self.stats.copied_unchanged,
1114
+ "total_parameters": self.stats.total_params
1115
+ },
1116
+ "layer_statistics": {
1117
+ "layers_processed": len(self.stats.layers_processed),
1118
+ "layer_ids": sorted(self.stats.layers_processed, key=lambda x: int(x) if x.isdigit() else 0)
1119
+ },
1120
+ "output_statistics": {
1121
+ "output_shards": self.writer.shard_idx - 1,
1122
+ "output_size_gb": round(self.writer.total_bytes / 1e9, 2)
1123
+ },
1124
+ "issues": {
1125
+ "warnings": self.stats.warnings[:20],
1126
+ "errors": self.stats.errors[:20],
1127
+ "total_warnings": len(self.stats.warnings),
1128
+ "total_errors": len(self.stats.errors)
1129
+ }
1130
+ }
1131
+
1132
+ # Log truncation if applicable
1133
+ if len(self.stats.warnings) > 20:
1134
+ logger.info(f"Report truncated: showing 20 of {len(self.stats.warnings)} warnings")
1135
+ if len(self.stats.errors) > 20:
1136
+ logger.info(f"Report truncated: showing 20 of {len(self.stats.errors)} errors")
1137
+
1138
+ return report
1139
+
1140
+ def run(self) -> Dict[str, Any]:
1141
+ """
1142
+ Run the full conversion process.
1143
+
1144
+ Returns:
1145
+ Conversion report dictionary
1146
+ """
1147
+ logger.info("=" * 70)
1148
+ logger.info("Starting FP8 to NVFP4 Streaming Conversion")
1149
+ logger.info("=" * 70)
1150
+
1151
+ self.stats.start_time = time.time()
1152
+
1153
+ # Get sorted list of shards
1154
+ shard_names = sorted(self.shard_to_keys.keys())
1155
+ total_shards = len(shard_names)
1156
+
1157
+ logger.info(f"Processing {total_shards} shards...")
1158
+
1159
+ # Process each shard
1160
+ for idx, shard_name in enumerate(shard_names, 1):
1161
+ logger.info(f"\n[{idx}/{total_shards}] Processing {shard_name}")
1162
+ tensors = self.process_shard(shard_name)
1163
+ logger.info(f" Processed {tensors} tensors")
1164
+
1165
+ # Check for orphaned MoE cache entries (partner never encountered)
1166
+ if self.moe_processed_cache:
1167
+ orphan_count = len(self.moe_processed_cache)
1168
+ logger.warning(f"Found {orphan_count} orphaned MoE cache entries (partner weight never processed):")
1169
+ for key in list(self.moe_processed_cache.keys())[:5]:
1170
+ logger.warning(f" - {key}")
1171
+ if orphan_count > 5:
1172
+ logger.warning(f" ... and {orphan_count - 5} more")
1173
+ self.moe_processed_cache.clear()
1174
+
1175
+ # Finalize output
1176
+ logger.info("\nFinalizing output...")
1177
+ self.writer.finalize()
1178
+
1179
+ # Generate and save config
1180
+ logger.info("Generating config.json...")
1181
+ config = self.generate_config()
1182
+ config_path = self.output_dir / "config.json"
1183
+ with open(config_path, 'w') as f:
1184
+ json.dump(config, f, indent=2)
1185
+
1186
+ # Copy auxiliary files
1187
+ logger.info("Copying auxiliary files...")
1188
+ self.copy_auxiliary_files()
1189
+
1190
+ self.stats.end_time = time.time()
1191
+
1192
+ # Generate report
1193
+ report = self.generate_report()
1194
+
1195
+ # Save report
1196
+ report_path = self.output_dir / "conversion_report.json"
1197
+ with open(report_path, 'w') as f:
1198
+ json.dump(report, f, indent=2)
1199
+ logger.info(f"Saved conversion report: {report_path}")
1200
+
1201
+ # Print summary
1202
+ elapsed = self.stats.end_time - self.stats.start_time
1203
+ logger.info("\n" + "=" * 70)
1204
+ logger.info("Conversion Complete!")
1205
+ logger.info(f" Time: {elapsed / 60:.1f} minutes")
1206
+ logger.info(f" FP8 tensors found: {self.stats.fp8_tensors}")
1207
+ logger.info(f" Primary conversions: {self.stats.primary_conversions}")
1208
+ logger.info(f" MoE partner conversions: {self.stats.moe_partner_conversions}")
1209
+ logger.info(f" Total NVFP4 tensors: {self.stats.total_nvfp4_tensors}")
1210
+ logger.info(f" Tensors preserved: {self.stats.preserved_sensitive}")
1211
+ logger.info(f" Output shards: {self.writer.shard_idx - 1}")
1212
+ logger.info(f" Output size: {self.writer.total_bytes / 1e9:.2f} GB")
1213
+ logger.info(f" Output: {self.output_dir}")
1214
+ logger.info("=" * 70)
1215
+
1216
+ return report
1217
+
1218
+
1219
+ # ============================================================================
1220
+ # Main Entry Point
1221
+ # ============================================================================
1222
+
1223
+ def main():
1224
+ import argparse
1225
+
1226
+ parser = argparse.ArgumentParser(
1227
+ description="Streaming FP8 to NVFP4 converter for DeepSeek V3.2"
1228
+ )
1229
+ parser.add_argument(
1230
+ "model_path",
1231
+ help="Path to FP8 model (e.g., /mnt/models/deepseek-v3.2)"
1232
+ )
1233
+ parser.add_argument(
1234
+ "--output_dir",
1235
+ default=None,
1236
+ help="Output directory (default: {model_path}-nvfp4)"
1237
+ )
1238
+ parser.add_argument(
1239
+ "--device",
1240
+ default="cuda",
1241
+ choices=["cuda", "cpu"],
1242
+ help="Device for computation (default: cuda)"
1243
+ )
1244
+ parser.add_argument(
1245
+ "--max_shard_size",
1246
+ default="5GB",
1247
+ help="Maximum output shard size (default: 5GB)"
1248
+ )
1249
+ parser.add_argument(
1250
+ "--fp8_block_size",
1251
+ type=int,
1252
+ default=128,
1253
+ help="FP8 quantization block size (default: 128)"
1254
+ )
1255
+ parser.add_argument(
1256
+ "--nvfp4_block_size",
1257
+ type=int,
1258
+ default=16,
1259
+ help="NVFP4 quantization block size (default: 16)"
1260
+ )
1261
+
1262
+ args = parser.parse_args()
1263
+
1264
+ # Default output directory
1265
+ if args.output_dir is None:
1266
+ args.output_dir = f"{args.model_path.rstrip('/')}-nvfp4"
1267
+
1268
+ # Set up logging
1269
+ logging.basicConfig(
1270
+ level=logging.INFO,
1271
+ format="%(asctime)s - %(levelname)s - %(message)s"
1272
+ )
1273
+
1274
+ # Create and run converter
1275
+ converter = FP8ToNVFP4StreamingConverter(
1276
+ model_path=args.model_path,
1277
+ output_dir=args.output_dir,
1278
+ device=args.device,
1279
+ max_shard_size=args.max_shard_size,
1280
+ fp8_block_size=args.fp8_block_size,
1281
+ nvfp4_block_size=args.nvfp4_block_size
1282
+ )
1283
+
1284
+ report = converter.run()
1285
+
1286
+ return report
1287
+
1288
+
1289
+ if __name__ == "__main__":
1290
+ main()