|
|
|
|
|
""" |
|
|
Integration test for NVFP4 model loading. |
|
|
|
|
|
This tests that the model can be loaded from sharded safetensors |
|
|
and that all weights have correct shapes and flags. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import torch |
|
|
|
|
|
|
|
|
from model import Transformer, ModelArgs |
|
|
from generate import load_sharded_model |
|
|
|
|
|
|
|
|
def clear_cache(): |
|
|
"""Clear system cache to free memory.""" |
|
|
print("Clearing system cache...") |
|
|
try: |
|
|
import subprocess |
|
|
subprocess.run( |
|
|
['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'], |
|
|
check=False, capture_output=True, text=True |
|
|
) |
|
|
print(" PASS: Cache cleared\n") |
|
|
except Exception as e: |
|
|
print(f" WARN: Could not clear cache: {e}\n") |
|
|
|
|
|
|
|
|
def check_memory(): |
|
|
"""Check available memory.""" |
|
|
try: |
|
|
import psutil |
|
|
mem = psutil.virtual_memory() |
|
|
print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total") |
|
|
print(f" {mem.percent:.1f}% used\n") |
|
|
return mem.available / 1e9 |
|
|
except ImportError: |
|
|
print("psutil not available, skipping memory check\n") |
|
|
return None |
|
|
|
|
|
|
|
|
def test_config_loading(): |
|
|
"""Test 1: Load and validate config.""" |
|
|
print("=" * 70) |
|
|
print("Test 1: Load Model Config") |
|
|
print("=" * 70) |
|
|
|
|
|
config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" |
|
|
|
|
|
print(f" Loading config from: {config_path}") |
|
|
with open(config_path) as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
args = ModelArgs(**config_dict) |
|
|
|
|
|
print(f" Model parameters:") |
|
|
print(f" - vocab_size: {args.vocab_size:,}") |
|
|
print(f" - dim: {args.dim}") |
|
|
print(f" - n_layers: {args.n_layers}") |
|
|
print(f" - n_routed_experts: {args.n_routed_experts}") |
|
|
print(f" - dtype: {args.dtype}") |
|
|
|
|
|
assert args.dtype == "nvfp4", f"Expected dtype='nvfp4', got '{args.dtype}'" |
|
|
assert args.n_layers == 61, f"Expected 61 layers, got {args.n_layers}" |
|
|
|
|
|
print(f" PASS: Config loaded successfully") |
|
|
print(f" PASS: Test 1 PASSED\n") |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def test_model_creation(args): |
|
|
"""Test 2: Create model instance.""" |
|
|
print("=" * 70) |
|
|
print("Test 2: Create Model Instance") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f" Creating Transformer model with dtype={args.dtype}...") |
|
|
print(f" (This may take 1-2 minutes)") |
|
|
|
|
|
torch.set_default_dtype(torch.bfloat16) |
|
|
|
|
|
with torch.device("cpu"): |
|
|
model = Transformer(args) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
total_buffers = sum(b.numel() for b in model.buffers()) |
|
|
|
|
|
print(f" Model created:") |
|
|
print(f" - Parameters: {total_params / 1e9:.2f}B") |
|
|
print(f" - Buffers: {total_buffers / 1e9:.2f}B") |
|
|
print(f" - Total: {(total_params + total_buffers) / 1e9:.2f}B elements") |
|
|
|
|
|
|
|
|
assert hasattr(model, 'embed'), "Model missing embed layer" |
|
|
assert hasattr(model, 'layers'), "Model missing layers" |
|
|
assert len(model.layers) == args.n_layers, f"Expected {args.n_layers} layers, got {len(model.layers)}" |
|
|
|
|
|
print(f" PASS: Model structure correct") |
|
|
print(f" PASS: Test 2 PASSED\n") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def test_weight_loading(model): |
|
|
"""Test 3: Load weights from sharded checkpoint.""" |
|
|
print("=" * 70) |
|
|
print("Test 3: Load Weights from Checkpoint") |
|
|
print("=" * 70) |
|
|
|
|
|
ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" |
|
|
|
|
|
print(f" Loading from: {ckpt_path}") |
|
|
print(f" (This will take 5-15 minutes for the full model)") |
|
|
print(f" Progress will be shown shard-by-shard...\n") |
|
|
|
|
|
load_sharded_model(model, ckpt_path) |
|
|
|
|
|
print(f"\n PASS: Weights loaded successfully") |
|
|
print(f" PASS: Test 3 PASSED\n") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def test_nvfp4_layers(model): |
|
|
"""Test 4: Verify NVFP4 layers have correct structure.""" |
|
|
print("=" * 70) |
|
|
print("Test 4: Verify NVFP4 Layer Structure") |
|
|
print("=" * 70) |
|
|
|
|
|
nvfp4_layers = [] |
|
|
total_layers = 0 |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
|
|
|
if hasattr(module, '_nvfp4_mode') and hasattr(module, 'weight'): |
|
|
total_layers += 1 |
|
|
if getattr(module, '_nvfp4_mode', False): |
|
|
nvfp4_layers.append((name, module)) |
|
|
|
|
|
print(f" Found {len(nvfp4_layers)} NVFP4 layers out of {total_layers} total linear layers") |
|
|
|
|
|
if len(nvfp4_layers) == 0: |
|
|
print(f" WARN: WARNING: No NVFP4 layers found!") |
|
|
print(f" This might indicate dtype configuration issue") |
|
|
return |
|
|
|
|
|
|
|
|
print(f"\n Inspecting first 5 NVFP4 layers:") |
|
|
for i, (name, module) in enumerate(nvfp4_layers[:5]): |
|
|
weight = module.weight |
|
|
weight_scale = module.weight_scale if hasattr(module, 'weight_scale') else None |
|
|
weight_scale_2 = module.weight_scale_2 if hasattr(module, 'weight_scale_2') else None |
|
|
|
|
|
print(f"\n [{i+1}] {name}:") |
|
|
print(f" weight: {weight.shape}, dtype={weight.dtype}") |
|
|
|
|
|
|
|
|
N, K_half = weight.shape |
|
|
K = K_half * 2 |
|
|
|
|
|
if weight_scale is not None: |
|
|
print(f" weight_scale: {weight_scale.shape}, dtype={weight_scale.dtype}") |
|
|
expected_scale_shape = (N, K // 16) |
|
|
if weight_scale.shape != expected_scale_shape: |
|
|
print(f" WARN: WARNING: Expected scale shape {expected_scale_shape}, got {weight_scale.shape}") |
|
|
else: |
|
|
print(f" PASS: Scale shape correct") |
|
|
else: |
|
|
print(f" WARN: WARNING: weight_scale not found!") |
|
|
|
|
|
if weight_scale_2 is not None: |
|
|
print(f" weight_scale_2: {weight_scale_2.shape}, dtype={weight_scale_2.dtype}, value={weight_scale_2.item():.6e}") |
|
|
if weight_scale_2.shape != torch.Size([1]): |
|
|
print(f" WARN: WARNING: Expected scale_2 shape [1], got {weight_scale_2.shape}") |
|
|
else: |
|
|
print(f" PASS: Scale_2 shape correct") |
|
|
else: |
|
|
print(f" WARN: WARNING: weight_scale_2 not found!") |
|
|
|
|
|
|
|
|
assert weight.dtype == torch.uint8, f"Weight should be uint8, got {weight.dtype}" |
|
|
|
|
|
print(f"\n PASS: NVFP4 layers have correct structure") |
|
|
print(f" PASS: Test 4 PASSED\n") |
|
|
|
|
|
|
|
|
def test_weight_statistics(model): |
|
|
"""Test 5: Check weight statistics to verify they're not zeros or corrupted.""" |
|
|
print("=" * 70) |
|
|
print("Test 5: Weight Statistics") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
nvfp4_count = 0 |
|
|
zero_count = 0 |
|
|
checked = 0 |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, '_nvfp4_mode') and getattr(module, '_nvfp4_mode', False): |
|
|
nvfp4_count += 1 |
|
|
|
|
|
|
|
|
if checked < 10: |
|
|
weight = module.weight |
|
|
weight_scale = module.weight_scale if hasattr(module, 'weight_scale') else None |
|
|
weight_scale_2 = module.weight_scale_2 if hasattr(module, 'weight_scale_2') else None |
|
|
|
|
|
|
|
|
num_zeros = (weight == 0).sum().item() |
|
|
total_elems = weight.numel() |
|
|
zero_percent = 100.0 * num_zeros / total_elems |
|
|
|
|
|
if checked == 0: |
|
|
print(f"\n Sample layer: {name}") |
|
|
print(f" Weight zeros: {zero_percent:.1f}%") |
|
|
if weight_scale is not None: |
|
|
scale_min = weight_scale.to(torch.float32).min().item() |
|
|
scale_max = weight_scale.to(torch.float32).max().item() |
|
|
print(f" Scale range: [{scale_min:.6e}, {scale_max:.6e}]") |
|
|
if weight_scale_2 is not None: |
|
|
print(f" Scale_2: {weight_scale_2.item():.6e}") |
|
|
|
|
|
|
|
|
if zero_percent > 95: |
|
|
zero_count += 1 |
|
|
print(f" WARN: WARNING: {name} has {zero_percent:.1f}% zeros (possibly corrupted)") |
|
|
|
|
|
checked += 1 |
|
|
|
|
|
print(f"\n Checked {checked} NVFP4 layers:") |
|
|
print(f" - Total NVFP4 layers: {nvfp4_count}") |
|
|
print(f" - Layers with >95% zeros: {zero_count}") |
|
|
|
|
|
if zero_count > checked // 2: |
|
|
print(f" WARN: WARNING: Many layers appear corrupted (too many zeros)") |
|
|
else: |
|
|
print(f" PASS: Weight statistics look reasonable") |
|
|
|
|
|
print(f" PASS: Test 5 PASSED\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run all model loading tests.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("NVFP4 Model Loading Integration Test") |
|
|
print("=" * 70) |
|
|
print("This test will load the full 671B parameter model") |
|
|
print("Expected runtime: 5-20 minutes") |
|
|
print("Memory required: ~400GB") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
|
|
|
available_gb = check_memory() |
|
|
if available_gb is not None and available_gb < 350: |
|
|
print(f"WARN: WARNING: Only {available_gb:.1f}GB available") |
|
|
print(f" Model may not fit in memory. Consider clearing cache.") |
|
|
user_input = input(" Continue anyway? (y/n): ") |
|
|
if user_input.lower() != 'y': |
|
|
print(" Aborted by user") |
|
|
return 1 |
|
|
|
|
|
|
|
|
user_input = input("Clear system cache before loading? (recommended) (y/n): ") |
|
|
if user_input.lower() == 'y': |
|
|
clear_cache() |
|
|
check_memory() |
|
|
|
|
|
try: |
|
|
|
|
|
args = test_config_loading() |
|
|
model = test_model_creation(args) |
|
|
model = test_weight_loading(model) |
|
|
test_nvfp4_layers(model) |
|
|
test_weight_statistics(model) |
|
|
|
|
|
|
|
|
print("=" * 70) |
|
|
print("PASS: ALL TESTS PASSED") |
|
|
print("=" * 70) |
|
|
print("Model loaded successfully with correct NVFP4 structure!") |
|
|
print("Ready for forward pass testing.") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print("\nKeeping model in memory for next test...") |
|
|
print("Run test_forward_pass.py in the same Python session to reuse loaded model") |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nFAIL: TEST FAILED: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|