""" Zen VL Training - Simplified & Working Just trains zen-vl-4b with our datasets """ import gradio as gr import torch from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer from datasets import load_dataset def train_zen_vl(): """Simple one-button training for zen-vl-4b""" logs = [] def log(msg): print(msg) logs.append(msg) yield "\n".join(logs) try: yield from log("🧘 Starting Zen VL 4B Training") yield from log("=" * 80) # GPU check has_gpu = torch.cuda.is_available() yield from log(f"šŸŽ® GPU: {has_gpu}") if has_gpu: yield from log(f" {torch.cuda.get_device_name(0)}") yield from log(f" {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") # Load model yield from log("\nšŸ“¦ Loading zen-vl-4b-instruct...") model = AutoModel.from_pretrained( "zenlm/zen-vl-4b-instruct", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct") yield from log("āœ… Model loaded") # Load datasets yield from log("\nšŸ“š Loading datasets...") all_data = [] datasets_to_load = [ ("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500), ("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500), ("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500), ("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500) ] for name, hf_id, config, max_samples in datasets_to_load: yield from log(f" Loading {name}...") try: if config: ds = load_dataset(hf_id, config, split="train", streaming=True) else: ds = load_dataset(hf_id, split="train", streaming=True) samples = [] for i, example in enumerate(ds): if i >= max_samples: break samples.append(example) all_data.extend(samples) yield from log(f" āœ… {len(samples)} samples") except Exception as e: yield from log(f" āš ļø Error: {e}") yield from log(f"\nāœ… Total: {len(all_data)} samples") # Training yield from log("\nāš™ļø Training Configuration:") yield from log(" Epochs: 3") yield from log(" Batch Size: 1") yield from log(" Learning Rate: 2e-5") yield from log(" Output: zenlm/zen-vl-4b-agent") training_args = TrainingArguments( output_dir="./zen-vl-output", num_train_epochs=3, per_device_train_batch_size=1, learning_rate=2e-5, logging_steps=10, save_steps=500, bf16=True, push_to_hub=True, hub_model_id="zenlm/zen-vl-4b-agent", report_to="none", ) trainer = Trainer( model=model, args=training_args, train_dataset=all_data if len(all_data) > 0 else None, ) yield from log("\nšŸ”„ TRAINING STARTED") yield from log("=" * 80) result = trainer.train() yield from log("\nāœ… TRAINING COMPLETED!") yield from log(f"šŸ“Š Final Loss: {result.training_loss:.4f}") yield from log("ā˜ļø Uploading to zenlm/zen-vl-4b-agent...") trainer.push_to_hub() yield from log("\nšŸŽ‰ SUCCESS! Model live at zenlm/zen-vl-4b-agent") except Exception as e: yield from log(f"\nāŒ ERROR: {str(e)}") import traceback yield from log(f"\n{traceback.format_exc()}") # Simple interface with gr.Blocks(title="Zen VL Training") as demo: gr.Markdown(""" # 🧘 Zen VL 4B Training Trains zen-vl-4b-instruct → zen-vl-4b-agent **Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k) **Total**: ~30k samples **Time**: ~6-8 hours on A10G **Output**: zenlm/zen-vl-4b-agent """) start_btn = gr.Button("šŸš€ Start Training", variant="primary", size="lg") output = gr.Textbox(label="Training Logs", lines=30) start_btn.click(train_zen_vl, outputs=output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)