sonar_core_1 / use_this_model.py
Vu Anh
Clean up training runs and enhance model export functionality
742fa4d
#!/usr/bin/env python3
"""
Demonstration script for using Sonar Core 1 models from Hugging Face Hub.
Shows how to download and use both VNTC and UTS2017_Bank pre-trained models.
"""
from huggingface_hub import hf_hub_download
import joblib
def predict_text(model, text):
"""Make prediction on a single text (consistent with inference.py)"""
try:
probabilities = model.predict_proba([text])[0]
# Get top 3 predictions sorted by probability
top_indices = probabilities.argsort()[-3:][::-1]
top_predictions = []
for idx in top_indices:
category = model.classes_[idx]
prob = probabilities[idx]
top_predictions.append((category, prob))
# The prediction should be the top category
prediction = top_predictions[0][0]
confidence = top_predictions[0][1]
return prediction, confidence, top_predictions
except Exception as e:
print(f"Error making prediction: {e}")
return None, 0, []
def load_model_from_hub(model_type="vntc"):
"""Load the pre-trained model from Hugging Face Hub
Args:
model_type: "vntc" for news classification or "uts2017_bank" for banking text
"""
if model_type == "vntc":
filename = "vntc_classifier_20250927_161550.joblib"
print("Downloading VNTC (Vietnamese News) model from Hugging Face Hub...")
elif model_type == "uts2017_bank":
filename = "uts2017_bank_classifier_20250927_161733.joblib"
print("Downloading UTS2017_Bank (Vietnamese Banking) model from Hugging Face Hub...")
else:
raise ValueError("model_type must be 'vntc' or 'uts2017_bank'")
model_path = hf_hub_download("undertheseanlp/sonar_core_1", filename)
print(f"Model downloaded to: {model_path}")
print("Loading model...")
model = joblib.load(model_path)
return model
def predict_vntc_examples(model):
"""Demonstrate predictions on VNTC (news) examples"""
print("\n" + "="*60)
print("VIETNAMESE NEWS CLASSIFICATION EXAMPLES (VNTC)")
print("="*60)
# Vietnamese news examples for different categories
examples = [
("Chính trị & Xã hội", "Chính phủ đã thông qua nghị định mới về chính sách xã hội"),
("Đời sống", "Xu hướng ăn uống lành mạnh đang được nhiều người quan tâm"),
("Khoa học", "Các nhà khoa học đã phát hiện ra loại vi khuẩn mới"),
("Kinh doanh", "Thị trường chứng khoán có nhiều biến động trong tuần qua"),
("Pháp luật", "Luật an toàn giao thông sẽ có hiệu lực từ tháng sau"),
("Sức khỏe", "Tiêm vaccine phòng chống COVID-19 đã đạt tỷ lệ cao"),
("Thế giới", "Hội nghị thượng đỉnh quốc tế sẽ diễn ra tại Geneva"),
("Thể thao", "Đội tuyển bóng đá Việt Nam giành chiến thắng 2-0"),
("Văn hóa", "Lễ hội truyền thống sẽ được tổ chức vào cuối tuần"),
("Vi tính", "Công nghệ trí tuệ nhân tạo đang phát triển mạnh mẽ")
]
print("Testing Vietnamese news classification:")
print("-" * 60)
for expected_category, text in examples:
try:
prediction, confidence, top_predictions = predict_text(model, text)
if prediction:
print(f"Text: {text}")
print(f"Expected: {expected_category}")
print(f"Predicted: {prediction}")
print(f"Confidence: {confidence:.3f}")
# Show top 3 predictions
print("Top 3 predictions:")
for i, (category, prob) in enumerate(top_predictions, 1):
print(f" {i}. {category}: {prob:.3f}")
print("-" * 60)
except Exception as e:
print(f"Error predicting '{text[:30]}...': {e}")
print("-" * 60)
def predict_uts2017_examples(model):
"""Demonstrate predictions on UTS2017_Bank examples"""
print("\n" + "="*60)
print("VIETNAMESE BANKING TEXT CLASSIFICATION EXAMPLES (UTS2017_Bank)")
print("="*60)
# Vietnamese banking examples for different categories
examples = [
("ACCOUNT", "Tôi muốn mở tài khoản tiết kiệm mới"),
("CARD", "Thẻ tín dụng của tôi bị khóa, làm sao để mở lại?"),
("CUSTOMER_SUPPORT", "Tôi cần hỗ trợ về dịch vụ ngân hàng"),
("DISCOUNT", "Có chương trình giảm giá nào cho khách hàng không?"),
("INTEREST_RATE", "Lãi suất tiết kiệm hiện tại là bao nhiều?"),
("INTERNET_BANKING", "Làm thế nào để đăng ký internet banking?"),
("LOAN", "Tôi muốn vay mua nhà với lãi suất ưu đãi"),
("MONEY_TRANSFER", "Chi phí chuyển tiền ra nước ngoài là bao nhiều?"),
("OTHER", "Tôi có câu hỏi về dịch vụ khác"),
("PAYMENT", "Thanh toán hóa đơn điện nước qua ngân hàng"),
("PROMOTION", "Khuyến mãi tháng này có gì hấp dẫn?"),
("SAVING", "Gói tiết kiệm nào có lãi suất cao nhất?"),
("SECURITY", "Bảo mật tài khoản ngân hàng như thế nào?"),
("TRADEMARK", "Ngân hàng ACB có uy tín không?")
]
print("Testing Vietnamese banking text classification:")
print("-" * 60)
for expected_category, text in examples:
try:
prediction, confidence, top_predictions = predict_text(model, text)
if prediction:
print(f"Text: {text}")
print(f"Expected: {expected_category}")
print(f"Predicted: {prediction}")
print(f"Confidence: {confidence:.3f}")
# Show top 3 predictions
print("Top 3 predictions:")
for i, (category, prob) in enumerate(top_predictions, 1):
print(f" {i}. {category}: {prob:.3f}")
print("-" * 60)
except Exception as e:
print(f"Error predicting '{text}': {e}")
print("-" * 60)
def interactive_mode(model, model_type):
"""Interactive mode for testing custom text"""
dataset_name = "VNTC (News)" if model_type == "vntc" else "UTS2017_Bank (Banking)"
print("\n" + "="*60)
print(f"INTERACTIVE MODE - {dataset_name.upper()} CLASSIFICATION")
print("="*60)
print("Enter Vietnamese text to classify (type 'quit' to exit):")
while True:
try:
user_input = input("\nText: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
if not user_input:
continue
prediction, confidence, top_predictions = predict_text(model, user_input)
if prediction:
print(f"Predicted category: {prediction}")
print(f"Confidence: {confidence:.3f}")
# Show top 3 predictions
print("Top 3 predictions:")
for i, (category, prob) in enumerate(top_predictions, 1):
print(f" {i}. {category}: {prob:.3f}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"Error: {e}")
def simple_usage_examples():
"""Show simple usage examples for HuggingFace Hub models"""
print("\n" + "="*60)
print("HUGGINGFACE HUB USAGE EXAMPLES")
print("="*60)
print("Code examples:")
print("""
# VNTC Model (Vietnamese News Classification)
from huggingface_hub import hf_hub_download
import joblib
# Download and load VNTC model from HuggingFace Hub
vntc_model = joblib.load(
hf_hub_download("undertheseanlp/sonar_core_1", "vntc_classifier_20250927_161550.joblib")
)
# Make prediction on news text
news_text = "Đội tuyển bóng đá Việt Nam giành chiến thắng"
prediction = vntc_model.predict([news_text])[0]
print(f"News category: {prediction}")
# UTS2017_Bank Model (Vietnamese Banking Text Classification)
# Download and load UTS2017_Bank model from HuggingFace Hub
bank_model = joblib.load(
hf_hub_download("undertheseanlp/sonar_core_1", "uts2017_bank_classifier_20250927_161733.joblib")
)
# Make prediction on banking text
bank_text = "Tôi muốn mở tài khoản tiết kiệm"
prediction = bank_model.predict([bank_text])[0]
print(f"Banking category: {prediction}")
# For local file inference, use inference.py instead
""")
def main():
"""Main demonstration function"""
print("Sonar Core 1 - Dual Model Hugging Face Hub Usage")
print("=" * 60)
try:
# Show simple usage examples
simple_usage_examples()
# Test VNTC model
print("\n" + "="*60)
print("TESTING VNTC MODEL (Vietnamese News Classification)")
print("="*60)
vntc_model = load_model_from_hub("vntc")
predict_vntc_examples(vntc_model)
# Test UTS2017_Bank model
print("\n" + "="*60)
print("TESTING UTS2017_BANK MODEL (Vietnamese Banking Text Classification)")
print("="*60)
bank_model = load_model_from_hub("uts2017_bank")
predict_uts2017_examples(bank_model)
# Check if we're in an interactive environment
try:
import sys
if hasattr(sys, 'ps1') or sys.stdin.isatty():
print("\nAvailable interactive modes:")
print("1. VNTC (News) classification")
print("2. UTS2017_Bank (Banking) classification")
choice = input("\nSelect model for interactive mode (1/2) or 'n' to skip: ").strip()
if choice == "1":
interactive_mode(vntc_model, "vntc")
elif choice == "2":
interactive_mode(bank_model, "uts2017_bank")
except (EOFError, OSError):
print("\nInteractive mode not available in this environment.")
print("Run this script in a regular terminal to use interactive mode.")
print("\nDemonstration complete!")
print("\nBoth models are now available on Hugging Face Hub:")
print("- VNTC (News): vntc_classifier_20250927_161550.joblib")
print("- UTS2017_Bank (Banking): uts2017_bank_classifier_20250927_161733.joblib")
except ImportError:
print("Error: huggingface_hub is required. Install with:")
print(" pip install huggingface_hub")
except Exception as e:
print(f"Error loading model: {e}")
print("\nMake sure you have internet connection and try again.")
if __name__ == "__main__":
main()