Petite-LLM-3 / test_float16_compatibility.py
Tonic's picture
adds flash attention
109031b
raw
history blame
3.42 kB
#!/usr/bin/env python3
"""
Test script for float16 compatibility with pre-quantized model
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_float16_compatibility():
"""Test float16 compatibility with pre-quantized model"""
model_id = "Tonic/petite-elle-L-aime-3-sft"
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Testing float16 compatibility on device: {device}")
# Test both float32 and float16
dtypes_to_test = []
if device == "cuda":
dtypes_to_test = [torch.float32, torch.float16]
else:
dtypes_to_test = [torch.float32] # Only test float32 on CPU
for dtype in dtypes_to_test:
logger.info(f"\nTesting with dtype: {dtype}")
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load model with specific dtype
model_kwargs = {
"device_map": "auto" if device == "cuda" else "cpu",
"torch_dtype": dtype,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
}
logger.info(f"Loading model with {dtype}...")
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
# Test generation
test_prompt = "Bonjour, comment allez-vous?"
inputs = tokenizer(test_prompt, return_tensors="pt")
if device == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
logger.info("Generating response...")
with torch.no_grad():
output_ids = model.generate(
inputs['input_ids'],
max_new_tokens=50,
temperature=0.7,
top_p=0.95,
do_sample=True,
attention_mask=inputs['attention_mask'],
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
cache_implementation="static"
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
assistant_response = response[len(test_prompt):].strip()
logger.info(f"βœ… {dtype} test successful!")
logger.info(f"Input: {test_prompt}")
logger.info(f"Output: {assistant_response}")
# Check memory usage
if device == "cuda":
memory_used = torch.cuda.memory_allocated() / 1024**3
logger.info(f"GPU Memory used: {memory_used:.2f} GB")
# Check model dtype
logger.info(f"Model dtype: {model.dtype}")
# Clean up
del model
torch.cuda.empty_cache() if device == "cuda" else None
except Exception as e:
logger.error(f"❌ {dtype} test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_float16_compatibility()