eousphoros's picture
Upload inference/generate.py with huggingface_hub
26741b2 verified
import os
import re
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_file
from model import Transformer, ModelArgs
from encoding_dsv32 import encode_messages, eos_token, thinking_end_token
def hf_to_deepseek_key(hf_key: str) -> str:
"""Convert HuggingFace checkpoint key to DeepSeek model key."""
key = hf_key
# Strip "model." prefix
if key.startswith("model."):
key = key[6:]
# Embedding
key = key.replace("embed_tokens.", "embed.")
# Final norm and head
key = key.replace("lm_head.", "head.")
# Attention projections
key = key.replace(".self_attn.", ".attn.")
key = key.replace(".q_a_proj.", ".wq_a.")
key = key.replace(".q_b_proj.", ".wq_b.")
key = key.replace(".q_a_layernorm.", ".q_norm.")
key = key.replace(".kv_a_proj_with_mqa.", ".wkv_a.")
key = key.replace(".kv_b_proj.", ".wkv_b.")
key = key.replace(".kv_a_layernorm.", ".kv_norm.")
key = key.replace(".o_proj.", ".wo.")
# Indexer attention
key = key.replace(".indexer.wk.", ".indexer.wk.")
key = key.replace(".indexer.wq_b.", ".indexer.wq_b.")
key = key.replace(".indexer.k_norm.", ".indexer.k_norm.")
key = key.replace(".indexer.weights_proj.", ".indexer.weights_proj.")
# Layer norms
key = key.replace(".input_layernorm.", ".attn_norm.")
key = key.replace(".post_attention_layernorm.", ".ffn_norm.")
# MLP (dense layers)
key = key.replace(".mlp.gate_proj.", ".ffn.w1.")
key = key.replace(".mlp.up_proj.", ".ffn.w3.")
key = key.replace(".mlp.down_proj.", ".ffn.w2.")
# MoE (uses "ffn" module name in model, not "moe")
key = key.replace(".mlp.shared_experts.gate_proj.", ".ffn.shared_experts.w1.")
key = key.replace(".mlp.shared_experts.up_proj.", ".ffn.shared_experts.w3.")
key = key.replace(".mlp.shared_experts.down_proj.", ".ffn.shared_experts.w2.")
key = key.replace(".mlp.experts.", ".ffn.experts.")
key = key.replace(".mlp.gate.weight", ".ffn.gate.weight")
key = key.replace(".mlp.gate.e_score_correction_bias", ".ffn.gate.bias")
# Expert weights
key = re.sub(r"\.ffn\.experts\.(\d+)\.gate_proj\.", r".ffn.experts.\1.w1.", key)
key = re.sub(r"\.ffn\.experts\.(\d+)\.up_proj\.", r".ffn.experts.\1.w3.", key)
key = re.sub(r"\.ffn\.experts\.(\d+)\.down_proj\.", r".ffn.experts.\1.w2.", key)
return key
def load_sharded_model(model, ckpt_path):
"""Load model weights from sharded safetensors files using index."""
index_path = os.path.join(ckpt_path, "model.safetensors.index.json")
if os.path.exists(index_path):
# Load from sharded format using index
with open(index_path) as f:
index = json.load(f)
weight_map = index["weight_map"]
# Get unique shard files
shard_files = sorted(set(weight_map.values()))
# Check memory before loading
try:
import psutil
mem = psutil.virtual_memory()
print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total ({mem.percent:.1f}% used)")
except ImportError:
pass # psutil not required
print(f"Loading {len(shard_files)} shards (streaming to GPU)...")
model_state = model.state_dict()
loaded_keys = set()
for i, shard_file in enumerate(shard_files):
shard_path = os.path.join(ckpt_path, shard_file)
print(f" [{i+1}/{len(shard_files)}] {shard_file}", end="", flush=True)
shard_dict = load_file(shard_path, device="cpu")
# Copy matching tensors to model (with key mapping)
matched = 0
for hf_key, tensor in shard_dict.items():
key = hf_to_deepseek_key(hf_key)
if key in model_state:
model_state[key].copy_(tensor)
loaded_keys.add(key)
matched += 1
print(f" ({matched} tensors)")
missing = set(model_state.keys()) - loaded_keys
if missing:
print(f"Warning: {len(missing)} missing keys in checkpoint")
for k in list(missing)[:5]:
print(f" - {k}")
# Reattach FP8 scales after loading
link_fp8_scales(model)
else:
# Fall back to single file
single_file = os.path.join(ckpt_path, "model0-mp1.safetensors")
print(f"Loading single file: {single_file}")
state_dict = load_file(single_file, device="cuda")
model.load_state_dict(state_dict, strict=False)
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
device = next(model.parameters()).device
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device=device)
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def clear_system_cache():
"""
Clear system cache to free memory (optional optimization).
This can help with large models by freeing cached memory.
Silently attempts cache clearing; failures are ignored.
"""
try:
import subprocess
subprocess.run(
['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'],
check=False, capture_output=True, text=True, timeout=5
)
except Exception:
# Silently ignore if cache clearing fails
pass
def link_fp8_scales(model):
"""
Link FP8 scales to weight tensors after loading.
After load_state_dict(), FP8 weights lose their .scale attribute.
This function reattaches them.
"""
from model import Linear, ColumnParallelLinear, RowParallelLinear
linked = 0
for name, module in model.named_modules():
if isinstance(module, (Linear, ColumnParallelLinear, RowParallelLinear)):
# Check if this is an FP8 layer
if hasattr(module, 'weight') and hasattr(module, 'scale'):
if module.weight is not None and module.scale is not None:
if module.weight.dtype == torch.float8_e4m3fn:
# Reattach scale as attribute
module.weight.scale = module.scale
linked += 1
if linked > 0:
print(f"✓ Linked scales for {linked} FP8 layers")
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(96) # Use all CPU threads
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
# Optionally clear cache to free memory before loading large model
if rank == 0:
clear_system_cache()
print("Creating model on CPU (this may take a while)...")
with torch.device("cpu"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("Loading model weights...")
load_sharded_model(model, ckpt_path)
model.eval()
print("DeepSeek V3.2 NVFP4 - Ready")
if interactive:
messages = []
# Get eos token id
eos_id = tokenizer.convert_tokens_to_ids(eos_token)
thinking_end_id = tokenizer.convert_tokens_to_ids(thinking_end_token)
print(f"EOS token: {eos_token!r} -> {eos_id}")
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
# Use DeepSeek V3.2 custom encoding (thinking_mode="chat" for no reasoning)
prompt_str = encode_messages(messages, thinking_mode="chat")
prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, eos_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
# Strip thinking end token if present
if completion.startswith(thinking_end_token):
completion = completion[len(thinking_end_token):]
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
eos_id = tokenizer.convert_tokens_to_ids(eos_token)
# Use DeepSeek V3.2 custom encoding
prompt_tokens = [
tokenizer.encode(
encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat"),
add_special_tokens=False
)
for prompt in prompts
]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, eos_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
# Strip thinking end token if present
if completion.startswith(thinking_end_token):
completion = completion[len(thinking_end_token):]
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)