eousphoros commited on
Commit
26741b2
·
verified ·
1 Parent(s): bc09fc1

Upload inference/generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/generate.py +59 -1
inference/generate.py CHANGED
@@ -81,6 +81,14 @@ def load_sharded_model(model, ckpt_path):
81
  # Get unique shard files
82
  shard_files = sorted(set(weight_map.values()))
83
 
 
 
 
 
 
 
 
 
84
  print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
85
  model_state = model.state_dict()
86
  loaded_keys = set()
@@ -105,6 +113,9 @@ def load_sharded_model(model, ckpt_path):
105
  print(f"Warning: {len(missing)} missing keys in checkpoint")
106
  for k in list(missing)[:5]:
107
  print(f" - {k}")
 
 
 
108
  else:
109
  # Fall back to single file
110
  single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
@@ -181,6 +192,48 @@ def generate(
181
  return completion_tokens
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def main(
185
  ckpt_path: str,
186
  config: str,
@@ -214,6 +267,11 @@ def main(
214
  with open(config) as f:
215
  args = ModelArgs(**json.load(f))
216
  print(args)
 
 
 
 
 
217
  print("Creating model on CPU (this may take a while)...")
218
  with torch.device("cpu"):
219
  model = Transformer(args)
@@ -221,7 +279,7 @@ def main(
221
  print("Loading model weights...")
222
  load_sharded_model(model, ckpt_path)
223
  model.eval()
224
- print("I'm DeepSeek 👋")
225
 
226
  if interactive:
227
  messages = []
 
81
  # Get unique shard files
82
  shard_files = sorted(set(weight_map.values()))
83
 
84
+ # Check memory before loading
85
+ try:
86
+ import psutil
87
+ mem = psutil.virtual_memory()
88
+ print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total ({mem.percent:.1f}% used)")
89
+ except ImportError:
90
+ pass # psutil not required
91
+
92
  print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
93
  model_state = model.state_dict()
94
  loaded_keys = set()
 
113
  print(f"Warning: {len(missing)} missing keys in checkpoint")
114
  for k in list(missing)[:5]:
115
  print(f" - {k}")
116
+
117
+ # Reattach FP8 scales after loading
118
+ link_fp8_scales(model)
119
  else:
120
  # Fall back to single file
121
  single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
 
192
  return completion_tokens
193
 
194
 
195
+ def clear_system_cache():
196
+ """
197
+ Clear system cache to free memory (optional optimization).
198
+
199
+ This can help with large models by freeing cached memory.
200
+ Silently attempts cache clearing; failures are ignored.
201
+ """
202
+ try:
203
+ import subprocess
204
+ subprocess.run(
205
+ ['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'],
206
+ check=False, capture_output=True, text=True, timeout=5
207
+ )
208
+ except Exception:
209
+ # Silently ignore if cache clearing fails
210
+ pass
211
+
212
+
213
+ def link_fp8_scales(model):
214
+ """
215
+ Link FP8 scales to weight tensors after loading.
216
+
217
+ After load_state_dict(), FP8 weights lose their .scale attribute.
218
+ This function reattaches them.
219
+ """
220
+ from model import Linear, ColumnParallelLinear, RowParallelLinear
221
+
222
+ linked = 0
223
+ for name, module in model.named_modules():
224
+ if isinstance(module, (Linear, ColumnParallelLinear, RowParallelLinear)):
225
+ # Check if this is an FP8 layer
226
+ if hasattr(module, 'weight') and hasattr(module, 'scale'):
227
+ if module.weight is not None and module.scale is not None:
228
+ if module.weight.dtype == torch.float8_e4m3fn:
229
+ # Reattach scale as attribute
230
+ module.weight.scale = module.scale
231
+ linked += 1
232
+
233
+ if linked > 0:
234
+ print(f"✓ Linked scales for {linked} FP8 layers")
235
+
236
+
237
  def main(
238
  ckpt_path: str,
239
  config: str,
 
267
  with open(config) as f:
268
  args = ModelArgs(**json.load(f))
269
  print(args)
270
+
271
+ # Optionally clear cache to free memory before loading large model
272
+ if rank == 0:
273
+ clear_system_cache()
274
+
275
  print("Creating model on CPU (this may take a while)...")
276
  with torch.device("cpu"):
277
  model = Transformer(args)
 
279
  print("Loading model weights...")
280
  load_sharded_model(model, ckpt_path)
281
  model.eval()
282
+ print("DeepSeek V3.2 NVFP4 - Ready")
283
 
284
  if interactive:
285
  messages = []