Spaces:
Paused
Paused
File size: 4,707 Bytes
333f111 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
Zen VL Training - Simplified & Working
Just trains zen-vl-4b with our datasets
"""
import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer
from datasets import load_dataset
def train_zen_vl():
"""Simple one-button training for zen-vl-4b"""
logs = []
def log(msg):
print(msg)
logs.append(msg)
yield "\n".join(logs)
try:
yield from log("๐ง Starting Zen VL 4B Training")
yield from log("=" * 80)
# GPU check
has_gpu = torch.cuda.is_available()
yield from log(f"๐ฎ GPU: {has_gpu}")
if has_gpu:
yield from log(f" {torch.cuda.get_device_name(0)}")
yield from log(f" {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
# Load model
yield from log("\n๐ฆ Loading zen-vl-4b-instruct...")
model = AutoModel.from_pretrained(
"zenlm/zen-vl-4b-instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct")
yield from log("โ
Model loaded")
# Load datasets
yield from log("\n๐ Loading datasets...")
all_data = []
datasets_to_load = [
("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500),
("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500),
("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500),
("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500)
]
for name, hf_id, config, max_samples in datasets_to_load:
yield from log(f" Loading {name}...")
try:
if config:
ds = load_dataset(hf_id, config, split="train", streaming=True)
else:
ds = load_dataset(hf_id, split="train", streaming=True)
samples = []
for i, example in enumerate(ds):
if i >= max_samples:
break
samples.append(example)
all_data.extend(samples)
yield from log(f" โ
{len(samples)} samples")
except Exception as e:
yield from log(f" โ ๏ธ Error: {e}")
yield from log(f"\nโ
Total: {len(all_data)} samples")
# Training
yield from log("\nโ๏ธ Training Configuration:")
yield from log(" Epochs: 3")
yield from log(" Batch Size: 1")
yield from log(" Learning Rate: 2e-5")
yield from log(" Output: zenlm/zen-vl-4b-agent")
training_args = TrainingArguments(
output_dir="./zen-vl-output",
num_train_epochs=3,
per_device_train_batch_size=1,
learning_rate=2e-5,
logging_steps=10,
save_steps=500,
bf16=True,
push_to_hub=True,
hub_model_id="zenlm/zen-vl-4b-agent",
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=all_data if len(all_data) > 0 else None,
)
yield from log("\n๐ฅ TRAINING STARTED")
yield from log("=" * 80)
result = trainer.train()
yield from log("\nโ
TRAINING COMPLETED!")
yield from log(f"๐ Final Loss: {result.training_loss:.4f}")
yield from log("โ๏ธ Uploading to zenlm/zen-vl-4b-agent...")
trainer.push_to_hub()
yield from log("\n๐ SUCCESS! Model live at zenlm/zen-vl-4b-agent")
except Exception as e:
yield from log(f"\nโ ERROR: {str(e)}")
import traceback
yield from log(f"\n{traceback.format_exc()}")
# Simple interface
with gr.Blocks(title="Zen VL Training") as demo:
gr.Markdown("""
# ๐ง Zen VL 4B Training
Trains zen-vl-4b-instruct โ zen-vl-4b-agent
**Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k)
**Total**: ~30k samples
**Time**: ~6-8 hours on A10G
**Output**: zenlm/zen-vl-4b-agent
""")
start_btn = gr.Button("๐ Start Training", variant="primary", size="lg")
output = gr.Textbox(label="Training Logs", lines=30)
start_btn.click(train_zen_vl, outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|