File size: 4,707 Bytes
333f111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Zen VL Training - Simplified & Working
Just trains zen-vl-4b with our datasets
"""

import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer
from datasets import load_dataset

def train_zen_vl():
    """Simple one-button training for zen-vl-4b"""
    
    logs = []
    
    def log(msg):
        print(msg)
        logs.append(msg)
        yield "\n".join(logs)
    
    try:
        yield from log("๐Ÿง˜ Starting Zen VL 4B Training")
        yield from log("=" * 80)
        
        # GPU check
        has_gpu = torch.cuda.is_available()
        yield from log(f"๐ŸŽฎ GPU: {has_gpu}")
        if has_gpu:
            yield from log(f"   {torch.cuda.get_device_name(0)}")
            yield from log(f"   {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
        
        # Load model
        yield from log("\n๐Ÿ“ฆ Loading zen-vl-4b-instruct...")
        model = AutoModel.from_pretrained(
            "zenlm/zen-vl-4b-instruct",
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct")
        yield from log("โœ… Model loaded")
        
        # Load datasets
        yield from log("\n๐Ÿ“š Loading datasets...")
        all_data = []
        
        datasets_to_load = [
            ("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500),
            ("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500),
            ("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500),
            ("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500)
        ]
        
        for name, hf_id, config, max_samples in datasets_to_load:
            yield from log(f"  Loading {name}...")
            try:
                if config:
                    ds = load_dataset(hf_id, config, split="train", streaming=True)
                else:
                    ds = load_dataset(hf_id, split="train", streaming=True)
                
                samples = []
                for i, example in enumerate(ds):
                    if i >= max_samples:
                        break
                    samples.append(example)
                
                all_data.extend(samples)
                yield from log(f"    โœ… {len(samples)} samples")
            except Exception as e:
                yield from log(f"    โš ๏ธ  Error: {e}")
        
        yield from log(f"\nโœ… Total: {len(all_data)} samples")
        
        # Training
        yield from log("\nโš™๏ธ  Training Configuration:")
        yield from log("  Epochs: 3")
        yield from log("  Batch Size: 1")
        yield from log("  Learning Rate: 2e-5")
        yield from log("  Output: zenlm/zen-vl-4b-agent")
        
        training_args = TrainingArguments(
            output_dir="./zen-vl-output",
            num_train_epochs=3,
            per_device_train_batch_size=1,
            learning_rate=2e-5,
            logging_steps=10,
            save_steps=500,
            bf16=True,
            push_to_hub=True,
            hub_model_id="zenlm/zen-vl-4b-agent",
            report_to="none",
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=all_data if len(all_data) > 0 else None,
        )
        
        yield from log("\n๐Ÿ”ฅ TRAINING STARTED")
        yield from log("=" * 80)
        
        result = trainer.train()
        
        yield from log("\nโœ… TRAINING COMPLETED!")
        yield from log(f"๐Ÿ“Š Final Loss: {result.training_loss:.4f}")
        yield from log("โ˜๏ธ  Uploading to zenlm/zen-vl-4b-agent...")
        
        trainer.push_to_hub()
        
        yield from log("\n๐ŸŽ‰ SUCCESS! Model live at zenlm/zen-vl-4b-agent")
        
    except Exception as e:
        yield from log(f"\nโŒ ERROR: {str(e)}")
        import traceback
        yield from log(f"\n{traceback.format_exc()}")

# Simple interface
with gr.Blocks(title="Zen VL Training") as demo:
    gr.Markdown("""
    # ๐Ÿง˜ Zen VL 4B Training
    
    Trains zen-vl-4b-instruct โ†’ zen-vl-4b-agent
    
    **Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k)
    **Total**: ~30k samples
    **Time**: ~6-8 hours on A10G
    **Output**: zenlm/zen-vl-4b-agent
    """)
    
    start_btn = gr.Button("๐Ÿš€ Start Training", variant="primary", size="lg")
    output = gr.Textbox(label="Training Logs", lines=30)
    
    start_btn.click(train_zen_vl, outputs=output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)