kvn420 commited on
Commit
8067fe5
·
verified ·
1 Parent(s): 7638a3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +451 -0
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import (
5
+ AutoTokenizer, AutoModel, AutoProcessor,
6
+ AutoModelForCausalLM, TrainingArguments, Trainer,
7
+ DataCollatorForLanguageModeling
8
+ )
9
+ from datasets import Dataset, load_dataset, concatenate_datasets
10
+ import json
11
+ import os
12
+ import requests
13
+ from PIL import Image
14
+ import librosa
15
+ import cv2
16
+ import numpy as np
17
+ from pathlib import Path
18
+ import logging
19
+ from typing import Dict, List, Optional, Union
20
+ import time
21
+ from huggingface_hub import HfApi, list_datasets_in_collection
22
+ import tempfile
23
+ import shutil
24
+
25
+ # Configuration du logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class MultimodalTrainer:
30
+ def __init__(self):
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.current_model = None
33
+ self.current_tokenizer = None
34
+ self.current_processor = None
35
+ self.training_data = []
36
+ self.hf_api = HfApi()
37
+
38
+ def load_model(self, model_name: str, model_type: str = "causal"):
39
+ """Charge un modèle depuis Hugging Face"""
40
+ try:
41
+ logger.info(f"Chargement du modèle: {model_name}")
42
+
43
+ if model_type == "causal":
44
+ self.current_model = AutoModelForCausalLM.from_pretrained(
45
+ model_name,
46
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
47
+ device_map="auto" if torch.cuda.is_available() else None,
48
+ trust_remote_code=True
49
+ )
50
+ else:
51
+ self.current_model = AutoModel.from_pretrained(
52
+ model_name,
53
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
54
+ device_map="auto" if torch.cuda.is_available() else None,
55
+ trust_remote_code=True
56
+ )
57
+
58
+ # Charge le tokenizer et processor
59
+ try:
60
+ self.current_tokenizer = AutoTokenizer.from_pretrained(
61
+ model_name, trust_remote_code=True
62
+ )
63
+ except:
64
+ logger.warning("Tokenizer non trouvé, utilisation d'un tokenizer par défaut")
65
+
66
+ try:
67
+ self.current_processor = AutoProcessor.from_pretrained(
68
+ model_name, trust_remote_code=True
69
+ )
70
+ except:
71
+ logger.warning("Processor non trouvé")
72
+
73
+ return f"✅ Modèle {model_name} chargé avec succès!"
74
+
75
+ except Exception as e:
76
+ error_msg = f"❌ Erreur lors du chargement: {str(e)}"
77
+ logger.error(error_msg)
78
+ return error_msg
79
+
80
+ def load_collection_datasets(self, collection_url: str):
81
+ """Charge tous les datasets d'une collection HF"""
82
+ try:
83
+ # Extrait l'ID de la collection depuis l'URL
84
+ collection_id = collection_url.split("/")[-1]
85
+
86
+ # Liste les datasets de la collection
87
+ collection_items = list_datasets_in_collection(collection_id)
88
+
89
+ datasets_info = []
90
+ loaded_datasets = []
91
+
92
+ for item in collection_items:
93
+ try:
94
+ dataset_name = item.id
95
+ dataset = load_dataset(dataset_name, split='train', streaming=False)
96
+ loaded_datasets.append(dataset)
97
+ datasets_info.append(f"✅ {dataset_name}: {len(dataset)} exemples")
98
+ logger.info(f"Dataset chargé: {dataset_name}")
99
+ except Exception as e:
100
+ datasets_info.append(f"❌ {dataset_name}: {str(e)}")
101
+ logger.error(f"Erreur dataset {dataset_name}: {e}")
102
+
103
+ # Combine tous les datasets
104
+ if loaded_datasets:
105
+ combined_dataset = concatenate_datasets(loaded_datasets)
106
+ self.training_data = combined_dataset
107
+
108
+ result = f"📊 Collection chargée!\n" + "\n".join(datasets_info)
109
+ result += f"\n\n🔢 Total combiné: {len(self.training_data)} exemples"
110
+
111
+ return result
112
+
113
+ except Exception as e:
114
+ error_msg = f"❌ Erreur collection: {str(e)}"
115
+ logger.error(error_msg)
116
+ return error_msg
117
+
118
+ def load_single_dataset(self, dataset_name: str, split: str = "train"):
119
+ """Charge un dataset individuel"""
120
+ try:
121
+ dataset = load_dataset(dataset_name, split=split)
122
+
123
+ if hasattr(self, 'training_data') and self.training_data:
124
+ # Combine avec les données existantes
125
+ self.training_data = concatenate_datasets([self.training_data, dataset])
126
+ else:
127
+ self.training_data = dataset
128
+
129
+ return f"✅ Dataset {dataset_name} ajouté! Total: {len(self.training_data)} exemples"
130
+
131
+ except Exception as e:
132
+ error_msg = f"❌ Erreur dataset: {str(e)}"
133
+ logger.error(error_msg)
134
+ return error_msg
135
+
136
+ def process_multimodal_data(self, example):
137
+ """Traite les données multimodales pour l'entraînement"""
138
+ processed = {}
139
+
140
+ # Traitement du texte
141
+ if 'text' in example:
142
+ if self.current_tokenizer:
143
+ tokens = self.current_tokenizer(
144
+ example['text'],
145
+ truncation=True,
146
+ padding=True,
147
+ max_length=512,
148
+ return_tensors="pt"
149
+ )
150
+ processed.update(tokens)
151
+
152
+ # Traitement des images
153
+ if 'image' in example:
154
+ try:
155
+ if isinstance(example['image'], str):
156
+ # URL ou chemin
157
+ if example['image'].startswith('http'):
158
+ response = requests.get(example['image'])
159
+ image = Image.open(io.BytesIO(response.content))
160
+ else:
161
+ image = Image.open(example['image'])
162
+ else:
163
+ image = example['image']
164
+
165
+ if self.current_processor:
166
+ image_inputs = self.current_processor(
167
+ images=image, return_tensors="pt"
168
+ )
169
+ processed.update(image_inputs)
170
+
171
+ except Exception as e:
172
+ logger.warning(f"Erreur traitement image: {e}")
173
+
174
+ # Traitement audio
175
+ if 'audio' in example:
176
+ try:
177
+ if isinstance(example['audio'], str):
178
+ audio_data, sr = librosa.load(example['audio'], sr=16000)
179
+ else:
180
+ audio_data = example['audio']
181
+ sr = 16000
182
+
183
+ # Conversion basique pour l'exemple
184
+ processed['audio'] = torch.tensor(audio_data).unsqueeze(0)
185
+
186
+ except Exception as e:
187
+ logger.warning(f"Erreur traitement audio: {e}")
188
+
189
+ return processed
190
+
191
+ def start_training(self,
192
+ output_dir: str,
193
+ num_epochs: int = 3,
194
+ learning_rate: float = 5e-5,
195
+ batch_size: int = 4,
196
+ save_steps: int = 500):
197
+ """Lance l'entraînement du modèle"""
198
+
199
+ if not self.current_model:
200
+ return "❌ Aucun modèle chargé!"
201
+
202
+ if not self.training_data:
203
+ return "❌ Aucune donnée d'entraînement!"
204
+
205
+ try:
206
+ # Préparation des données
207
+ logger.info("Préparation des données...")
208
+
209
+ # Arguments d'entraînement
210
+ training_args = TrainingArguments(
211
+ output_dir=output_dir,
212
+ num_train_epochs=num_epochs,
213
+ per_device_train_batch_size=batch_size,
214
+ learning_rate=learning_rate,
215
+ logging_steps=50,
216
+ save_steps=save_steps,
217
+ eval_steps=save_steps,
218
+ warmup_steps=100,
219
+ fp16=torch.cuda.is_available(),
220
+ dataloader_num_workers=2,
221
+ remove_unused_columns=False,
222
+ report_to=None # Désactive wandb/tensorboard
223
+ )
224
+
225
+ # Data collator
226
+ data_collator = DataCollatorForLanguageModeling(
227
+ tokenizer=self.current_tokenizer,
228
+ mlm=False
229
+ ) if self.current_tokenizer else None
230
+
231
+ # Trainer
232
+ trainer = Trainer(
233
+ model=self.current_model,
234
+ args=training_args,
235
+ train_dataset=self.training_data,
236
+ data_collator=data_collator,
237
+ )
238
+
239
+ # Lance l'entraînement
240
+ logger.info("🚀 Début de l'entraînement...")
241
+ trainer.train()
242
+
243
+ # Sauvegarde
244
+ trainer.save_model()
245
+ if self.current_tokenizer:
246
+ self.current_tokenizer.save_pretrained(output_dir)
247
+
248
+ return f"✅ Entraînement terminé! Modèle sauvegardé dans {output_dir}"
249
+
250
+ except Exception as e:
251
+ error_msg = f"❌ Erreur entraînement: {str(e)}"
252
+ logger.error(error_msg)
253
+ return error_msg
254
+
255
+ def get_model_info(self):
256
+ """Retourne les informations du modèle actuel"""
257
+ if not self.current_model:
258
+ return "Aucun modèle chargé"
259
+
260
+ info = f"📋 Modèle actuel:\n"
261
+ info += f"Type: {type(self.current_model).__name__}\n"
262
+ info += f"Device: {next(self.current_model.parameters()).device}\n"
263
+
264
+ # Compte les paramètres
265
+ total_params = sum(p.numel() for p in self.current_model.parameters())
266
+ trainable_params = sum(p.numel() for p in self.current_model.parameters() if p.requires_grad)
267
+
268
+ info += f"Paramètres totaux: {total_params:,}\n"
269
+ info += f"Paramètres entraînables: {trainable_params:,}\n"
270
+
271
+ if hasattr(self, 'training_data') and self.training_data:
272
+ info += f"\n📊 Données: {len(self.training_data)} exemples"
273
+
274
+ return info
275
+
276
+ # Initialisation du trainer
277
+ trainer = MultimodalTrainer()
278
+
279
+ # Interface Gradio
280
+ def create_interface():
281
+ with gr.Blocks(title="🔥 Multimodal Training Hub", theme=gr.themes.Soft()) as app:
282
+
283
+ gr.Markdown("""
284
+ # 🔥 Multimodal Training Hub
285
+ ### Entraînez vos modèles multimodaux avec facilité!
286
+
287
+ Supporté: Texte 📝 • Images 🖼️ • Audio 🎵 • Vidéo 🎬
288
+ """)
289
+
290
+ with gr.Tab("🤖 Modèle"):
291
+ with gr.Row():
292
+ with gr.Column():
293
+ model_input = gr.Textbox(
294
+ label="Nom du modèle HuggingFace",
295
+ placeholder="kvn420/Tenro_V4.1",
296
+ value="kvn420/Tenro_V4.1"
297
+ )
298
+ model_type = gr.Dropdown(
299
+ label="Type de modèle",
300
+ choices=["causal", "base"],
301
+ value="causal"
302
+ )
303
+ load_model_btn = gr.Button("🔄 Charger le modèle", variant="primary")
304
+
305
+ with gr.Column():
306
+ model_status = gr.Textbox(
307
+ label="Status du modèle",
308
+ interactive=False,
309
+ lines=8
310
+ )
311
+
312
+ load_model_btn.click(
313
+ trainer.load_model,
314
+ inputs=[model_input, model_type],
315
+ outputs=model_status
316
+ )
317
+
318
+ with gr.Tab("📊 Données"):
319
+ with gr.Row():
320
+ with gr.Column():
321
+ gr.Markdown("### 📦 Collection HuggingFace")
322
+ collection_input = gr.Textbox(
323
+ label="URL de la collection",
324
+ placeholder="https://huggingface.co/collections/kvn420/op-67aa4430ba254a4ff0689742"
325
+ )
326
+ load_collection_btn = gr.Button("📥 Charger collection", variant="secondary")
327
+
328
+ gr.Markdown("### 📝 Dataset individuel")
329
+ dataset_input = gr.Textbox(
330
+ label="Nom du dataset",
331
+ placeholder="microsoft/coco"
332
+ )
333
+ dataset_split = gr.Textbox(
334
+ label="Split",
335
+ value="train"
336
+ )
337
+ load_dataset_btn = gr.Button("➕ Ajouter dataset", variant="secondary")
338
+
339
+ with gr.Column():
340
+ data_status = gr.Textbox(
341
+ label="Status des données",
342
+ interactive=False,
343
+ lines=12
344
+ )
345
+
346
+ load_collection_btn.click(
347
+ trainer.load_collection_datasets,
348
+ inputs=collection_input,
349
+ outputs=data_status
350
+ )
351
+
352
+ load_dataset_btn.click(
353
+ trainer.load_single_dataset,
354
+ inputs=[dataset_input, dataset_split],
355
+ outputs=data_status
356
+ )
357
+
358
+ with gr.Tab("🏋️ Entraînement"):
359
+ with gr.Row():
360
+ with gr.Column():
361
+ output_dir = gr.Textbox(
362
+ label="Dossier de sortie",
363
+ value="./trained_model"
364
+ )
365
+
366
+ with gr.Row():
367
+ num_epochs = gr.Number(
368
+ label="Époques",
369
+ value=3,
370
+ minimum=1
371
+ )
372
+ batch_size = gr.Number(
373
+ label="Batch size",
374
+ value=4,
375
+ minimum=1
376
+ )
377
+
378
+ with gr.Row():
379
+ learning_rate = gr.Number(
380
+ label="Learning rate",
381
+ value=5e-5,
382
+ step=1e-6
383
+ )
384
+ save_steps = gr.Number(
385
+ label="Save steps",
386
+ value=500,
387
+ minimum=100
388
+ )
389
+
390
+ train_btn = gr.Button("🚀 Lancer l'entraînement", variant="primary", size="lg")
391
+
392
+ with gr.Column():
393
+ training_status = gr.Textbox(
394
+ label="Status de l'entraînement",
395
+ interactive=False,
396
+ lines=8
397
+ )
398
+
399
+ info_btn = gr.Button("ℹ️ Info modèle")
400
+ model_info = gr.Textbox(
401
+ label="Informations du modèle",
402
+ interactive=False,
403
+ lines=6
404
+ )
405
+
406
+ train_btn.click(
407
+ trainer.start_training,
408
+ inputs=[output_dir, num_epochs, learning_rate, batch_size, save_steps],
409
+ outputs=training_status
410
+ )
411
+
412
+ info_btn.click(
413
+ trainer.get_model_info,
414
+ outputs=model_info
415
+ )
416
+
417
+ with gr.Tab("📚 Aide"):
418
+ gr.Markdown("""
419
+ ## 🚀 Guide d'utilisation
420
+
421
+ ### 1. Charger un modèle
422
+ - Entrez le nom d'un modèle HuggingFace (ex: `kvn420/Tenro_V4.1`)
423
+ - Choisissez le type (causal pour génération, base pour embedding)
424
+ - Cliquez sur "Charger le modèle"
425
+
426
+ ### 2. Ajouter des données
427
+ **Collection:** Chargez tous les datasets d'une collection HF
428
+ **Dataset individuel:** Ajoutez un dataset spécifique
429
+
430
+ ### 3. Entraîner
431
+ - Configurez les paramètres d'entraînement
432
+ - Lancez l'entraînement avec "🚀 Lancer l'entraînement"
433
+
434
+ ### 📋 Formats supportés
435
+ - **Texte:** Colonnes `text`, `prompt`, `conversation`
436
+ - **Images:** Colonnes `image`, `images` (URLs ou chemins)
437
+ - **Audio:** Colonnes `audio` (fichiers audio)
438
+ - **Vidéo:** Colonnes `video` (fichiers vidéo)
439
+
440
+ ### ⚡ Conseils
441
+ - Utilisez un GPU pour l'entraînement (T4, A10G recommandé)
442
+ - Ajustez le batch_size selon votre mémoire GPU
443
+ - Sauvegardez régulièrement avec save_steps
444
+ """)
445
+
446
+ return app
447
+
448
+ # Lancement de l'application
449
+ if __name__ == "__main__":
450
+ app = create_interface()
451
+ app.launch(share=True, server_name="0.0.0.0", server_port=7860)