File size: 14,736 Bytes
415d30c f6e007b 415d30c fd1eaf9 415d30c 46809e5 f6e007b 46809e5 f6e007b 46809e5 415d30c 26741b2 415d30c f6e007b 415d30c 46809e5 415d30c 46809e5 415d30c 26741b2 415d30c f6e007b 415d30c f6e007b 415d30c f6e007b 415d30c 26741b2 415d30c f6e007b 415d30c 26741b2 f6e007b 415d30c 26741b2 415d30c fd1eaf9 415d30c fd1eaf9 415d30c fd1eaf9 415d30c fd1eaf9 415d30c fd1eaf9 415d30c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import os
import re
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_file
from model import Transformer, ModelArgs
from encoding_dsv32 import encode_messages, eos_token, thinking_end_token
def hf_to_deepseek_key(hf_key: str) -> str:
"""Convert HuggingFace checkpoint key to DeepSeek model key."""
key = hf_key
# Strip "model." prefix
if key.startswith("model."):
key = key[6:]
# Embedding
key = key.replace("embed_tokens.", "embed.")
# Final norm and head
key = key.replace("lm_head.", "head.")
# Attention projections
key = key.replace(".self_attn.", ".attn.")
key = key.replace(".q_a_proj.", ".wq_a.")
key = key.replace(".q_b_proj.", ".wq_b.")
key = key.replace(".q_a_layernorm.", ".q_norm.")
key = key.replace(".kv_a_proj_with_mqa.", ".wkv_a.")
key = key.replace(".kv_b_proj.", ".wkv_b.")
key = key.replace(".kv_a_layernorm.", ".kv_norm.")
key = key.replace(".o_proj.", ".wo.")
# Indexer attention
key = key.replace(".indexer.wk.", ".indexer.wk.")
key = key.replace(".indexer.wq_b.", ".indexer.wq_b.")
key = key.replace(".indexer.k_norm.", ".indexer.k_norm.")
key = key.replace(".indexer.weights_proj.", ".indexer.weights_proj.")
# Layer norms
key = key.replace(".input_layernorm.", ".attn_norm.")
key = key.replace(".post_attention_layernorm.", ".ffn_norm.")
# MLP (dense layers)
key = key.replace(".mlp.gate_proj.", ".ffn.w1.")
key = key.replace(".mlp.up_proj.", ".ffn.w3.")
key = key.replace(".mlp.down_proj.", ".ffn.w2.")
# MoE (uses "ffn" module name in model, not "moe")
key = key.replace(".mlp.shared_experts.gate_proj.", ".ffn.shared_experts.w1.")
key = key.replace(".mlp.shared_experts.up_proj.", ".ffn.shared_experts.w3.")
key = key.replace(".mlp.shared_experts.down_proj.", ".ffn.shared_experts.w2.")
key = key.replace(".mlp.experts.", ".ffn.experts.")
key = key.replace(".mlp.gate.weight", ".ffn.gate.weight")
key = key.replace(".mlp.gate.e_score_correction_bias", ".ffn.gate.bias")
# Expert weights
key = re.sub(r"\.ffn\.experts\.(\d+)\.gate_proj\.", r".ffn.experts.\1.w1.", key)
key = re.sub(r"\.ffn\.experts\.(\d+)\.up_proj\.", r".ffn.experts.\1.w3.", key)
key = re.sub(r"\.ffn\.experts\.(\d+)\.down_proj\.", r".ffn.experts.\1.w2.", key)
return key
def load_sharded_model(model, ckpt_path):
"""Load model weights from sharded safetensors files using index."""
index_path = os.path.join(ckpt_path, "model.safetensors.index.json")
if os.path.exists(index_path):
# Load from sharded format using index
with open(index_path) as f:
index = json.load(f)
weight_map = index["weight_map"]
# Get unique shard files
shard_files = sorted(set(weight_map.values()))
# Check memory before loading
try:
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total ({mem.percent:.1f}% used)")
except ImportError:
pass # psutil not required
print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
model_state = model.state_dict()
loaded_keys = set()
for i, shard_file in enumerate(shard_files):
shard_path = os.path.join(ckpt_path, shard_file)
print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
shard_dict = load_file(shard_path, device="cpu")
# Copy matching tensors to model (with key mapping)
matched = 0
for hf_key, tensor in shard_dict.items():
key = hf_to_deepseek_key(hf_key)
if key in model_state:
model_state[key].copy_(tensor)
loaded_keys.add(key)
matched += 1
print(f" ({matched} tensors)")
missing = set(model_state.keys()) - loaded_keys
if missing:
print(f"Warning: {len(missing)} missing keys in checkpoint")
for k in list(missing)[:5]:
print(f" - {k}")
# Reattach FP8 scales after loading
link_fp8_scales(model)
else:
# Fall back to single file
single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
print(f"Loading single file: {single_file}")
state_dict = load_file(single_file, device="cuda")
model.load_state_dict(state_dict, strict=False)
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
device = next(model.parameters()).device
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device=device)
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def clear_system_cache():
"""
Clear system cache to free memory (optional optimization).
This can help with large models by freeing cached memory.
Silently attempts cache clearing; failures are ignored.
"""
try:
import subprocess
subprocess.run(
['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'],
check=False, capture_output=True, text=True, timeout=5
)
except Exception:
# Silently ignore if cache clearing fails
pass
def link_fp8_scales(model):
"""
Link FP8 scales to weight tensors after loading.
After load_state_dict(), FP8 weights lose their .scale attribute.
This function reattaches them.
"""
from model import Linear, ColumnParallelLinear, RowParallelLinear
linked = 0
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnParallelLinear, RowParallelLinear)):
# Check if this is an FP8 layer
if hasattr(module, 'weight') and hasattr(module, 'scale'):
if module.weight is not None and module.scale is not None:
if module.weight.dtype == torch.float8_e4m3fn:
# Reattach scale as attribute
module.weight.scale = module.scale
linked += 1
if linked > 0:
print(f"✓ Linked scales for {linked} FP8 layers")
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(96) # Use all CPU threads
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
# Optionally clear cache to free memory before loading large model
if rank == 0:
clear_system_cache()
print("Creating model on CPU (this may take a while)...")
with torch.device("cpu"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("Loading model weights...")
load_sharded_model(model, ckpt_path)
model.eval()
print("DeepSeek V3.2 NVFP4 - Ready")
if interactive:
messages = []
# Get eos token id
eos_id = tokenizer.convert_tokens_to_ids(eos_token)
thinking_end_id = tokenizer.convert_tokens_to_ids(thinking_end_token)
print(f"EOS token: {eos_token!r} -> {eos_id}")
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
# Use DeepSeek V3.2 custom encoding (thinking_mode="chat" for no reasoning)
prompt_str = encode_messages(messages, thinking_mode="chat")
prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, eos_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
# Strip thinking end token if present
if completion.startswith(thinking_end_token):
completion = completion[len(thinking_end_token):]
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
eos_id = tokenizer.convert_tokens_to_ids(eos_token)
# Use DeepSeek V3.2 custom encoding
prompt_tokens = [
tokenizer.encode(
encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat"),
add_special_tokens=False
)
for prompt in prompts
]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, eos_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
# Strip thinking end token if present
if completion.startswith(thinking_end_token):
completion = completion[len(thinking_end_token):]
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|