AutoGDataset / app.py
Nattapong Tapachoom
Refactor app.py to improve model loading and PDF processing; update dataset generation logic and enhance UI components
084df26
import os
import io
import re
import json
from datetime import datetime
from typing import List, Dict, Any, Tuple
import gradio as gr
from pypdf import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# โหลดโมเดลเริ่มต้น (default)
DEFAULT_MODEL = "HuggingFaceH4/zephyr-7b-beta"
# สร้าง pipeline global
gen_pipe = None
tokenizer = None
current_model_id = None
def load_model(model_id: str, hf_token: str = None):
global gen_pipe, tokenizer, current_model_id
if current_model_id == model_id and gen_pipe is not None:
return gen_pipe
print(f"🔄 Loading model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto")
gen_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
current_model_id = model_id
return gen_pipe
def ensure_output_dir() -> str:
outdir = os.path.join(os.getcwd(), "outputs")
os.makedirs(outdir, exist_ok=True)
return outdir
def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]:
docs = []
combined_text_parts: List[str] = []
for f in files:
path = f.name if hasattr(f, "name") else f
reader = PdfReader(path)
pages_text = []
for i, page in enumerate(reader.pages):
text = page.extract_text() or ""
text = re.sub(r"\s+", " ", text).strip()
if text:
pages_text.append({"page": i + 1, "text": text})
combined_text_parts.append(text)
docs.append({"file": os.path.basename(path), "pages": pages_text})
combined_text = "\n\n".join(combined_text_parts)
return combined_text, docs
def chunk_text(text: str, chunk_size: int = 1500, overlap: int = 200, max_chunks: int = 5) -> List[str]:
text = text.strip()
if not text:
return []
chunks: List[str] = []
start = 0
n = len(text)
while start < n and len(chunks) < max_chunks:
end = min(start + chunk_size, n)
chunk = text[start:end]
chunks.append(chunk)
if end >= n:
break
start = max(end - overlap, 0)
return chunks
# เทมเพลต prompt พื้นฐาน
DEFAULT_QA_PROMPT = (
"คุณเป็นผู้ช่วยสร้างชุดข้อมูล อ่านเนื้อหานี้แล้วสร้างคำถาม-คำตอบ "
"จำนวน {min_pairs} ถึง {max_pairs} คู่ "
"ส่งคืน JSON array ที่มี objects รูปแบบ {{\"question\": str, \"answer\": str}} เท่านั้น\n\n"
"เนื้อหา:\n{content}\n"
)
def generate_dataset(files: List[gr.File],
task: str,
preset_model: str,
custom_model_id: str,
hf_token: str,
chunk_size: int,
overlap: int,
max_chunks: int,
max_new_tokens: int,
temperature: float,
min_pairs: int,
max_pairs: int):
if not files:
return "❌ กรุณาอัปโหลดไฟล์ PDF", None, None
# โหลดโมเดล
model_id = (custom_model_id or "").strip() or preset_model or DEFAULT_MODEL
pipe = load_model(model_id, hf_token or None)
# อ่าน PDF และตัดเป็น chunk
full_text, _ = read_pdfs(files)
chunks = chunk_text(full_text, chunk_size, overlap, max_chunks)
if not chunks:
return "❌ ไม่สามารถดึงข้อความจาก PDF", None, None
results = []
for ch in chunks:
prompt = DEFAULT_QA_PROMPT.format(
min_pairs=min_pairs,
max_pairs=max_pairs,
content=ch
)
output = pipe(prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0.0)[0]["generated_text"]
# พยายาม extract JSON
start, end = output.find("["), output.rfind("]")
if start != -1 and end != -1:
try:
data = json.loads(output[start:end + 1])
if isinstance(data, list):
results.extend(data)
except Exception:
pass
if not results:
return "❌ ไม่สามารถสร้างข้อมูล JSON ได้", None, None
# Save output
outdir = ensure_output_dir()
ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
json_path = os.path.join(outdir, f"dataset_{task}_{ts}.json")
jsonl_path = os.path.join(outdir, f"dataset_{task}_{ts}.jsonl")
with io.open(json_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with io.open(jsonl_path, "w", encoding="utf-8") as f:
for item in results:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
return f"✅ สร้างข้อมูลสำเร็จ {len(results)} รายการ", json_path, jsonl_path
# ---------------- Gradio UI ----------------
PRESET_MODELS = [
DEFAULT_MODEL,
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
"google/flan-t5-large"
]
with gr.Blocks(title="Thai PDF → Dataset Generator") as demo:
gr.Markdown("# 📚 Thai Auto Dataset Generator")
with gr.Row():
pdf_files = gr.File(label="อัปโหลด PDF", file_count="multiple", file_types=[".pdf"])
with gr.Row():
task = gr.Textbox(label="Task", value="QA")
preset_model = gr.Dropdown(label="Preset Model", choices=PRESET_MODELS, value=DEFAULT_MODEL)
custom_model_id = gr.Textbox(label="Custom Model ID", placeholder="org/model-name")
hf_token = gr.Textbox(label="HF Token", type="password")
with gr.Row():
max_new_tokens = gr.Slider(64, 1024, value=512, step=16, label="Max New Tokens")
temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature")
with gr.Row():
chunk_size = gr.Slider(500, 4000, value=1500, step=50, label="Chunk Size")
overlap = gr.Slider(0, 1000, value=200, step=50, label="Overlap")
max_chunks = gr.Slider(1, 20, value=5, step=1, label="Max Chunks")
with gr.Row():
min_pairs = gr.Slider(1, 10, value=3, step=1, label="Min Pairs")
max_pairs = gr.Slider(1, 12, value=6, step=1, label="Max Pairs")
generate_btn = gr.Button("🚀 Generate Dataset")
status = gr.Markdown()
out_json = gr.File(label="JSON")
out_jsonl = gr.File(label="JSONL")
generate_btn.click(
fn=generate_dataset,
inputs=[pdf_files, task, preset_model, custom_model_id, hf_token,
chunk_size, overlap, max_chunks, max_new_tokens, temperature,
min_pairs, max_pairs],
outputs=[status, out_json, out_jsonl]
)
if __name__ == "__main__":
demo.queue().launch()