Spaces:
Running
Running
File size: 5,450 Bytes
dc56cce c24ab89 dc56cce 6130c96 dc56cce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
# The name of your existing training script
TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
# --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
# ----------------------------------------------------------------
def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
"""
Handles the Gradio submission and executes the training script using subprocess.
"""
# 1. Setup: Create output directory if it doesn't exist
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
# 2. File Handling: Use the temporary path of the uploaded file
# if dataset_file is None or not dataset_file.path.endswith(".json"):
# return "β ERROR: Please upload a valid Label Studio JSON file.", None
input_path = dataset_file.path
progress(0.1, desc="Starting LayoutLMv3 Training...")
log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 3. Construct the subprocess command
command = [
sys.executable,
TRAINING_SCRIPT,
"--mode", "train",
"--input", input_path,
"--batch_size", str(batch_size),
"--epochs", str(epochs),
"--lr", str(lr),
"--max_len", str(max_len)
]
log_output += f"Executing command: {' '.join(command)}\n\n"
try:
# 4. Run the training script and capture output
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Stream logs in real-time
for line in iter(process.stdout.readline, ""):
log_output += line
yield log_output, None # Send partial log to Gradio output
process.stdout.close()
return_code = process.wait()
# 5. Check for successful completion
if return_code == 0:
log_output += "\nβ
TRAINING COMPLETE! Model saved."
# 6. Prepare download links based on script's saved path
model_exists = os.path.exists(MODEL_FILE_PATH)
if model_exists:
log_output += f"\nModel path: {MODEL_FILE_PATH}"
# Return final log, and the file path for Gradio's download component
return log_output, MODEL_FILE_PATH
else:
log_output += f"\nβ οΈ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
return log_output, None
else:
log_output += f"\n\nβ TRAINING FAILED with return code {return_code}. Check logs above."
return log_output, None
except FileNotFoundError:
return f"β ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
except Exception as e:
return f"β An unexpected error occurred: {e}", None
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
gr.Markdown("# π LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
gr.Markdown(
"""
Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
**Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="1. Upload Label Studio JSON Dataset"
)
gr.Markdown("---")
gr.Markdown("### βοΈ Training Parameters")
batch_size_input = gr.Slider(
minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
)
epochs_input = gr.Slider(
minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
)
lr_input = gr.Number(
value=5e-5, label="Learning Rate (--lr)"
)
max_len_input = gr.Number(
value=512, label="Max Sequence Length (--max_len)"
)
with gr.Column(scale=2):
train_button = gr.Button("π₯ Train Model", variant="primary")
log_output = gr.Textbox(
label="Training Log Output",
lines=20,
autoscroll=True,
placeholder="Click 'Train Model' to start and see real-time logs..."
)
gr.Markdown("---")
gr.Markdown(f"### π Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
# Only providing the download link for the saved .pth model file
model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
# Define the action when the button is clicked
train_button.click(
fn=train_model,
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
outputs=[log_output, model_download]
)
if __name__ == "__main__":
demo.launch(server_port=7860, server_name="0.0.0.0") |