import gradio as gr import subprocess import os import sys from datetime import datetime # The name of your existing training script TRAINING_SCRIPT = "LayoutLM_Train_Passage.py" # --- CORRECTED MODEL PATH BASED ON LayoutLM_Train_Passage.py --- MODEL_OUTPUT_DIR = "checkpoints" MODEL_FILE_NAME = "layoutlmv3_crf_passage.pth" MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME) # ---------------------------------------------------------------- def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()): """ Handles the Gradio submission and executes the training script using subprocess. """ # 1. Setup: Create output directory if it doesn't exist os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True) # 2. File Handling: Use the temporary path of the uploaded file # if dataset_file is None or not dataset_file.path.endswith(".json"): # return "❌ ERROR: Please upload a valid Label Studio JSON file.", None input_path = dataset_file.path progress(0.1, desc="Starting LayoutLMv3 Training...") log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" # 3. Construct the subprocess command command = [ sys.executable, TRAINING_SCRIPT, "--mode", "train", "--input", input_path, "--batch_size", str(batch_size), "--epochs", str(epochs), "--lr", str(lr), "--max_len", str(max_len) ] log_output += f"Executing command: {' '.join(command)}\n\n" try: # 4. Run the training script and capture output process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1 ) # Stream logs in real-time for line in iter(process.stdout.readline, ""): log_output += line yield log_output, None # Send partial log to Gradio output process.stdout.close() return_code = process.wait() # 5. Check for successful completion if return_code == 0: log_output += "\n✅ TRAINING COMPLETE! Model saved." # 6. Prepare download links based on script's saved path model_exists = os.path.exists(MODEL_FILE_PATH) if model_exists: log_output += f"\nModel path: {MODEL_FILE_PATH}" # Return final log, and the file path for Gradio's download component return log_output, MODEL_FILE_PATH else: log_output += f"\n⚠️ WARNING: Training completed, but model file not found at expected path ({MODEL_FILE_PATH})." return log_output, None else: log_output += f"\n\n❌ TRAINING FAILED with return code {return_code}. Check logs above." return log_output, None except FileNotFoundError: return f"❌ ERROR: The training script '{TRAINING_SCRIPT}' was not found. Ensure it is in the root directory of your Space.", None except Exception as e: return f"❌ An unexpected error occurred: {e}", None # --- Gradio Interface Setup (using Blocks for a nicer layout) --- with gr.Blocks(title="LayoutLMv3 Fine-Tuning App") as demo: gr.Markdown("# 🚀 LayoutLMv3 Fine-Tuning on Hugging Face Spaces") gr.Markdown( """ Upload your Label Studio JSON file, set your hyperparameters, and click **Train Model** to fine-tune the LayoutLMv3 model using your script. **Note:** The trained model is saved in the **`checkpoints/`** folder as **`layoutlmv3_crf_passage.pth`**. """ ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="1. Upload Label Studio JSON Dataset" ) gr.Markdown("---") gr.Markdown("### ⚙️ Training Parameters") batch_size_input = gr.Slider( minimum=1, maximum=32, step=1, value=4, label="Batch Size (--batch_size)" ) epochs_input = gr.Slider( minimum=1, maximum=20, step=1, value=5, label="Epochs (--epochs)" ) lr_input = gr.Number( value=5e-5, label="Learning Rate (--lr)" ) max_len_input = gr.Number( value=512, label="Max Sequence Length (--max_len)" ) with gr.Column(scale=2): train_button = gr.Button("🔥 Train Model", variant="primary") log_output = gr.Textbox( label="Training Log Output", lines=20, autoscroll=True, placeholder="Click 'Train Model' to start and see real-time logs..." ) gr.Markdown("---") gr.Markdown(f"### 🎉 Trained Model Output (Saved to `{MODEL_OUTPUT_DIR}/`)") # Only providing the download link for the saved .pth model file model_download = gr.File(label=f"Trained Model File ({MODEL_FILE_NAME})", interactive=False) # Define the action when the button is clicked train_button.click( fn=train_model, inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input], outputs=[log_output, model_download] ) if __name__ == "__main__": demo.launch(server_port=7860, server_name="0.0.0.0")