Spaces:
Paused
Paused
| """ | |
| Zen VL Training - Simplified & Working | |
| Just trains zen-vl-4b with our datasets | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer | |
| from datasets import load_dataset | |
| def train_zen_vl(): | |
| """Simple one-button training for zen-vl-4b""" | |
| logs = [] | |
| def log(msg): | |
| print(msg) | |
| logs.append(msg) | |
| yield "\n".join(logs) | |
| try: | |
| yield from log("๐ง Starting Zen VL 4B Training") | |
| yield from log("=" * 80) | |
| # GPU check | |
| has_gpu = torch.cuda.is_available() | |
| yield from log(f"๐ฎ GPU: {has_gpu}") | |
| if has_gpu: | |
| yield from log(f" {torch.cuda.get_device_name(0)}") | |
| yield from log(f" {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") | |
| # Load model | |
| yield from log("\n๐ฆ Loading zen-vl-4b-instruct...") | |
| model = AutoModel.from_pretrained( | |
| "zenlm/zen-vl-4b-instruct", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct") | |
| yield from log("โ Model loaded") | |
| # Load datasets | |
| yield from log("\n๐ Loading datasets...") | |
| all_data = [] | |
| datasets_to_load = [ | |
| ("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500), | |
| ("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500), | |
| ("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500), | |
| ("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500) | |
| ] | |
| for name, hf_id, config, max_samples in datasets_to_load: | |
| yield from log(f" Loading {name}...") | |
| try: | |
| if config: | |
| ds = load_dataset(hf_id, config, split="train", streaming=True) | |
| else: | |
| ds = load_dataset(hf_id, split="train", streaming=True) | |
| samples = [] | |
| for i, example in enumerate(ds): | |
| if i >= max_samples: | |
| break | |
| samples.append(example) | |
| all_data.extend(samples) | |
| yield from log(f" โ {len(samples)} samples") | |
| except Exception as e: | |
| yield from log(f" โ ๏ธ Error: {e}") | |
| yield from log(f"\nโ Total: {len(all_data)} samples") | |
| # Training | |
| yield from log("\nโ๏ธ Training Configuration:") | |
| yield from log(" Epochs: 3") | |
| yield from log(" Batch Size: 1") | |
| yield from log(" Learning Rate: 2e-5") | |
| yield from log(" Output: zenlm/zen-vl-4b-agent") | |
| training_args = TrainingArguments( | |
| output_dir="./zen-vl-output", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=1, | |
| learning_rate=2e-5, | |
| logging_steps=10, | |
| save_steps=500, | |
| bf16=True, | |
| push_to_hub=True, | |
| hub_model_id="zenlm/zen-vl-4b-agent", | |
| report_to="none", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=all_data if len(all_data) > 0 else None, | |
| ) | |
| yield from log("\n๐ฅ TRAINING STARTED") | |
| yield from log("=" * 80) | |
| result = trainer.train() | |
| yield from log("\nโ TRAINING COMPLETED!") | |
| yield from log(f"๐ Final Loss: {result.training_loss:.4f}") | |
| yield from log("โ๏ธ Uploading to zenlm/zen-vl-4b-agent...") | |
| trainer.push_to_hub() | |
| yield from log("\n๐ SUCCESS! Model live at zenlm/zen-vl-4b-agent") | |
| except Exception as e: | |
| yield from log(f"\nโ ERROR: {str(e)}") | |
| import traceback | |
| yield from log(f"\n{traceback.format_exc()}") | |
| # Simple interface | |
| with gr.Blocks(title="Zen VL Training") as demo: | |
| gr.Markdown(""" | |
| # ๐ง Zen VL 4B Training | |
| Trains zen-vl-4b-instruct โ zen-vl-4b-agent | |
| **Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k) | |
| **Total**: ~30k samples | |
| **Time**: ~6-8 hours on A10G | |
| **Output**: zenlm/zen-vl-4b-agent | |
| """) | |
| start_btn = gr.Button("๐ Start Training", variant="primary", size="lg") | |
| output = gr.Textbox(label="Training Logs", lines=30) | |
| start_btn.click(train_zen_vl, outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |