File size: 5,450 Bytes
dc56cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24ab89
 
dc56cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6130c96
dc56cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime

# The name of your existing training script
TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"

# --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)


# ----------------------------------------------------------------


def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
    """
    Handles the Gradio submission and executes the training script using subprocess.
    """

    # 1. Setup: Create output directory if it doesn't exist
    os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

    # 2. File Handling: Use the temporary path of the uploaded file
    # if dataset_file is None or not dataset_file.path.endswith(".json"):
    #     return "❌ ERROR: Please upload a valid Label Studio JSON file.", None

    input_path = dataset_file.path

    progress(0.1, desc="Starting LayoutLMv3 Training...")

    log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"

    # 3. Construct the subprocess command
    command = [
        sys.executable,
        TRAINING_SCRIPT,
        "--mode", "train",
        "--input", input_path,
        "--batch_size", str(batch_size),
        "--epochs", str(epochs),
        "--lr", str(lr),
        "--max_len", str(max_len)
    ]

    log_output += f"Executing command: {' '.join(command)}\n\n"

    try:
        # 4. Run the training script and capture output
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )

        # Stream logs in real-time
        for line in iter(process.stdout.readline, ""):
            log_output += line
            yield log_output, None  # Send partial log to Gradio output

        process.stdout.close()
        return_code = process.wait()

        # 5. Check for successful completion
        if return_code == 0:
            log_output += "\nβœ… TRAINING COMPLETE! Model saved."

            # 6. Prepare download links based on script's saved path
            model_exists = os.path.exists(MODEL_FILE_PATH)

            if model_exists:
                log_output += f"\nModel path: {MODEL_FILE_PATH}"
                # Return final log, and the file path for Gradio's download component
                return log_output, MODEL_FILE_PATH
            else:
                log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
                return log_output, None
        else:
            log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
            return log_output, None

    except FileNotFoundError:
        return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
    except Exception as e:
        return f"❌ An unexpected error occurred: {e}", None


# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
    gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
    gr.Markdown(
        """
        Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.

        **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            file_input = gr.File(
                label="1. Upload Label Studio JSON Dataset"
            )

            gr.Markdown("---")
            gr.Markdown("### βš™οΈ Training Parameters")

            batch_size_input = gr.Slider(
                minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
            )
            epochs_input = gr.Slider(
                minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
            )
            lr_input = gr.Number(
                value=5e-5, label="Learning Rate (--lr)"
            )
            max_len_input = gr.Number(
                value=512, label="Max Sequence Length (--max_len)"
            )

        with gr.Column(scale=2):
            train_button = gr.Button("πŸ”₯ Train Model", variant="primary")

            log_output = gr.Textbox(
                label="Training Log Output",
                lines=20,
                autoscroll=True,
                placeholder="Click 'Train Model' to start and see real-time logs..."
            )

            gr.Markdown("---")
            gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")

            # Only providing the download link for the saved .pth model file
            model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)

    # Define the action when the button is clicked
    train_button.click(
        fn=train_model,
        inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
        outputs=[log_output, model_download]
    )

if __name__ == "__main__":
    demo.launch(server_port=7860, server_name="0.0.0.0")