Upload inference/generate.py with huggingface_hub
Browse files- inference/generate.py +59 -1
inference/generate.py
CHANGED
|
@@ -81,6 +81,14 @@ def load_sharded_model(model, ckpt_path):
|
|
| 81 |
# Get unique shard files
|
| 82 |
shard_files = sorted(set(weight_map.values()))
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
|
| 85 |
model_state = model.state_dict()
|
| 86 |
loaded_keys = set()
|
|
@@ -105,6 +113,9 @@ def load_sharded_model(model, ckpt_path):
|
|
| 105 |
print(f"Warning: {len(missing)} missing keys in checkpoint")
|
| 106 |
for k in list(missing)[:5]:
|
| 107 |
print(f" - {k}")
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
# Fall back to single file
|
| 110 |
single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
|
|
@@ -181,6 +192,48 @@ def generate(
|
|
| 181 |
return completion_tokens
|
| 182 |
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def main(
|
| 185 |
ckpt_path: str,
|
| 186 |
config: str,
|
|
@@ -214,6 +267,11 @@ def main(
|
|
| 214 |
with open(config) as f:
|
| 215 |
args = ModelArgs(**json.load(f))
|
| 216 |
print(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
print("Creating model on CPU (this may take a while)...")
|
| 218 |
with torch.device("cpu"):
|
| 219 |
model = Transformer(args)
|
|
@@ -221,7 +279,7 @@ def main(
|
|
| 221 |
print("Loading model weights...")
|
| 222 |
load_sharded_model(model, ckpt_path)
|
| 223 |
model.eval()
|
| 224 |
-
print("
|
| 225 |
|
| 226 |
if interactive:
|
| 227 |
messages = []
|
|
|
|
| 81 |
# Get unique shard files
|
| 82 |
shard_files = sorted(set(weight_map.values()))
|
| 83 |
|
| 84 |
+
# Check memory before loading
|
| 85 |
+
try:
|
| 86 |
+
import psutil
|
| 87 |
+
mem = psutil.virtual_memory()
|
| 88 |
+
print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total ({mem.percent:.1f}% used)")
|
| 89 |
+
except ImportError:
|
| 90 |
+
pass # psutil not required
|
| 91 |
+
|
| 92 |
print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
|
| 93 |
model_state = model.state_dict()
|
| 94 |
loaded_keys = set()
|
|
|
|
| 113 |
print(f"Warning: {len(missing)} missing keys in checkpoint")
|
| 114 |
for k in list(missing)[:5]:
|
| 115 |
print(f" - {k}")
|
| 116 |
+
|
| 117 |
+
# Reattach FP8 scales after loading
|
| 118 |
+
link_fp8_scales(model)
|
| 119 |
else:
|
| 120 |
# Fall back to single file
|
| 121 |
single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
|
|
|
|
| 192 |
return completion_tokens
|
| 193 |
|
| 194 |
|
| 195 |
+
def clear_system_cache():
|
| 196 |
+
"""
|
| 197 |
+
Clear system cache to free memory (optional optimization).
|
| 198 |
+
|
| 199 |
+
This can help with large models by freeing cached memory.
|
| 200 |
+
Silently attempts cache clearing; failures are ignored.
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
import subprocess
|
| 204 |
+
subprocess.run(
|
| 205 |
+
['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'],
|
| 206 |
+
check=False, capture_output=True, text=True, timeout=5
|
| 207 |
+
)
|
| 208 |
+
except Exception:
|
| 209 |
+
# Silently ignore if cache clearing fails
|
| 210 |
+
pass
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def link_fp8_scales(model):
|
| 214 |
+
"""
|
| 215 |
+
Link FP8 scales to weight tensors after loading.
|
| 216 |
+
|
| 217 |
+
After load_state_dict(), FP8 weights lose their .scale attribute.
|
| 218 |
+
This function reattaches them.
|
| 219 |
+
"""
|
| 220 |
+
from model import Linear, ColumnParallelLinear, RowParallelLinear
|
| 221 |
+
|
| 222 |
+
linked = 0
|
| 223 |
+
for name, module in model.named_modules():
|
| 224 |
+
if isinstance(module, (Linear, ColumnParallelLinear, RowParallelLinear)):
|
| 225 |
+
# Check if this is an FP8 layer
|
| 226 |
+
if hasattr(module, 'weight') and hasattr(module, 'scale'):
|
| 227 |
+
if module.weight is not None and module.scale is not None:
|
| 228 |
+
if module.weight.dtype == torch.float8_e4m3fn:
|
| 229 |
+
# Reattach scale as attribute
|
| 230 |
+
module.weight.scale = module.scale
|
| 231 |
+
linked += 1
|
| 232 |
+
|
| 233 |
+
if linked > 0:
|
| 234 |
+
print(f"✓ Linked scales for {linked} FP8 layers")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
def main(
|
| 238 |
ckpt_path: str,
|
| 239 |
config: str,
|
|
|
|
| 267 |
with open(config) as f:
|
| 268 |
args = ModelArgs(**json.load(f))
|
| 269 |
print(args)
|
| 270 |
+
|
| 271 |
+
# Optionally clear cache to free memory before loading large model
|
| 272 |
+
if rank == 0:
|
| 273 |
+
clear_system_cache()
|
| 274 |
+
|
| 275 |
print("Creating model on CPU (this may take a while)...")
|
| 276 |
with torch.device("cpu"):
|
| 277 |
model = Transformer(args)
|
|
|
|
| 279 |
print("Loading model weights...")
|
| 280 |
load_sharded_model(model, ckpt_path)
|
| 281 |
model.eval()
|
| 282 |
+
print("DeepSeek V3.2 NVFP4 - Ready")
|
| 283 |
|
| 284 |
if interactive:
|
| 285 |
messages = []
|