LayoutLM_train / app.py
aagamjtdev's picture
app.py
c24ab89
raw
history blame
5.45 kB
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
# The name of your existing training script
TRAINING_SCRIPT = "LayoutLM_Train_Passage.py"
# --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py ---
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
# ----------------------------------------------------------------
def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
"""
Handles the Gradio submission and executes the training script using subprocess.
"""
# 1. Setup: Create output directory if it doesn't exist
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
# 2. File Handling: Use the temporary path of the uploaded file
# if dataset_file is None or not dataset_file.path.endswith(".json"):
# return "❌ ERROR: Please upload a valid Label Studio JSON file.", None
input_path = dataset_file.path
progress(0.1, desc="Starting LayoutLMv3 Training...")
log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 3. Construct the subprocess command
command = [
sys.executable,
TRAINING_SCRIPT,
"--mode", "train",
"--input", input_path,
"--batch_size", str(batch_size),
"--epochs", str(epochs),
"--lr", str(lr),
"--max_len", str(max_len)
]
log_output += f"Executing command: {' '.join(command)}\n\n"
try:
# 4. Run the training script and capture output
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Stream logs in real-time
for line in iter(process.stdout.readline, ""):
log_output += line
yield log_output, None # Send partial log to Gradio output
process.stdout.close()
return_code = process.wait()
# 5. Check for successful completion
if return_code == 0:
log_output += "\nβœ… TRAINING COMPLETE! Model saved."
# 6. Prepare download links based on script's saved path
model_exists = os.path.exists(MODEL_FILE_PATH)
if model_exists:
log_output += f"\nModel path: {MODEL_FILE_PATH}"
# Return final log, and the file path for Gradio's download component
return log_output, MODEL_FILE_PATH
else:
log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})."
return log_output, None
else:
log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above."
return log_output, None
except FileNotFoundError:
return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None
except Exception as e:
return f"❌ An unexpected error occurred: {e}", None
# --- Gradio Interface Setup (using Blocks for a nicer layout) ---
with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo:
gr.Markdown("# πŸš€ LayoutLMv3 Fine-Tuning on Hugging Face Spaces")
gr.Markdown(
"""
Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script.
**Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**.
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="1. Upload Label Studio JSON Dataset"
)
gr.Markdown("---")
gr.Markdown("### βš™οΈ Training Parameters")
batch_size_input = gr.Slider(
minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)"
)
epochs_input = gr.Slider(
minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)"
)
lr_input = gr.Number(
value=5e-5, label="Learning Rate (--lr)"
)
max_len_input = gr.Number(
value=512, label="Max Sequence Length (--max_len)"
)
with gr.Column(scale=2):
train_button = gr.Button("πŸ”₯ Train Model", variant="primary")
log_output = gr.Textbox(
label="Training Log Output",
lines=20,
autoscroll=True,
placeholder="Click 'Train Model' to start and see real-time logs..."
)
gr.Markdown("---")
gr.Markdown(f"### πŸŽ‰ Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)")
# Only providing the download link for the saved .pth model file
model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False)
# Define the action when the button is clicked
train_button.click(
fn=train_model,
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
outputs=[log_output, model_download]
)
if __name__ == "__main__":
demo.launch(server_port=7860, server_name="0.0.0.0")