Nattapong Tapachoom commited on
Commit
084df26
·
1 Parent(s): b410a7f

Refactor app.py to improve model loading and PDF processing; update dataset generation logic and enhance UI components

Browse files
Files changed (1) hide show
  1. app.py +110 -640
app.py CHANGED
@@ -6,28 +6,30 @@ from datetime import datetime
6
  from typing import List, Dict, Any, Tuple
7
 
8
  import gradio as gr
 
9
 
10
- # Detect if OAuth is available (enabled on Spaces when hf_oauth: true)
11
- OAUTH_AVAILABLE = bool(os.getenv("OAUTH_CLIENT_ID"))
12
 
13
- # Require login to use the app. Defaults to on only when OAuth is available.
14
- _default_require = "1" if OAUTH_AVAILABLE else "0"
15
- REQUIRE_LOGIN = os.getenv("REQUIRE_LOGIN", _default_require).strip().lower() in ("1", "true", "yes", "y")
16
 
17
- try:
18
- from pypdf import PdfReader
19
- except Exception: # pragma: no cover - lazy import warning only
20
- PdfReader = None # type: ignore
21
 
22
- # LangChain components
23
- try:
24
- from langchain_core.prompts import PromptTemplate
25
- from langchain_core.output_parsers import JsonOutputParser
26
- from langchain_huggingface import HuggingFaceEndpoint
27
- except Exception:
28
- PromptTemplate = None # type: ignore
29
- JsonOutputParser = None # type: ignore
30
- HuggingFaceEndpoint = None # type: ignore
 
 
 
31
 
32
 
33
  def ensure_output_dir() -> str:
@@ -37,11 +39,6 @@ def ensure_output_dir() -> str:
37
 
38
 
39
  def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]:
40
- if not files:
41
- return "", []
42
- if PdfReader is None:
43
- raise RuntimeError("pypdf is not installed. Please add it to requirements.txt or pip install pypdf.")
44
-
45
  docs = []
46
  combined_text_parts: List[str] = []
47
  for f in files:
@@ -49,11 +46,7 @@ def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]:
49
  reader = PdfReader(path)
50
  pages_text = []
51
  for i, page in enumerate(reader.pages):
52
- try:
53
- text = page.extract_text() or ""
54
- except Exception:
55
- text = ""
56
- # Normalize whitespace
57
  text = re.sub(r"\s+", " ", text).strip()
58
  if text:
59
  pages_text.append({"page": i + 1, "text": text})
@@ -63,665 +56,142 @@ def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]:
63
  return combined_text, docs
64
 
65
 
66
- def chunk_text(text: str, chunk_size: int = 1500, overlap: int = 200, max_chunks: int = 5) -> List[Dict[str, Any]]:
67
  text = text.strip()
68
  if not text:
69
  return []
70
- chunks: List[Dict[str, Any]] = []
71
  start = 0
72
  n = len(text)
73
  while start < n and len(chunks) < max_chunks:
74
  end = min(start + chunk_size, n)
75
  chunk = text[start:end]
76
- # try to end on a sentence boundary
77
- if end < n:
78
- m = re.search(r"[\.!?]\s", text[end - 200:end] if end - 200 > start else text[start:end])
79
- if m:
80
- end = start + (m.end())
81
- chunk = text[start:end]
82
- chunks.append({"index": len(chunks), "start": start, "end": end, "text": chunk})
83
  if end >= n:
84
  break
85
  start = max(end - overlap, 0)
86
- if start == end: # safety
87
- start += 1
88
  return chunks
89
 
90
 
91
- DEFAULT_QA_PROMPT_TMPL = (
92
- 'คุณเป็นผู้สร้างชุดข้อมูลที่เป็นประโยชน์ อ่านเนื้อหาที่ให้มาและสร้างคู่คำถาม-คำตอบที่มีคุณภาพสูงและตรงตามข้อเท็จจริง จำนวน {min_pairs} ถึง {max_pairs} คู่ '
93
- 'ส่งคืนเฉพาะ JSON array ที่มี objects ในรูปแบบ {{"question": str, "answer": str}} เท่านั้น ไม่ต้องใส่ข้อความเพิ่มเติม คำอธิบาย หรือ code fences\n\n'
94
- 'เนื้อหา:\n{content}\n'
 
 
95
  )
96
 
97
- TASK_TEMPLATES: Dict[str, str] = {
98
- "QA": DEFAULT_QA_PROMPT_TMPL,
99
- "Summarization": (
100
- 'สรุปเนื้อหาต่อไปนี้เป็นบทสรุปที่กระชับ จำนวน {min_pairs} ถึง {max_pairs} บทสรุป โดยครอบคลุมข้อมูลสำคัญ '
101
- 'ส่งคืนเฉพาะ JSON array ที่มี objects ในรูปแบบ {{"summary": str}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
102
- 'เนื้อหา:\n{content}\n'
103
- ),
104
- "Keywords": (
105
- 'แยกคำสำคัญหรือวลีสำคัญจากเนื้อหา จำนวน {min_pairs} ถึง {max_pairs} คำ '
106
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"keyword": str}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
107
- 'เนื้อหา:\n{content}\n'
108
- ),
109
- "NER": (
110
- 'แยกเอนทิตีที่มีชื่อเฉพาะจากเนื้อหา ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"text": str, "label": str, "start": int, "end": int}} '
111
- 'ป้ายกำกับควรเป็นประเภทมาตรฐาน เช่น PER (บุคคล), ORG (องค์กร), LOC (สถานที่), MISC (อื่นๆ){ner_labels_clause}\n\n'
112
- 'เนื้อหา:\n{content}\n'
113
- ),
114
- "Classification": (
115
- 'จำแนกเนื้อหาตามป้ายกำกับต่อไปนี้: {labels} {multi_label_clause} '
116
- 'ส่งคืนเฉพาะ JSON array ที่มี objects ในรูปแบบ {{"labels": [str], "rationale": str}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
117
- 'เนื้อหา:\n{content}\n'
118
- ),
119
- "MCQ": (
120
- 'สร้างคำถามแบบเลือกตอบจากเนื้อหา จำนวน {min_pairs} ถึง {max_pairs} ข้อ แต่ละข้อมี {num_options} ตัวเลือก '
121
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"question": str, "options": [str], "answer_index": int}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
122
- 'เนื้อหา:\n{content}\n'
123
- ),
124
- "True/False": (
125
- 'สร้างข้อความจริง/เท็จที่อิงจากเนื้อหาเท่านั้น จำนวน {min_pairs} ถึง {max_pairs} ข้อความ '
126
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"statement": str, "answer": bool, "explanation": str}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
127
- 'เนื้อหา:\n{content}\n'
128
- ),
129
- "Translation": (
130
- 'แปลเนื้อหาเป็น{target_language} สร้างคู่ประโยคแบบคู่ขนาน จำนวน {min_pairs} ถึง {max_pairs} คู่ '
131
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"source": str, "target": str}} เท่านั้น ไม่ต้องมีข้อความเพิ่มเติม\n\n'
132
- 'เนื้อหา:\n{content}\n'
133
- ),
134
- "RLHF": (
135
- 'สร้างข้อมูลสำหรับ Reinforcement Learning from Human Feedback (RLHF) จากเนื้อหานี้ '
136
- 'สร้างคำถามและการตอบสนองหลายแบบ พร้อมคะแนนความต้องการของมนุษย์ จำนวน {min_pairs} ถึง {max_pairs} ชุด '
137
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"prompt": str, "responses": [str], "scores": [float], "preferred_response": str}} เท่านั้น\n\n'
138
- 'เนื้อหา:\n{content}\n'
139
- ),
140
- "DPO": (
141
- 'สร้างข้อมูลสำหรับ Direct Preference Optimization (DPO) จากเนื้อหานี้ '
142
- 'สร้างคำถามพร้อมการตอบสนองที่ดีและไม่ดี จำนวน {min_pairs} ถึง {max_pairs} คู่ '
143
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"prompt": str, "chosen": str, "rejected": str, "reason": str}} เท่านั้น\n\n'
144
- 'เนื้อหา:\n{content}\n'
145
- ),
146
- "Instruction_Following": (
147
- 'สร้างคำสั่งและการตอบสนองสำหรับการฝึกการทำตามคำสั่ง จำนวน {min_pairs} ถึง {max_pairs} คู่ '
148
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"instruction": str, "input": str, "output": str, "difficulty": str}} เท่านั้น\n\n'
149
- 'เนื้อหา:\n{content}\n'
150
- ),
151
- "Constitutional_AI": (
152
- 'สร้างข้อมูลสำหรับ Constitutional AI โดยสร้างคำถามที่อาจมีปัญหาทางจริยธรรมและคำตอบที่เหมาะสม '
153
- 'จำนวน {min_pairs} ถึง {max_pairs} คู่ '
154
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"problematic_prompt": str, "constitutional_response": str, "principle": str}} เท่านั้น\n\n'
155
- 'เนื้อหา:\n{content}\n'
156
- ),
157
- "Chain_of_Thought": (
158
- 'สร้างตัวอย่างการคิดแบบขั้นตอน (Chain of Thought) จากเนื้อหา จำนวน {min_pairs} ถึง {max_pairs} ตัวอย่าง '
159
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"problem": str, "thinking_steps": [str], "final_answer": str}} เท่านั้น\n\n'
160
- 'เนื้อหา:\n{content}\n'
161
- ),
162
- "Dialogue": (
163
- 'สร้างบทสนทนาระหว่างผู้ใช้และผู้ช่วย AI จากเนื้อหา จำนวน {min_pairs} ถึง {max_pairs} บทสนทนา '
164
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"dialogue": [{{"role": str, "content": str}}], "context": str}} เท่านั้น\n\n'
165
- 'เนื้อหา:\n{content}\n'
166
- ),
167
- "Thai_Culture": (
168
- 'สร้างคำถาม-คำตอบเกี่ยวกับวัฒนธรรมไทยจากเนื้อหา เน้นความเข้าใจภาษาไทยและบริบททางวัฒนธรรม '
169
- 'จำนวน {min_pairs} ถึง {max_pairs} คู่ '
170
- 'ส่งคืนเฉพาะ JSON array ของ objects ที่มี {{"question_th": str, "answer_th": str, "cultural_context": str}} เท่านั้น\n\n'
171
- 'เนื้อหา:\n{content}\n'
172
- ),
173
- }
174
-
175
-
176
- def extract_json_array(text: str) -> List[Dict[str, Any]]:
177
- if not text:
178
- return []
179
- # Remove code fences
180
- text = re.sub(r"```[a-zA-Z]*", "```", text)
181
- text = text.replace("```", "")
182
- # Find first [ ... ] block
183
- start = text.find("[")
184
- end = text.rfind("]")
185
- if start != -1 and end != -1 and end > start:
186
- candidate = text[start : end + 1]
187
- else:
188
- candidate = text
189
- try:
190
- data = json.loads(candidate)
191
- if isinstance(data, list):
192
- # normalize
193
- norm = []
194
- for item in data:
195
- if not isinstance(item, dict):
196
- continue
197
- q = str(item.get("question", "").strip())
198
- a = str(item.get("answer", "").strip())
199
- if q and a:
200
- norm.append({"question": q, "answer": a})
201
- return norm
202
- except Exception:
203
- pass
204
- return []
205
-
206
-
207
- def build_langchain(model_id: str, hf_token: str | None, max_new_tokens: int, temperature: float, template: str):
208
- if any(x is None for x in [PromptTemplate, JsonOutputParser, HuggingFaceEndpoint]):
209
- raise RuntimeError("langchain, langchain-community, and langchain-huggingface are required. Please add to requirements.txt.")
210
- # Prompt
211
- prompt = PromptTemplate.from_template(template)
212
- # Model wrapper (Hugging Face Inference API)
213
- llm = HuggingFaceEndpoint(
214
- model=model_id,
215
- token=hf_token,
216
- task="text-generation",
217
- max_new_tokens=max_new_tokens,
218
- temperature=temperature,
219
- do_sample=temperature > 0.0,
220
- )
221
- parser = JsonOutputParser()
222
- chain = prompt | llm | parser
223
- return chain
224
-
225
-
226
- def get_task_template(task: str, custom_instruction: str | None) -> str:
227
- base = TASK_TEMPLATES.get(task, DEFAULT_QA_PROMPT_TMPL)
228
- if custom_instruction and custom_instruction.strip():
229
- # Allow user to override fully, but ensure {content} is present
230
- if "{content}" not in custom_instruction:
231
- custom_instruction = custom_instruction.strip() + "\n\nContent:\n{content}\n"
232
- return custom_instruction
233
- return base
234
-
235
-
236
- def normalize_items(task: str, data: Any) -> List[Dict[str, Any]]:
237
- # Convert model output to list[dict] per task
238
- items: List[Dict[str, Any]] = []
239
- if data is None:
240
- return items
241
- if isinstance(data, str):
242
- data = extract_json_array(data)
243
- if isinstance(data, dict):
244
- # handle wrappers like {"items": [...]}
245
- if "items" in data and isinstance(data["items"], list):
246
- data = data["items"]
247
- else:
248
- data = [data]
249
- if isinstance(data, list):
250
- # keywords may be list[str]
251
- if task == "Keywords" and data and all(isinstance(x, str) for x in data):
252
- return [{"keyword": x} for x in data if x]
253
- for el in data:
254
- if isinstance(el, dict):
255
- items.append(el)
256
- # Validate per-task required fields and normalize variants
257
- norm: List[Dict[str, Any]] = []
258
- for it in items:
259
- if task == "QA":
260
- q = str(it.get("question", "")).strip()
261
- a = str(it.get("answer", "")).strip()
262
- if q and a:
263
- norm.append({"question": q, "answer": a})
264
- elif task == "Summarization":
265
- s = str(it.get("summary", "")).strip()
266
- if s:
267
- norm.append({"summary": s})
268
- elif task == "Keywords":
269
- k = it.get("keyword")
270
- if isinstance(k, str) and k.strip():
271
- norm.append({"keyword": k.strip()})
272
- elif isinstance(it.get("keywords"), list):
273
- for kw in it["keywords"]:
274
- if isinstance(kw, str) and kw.strip():
275
- norm.append({"keyword": kw.strip()})
276
- elif task == "NER":
277
- txt = it.get("text")
278
- label = it.get("label")
279
- start = it.get("start")
280
- end = it.get("end")
281
- if isinstance(txt, str) and isinstance(label, str) and isinstance(start, int) and isinstance(end, int):
282
- norm.append({"text": txt, "label": label, "start": start, "end": end})
283
- elif isinstance(it.get("entities"), list):
284
- for ent in it["entities"]:
285
- if all(k in ent for k in ("text", "label", "start", "end")):
286
- norm.append({
287
- "text": str(ent.get("text", "")),
288
- "label": str(ent.get("label", "")),
289
- "start": int(ent.get("start", 0)),
290
- "end": int(ent.get("end", 0)),
291
- })
292
- elif task == "Classification":
293
- labels = it.get("labels")
294
- if isinstance(labels, str):
295
- labels = [labels]
296
- if isinstance(labels, list):
297
- labels = [str(x).strip() for x in labels if str(x).strip()]
298
- rationale = str(it.get("rationale", "")).strip()
299
- if labels:
300
- norm.append({"labels": labels, "rationale": rationale})
301
- elif task == "MCQ":
302
- q = it.get("question")
303
- options = it.get("options")
304
- answer_index = it.get("answer_index")
305
- answer = it.get("answer")
306
- if isinstance(options, list) and all(isinstance(o, str) for o in options) and isinstance(q, str):
307
- if isinstance(answer_index, int):
308
- idx = answer_index
309
- elif isinstance(answer, str) and answer in options:
310
- idx = options.index(answer)
311
- else:
312
- continue
313
- norm.append({"question": q, "options": options, "answer_index": idx})
314
- elif task == "True/False":
315
- st = it.get("statement")
316
- ans = it.get("answer")
317
- expl = it.get("explanation", "")
318
- if isinstance(st, str):
319
- if isinstance(ans, bool):
320
- val = ans
321
- elif isinstance(ans, str):
322
- val = ans.strip().lower() in ("true", "t", "yes", "1")
323
- else:
324
- continue
325
- norm.append({"statement": st, "answer": val, "explanation": str(expl)})
326
- elif task == "Translation":
327
- src = it.get("source")
328
- tgt = it.get("target")
329
- if isinstance(src, str) and isinstance(tgt, str) and src.strip() and tgt.strip():
330
- norm.append({"source": src, "target": tgt})
331
- elif task == "RLHF":
332
- prompt = it.get("prompt")
333
- responses = it.get("responses")
334
- scores = it.get("scores")
335
- preferred = it.get("preferred_response")
336
- if isinstance(prompt, str) and isinstance(responses, list) and isinstance(scores, list):
337
- norm.append({
338
- "prompt": prompt,
339
- "responses": responses,
340
- "scores": scores,
341
- "preferred_response": str(preferred) if preferred else ""
342
- })
343
- elif task == "DPO":
344
- prompt = it.get("prompt")
345
- chosen = it.get("chosen")
346
- rejected = it.get("rejected")
347
- reason = it.get("reason", "")
348
- if isinstance(prompt, str) and isinstance(chosen, str) and isinstance(rejected, str):
349
- norm.append({
350
- "prompt": prompt,
351
- "chosen": chosen,
352
- "rejected": rejected,
353
- "reason": str(reason)
354
- })
355
- elif task == "Instruction_Following":
356
- instruction = it.get("instruction")
357
- input_text = it.get("input", "")
358
- output = it.get("output")
359
- difficulty = it.get("difficulty", "medium")
360
- if isinstance(instruction, str) and isinstance(output, str):
361
- norm.append({
362
- "instruction": instruction,
363
- "input": str(input_text),
364
- "output": output,
365
- "difficulty": str(difficulty)
366
- })
367
- elif task == "Constitutional_AI":
368
- problematic = it.get("problematic_prompt")
369
- constitutional = it.get("constitutional_response")
370
- principle = it.get("principle", "")
371
- if isinstance(problematic, str) and isinstance(constitutional, str):
372
- norm.append({
373
- "problematic_prompt": problematic,
374
- "constitutional_response": constitutional,
375
- "principle": str(principle)
376
- })
377
- elif task == "Chain_of_Thought":
378
- problem = it.get("problem")
379
- steps = it.get("thinking_steps")
380
- answer = it.get("final_answer")
381
- if isinstance(problem, str) and isinstance(steps, list) and isinstance(answer, str):
382
- norm.append({
383
- "problem": problem,
384
- "thinking_steps": steps,
385
- "final_answer": answer
386
- })
387
- elif task == "Dialogue":
388
- dialogue = it.get("dialogue")
389
- context = it.get("context", "")
390
- if isinstance(dialogue, list):
391
- norm.append({
392
- "dialogue": dialogue,
393
- "context": str(context)
394
- })
395
- elif task == "Thai_Culture":
396
- question_th = it.get("question_th")
397
- answer_th = it.get("answer_th")
398
- cultural_context = it.get("cultural_context", "")
399
- if isinstance(question_th, str) and isinstance(answer_th, str):
400
- norm.append({
401
- "question_th": question_th,
402
- "answer_th": answer_th,
403
- "cultural_context": str(cultural_context)
404
- })
405
- return norm
406
-
407
-
408
- def generate_dataset(
409
- user_profile: Any | None,
410
- files: List[gr.File],
411
- task: str,
412
- preset_model: str,
413
- custom_model_id: str,
414
- hf_token: str,
415
- chunk_size: int,
416
- overlap: int,
417
- max_chunks: int,
418
- max_new_tokens: int,
419
- temperature: float,
420
- custom_instruction: str,
421
- min_pairs: int,
422
- max_pairs: int,
423
- class_labels_text: str,
424
- multi_label: bool,
425
- target_language: str,
426
- num_options: int,
427
- ner_labels_text: str,
428
- ):
429
- # Enforce login if required
430
- if REQUIRE_LOGIN and not user_profile:
431
- return "กรุณาเข้าสู่ระบบก่อนเพื่อสร้างชุดข้อมูล", None, None
432
-
433
- # Read and chunk
434
- full_text, _docs = read_pdfs(files)
435
- chunks = chunk_text(full_text, chunk_size=chunk_size, overlap=overlap, max_chunks=max_chunks)
436
  if not chunks:
437
- return "ไม่สามารถแยกข้อความจากไฟล์ PDF ได้", None, None
438
-
439
- model_id = (custom_model_id or "").strip() or preset_model
440
- # Prepare template per task
441
- base_template = get_task_template(task, custom_instruction)
442
- # enrich template with conditional clauses
443
- ner_clause = ""
444
- if ner_labels_text.strip():
445
- ner_clause = f" (limit to: {ner_labels_text.strip()})"
446
- base_template = base_template.replace("{ner_labels_clause}", ner_clause)
447
- if "{labels}" in base_template:
448
- labels_text = class_labels_text.strip() or "[]"
449
- base_template = base_template.replace("{labels}", labels_text)
450
- if "{multi_label_clause}" in base_template:
451
- base_template = base_template.replace("{multi_label_clause}", " Allow multiple labels." if multi_label else " Choose a single best label.")
452
- if "{num_options}" in base_template:
453
- base_template = base_template.replace("{num_options}", str(int(num_options)))
454
- try:
455
- chain = build_langchain(model_id, hf_token or None, max_new_tokens, temperature, base_template)
456
- except Exception as e:
457
- return f"ข้อผิดพลาดในการเตรียม LangChain: {e}", None, None
458
-
459
- results: List[Dict[str, Any]] = []
460
  for ch in chunks:
461
- try:
462
- variables = {"content": ch["text"], "min_pairs": min_pairs, "max_pairs": max_pairs}
463
- if "{target_language}" in base_template:
464
- variables["target_language"] = target_language or "English"
465
- data = chain.invoke(variables)
466
- items = normalize_items(task, data)
467
- except Exception:
468
- # If parser fails, try best-effort extraction on raw string
 
 
 
 
 
469
  try:
470
- raw = (PromptTemplate.from_template(base_template) | HuggingFaceEndpoint(model=model_id, token=hf_token, task="text-generation")).invoke(variables) # type: ignore
471
- items = normalize_items(task, raw)
 
472
  except Exception:
473
- items = []
474
-
475
- for it in items:
476
- # Enrich with context and task
477
- it["context"] = (ch["text"][:500] + ("..." if len(ch["text"]) > 500 else ""))
478
- it["task"] = task
479
- results.append(it)
480
 
481
  if not results:
482
- return f"โมเดลไม่ได้ส่งคืนข้อมูลที่ถูกต้องสำหรับงาน {task} ลองปรับ prompt หรือโมเดล", None, None
483
-
484
- # Deduplicate per task key
485
- unique: List[Dict[str, Any]] = []
486
- seen = set()
487
- def key_of(item: Dict[str, Any]) -> str:
488
- if task == "QA":
489
- return (item.get("question") or "").strip().lower()
490
- if task == "Summarization":
491
- return (item.get("summary") or "").strip().lower()
492
- if task == "Keywords":
493
- return (item.get("keyword") or "").strip().lower()
494
- if task == "NER":
495
- return f"{item.get('text')}|{item.get('label')}|{item.get('start')}|{item.get('end')}"
496
- if task == "Classification":
497
- return ",".join(sorted([str(x).lower() for x in item.get("labels", [])]))
498
- if task == "MCQ":
499
- return (item.get("question") or "").strip().lower()
500
- if task == "True/False":
501
- return (item.get("statement") or "").strip().lower()
502
- if task == "Translation":
503
- return f"{item.get('source')}|{item.get('target')}"
504
- if task == "RLHF":
505
- return (item.get("prompt") or "").strip().lower()
506
- if task == "DPO":
507
- return (item.get("prompt") or "").strip().lower()
508
- if task == "Instruction_Following":
509
- return (item.get("instruction") or "").strip().lower()
510
- if task == "Constitutional_AI":
511
- return (item.get("problematic_prompt") or "").strip().lower()
512
- if task == "Chain_of_Thought":
513
- return (item.get("problem") or "").strip().lower()
514
- if task == "Dialogue":
515
- dialogue = item.get("dialogue", [])
516
- if dialogue and isinstance(dialogue, list):
517
- return str(dialogue[0].get("content", "")).strip().lower()
518
- return ""
519
- if task == "Thai_Culture":
520
- return (item.get("question_th") or "").strip().lower()
521
- return json.dumps(item, ensure_ascii=False)
522
- for r in results:
523
- k = key_of(r)
524
- if k and k not in seen:
525
- unique.append(r)
526
- seen.add(k)
527
-
528
- # Save to outputs
529
  outdir = ensure_output_dir()
530
  ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
531
- safe_task = task.lower().replace("/", "-").replace(" ", "_")
532
- json_path = os.path.join(outdir, f"dataset_{safe_task}_{ts}.json")
533
- jsonl_path = os.path.join(outdir, f"dataset_{safe_task}_{ts}.jsonl")
534
  with io.open(json_path, "w", encoding="utf-8") as f:
535
- json.dump(unique, f, ensure_ascii=False, indent=2)
536
  with io.open(jsonl_path, "w", encoding="utf-8") as f:
537
- for item in unique:
538
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
539
 
540
- return f"สร้างข้อมูลสำเร็จ {len(unique)} รายการสำหรับงาน: {task} 🎉", json_path, jsonl_path
541
 
542
 
 
543
  PRESET_MODELS = [
544
- # Thai-capable models
545
- "openthaigpt/openthaigpt-1.0.0-alpha-7b-chat",
546
- "scb10x/llama-3-typhoon-v1.5-8b-instruct",
547
- "airesearch/wangchanberta-base-att-spm-uncased",
548
-
549
- # Multilingual models good for Thai
550
- "google/mt5-large",
551
- "microsoft/mdeberta-v3-base",
552
- "facebook/xglm-7.5B",
553
- "microsoft/DialoGPT-medium",
554
-
555
- # General powerful models
556
- "HuggingFaceH4/zephyr-7b-beta",
557
  "mistralai/Mistral-7B-Instruct-v0.2",
558
- "google/flan-t5-large",
559
  "meta-llama/Llama-2-7b-chat-hf",
560
- "microsoft/DialoGPT-large",
561
  ]
562
 
 
 
563
 
564
- with gr.Blocks(title="AutoGDataset Thai - PDF to Dataset Generator") as demo:
565
- gr.Markdown("""
566
- # AutoGDataset Thai 🇹🇭
567
- สร้างชุดข้อมูล (Dataset) ภาษาไทยจากไฟล์ PDF โดยใช้ LangChain กับโมเดล Hugging Face
568
-
569
- **คุณสมบัติ:**
570
- - รองรับงานหลากหลายประเภท: QA, RLHF, DPO, Constitutional AI และอื่นๆ
571
- - เน้นการสร้างข้อมูลภาษาไทยคุณภาพสูง
572
- - รองรับโมเดลภาษาไทยและ multilingual models
573
- - สามารถปรับแต่ง prompt เพื่อเพิ่มประสิทธิภาพ
574
-
575
- เลือกโมเดลที่มีอยู่หรือระบุ repo id ที่กำหนดเอง ระบุ `HF_TOKEN` หากจำเป็นสำหรับโมเดล
576
- """)
577
-
578
- # Login requirement (Hugging Face OAuth via Gradio LoginButton when available)
579
- user_state = gr.State(value=None)
580
- effective_require_login = bool(REQUIRE_LOGIN and OAUTH_AVAILABLE)
581
  with gr.Row():
582
- login_info = gr.Markdown(
583
- value=(
584
- "กรุณาเข้าสู่ระบบด้วยบัญชี Hugging Face เพื่อใช้งานแอป"
585
- if effective_require_login
586
- else (
587
- "การเข้าสู่ระบบเป็นทางเลือก" if OAUTH_AVAILABLE else "ไม่ได้ตั้งค่าการเข้าสู่ระบบ OAuth ในการติดตั้งนี้"
588
- )
589
- ),
590
- elem_id="login-info",
591
- )
592
- if OAUTH_AVAILABLE:
593
- with gr.Row():
594
- login_btn = gr.LoginButton(value="เข้าสู่ระบบด้วย Hugging Face")
595
 
596
  with gr.Row():
597
- pdf_files = gr.File(label="อัปโหลดไฟล์ PDF", file_count="multiple", file_types=[".pdf"])
598
-
599
- with gr.Group():
600
- with gr.Row():
601
- task = gr.Dropdown(
602
- label="งานที่ต้องการ (Task Type)",
603
- choices=[
604
- "QA",
605
- "Summarization",
606
- "Keywords",
607
- "NER",
608
- "Classification",
609
- "MCQ",
610
- "True/False",
611
- "Translation",
612
- "RLHF",
613
- "DPO",
614
- "Instruction_Following",
615
- "Constitutional_AI",
616
- "Chain_of_Thought",
617
- "Dialogue",
618
- "Thai_Culture",
619
- ],
620
- value="Thai_Culture",
621
- )
622
- with gr.Row():
623
- preset_model = gr.Dropdown(label="โมเดลที่กำหนดไว้ (Preset Model)", choices=PRESET_MODELS, value=PRESET_MODELS[0])
624
- custom_model_id = gr.Textbox(label="รหัสโมเดลกำหนดเอง (ไม่บังคับ)", placeholder="org/model-name")
625
- with gr.Row():
626
- hf_token = gr.Textbox(label="HF Token", type="password", value=os.environ.get("HF_TOKEN", ""), placeholder="hf_xxx (จำเป็นสำหรับหลายโมเดล)")
627
- with gr.Row():
628
- max_new_tokens = gr.Slider(64, 1024, value=512, step=16, label="จำนวน Token สูงสุด")
629
- temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="อุณหภูมิ (ความสร้างสรรค์)")
630
-
631
- with gr.Accordion("การตั้งค่าขั้นสูง (Advanced Settings)", open=False):
632
- with gr.Row():
633
- chunk_size = gr.Slider(500, 4000, value=1500, step=50, label="ขนาดส่วนข้อความ (ตัวอักษร)")
634
- overlap = gr.Slider(0, 1000, value=200, step=50, label="การทับซ้อน (ตัวอักษร)")
635
- max_chunks = gr.Slider(1, 40, value=5, step=1, label="จำนวนส่วนสูงสุด")
636
- with gr.Row():
637
- min_pairs = gr.Slider(1, 10, value=3, step=1, label="คู่ข้อมูลต่ำสุด/ส่วน")
638
- max_pairs = gr.Slider(1, 12, value=6, step=1, label="คู่ข้อมูลสูงสุด/ส่วน")
639
- custom_instruction = gr.Textbox(
640
- label="คำสั่งกำหนดเอง (ไม่บังคับ)",
641
- lines=3,
642
- placeholder="แทนที่คำสั่งเริ่มต้น ต้องส่งคืน JSON array บริสุทธิ์ตามโครงสร้างงาน",
643
- value="สร้างข้อมูลภาษาไทยคุณภาพสูงที่เข้าใจบริบททางวัฒนธรรมไทย ใช้ภาษาไทยที่เป็นธรรมชาติและเหมาะสมกับเนื้อหา"
644
- )
645
-
646
- # Task-specific controls
647
- classification_labels = gr.Textbox(label="ป้ายกำกับการจำแนก (คั่นด้วยคอมมา)", visible=False)
648
- multi_label = gr.Checkbox(label="อนุญาตหลายป้ายกำกับ", value=False, visible=False)
649
- target_language = gr.Textbox(label="ภาษาเป้าหมาย (การแปล)", value="ไทย", visible=False)
650
- num_options = gr.Slider(3, 6, value=4, step=1, label="ตัวเลือก MCQ", visible=False)
651
- ner_labels = gr.Textbox(label="ป้ายกำกับ NER (คั่นด้วยคอมมา, ไม่บังคับ)", visible=False)
652
 
653
- generate_btn = gr.Button("สร้างชุดข้อมูล (Generate Dataset)", variant="primary", interactive=(not effective_require_login))
 
 
654
 
655
  with gr.Row():
656
- status = gr.Markdown()
 
 
 
657
  with gr.Row():
658
- out_json = gr.File(label="ดาวน์โหลด JSON")
659
- out_jsonl = gr.File(label="ดาวน์โหลด JSONL")
660
-
661
- # Toggle visibility for task-specific controls
662
- def _switch_task(t: str):
663
- is_cls = t == "Classification"
664
- is_tr = t == "Translation"
665
- is_mcq = t == "MCQ"
666
- is_ner = t == "NER"
667
- return (
668
- gr.update(visible=is_cls), # classification_labels
669
- gr.update(visible=is_cls), # multi_label
670
- gr.update(visible=is_tr), # target_language
671
- gr.update(visible=is_mcq), # num_options
672
- gr.update(visible=is_ner), # ner_labels
673
- )
674
 
675
- task.change(_switch_task, inputs=task, outputs=[classification_labels, multi_label, target_language, num_options, ner_labels])
 
 
 
676
 
677
  generate_btn.click(
678
  fn=generate_dataset,
679
- inputs=[
680
- user_state,
681
- pdf_files,
682
- task,
683
- preset_model,
684
- custom_model_id,
685
- hf_token,
686
- chunk_size,
687
- overlap,
688
- max_chunks,
689
- max_new_tokens,
690
- temperature,
691
- custom_instruction,
692
- min_pairs,
693
- max_pairs,
694
- classification_labels,
695
- multi_label,
696
- target_language,
697
- num_options,
698
- ner_labels,
699
- ],
700
- outputs=[status, out_json, out_jsonl],
701
- show_progress=True,
702
- api_name="generate",
703
  )
704
 
705
- if OAUTH_AVAILABLE:
706
- def _on_login(user):
707
- try:
708
- username = None
709
- if isinstance(user, dict):
710
- username = user.get("username") or user.get("name")
711
- if not username and hasattr(user, "username"):
712
- username = getattr(user, "username")
713
- msg = f"เข้าสู่ระบบแล้วในนาม @{username}" if username else "เข้าสู่ระบบแล้ว"
714
- except Exception:
715
- msg = "เข้าสู่ระบบแล้ว"
716
- return user, gr.update(value=msg), gr.update(interactive=True)
717
-
718
- # Enable Generate button after login and store user profile
719
- if hasattr(login_btn, "login"):
720
- login_btn.login(_on_login, inputs=None, outputs=[user_state, login_info, generate_btn])
721
- else:
722
- # In local/dev without OAuth routing, clicking will mock-login
723
- login_btn.click(lambda: ("local_user", gr.update(value="เข้าสู่ระบบแล้ว (ภายในเครื่อง)"), gr.update(interactive=True)), inputs=None, outputs=[user_state, login_info, generate_btn])
724
-
725
  if __name__ == "__main__":
726
- # For local runs
727
- demo.queue().launch()
 
6
  from typing import List, Dict, Any, Tuple
7
 
8
  import gradio as gr
9
+ from pypdf import PdfReader
10
 
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
12
 
13
+ # โหลดโมเดลเริ่มต้น (default)
14
+ DEFAULT_MODEL = "HuggingFaceH4/zephyr-7b-beta"
 
15
 
16
+ # สร้าง pipeline global
17
+ gen_pipe = None
18
+ tokenizer = None
19
+ current_model_id = None
20
 
21
+
22
+ def load_model(model_id: str, hf_token: str = None):
23
+ global gen_pipe, tokenizer, current_model_id
24
+ if current_model_id == model_id and gen_pipe is not None:
25
+ return gen_pipe
26
+
27
+ print(f"🔄 Loading model: {model_id}")
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
29
+ model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, device_map="auto")
30
+ gen_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
31
+ current_model_id = model_id
32
+ return gen_pipe
33
 
34
 
35
  def ensure_output_dir() -> str:
 
39
 
40
 
41
  def read_pdfs(files: List[gr.File]) -> Tuple[str, List[Dict[str, Any]]]:
 
 
 
 
 
42
  docs = []
43
  combined_text_parts: List[str] = []
44
  for f in files:
 
46
  reader = PdfReader(path)
47
  pages_text = []
48
  for i, page in enumerate(reader.pages):
49
+ text = page.extract_text() or ""
 
 
 
 
50
  text = re.sub(r"\s+", " ", text).strip()
51
  if text:
52
  pages_text.append({"page": i + 1, "text": text})
 
56
  return combined_text, docs
57
 
58
 
59
+ def chunk_text(text: str, chunk_size: int = 1500, overlap: int = 200, max_chunks: int = 5) -> List[str]:
60
  text = text.strip()
61
  if not text:
62
  return []
63
+ chunks: List[str] = []
64
  start = 0
65
  n = len(text)
66
  while start < n and len(chunks) < max_chunks:
67
  end = min(start + chunk_size, n)
68
  chunk = text[start:end]
69
+ chunks.append(chunk)
 
 
 
 
 
 
70
  if end >= n:
71
  break
72
  start = max(end - overlap, 0)
 
 
73
  return chunks
74
 
75
 
76
+ # เทมเพลต prompt พื้นฐาน
77
+ DEFAULT_QA_PROMPT = (
78
+ "คุณเป็นผู้ช่วยสร้างชุดข้อมูล อ่านเนื้อหานี้แล้วสร้างคำถาม-คำตอบ "
79
+ "จำนวน {min_pairs} ถึ�� {max_pairs} คู่ "
80
+ "ส่งคืน JSON array ที่มี objects รูปแบบ {{\"question\": str, \"answer\": str}} เท่านั้น\n\n"
81
+ "เนื้อหา:\n{content}\n"
82
  )
83
 
84
+
85
+ def generate_dataset(files: List[gr.File],
86
+ task: str,
87
+ preset_model: str,
88
+ custom_model_id: str,
89
+ hf_token: str,
90
+ chunk_size: int,
91
+ overlap: int,
92
+ max_chunks: int,
93
+ max_new_tokens: int,
94
+ temperature: float,
95
+ min_pairs: int,
96
+ max_pairs: int):
97
+ if not files:
98
+ return "❌ กรุณาอัปโหลดไฟล์ PDF", None, None
99
+
100
+ # โหลดโมเดล
101
+ model_id = (custom_model_id or "").strip() or preset_model or DEFAULT_MODEL
102
+ pipe = load_model(model_id, hf_token or None)
103
+
104
+ # อ่าน PDF และตัดเป็น chunk
105
+ full_text, _ = read_pdfs(files)
106
+ chunks = chunk_text(full_text, chunk_size, overlap, max_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if not chunks:
108
+ return " ไม่สามารถดึงข้อความจาก PDF", None, None
109
+
110
+ results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  for ch in chunks:
112
+ prompt = DEFAULT_QA_PROMPT.format(
113
+ min_pairs=min_pairs,
114
+ max_pairs=max_pairs,
115
+ content=ch
116
+ )
117
+ output = pipe(prompt,
118
+ max_new_tokens=max_new_tokens,
119
+ temperature=temperature,
120
+ do_sample=temperature > 0.0)[0]["generated_text"]
121
+
122
+ # พยายาม extract JSON
123
+ start, end = output.find("["), output.rfind("]")
124
+ if start != -1 and end != -1:
125
  try:
126
+ data = json.loads(output[start:end + 1])
127
+ if isinstance(data, list):
128
+ results.extend(data)
129
  except Exception:
130
+ pass
 
 
 
 
 
 
131
 
132
  if not results:
133
+ return " ไม่สามารถสร้างข้อมูล JSON ได้", None, None
134
+
135
+ # Save output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  outdir = ensure_output_dir()
137
  ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
138
+ json_path = os.path.join(outdir, f"dataset_{task}_{ts}.json")
139
+ jsonl_path = os.path.join(outdir, f"dataset_{task}_{ts}.jsonl")
140
+
141
  with io.open(json_path, "w", encoding="utf-8") as f:
142
+ json.dump(results, f, ensure_ascii=False, indent=2)
143
  with io.open(jsonl_path, "w", encoding="utf-8") as f:
144
+ for item in results:
145
  f.write(json.dumps(item, ensure_ascii=False) + "\n")
146
 
147
+ return f"สร้างข้อมูลสำเร็จ {len(results)} รายการ", json_path, jsonl_path
148
 
149
 
150
+ # ---------------- Gradio UI ----------------
151
  PRESET_MODELS = [
152
+ DEFAULT_MODEL,
 
 
 
 
 
 
 
 
 
 
 
 
153
  "mistralai/Mistral-7B-Instruct-v0.2",
 
154
  "meta-llama/Llama-2-7b-chat-hf",
155
+ "google/flan-t5-large"
156
  ]
157
 
158
+ with gr.Blocks(title="Thai PDF → Dataset Generator") as demo:
159
+ gr.Markdown("# 📚 Thai Auto Dataset Generator")
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  with gr.Row():
162
+ pdf_files = gr.File(label="อัปโหลด PDF", file_count="multiple", file_types=[".pdf"])
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  with gr.Row():
165
+ task = gr.Textbox(label="Task", value="QA")
166
+ preset_model = gr.Dropdown(label="Preset Model", choices=PRESET_MODELS, value=DEFAULT_MODEL)
167
+ custom_model_id = gr.Textbox(label="Custom Model ID", placeholder="org/model-name")
168
+ hf_token = gr.Textbox(label="HF Token", type="password")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ with gr.Row():
171
+ max_new_tokens = gr.Slider(64, 1024, value=512, step=16, label="Max New Tokens")
172
+ temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature")
173
 
174
  with gr.Row():
175
+ chunk_size = gr.Slider(500, 4000, value=1500, step=50, label="Chunk Size")
176
+ overlap = gr.Slider(0, 1000, value=200, step=50, label="Overlap")
177
+ max_chunks = gr.Slider(1, 20, value=5, step=1, label="Max Chunks")
178
+
179
  with gr.Row():
180
+ min_pairs = gr.Slider(1, 10, value=3, step=1, label="Min Pairs")
181
+ max_pairs = gr.Slider(1, 12, value=6, step=1, label="Max Pairs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ generate_btn = gr.Button("🚀 Generate Dataset")
184
+ status = gr.Markdown()
185
+ out_json = gr.File(label="JSON")
186
+ out_jsonl = gr.File(label="JSONL")
187
 
188
  generate_btn.click(
189
  fn=generate_dataset,
190
+ inputs=[pdf_files, task, preset_model, custom_model_id, hf_token,
191
+ chunk_size, overlap, max_chunks, max_new_tokens, temperature,
192
+ min_pairs, max_pairs],
193
+ outputs=[status, out_json, out_jsonl]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if __name__ == "__main__":
197
+ demo.queue().launch()