#!/usr/bin/env python3 """ Test NVFP4 model token generation. This tests autoregressive token generation (5 tokens only for speed). """ import sys import json import torch from transformers import AutoTokenizer from model import Transformer, ModelArgs from generate import load_sharded_model, link_fp8_scales from encoding_dsv32 import encode_messages, eos_token def test_minimal_generation(): """Test generating 5 tokens autoregressively.""" print("\n" + "=" * 70) print("NVFP4 Minimal Generation Test") print("=" * 70) print("Testing autoregressive generation (5 tokens)") print("Expected runtime: 1-10 minutes") print("=" * 70 + "\n") # Load config print("Loading config...") config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" with open(config_path) as f: args = ModelArgs(**json.load(f)) print(f" PASS: Config loaded\n") # Create model print("Creating model...") torch.set_default_dtype(torch.bfloat16) with torch.device("cpu"): model = Transformer(args) print(f" PASS: Model created\n") # Load weights print("Loading weights...") ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" load_sharded_model(model, ckpt_path) print(f" PASS: Weights loaded\n") # Load tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(ckpt_path) print(f" PASS: Tokenizer loaded (vocab size: {len(tokenizer)})\n") # Prepare prompt using DeepSeek V3.2 message encoding user_message = "Hello" print(f"User message: '{user_message}'") messages = [{"role": "user", "content": user_message}] # Use proper DeepSeek V3.2 encoding (thinking_mode="chat" for no reasoning) prompt_str = encode_messages(messages, thinking_mode="chat") print(f"Encoded prompt string (first 100 chars): '{prompt_str[:100]}...'") prompt_tokens = tokenizer.encode(prompt_str, add_special_tokens=False) print(f"Encoded tokens: {len(prompt_tokens)} tokens") print(f"First 10 tokens: {prompt_tokens[:10]}\n") tokens = torch.tensor([prompt_tokens], dtype=torch.long, device="cpu") # Get EOS token ID eos_id = tokenizer.convert_tokens_to_ids(eos_token) print(f"EOS token: '{eos_token}' -> ID {eos_id}\n") # Generate tokens max_new_tokens = 5 print(f"Generating {max_new_tokens} tokens...") print("-" * 70) generated_tokens = [] prev_pos = 0 try: with torch.inference_mode(): for step in range(max_new_tokens): print(f"\nStep {step+1}/{max_new_tokens}:") # Forward pass logits = model(tokens[:, prev_pos:], start_pos=prev_pos) # Sample next token (argmax for deterministic output) next_token = logits.argmax(dim=-1) next_token_id = next_token.item() generated_tokens.append(next_token_id) # Decode token decoded = tokenizer.decode([next_token_id]) print(f" Generated token {next_token_id}: '{decoded}'") # Check for EOS if next_token_id == eos_id: print(f" PASS: Reached EOS token, stopping generation") break # Check for issues if torch.isnan(logits).any(): print(f" FAIL: ERROR: NaN in logits at step {step+1}") return 1 if torch.isinf(logits).any(): print(f" FAIL: ERROR: Inf in logits at step {step+1}") return 1 # Append to sequence tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1) prev_pos = tokens.shape[1] - 1 # Show current full text (decode only the new tokens after prompt) generated_text = tokenizer.decode(generated_tokens) print(f" Generated so far: '{generated_text}'") print("\n" + "-" * 70) # Final output full_text = tokenizer.decode(tokens[0].tolist()) generated_text = tokenizer.decode(generated_tokens) print(f"\nPASS: Generation completed successfully!") print(f"\nResults:") print(f" User message: '{user_message}'") print(f" Generated: '{generated_text}'") print(f" Full text: '{full_text}'") print(f" Generated tokens: {generated_tokens}") # Basic sanity check if len(generated_tokens) != max_new_tokens: print(f"\nWARN: WARNING: Expected {max_new_tokens} tokens, got {len(generated_tokens)}") print("\n" + "=" * 70) print("PASS: GENERATION TEST PASSED") print("=" * 70) print("Token generation working correctly!") print("Ready for full interactive inference.") print("=" * 70) return 0 except Exception as e: print(f"\nFAIL: GENERATION FAILED: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(test_minimal_generation())