|
|
|
|
|
""" |
|
|
Demonstration script for using Sonar Core 1 models from Hugging Face Hub. |
|
|
Shows how to download and use both VNTC and UTS2017_Bank pre-trained models. |
|
|
""" |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import joblib |
|
|
|
|
|
|
|
|
def predict_text(model, text): |
|
|
"""Make prediction on a single text (consistent with inference.py)""" |
|
|
try: |
|
|
probabilities = model.predict_proba([text])[0] |
|
|
|
|
|
|
|
|
top_indices = probabilities.argsort()[-3:][::-1] |
|
|
top_predictions = [] |
|
|
for idx in top_indices: |
|
|
category = model.classes_[idx] |
|
|
prob = probabilities[idx] |
|
|
top_predictions.append((category, prob)) |
|
|
|
|
|
|
|
|
prediction = top_predictions[0][0] |
|
|
confidence = top_predictions[0][1] |
|
|
|
|
|
return prediction, confidence, top_predictions |
|
|
except Exception as e: |
|
|
print(f"Error making prediction: {e}") |
|
|
return None, 0, [] |
|
|
|
|
|
|
|
|
def load_model_from_hub(model_type="vntc"): |
|
|
"""Load the pre-trained model from Hugging Face Hub |
|
|
|
|
|
Args: |
|
|
model_type: "vntc" for news classification or "uts2017_bank" for banking text |
|
|
""" |
|
|
if model_type == "vntc": |
|
|
filename = "vntc_classifier_20250927_161550.joblib" |
|
|
print("Downloading VNTC (Vietnamese News) model from Hugging Face Hub...") |
|
|
elif model_type == "uts2017_bank": |
|
|
filename = "uts2017_bank_classifier_20250927_161733.joblib" |
|
|
print("Downloading UTS2017_Bank (Vietnamese Banking) model from Hugging Face Hub...") |
|
|
else: |
|
|
raise ValueError("model_type must be 'vntc' or 'uts2017_bank'") |
|
|
|
|
|
model_path = hf_hub_download("undertheseanlp/sonar_core_1", filename) |
|
|
print(f"Model downloaded to: {model_path}") |
|
|
|
|
|
print("Loading model...") |
|
|
model = joblib.load(model_path) |
|
|
return model |
|
|
|
|
|
|
|
|
def predict_vntc_examples(model): |
|
|
"""Demonstrate predictions on VNTC (news) examples""" |
|
|
print("\n" + "="*60) |
|
|
print("VIETNAMESE NEWS CLASSIFICATION EXAMPLES (VNTC)") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
examples = [ |
|
|
("Chính trị & Xã hội", "Chính phủ đã thông qua nghị định mới về chính sách xã hội"), |
|
|
("Đời sống", "Xu hướng ăn uống lành mạnh đang được nhiều người quan tâm"), |
|
|
("Khoa học", "Các nhà khoa học đã phát hiện ra loại vi khuẩn mới"), |
|
|
("Kinh doanh", "Thị trường chứng khoán có nhiều biến động trong tuần qua"), |
|
|
("Pháp luật", "Luật an toàn giao thông sẽ có hiệu lực từ tháng sau"), |
|
|
("Sức khỏe", "Tiêm vaccine phòng chống COVID-19 đã đạt tỷ lệ cao"), |
|
|
("Thế giới", "Hội nghị thượng đỉnh quốc tế sẽ diễn ra tại Geneva"), |
|
|
("Thể thao", "Đội tuyển bóng đá Việt Nam giành chiến thắng 2-0"), |
|
|
("Văn hóa", "Lễ hội truyền thống sẽ được tổ chức vào cuối tuần"), |
|
|
("Vi tính", "Công nghệ trí tuệ nhân tạo đang phát triển mạnh mẽ") |
|
|
] |
|
|
|
|
|
print("Testing Vietnamese news classification:") |
|
|
print("-" * 60) |
|
|
|
|
|
for expected_category, text in examples: |
|
|
try: |
|
|
prediction, confidence, top_predictions = predict_text(model, text) |
|
|
|
|
|
if prediction: |
|
|
print(f"Text: {text}") |
|
|
print(f"Expected: {expected_category}") |
|
|
print(f"Predicted: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
|
|
|
|
|
|
print("Top 3 predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions, 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
|
|
|
print("-" * 60) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error predicting '{text[:30]}...': {e}") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
def predict_uts2017_examples(model): |
|
|
"""Demonstrate predictions on UTS2017_Bank examples""" |
|
|
print("\n" + "="*60) |
|
|
print("VIETNAMESE BANKING TEXT CLASSIFICATION EXAMPLES (UTS2017_Bank)") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
examples = [ |
|
|
("ACCOUNT", "Tôi muốn mở tài khoản tiết kiệm mới"), |
|
|
("CARD", "Thẻ tín dụng của tôi bị khóa, làm sao để mở lại?"), |
|
|
("CUSTOMER_SUPPORT", "Tôi cần hỗ trợ về dịch vụ ngân hàng"), |
|
|
("DISCOUNT", "Có chương trình giảm giá nào cho khách hàng không?"), |
|
|
("INTEREST_RATE", "Lãi suất tiết kiệm hiện tại là bao nhiều?"), |
|
|
("INTERNET_BANKING", "Làm thế nào để đăng ký internet banking?"), |
|
|
("LOAN", "Tôi muốn vay mua nhà với lãi suất ưu đãi"), |
|
|
("MONEY_TRANSFER", "Chi phí chuyển tiền ra nước ngoài là bao nhiều?"), |
|
|
("OTHER", "Tôi có câu hỏi về dịch vụ khác"), |
|
|
("PAYMENT", "Thanh toán hóa đơn điện nước qua ngân hàng"), |
|
|
("PROMOTION", "Khuyến mãi tháng này có gì hấp dẫn?"), |
|
|
("SAVING", "Gói tiết kiệm nào có lãi suất cao nhất?"), |
|
|
("SECURITY", "Bảo mật tài khoản ngân hàng như thế nào?"), |
|
|
("TRADEMARK", "Ngân hàng ACB có uy tín không?") |
|
|
] |
|
|
|
|
|
print("Testing Vietnamese banking text classification:") |
|
|
print("-" * 60) |
|
|
|
|
|
for expected_category, text in examples: |
|
|
try: |
|
|
prediction, confidence, top_predictions = predict_text(model, text) |
|
|
|
|
|
if prediction: |
|
|
print(f"Text: {text}") |
|
|
print(f"Expected: {expected_category}") |
|
|
print(f"Predicted: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
|
|
|
|
|
|
print("Top 3 predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions, 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
|
|
|
print("-" * 60) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error predicting '{text}': {e}") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
def interactive_mode(model, model_type): |
|
|
"""Interactive mode for testing custom text""" |
|
|
dataset_name = "VNTC (News)" if model_type == "vntc" else "UTS2017_Bank (Banking)" |
|
|
print("\n" + "="*60) |
|
|
print(f"INTERACTIVE MODE - {dataset_name.upper()} CLASSIFICATION") |
|
|
print("="*60) |
|
|
print("Enter Vietnamese text to classify (type 'quit' to exit):") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("\nText: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
prediction, confidence, top_predictions = predict_text(model, user_input) |
|
|
|
|
|
if prediction: |
|
|
print(f"Predicted category: {prediction}") |
|
|
print(f"Confidence: {confidence:.3f}") |
|
|
|
|
|
|
|
|
print("Top 3 predictions:") |
|
|
for i, (category, prob) in enumerate(top_predictions, 1): |
|
|
print(f" {i}. {category}: {prob:.3f}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting...") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|
|
|
def simple_usage_examples(): |
|
|
"""Show simple usage examples for HuggingFace Hub models""" |
|
|
print("\n" + "="*60) |
|
|
print("HUGGINGFACE HUB USAGE EXAMPLES") |
|
|
print("="*60) |
|
|
|
|
|
print("Code examples:") |
|
|
print(""" |
|
|
# VNTC Model (Vietnamese News Classification) |
|
|
from huggingface_hub import hf_hub_download |
|
|
import joblib |
|
|
|
|
|
# Download and load VNTC model from HuggingFace Hub |
|
|
vntc_model = joblib.load( |
|
|
hf_hub_download("undertheseanlp/sonar_core_1", "vntc_classifier_20250927_161550.joblib") |
|
|
) |
|
|
|
|
|
# Make prediction on news text |
|
|
news_text = "Đội tuyển bóng đá Việt Nam giành chiến thắng" |
|
|
prediction = vntc_model.predict([news_text])[0] |
|
|
print(f"News category: {prediction}") |
|
|
|
|
|
# UTS2017_Bank Model (Vietnamese Banking Text Classification) |
|
|
# Download and load UTS2017_Bank model from HuggingFace Hub |
|
|
bank_model = joblib.load( |
|
|
hf_hub_download("undertheseanlp/sonar_core_1", "uts2017_bank_classifier_20250927_161733.joblib") |
|
|
) |
|
|
|
|
|
# Make prediction on banking text |
|
|
bank_text = "Tôi muốn mở tài khoản tiết kiệm" |
|
|
prediction = bank_model.predict([bank_text])[0] |
|
|
print(f"Banking category: {prediction}") |
|
|
|
|
|
# For local file inference, use inference.py instead |
|
|
""") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main demonstration function""" |
|
|
print("Sonar Core 1 - Dual Model Hugging Face Hub Usage") |
|
|
print("=" * 60) |
|
|
|
|
|
try: |
|
|
|
|
|
simple_usage_examples() |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("TESTING VNTC MODEL (Vietnamese News Classification)") |
|
|
print("="*60) |
|
|
|
|
|
vntc_model = load_model_from_hub("vntc") |
|
|
predict_vntc_examples(vntc_model) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("TESTING UTS2017_BANK MODEL (Vietnamese Banking Text Classification)") |
|
|
print("="*60) |
|
|
|
|
|
bank_model = load_model_from_hub("uts2017_bank") |
|
|
predict_uts2017_examples(bank_model) |
|
|
|
|
|
|
|
|
try: |
|
|
import sys |
|
|
if hasattr(sys, 'ps1') or sys.stdin.isatty(): |
|
|
print("\nAvailable interactive modes:") |
|
|
print("1. VNTC (News) classification") |
|
|
print("2. UTS2017_Bank (Banking) classification") |
|
|
|
|
|
choice = input("\nSelect model for interactive mode (1/2) or 'n' to skip: ").strip() |
|
|
|
|
|
if choice == "1": |
|
|
interactive_mode(vntc_model, "vntc") |
|
|
elif choice == "2": |
|
|
interactive_mode(bank_model, "uts2017_bank") |
|
|
|
|
|
except (EOFError, OSError): |
|
|
print("\nInteractive mode not available in this environment.") |
|
|
print("Run this script in a regular terminal to use interactive mode.") |
|
|
|
|
|
print("\nDemonstration complete!") |
|
|
print("\nBoth models are now available on Hugging Face Hub:") |
|
|
print("- VNTC (News): vntc_classifier_20250927_161550.joblib") |
|
|
print("- UTS2017_Bank (Banking): uts2017_bank_classifier_20250927_161733.joblib") |
|
|
|
|
|
except ImportError: |
|
|
print("Error: huggingface_hub is required. Install with:") |
|
|
print(" pip install huggingface_hub") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
print("\nMake sure you have internet connection and try again.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |