Train / app.py
kvn420's picture
Create app.py
8067fe5 verified
raw
history blame
17.4 kB
import gradio as gr
import torch
import torch.nn as nn
from transformers import (
AutoTokenizer, AutoModel, AutoProcessor,
AutoModelForCausalLM, TrainingArguments, Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, load_dataset, concatenate_datasets
import json
import os
import requests
from PIL import Image
import librosa
import cv2
import numpy as np
from pathlib import Path
import logging
from typing import Dict, List, Optional, Union
import time
from huggingface_hub import HfApi, list_datasets_in_collection
import tempfile
import shutil
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultimodalTrainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.current_model = None
self.current_tokenizer = None
self.current_processor = None
self.training_data = []
self.hf_api = HfApi()
def load_model(self, model_name: str, model_type: str = "causal"):
"""Charge un modèle depuis Hugging Face"""
try:
logger.info(f"Chargement du modèle: {model_name}")
if model_type == "causal":
self.current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
else:
self.current_model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Charge le tokenizer et processor
try:
self.current_tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
except:
logger.warning("Tokenizer non trouvé, utilisation d'un tokenizer par défaut")
try:
self.current_processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True
)
except:
logger.warning("Processor non trouvé")
return f"✅ Modèle {model_name} chargé avec succès!"
except Exception as e:
error_msg = f"❌ Erreur lors du chargement: {str(e)}"
logger.error(error_msg)
return error_msg
def load_collection_datasets(self, collection_url: str):
"""Charge tous les datasets d'une collection HF"""
try:
# Extrait l'ID de la collection depuis l'URL
collection_id = collection_url.split("/")[-1]
# Liste les datasets de la collection
collection_items = list_datasets_in_collection(collection_id)
datasets_info = []
loaded_datasets = []
for item in collection_items:
try:
dataset_name = item.id
dataset = load_dataset(dataset_name, split='train', streaming=False)
loaded_datasets.append(dataset)
datasets_info.append(f"✅ {dataset_name}: {len(dataset)} exemples")
logger.info(f"Dataset chargé: {dataset_name}")
except Exception as e:
datasets_info.append(f"❌ {dataset_name}: {str(e)}")
logger.error(f"Erreur dataset {dataset_name}: {e}")
# Combine tous les datasets
if loaded_datasets:
combined_dataset = concatenate_datasets(loaded_datasets)
self.training_data = combined_dataset
result = f"📊 Collection chargée!\n" + "\n".join(datasets_info)
result += f"\n\n🔢 Total combiné: {len(self.training_data)} exemples"
return result
except Exception as e:
error_msg = f"❌ Erreur collection: {str(e)}"
logger.error(error_msg)
return error_msg
def load_single_dataset(self, dataset_name: str, split: str = "train"):
"""Charge un dataset individuel"""
try:
dataset = load_dataset(dataset_name, split=split)
if hasattr(self, 'training_data') and self.training_data:
# Combine avec les données existantes
self.training_data = concatenate_datasets([self.training_data, dataset])
else:
self.training_data = dataset
return f"✅ Dataset {dataset_name} ajouté! Total: {len(self.training_data)} exemples"
except Exception as e:
error_msg = f"❌ Erreur dataset: {str(e)}"
logger.error(error_msg)
return error_msg
def process_multimodal_data(self, example):
"""Traite les données multimodales pour l'entraînement"""
processed = {}
# Traitement du texte
if 'text' in example:
if self.current_tokenizer:
tokens = self.current_tokenizer(
example['text'],
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
processed.update(tokens)
# Traitement des images
if 'image' in example:
try:
if isinstance(example['image'], str):
# URL ou chemin
if example['image'].startswith('http'):
response = requests.get(example['image'])
image = Image.open(io.BytesIO(response.content))
else:
image = Image.open(example['image'])
else:
image = example['image']
if self.current_processor:
image_inputs = self.current_processor(
images=image, return_tensors="pt"
)
processed.update(image_inputs)
except Exception as e:
logger.warning(f"Erreur traitement image: {e}")
# Traitement audio
if 'audio' in example:
try:
if isinstance(example['audio'], str):
audio_data, sr = librosa.load(example['audio'], sr=16000)
else:
audio_data = example['audio']
sr = 16000
# Conversion basique pour l'exemple
processed['audio'] = torch.tensor(audio_data).unsqueeze(0)
except Exception as e:
logger.warning(f"Erreur traitement audio: {e}")
return processed
def start_training(self,
output_dir: str,
num_epochs: int = 3,
learning_rate: float = 5e-5,
batch_size: int = 4,
save_steps: int = 500):
"""Lance l'entraînement du modèle"""
if not self.current_model:
return "❌ Aucun modèle chargé!"
if not self.training_data:
return "❌ Aucune donnée d'entraînement!"
try:
# Préparation des données
logger.info("Préparation des données...")
# Arguments d'entraînement
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
logging_steps=50,
save_steps=save_steps,
eval_steps=save_steps,
warmup_steps=100,
fp16=torch.cuda.is_available(),
dataloader_num_workers=2,
remove_unused_columns=False,
report_to=None # Désactive wandb/tensorboard
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.current_tokenizer,
mlm=False
) if self.current_tokenizer else None
# Trainer
trainer = Trainer(
model=self.current_model,
args=training_args,
train_dataset=self.training_data,
data_collator=data_collator,
)
# Lance l'entraînement
logger.info("🚀 Début de l'entraînement...")
trainer.train()
# Sauvegarde
trainer.save_model()
if self.current_tokenizer:
self.current_tokenizer.save_pretrained(output_dir)
return f"✅ Entraînement terminé! Modèle sauvegardé dans {output_dir}"
except Exception as e:
error_msg = f"❌ Erreur entraînement: {str(e)}"
logger.error(error_msg)
return error_msg
def get_model_info(self):
"""Retourne les informations du modèle actuel"""
if not self.current_model:
return "Aucun modèle chargé"
info = f"📋 Modèle actuel:\n"
info += f"Type: {type(self.current_model).__name__}\n"
info += f"Device: {next(self.current_model.parameters()).device}\n"
# Compte les paramètres
total_params = sum(p.numel() for p in self.current_model.parameters())
trainable_params = sum(p.numel() for p in self.current_model.parameters() if p.requires_grad)
info += f"Paramètres totaux: {total_params:,}\n"
info += f"Paramètres entraînables: {trainable_params:,}\n"
if hasattr(self, 'training_data') and self.training_data:
info += f"\n📊 Données: {len(self.training_data)} exemples"
return info
# Initialisation du trainer
trainer = MultimodalTrainer()
# Interface Gradio
def create_interface():
with gr.Blocks(title="🔥 Multimodal Training Hub", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🔥 Multimodal Training Hub
### Entraînez vos modèles multimodaux avec facilité!
Supporté: Texte 📝 • Images 🖼️ • Audio 🎵 • Vidéo 🎬
""")
with gr.Tab("🤖 Modèle"):
with gr.Row():
with gr.Column():
model_input = gr.Textbox(
label="Nom du modèle HuggingFace",
placeholder="kvn420/Tenro_V4.1",
value="kvn420/Tenro_V4.1"
)
model_type = gr.Dropdown(
label="Type de modèle",
choices=["causal", "base"],
value="causal"
)
load_model_btn = gr.Button("🔄 Charger le modèle", variant="primary")
with gr.Column():
model_status = gr.Textbox(
label="Status du modèle",
interactive=False,
lines=8
)
load_model_btn.click(
trainer.load_model,
inputs=[model_input, model_type],
outputs=model_status
)
with gr.Tab("📊 Données"):
with gr.Row():
with gr.Column():
gr.Markdown("### 📦 Collection HuggingFace")
collection_input = gr.Textbox(
label="URL de la collection",
placeholder="https://huggingface.co/collections/kvn420/op-67aa4430ba254a4ff0689742"
)
load_collection_btn = gr.Button("📥 Charger collection", variant="secondary")
gr.Markdown("### 📝 Dataset individuel")
dataset_input = gr.Textbox(
label="Nom du dataset",
placeholder="microsoft/coco"
)
dataset_split = gr.Textbox(
label="Split",
value="train"
)
load_dataset_btn = gr.Button("➕ Ajouter dataset", variant="secondary")
with gr.Column():
data_status = gr.Textbox(
label="Status des données",
interactive=False,
lines=12
)
load_collection_btn.click(
trainer.load_collection_datasets,
inputs=collection_input,
outputs=data_status
)
load_dataset_btn.click(
trainer.load_single_dataset,
inputs=[dataset_input, dataset_split],
outputs=data_status
)
with gr.Tab("🏋️ Entraînement"):
with gr.Row():
with gr.Column():
output_dir = gr.Textbox(
label="Dossier de sortie",
value="./trained_model"
)
with gr.Row():
num_epochs = gr.Number(
label="Époques",
value=3,
minimum=1
)
batch_size = gr.Number(
label="Batch size",
value=4,
minimum=1
)
with gr.Row():
learning_rate = gr.Number(
label="Learning rate",
value=5e-5,
step=1e-6
)
save_steps = gr.Number(
label="Save steps",
value=500,
minimum=100
)
train_btn = gr.Button("🚀 Lancer l'entraînement", variant="primary", size="lg")
with gr.Column():
training_status = gr.Textbox(
label="Status de l'entraînement",
interactive=False,
lines=8
)
info_btn = gr.Button("ℹ️ Info modèle")
model_info = gr.Textbox(
label="Informations du modèle",
interactive=False,
lines=6
)
train_btn.click(
trainer.start_training,
inputs=[output_dir, num_epochs, learning_rate, batch_size, save_steps],
outputs=training_status
)
info_btn.click(
trainer.get_model_info,
outputs=model_info
)
with gr.Tab("📚 Aide"):
gr.Markdown("""
## 🚀 Guide d'utilisation
### 1. Charger un modèle
- Entrez le nom d'un modèle HuggingFace (ex: `kvn420/Tenro_V4.1`)
- Choisissez le type (causal pour génération, base pour embedding)
- Cliquez sur "Charger le modèle"
### 2. Ajouter des données
**Collection:** Chargez tous les datasets d'une collection HF
**Dataset individuel:** Ajoutez un dataset spécifique
### 3. Entraîner
- Configurez les paramètres d'entraînement
- Lancez l'entraînement avec "🚀 Lancer l'entraînement"
### 📋 Formats supportés
- **Texte:** Colonnes `text`, `prompt`, `conversation`
- **Images:** Colonnes `image`, `images` (URLs ou chemins)
- **Audio:** Colonnes `audio` (fichiers audio)
- **Vidéo:** Colonnes `video` (fichiers vidéo)
### ⚡ Conseils
- Utilisez un GPU pour l'entraînement (T4, A10G recommandé)
- Ajustez le batch_size selon votre mémoire GPU
- Sauvegardez régulièrement avec save_steps
""")
return app
# Lancement de l'application
if __name__ == "__main__":
app = create_interface()
app.launch(share=True, server_name="0.0.0.0", server_port=7860)