Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel | |
| import gradio as gr | |
| import re | |
| import os | |
| import json | |
| import chardet | |
| from sklearn.metrics import precision_score, recall_score, f1_score | |
| import time | |
| # ======================== 数据库模块 ======================== | |
| import pymysql | |
| from configparser import ConfigParser | |
| from contextlib import contextmanager | |
| def get_db_connection(): | |
| config = ConfigParser() | |
| config.read('db_config.ini') | |
| conn = pymysql.connect( | |
| host=config.get('mysql', 'host'), | |
| user=config.get('mysql', 'user'), | |
| password=config.get('mysql', 'password'), | |
| database=config.get('mysql', 'database'), | |
| port=config.getint('mysql', 'port', fallback=3306), | |
| charset=config.get('mysql', 'charset', fallback='utf8mb4'), | |
| cursorclass=pymysql.cursors.DictCursor | |
| ) | |
| try: | |
| yield conn | |
| finally: | |
| conn.close() | |
| def save_to_db(table, data): | |
| conn = None | |
| try: | |
| # 表名白名单验证 | |
| valid_tables = ["entities", "relations"] | |
| if table not in valid_tables: | |
| raise ValueError(f"Invalid table: {table}") | |
| conn = get_db_connection() | |
| with conn.cursor() as cursor: | |
| # 使用参数化查询避免注入 | |
| columns = ', '.join(data.keys()) | |
| placeholders = ', '.join(['%s'] * len(data)) | |
| sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})" | |
| cursor.execute(sql, list(data.values())) | |
| conn.commit() | |
| except pymysql.Error as e: # 细化异常类型 | |
| print(f"数据库错误: {e}") | |
| conn.rollback() | |
| except ValueError as e: # 表名无效 | |
| print(f"参数错误: {e}") | |
| finally: | |
| if conn: | |
| conn.close() | |
| # ======================== 模型加载 ======================== | |
| NER_MODEL_NAME = "hfl/chinese-roberta-wwm-ext-large" | |
| bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME) | |
| bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME) | |
| bert_ner_pipeline = pipeline( | |
| "ner", | |
| model=bert_ner_model, | |
| tokenizer=bert_tokenizer, | |
| aggregation_strategy="first" | |
| ) | |
| LABEL_MAPPING = { | |
| "address": "LOC", | |
| "company": "ORG", | |
| "name": "PER", | |
| "organization": "ORG", | |
| "position": "TITLE" | |
| } | |
| chatglm_model, chatglm_tokenizer = None, None | |
| use_chatglm = False | |
| try: | |
| chatglm_model_name = "THUDM/chatglm-6b-int4" | |
| chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True) | |
| chatglm_model = AutoModel.from_pretrained( | |
| chatglm_model_name, | |
| trust_remote_code=True, | |
| device_map="cpu", | |
| torch_dtype=torch.float32 | |
| ).eval() | |
| use_chatglm = True | |
| print("✅ 4-bit量化版ChatGLM加载成功") | |
| except Exception as e: | |
| print(f"❌ ChatGLM加载失败: {e}") | |
| # ======================== 知识图谱结构 ======================== | |
| knowledge_graph = {"entities": set(), "relations": set()} | |
| def update_knowledge_graph(entities, relations): | |
| # 保存实体 | |
| for e in entities: | |
| if isinstance(e, dict) and 'text' in e and 'type' in e: | |
| save_to_db('entities', { | |
| 'text': e['text'], | |
| 'type': e['type'], | |
| 'start_pos': e.get('start', -1), | |
| 'end_pos': e.get('end', -1), | |
| 'source': 'user_input' | |
| }) | |
| # 保存关系 | |
| for r in relations: | |
| if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")): | |
| save_to_db('relations', { | |
| 'head_entity': r['head'], | |
| 'tail_entity': r['tail'], | |
| 'relation_type': r['relation'], | |
| 'source_text': '' # 可添加原文关联 | |
| }) | |
| def visualize_kg_text(): | |
| nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]] | |
| edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]] | |
| return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges) | |
| # ======================== 实体识别(NER) ======================== | |
| def merge_adjacent_entities(entities): | |
| if not entities: | |
| return entities | |
| merged = [entities[0]] | |
| for entity in entities[1:]: | |
| last = merged[-1] | |
| # 合并相邻的同类型实体 | |
| if (entity["type"] == last["type"] and | |
| entity["start"] == last["end"]): | |
| last["text"] += entity["text"] | |
| last["end"] = entity["end"] | |
| else: | |
| merged.append(entity) | |
| return merged | |
| def ner(text, model_type="bert"): | |
| start_time = time.time() | |
| # 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER | |
| if model_type == "chatglm" and use_chatglm: | |
| try: | |
| prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段: | |
| 示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}] | |
| 文本:{text}""" | |
| response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1) | |
| if isinstance(response, tuple): | |
| response = response[0] | |
| try: | |
| json_str = re.search(r'\[.*\]', response, re.DOTALL).group() | |
| entities = json.loads(json_str) | |
| valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))] | |
| return valid_entities, time.time() - start_time | |
| except Exception as e: | |
| print(f"JSON解析失败: {e}") | |
| return [], time.time() - start_time | |
| except Exception as e: | |
| print(f"ChatGLM调用失败: {e}") | |
| return [], time.time() - start_time | |
| # 使用BERT NER | |
| text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段 | |
| raw_results = [] | |
| for idx, chunk in enumerate(text_chunks): | |
| chunk_results = bert_ner_pipeline(chunk) | |
| for r in chunk_results: | |
| r["start"] += idx * 510 | |
| r["end"] += idx * 510 | |
| raw_results.extend(chunk_results) | |
| entities = [{ | |
| "text": r['word'].replace(' ', ''), | |
| "start": r['start'], | |
| "end": r['end'], | |
| "type": LABEL_MAPPING.get(r['entity_group'], r['entity_group']) | |
| } for r in raw_results] | |
| entities = merge_adjacent_entities(entities) | |
| return entities, time.time() - start_time | |
| # ======================== 关系抽取(RE) ======================== | |
| def re_extract(entities, text): | |
| # 参数校验 | |
| if not entities or not text: | |
| return [] | |
| # 实体类型过滤(根据业务需求调整) | |
| valid_entity_types = {"PER", "LOC", "ORG", "TITLE"} | |
| filtered_entities = [e for e in entities if e.get("type") in valid_entity_types] | |
| # --------------------- 处理单实体场景 --------------------- | |
| if len(filtered_entities) == 1: | |
| single_relations = [] | |
| ent = filtered_entities[0] | |
| # 规则1:人物职位检测 | |
| if ent["type"] == "PER": | |
| position_keywords = ["CEO", "经理", "总监", "工程师", "教授"] | |
| for keyword in position_keywords: | |
| if keyword in text: | |
| single_relations.append({ | |
| "head": ent["text"], | |
| "tail": keyword, | |
| "relation": "担任职位" | |
| }) | |
| break | |
| # 规则2:机构地点检测 | |
| if ent["type"] in ["ORG", "LOC"]: | |
| location_verbs = ["位于", "坐落于", "地处"] | |
| for verb in location_verbs: | |
| if verb in text: | |
| match = re.search(fr"{ent['text']}{verb}(.*?)[,。]", text) | |
| if match: | |
| single_relations.append({ | |
| "head": ent["text"], | |
| "tail": match.group(1).strip(), | |
| "relation": "位置" | |
| }) | |
| break | |
| return single_relations | |
| # --------------------- 多实体关系抽取 --------------------- | |
| relations = [] | |
| # 方案1:使用ChatGLM抽取关系 | |
| if use_chatglm and len(filtered_entities) >= 2: | |
| try: | |
| entity_list = [e["text"] for e in filtered_entities] | |
| prompt = f"""请分析以下文本中的实体关系,严格按照JSON列表格式返回: | |
| 文本内容:{text} | |
| 候选实体:{entity_list} | |
| 要求: | |
| 1. 只返回存在明确关系的实体对 | |
| 2. 关系类型使用:属于、位于、任职于、合作、其他 | |
| 3. 示例格式:[{{"head":"实体1", "tail":"实体2", "relation":"关系类型"}}] | |
| 请直接返回JSON,不要多余内容:""" | |
| response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.01) | |
| if isinstance(response, tuple): | |
| response = response[0] | |
| # 增强JSON解析 | |
| try: | |
| json_str = re.search(r'(\[.*?\])', response, re.DOTALL) | |
| if json_str: | |
| json_str = json_str.group(1) | |
| json_str = re.sub(r'[\u201c\u201d]', '"', json_str) # 处理中文引号 | |
| json_str = re.sub(r'(?<!,)\n', '', json_str) # 保留逗号后的换行 | |
| relations = json.loads(json_str) | |
| # 验证关系有效性 | |
| valid_relations = [] | |
| valid_rel_types = {"属于", "位于", "任职于", "合作", "其他"} | |
| for rel in relations: | |
| if (isinstance(rel, dict) and | |
| rel.get("head") in entity_list and | |
| rel.get("tail") in entity_list and | |
| rel.get("relation") in valid_rel_types): | |
| valid_relations.append(rel) | |
| relations = valid_relations | |
| except Exception as e: | |
| print(f"[DEBUG] 关系解析失败: {str(e)}") | |
| except Exception as e: | |
| print(f"ChatGLM关系抽取异常: {str(e)}") | |
| # 方案2:规则兜底(当模型不可用或未抽取出关系时) | |
| if len(relations) == 0: | |
| # 规则1:A位于B | |
| location_matches = re.finditer(r'([^\s,。]+)[位于|坐落于|地处]([^\s,。]+)', text) | |
| for match in location_matches: | |
| head, tail = match.groups() | |
| relations.append({"head": head, "tail": tail, "relation": "位于"}) | |
| # 规则2:A属于B | |
| belong_matches = re.finditer(r'([^\s,。]+)(属于|隶属于)([^\s,。]+)', text) | |
| for match in belong_matches: | |
| head, _, tail = match.groups() | |
| relations.append({"head": head, "tail": tail, "relation": "属于"}) | |
| # 规则3:人物-机构关系 | |
| person_org_pattern = r'([\u4e00-\u9fa5]{2,4})(现任|担任|就职于)([\u4e00-\u9fa5]+?公司|[\u4e00-\u9fa5]+?大学)' | |
| for match in re.finditer(person_org_pattern, text): | |
| head, _, tail = match.groups() | |
| relations.append({"head": head, "tail": tail, "relation": "任职于"}) | |
| # 后处理:去重和验证 | |
| seen = set() | |
| final_relations = [] | |
| for rel in relations: | |
| key = (rel["head"], rel["tail"], rel["relation"]) | |
| if key not in seen: | |
| # 验证实体是否存在 | |
| head_exists = any(e["text"] == rel["head"] for e in filtered_entities) | |
| tail_exists = any(e["text"] == rel["tail"] for e in filtered_entities) | |
| if head_exists and tail_exists: | |
| final_relations.append(rel) | |
| seen.add(key) | |
| return final_relations | |
| # ======================== 文本分析主流程 ======================== | |
| def process_text(text, model_type="bert"): | |
| entities, duration = ner(text, model_type) | |
| relations = re_extract(entities, text) | |
| update_knowledge_graph(entities, relations) | |
| ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities) | |
| rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations) | |
| kg_text = visualize_kg_text() | |
| return ent_text, rel_text, kg_text, f"{duration:.2f} 秒" | |
| def process_file(file, model_type="bert"): | |
| try: | |
| with open(file.name, 'rb') as f: | |
| content = f.read() | |
| if len(content) > 5 * 1024 * 1024: | |
| return "❌ 文件太大", "", "", "" | |
| # 检测编码 | |
| try: | |
| encoding = chardet.detect(content)['encoding'] or 'utf-8' | |
| text = content.decode(encoding) | |
| except UnicodeDecodeError: | |
| # 尝试常见中文编码 | |
| for enc in ['gb18030', 'utf-16', 'big5'] : | |
| try: | |
| text = content.decode(enc) | |
| break | |
| except: | |
| continue | |
| else: | |
| return "❌ 编码解析失败", "", "", "" | |
| return process_text(text, model_type) | |
| except Exception as e: | |
| return f"❌ 文件处理错误: {str(e)}", "", "", "" | |
| # ======================== 模型评估与自动标注 ======================== | |
| def convert_telegram_json_to_eval_format(path): | |
| with open(path, encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and "text" in data: | |
| return [{"text": data["text"], "entities": [ | |
| {"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", []) | |
| ]}] | |
| elif isinstance(data, list): | |
| return data | |
| elif isinstance(data, dict) and "messages" in data: | |
| result = [] | |
| for m in data.get("messages", []): | |
| if isinstance(m.get("text"), str): | |
| result.append({"text": m["text"], "entities": []}) | |
| elif isinstance(m.get("text"), list): | |
| txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]]) | |
| result.append({"text": txt, "entities": []}) | |
| return result | |
| return [] | |
| def evaluate_ner_model(data, model_type): | |
| tp, fp, fn = 0, 0, 0 | |
| POS_TOLERANCE = 1 | |
| for item in data: | |
| text = item["text"] | |
| # 处理标注数据 | |
| gold_entities = [] | |
| for e in item.get("entities", []): | |
| if "text" in e and "type" in e: | |
| norm_type = LABEL_MAPPING.get(e["type"], e["type"]) | |
| gold_entities.append({ | |
| "text": e["text"], | |
| "type": norm_type, | |
| "start": e.get("start", -1), | |
| "end": e.get("end", -1) | |
| }) | |
| # 获取预测结果 | |
| pred_entities, _ = ner(text, model_type) | |
| # 初始化匹配状态 | |
| matched_gold = [False] * len(gold_entities) | |
| matched_pred = [False] * len(pred_entities) | |
| # 遍历预测实体寻找匹配 | |
| for p_idx, p in enumerate(pred_entities): | |
| for g_idx, g in enumerate(gold_entities): | |
| if not matched_gold[g_idx] and \ | |
| p["text"] == g["text"] and \ | |
| p["type"] == g["type"] and \ | |
| abs(p["start"] - g["start"]) <= POS_TOLERANCE and \ | |
| abs(p["end"] - g["end"]) <= POS_TOLERANCE: | |
| matched_gold[g_idx] = True | |
| matched_pred[p_idx] = True | |
| break | |
| # 统计指标 | |
| tp += sum(matched_pred) | |
| fp += len(pred_entities) - sum(matched_pred) | |
| fn += len(gold_entities) - sum(matched_gold) | |
| # 处理除零情况 | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| return (f"Precision: {precision:.2f}\n" | |
| f"Recall: {recall:.2f}\n" | |
| f"F1: {f1:.2f}") | |
| def auto_annotate(file, model_type): | |
| data = convert_telegram_json_to_eval_format(file.name) | |
| for item in data: | |
| ents, _ = ner(item["text"], model_type) | |
| item["entities"] = ents | |
| return json.dumps(data, ensure_ascii=False, indent=2) | |
| def save_json(json_text): | |
| fname = f"auto_labeled_{int(time.time())}.json" | |
| with open(fname, "w", encoding="utf-8") as f: | |
| f.write(json_text) | |
| return fname | |
| # ======================== 数据集导入 ======================== | |
| def import_dataset(path="D:/云边智算/暗语识别/filtered_results"): | |
| import os | |
| import json | |
| for filename in os.listdir(path): | |
| if filename.endswith('.json'): | |
| filepath = os.path.join(path, filename) | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # 调用现有处理流程 | |
| process_text(data['text']) | |
| print(f"已处理文件: {filename}") | |
| # ======================== Gradio 界面 ======================== | |
| with gr.Blocks(css=""" | |
| .kg-graph {height: 500px; overflow-y: auto;} | |
| .warning {color: #ff6b6b;} | |
| """) as demo: | |
| gr.Markdown("# 🤖 聊天记录实体关系识别系统") | |
| with gr.Tab("📄 文本分析"): | |
| input_text = gr.Textbox(lines=6, label="输入文本") | |
| model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型") | |
| btn = gr.Button("开始分析") | |
| out1 = gr.Textbox(label="识别实体") | |
| out2 = gr.Textbox(label="识别关系") | |
| out3 = gr.Textbox(label="知识图谱") | |
| out4 = gr.Textbox(label="耗时") | |
| btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4]) | |
| with gr.Tab("🗂 文件分析"): | |
| file_input = gr.File(file_types=[".txt", ".json"]) | |
| file_btn = gr.Button("上传并分析") | |
| fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox() | |
| file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4]) | |
| with gr.Tab("📊 模型评估"): | |
| eval_file = gr.File(label="上传标注 JSON") | |
| eval_model = gr.Radio(["bert", "chatglm"], value="bert") | |
| eval_btn = gr.Button("开始评估") | |
| eval_output = gr.Textbox(label="评估结果", lines=5) | |
| eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m), | |
| [eval_file, eval_model], eval_output) | |
| with gr.Tab("✏️ 自动标注"): | |
| raw_file = gr.File(label="上传 Telegram 原始 JSON") | |
| auto_model = gr.Radio(["bert", "chatglm"], value="bert") | |
| auto_btn = gr.Button("自动标注") | |
| marked_texts = gr.Textbox(label="标注结果", lines=20) | |
| download_btn = gr.Button("💾 下载标注文件") | |
| auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts) | |
| download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File()) | |
| with gr.Tab("📂 数据管理"): | |
| gr.Markdown("### 数据集导入") | |
| dataset_path = gr.Textbox( | |
| value="D:/云边智算/暗语识别/filtered_results", | |
| label="数据集路径" | |
| ) | |
| import_btn = gr.Button("导入数据集到数据库") | |
| import_output = gr.Textbox(label="导入日志") | |
| import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |