zen-training / app_simple_backup.py
Hanzo Dev
Restore full configurable UI - user requested dropdowns back
333f111
"""
Zen VL Training - Simplified & Working
Just trains zen-vl-4b with our datasets
"""
import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer
from datasets import load_dataset
def train_zen_vl():
"""Simple one-button training for zen-vl-4b"""
logs = []
def log(msg):
print(msg)
logs.append(msg)
yield "\n".join(logs)
try:
yield from log("๐Ÿง˜ Starting Zen VL 4B Training")
yield from log("=" * 80)
# GPU check
has_gpu = torch.cuda.is_available()
yield from log(f"๐ŸŽฎ GPU: {has_gpu}")
if has_gpu:
yield from log(f" {torch.cuda.get_device_name(0)}")
yield from log(f" {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
# Load model
yield from log("\n๐Ÿ“ฆ Loading zen-vl-4b-instruct...")
model = AutoModel.from_pretrained(
"zenlm/zen-vl-4b-instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct")
yield from log("โœ… Model loaded")
# Load datasets
yield from log("\n๐Ÿ“š Loading datasets...")
all_data = []
datasets_to_load = [
("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500),
("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500),
("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500),
("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500)
]
for name, hf_id, config, max_samples in datasets_to_load:
yield from log(f" Loading {name}...")
try:
if config:
ds = load_dataset(hf_id, config, split="train", streaming=True)
else:
ds = load_dataset(hf_id, split="train", streaming=True)
samples = []
for i, example in enumerate(ds):
if i >= max_samples:
break
samples.append(example)
all_data.extend(samples)
yield from log(f" โœ… {len(samples)} samples")
except Exception as e:
yield from log(f" โš ๏ธ Error: {e}")
yield from log(f"\nโœ… Total: {len(all_data)} samples")
# Training
yield from log("\nโš™๏ธ Training Configuration:")
yield from log(" Epochs: 3")
yield from log(" Batch Size: 1")
yield from log(" Learning Rate: 2e-5")
yield from log(" Output: zenlm/zen-vl-4b-agent")
training_args = TrainingArguments(
output_dir="./zen-vl-output",
num_train_epochs=3,
per_device_train_batch_size=1,
learning_rate=2e-5,
logging_steps=10,
save_steps=500,
bf16=True,
push_to_hub=True,
hub_model_id="zenlm/zen-vl-4b-agent",
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=all_data if len(all_data) > 0 else None,
)
yield from log("\n๐Ÿ”ฅ TRAINING STARTED")
yield from log("=" * 80)
result = trainer.train()
yield from log("\nโœ… TRAINING COMPLETED!")
yield from log(f"๐Ÿ“Š Final Loss: {result.training_loss:.4f}")
yield from log("โ˜๏ธ Uploading to zenlm/zen-vl-4b-agent...")
trainer.push_to_hub()
yield from log("\n๐ŸŽ‰ SUCCESS! Model live at zenlm/zen-vl-4b-agent")
except Exception as e:
yield from log(f"\nโŒ ERROR: {str(e)}")
import traceback
yield from log(f"\n{traceback.format_exc()}")
# Simple interface
with gr.Blocks(title="Zen VL Training") as demo:
gr.Markdown("""
# ๐Ÿง˜ Zen VL 4B Training
Trains zen-vl-4b-instruct โ†’ zen-vl-4b-agent
**Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k)
**Total**: ~30k samples
**Time**: ~6-8 hours on A10G
**Output**: zenlm/zen-vl-4b-agent
""")
start_btn = gr.Button("๐Ÿš€ Start Training", variant="primary", size="lg")
output = gr.Textbox(label="Training Logs", lines=30)
start_btn.click(train_zen_vl, outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)