wechat-ner-re / app.py
chen666-666's picture
Upload 2 files
2543020 verified
raw
history blame
20 kB
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModel
import gradio as gr
import re
import os
import json
import chardet
from sklearn.metrics import precision_score, recall_score, f1_score
import time
# ======================== 数据库模块 ========================
import pymysql
from configparser import ConfigParser
from contextlib import contextmanager
@contextmanager
def get_db_connection():
config = ConfigParser()
config.read('db_config.ini')
conn = pymysql.connect(
host=config.get('mysql', 'host'),
user=config.get('mysql', 'user'),
password=config.get('mysql', 'password'),
database=config.get('mysql', 'database'),
port=config.getint('mysql', 'port', fallback=3306),
charset=config.get('mysql', 'charset', fallback='utf8mb4'),
cursorclass=pymysql.cursors.DictCursor
)
try:
yield conn
finally:
conn.close()
def save_to_db(table, data):
conn = None
try:
# 表名白名单验证
valid_tables = ["entities", "relations"]
if table not in valid_tables:
raise ValueError(f"Invalid table: {table}")
conn = get_db_connection()
with conn.cursor() as cursor:
# 使用参数化查询避免注入
columns = ', '.join(data.keys())
placeholders = ', '.join(['%s'] * len(data))
sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
cursor.execute(sql, list(data.values()))
conn.commit()
except pymysql.Error as e: # 细化异常类型
print(f"数据库错误: {e}")
conn.rollback()
except ValueError as e: # 表名无效
print(f"参数错误: {e}")
finally:
if conn:
conn.close()
# ======================== 模型加载 ========================
NER_MODEL_NAME = "hfl/chinese-roberta-wwm-ext-large"
bert_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
bert_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_NAME)
bert_ner_pipeline = pipeline(
"ner",
model=bert_ner_model,
tokenizer=bert_tokenizer,
aggregation_strategy="first"
)
LABEL_MAPPING = {
"address": "LOC",
"company": "ORG",
"name": "PER",
"organization": "ORG",
"position": "TITLE"
}
chatglm_model, chatglm_tokenizer = None, None
use_chatglm = False
try:
chatglm_model_name = "THUDM/chatglm-6b-int4"
chatglm_tokenizer = AutoTokenizer.from_pretrained(chatglm_model_name, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(
chatglm_model_name,
trust_remote_code=True,
device_map="cpu",
torch_dtype=torch.float32
).eval()
use_chatglm = True
print("✅ 4-bit量化版ChatGLM加载成功")
except Exception as e:
print(f"❌ ChatGLM加载失败: {e}")
# ======================== 知识图谱结构 ========================
knowledge_graph = {"entities": set(), "relations": set()}
def update_knowledge_graph(entities, relations):
# 保存实体
for e in entities:
if isinstance(e, dict) and 'text' in e and 'type' in e:
save_to_db('entities', {
'text': e['text'],
'type': e['type'],
'start_pos': e.get('start', -1),
'end_pos': e.get('end', -1),
'source': 'user_input'
})
# 保存关系
for r in relations:
if isinstance(r, dict) and all(k in r for k in ("head", "tail", "relation")):
save_to_db('relations', {
'head_entity': r['head'],
'tail_entity': r['tail'],
'relation_type': r['relation'],
'source_text': '' # 可添加原文关联
})
def visualize_kg_text():
nodes = [f"{ent[0]} ({ent[1]})" for ent in knowledge_graph["entities"]]
edges = [f"{h} --[{r}]-> {t}" for h, t, r in knowledge_graph["relations"]]
return "\n".join(["📌 实体:"] + nodes + ["", "📎 关系:"] + edges)
# ======================== 实体识别(NER) ========================
def merge_adjacent_entities(entities):
if not entities:
return entities
merged = [entities[0]]
for entity in entities[1:]:
last = merged[-1]
# 合并相邻的同类型实体
if (entity["type"] == last["type"] and
entity["start"] == last["end"]):
last["text"] += entity["text"]
last["end"] = entity["end"]
else:
merged.append(entity)
return merged
def ner(text, model_type="bert"):
start_time = time.time()
# 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
if model_type == "chatglm" and use_chatglm:
try:
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
示例:[{{"text": "北京", "type": "LOC", "start": 0, "end": 2}}]
文本:{text}"""
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.1)
if isinstance(response, tuple):
response = response[0]
try:
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
entities = json.loads(json_str)
valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
return valid_entities, time.time() - start_time
except Exception as e:
print(f"JSON解析失败: {e}")
return [], time.time() - start_time
except Exception as e:
print(f"ChatGLM调用失败: {e}")
return [], time.time() - start_time
# 使用BERT NER
text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
raw_results = []
for idx, chunk in enumerate(text_chunks):
chunk_results = bert_ner_pipeline(chunk)
for r in chunk_results:
r["start"] += idx * 510
r["end"] += idx * 510
raw_results.extend(chunk_results)
entities = [{
"text": r['word'].replace(' ', ''),
"start": r['start'],
"end": r['end'],
"type": LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
} for r in raw_results]
entities = merge_adjacent_entities(entities)
return entities, time.time() - start_time
# ======================== 关系抽取(RE) ========================
def re_extract(entities, text):
# 参数校验
if not entities or not text:
return []
# 实体类型过滤(根据业务需求调整)
valid_entity_types = {"PER", "LOC", "ORG", "TITLE"}
filtered_entities = [e for e in entities if e.get("type") in valid_entity_types]
# --------------------- 处理单实体场景 ---------------------
if len(filtered_entities) == 1:
single_relations = []
ent = filtered_entities[0]
# 规则1:人物职位检测
if ent["type"] == "PER":
position_keywords = ["CEO", "经理", "总监", "工程师", "教授"]
for keyword in position_keywords:
if keyword in text:
single_relations.append({
"head": ent["text"],
"tail": keyword,
"relation": "担任职位"
})
break
# 规则2:机构地点检测
if ent["type"] in ["ORG", "LOC"]:
location_verbs = ["位于", "坐落于", "地处"]
for verb in location_verbs:
if verb in text:
match = re.search(fr"{ent['text']}{verb}(.*?)[,。]", text)
if match:
single_relations.append({
"head": ent["text"],
"tail": match.group(1).strip(),
"relation": "位置"
})
break
return single_relations
# --------------------- 多实体关系抽取 ---------------------
relations = []
# 方案1:使用ChatGLM抽取关系
if use_chatglm and len(filtered_entities) >= 2:
try:
entity_list = [e["text"] for e in filtered_entities]
prompt = f"""请分析以下文本中的实体关系,严格按照JSON列表格式返回:
文本内容:{text}
候选实体:{entity_list}
要求:
1. 只返回存在明确关系的实体对
2. 关系类型使用:属于、位于、任职于、合作、其他
3. 示例格式:[{{"head":"实体1", "tail":"实体2", "relation":"关系类型"}}]
请直接返回JSON,不要多余内容:"""
response = chatglm_model.chat(chatglm_tokenizer, prompt, temperature=0.01)
if isinstance(response, tuple):
response = response[0]
# 增强JSON解析
try:
json_str = re.search(r'(\[.*?\])', response, re.DOTALL)
if json_str:
json_str = json_str.group(1)
json_str = re.sub(r'[\u201c\u201d]', '"', json_str) # 处理中文引号
json_str = re.sub(r'(?<!,)\n', '', json_str) # 保留逗号后的换行
relations = json.loads(json_str)
# 验证关系有效性
valid_relations = []
valid_rel_types = {"属于", "位于", "任职于", "合作", "其他"}
for rel in relations:
if (isinstance(rel, dict) and
rel.get("head") in entity_list and
rel.get("tail") in entity_list and
rel.get("relation") in valid_rel_types):
valid_relations.append(rel)
relations = valid_relations
except Exception as e:
print(f"[DEBUG] 关系解析失败: {str(e)}")
except Exception as e:
print(f"ChatGLM关系抽取异常: {str(e)}")
# 方案2:规则兜底(当模型不可用或未抽取出关系时)
if len(relations) == 0:
# 规则1:A位于B
location_matches = re.finditer(r'([^\s,。]+)[位于|坐落于|地处]([^\s,。]+)', text)
for match in location_matches:
head, tail = match.groups()
relations.append({"head": head, "tail": tail, "relation": "位于"})
# 规则2:A属于B
belong_matches = re.finditer(r'([^\s,。]+)(属于|隶属于)([^\s,。]+)', text)
for match in belong_matches:
head, _, tail = match.groups()
relations.append({"head": head, "tail": tail, "relation": "属于"})
# 规则3:人物-机构关系
person_org_pattern = r'([\u4e00-\u9fa5]{2,4})(现任|担任|就职于)([\u4e00-\u9fa5]+?公司|[\u4e00-\u9fa5]+?大学)'
for match in re.finditer(person_org_pattern, text):
head, _, tail = match.groups()
relations.append({"head": head, "tail": tail, "relation": "任职于"})
# 后处理:去重和验证
seen = set()
final_relations = []
for rel in relations:
key = (rel["head"], rel["tail"], rel["relation"])
if key not in seen:
# 验证实体是否存在
head_exists = any(e["text"] == rel["head"] for e in filtered_entities)
tail_exists = any(e["text"] == rel["tail"] for e in filtered_entities)
if head_exists and tail_exists:
final_relations.append(rel)
seen.add(key)
return final_relations
# ======================== 文本分析主流程 ========================
def process_text(text, model_type="bert"):
entities, duration = ner(text, model_type)
relations = re_extract(entities, text)
update_knowledge_graph(entities, relations)
ent_text = "\n".join(f"{e['text']} ({e['type']}) [{e['start']}-{e['end']}]" for e in entities)
rel_text = "\n".join(f"{r['head']} --[{r['relation']}]-> {r['tail']}" for r in relations)
kg_text = visualize_kg_text()
return ent_text, rel_text, kg_text, f"{duration:.2f} 秒"
def process_file(file, model_type="bert"):
try:
with open(file.name, 'rb') as f:
content = f.read()
if len(content) > 5 * 1024 * 1024:
return "❌ 文件太大", "", "", ""
# 检测编码
try:
encoding = chardet.detect(content)['encoding'] or 'utf-8'
text = content.decode(encoding)
except UnicodeDecodeError:
# 尝试常见中文编码
for enc in ['gb18030', 'utf-16', 'big5'] :
try:
text = content.decode(enc)
break
except:
continue
else:
return "❌ 编码解析失败", "", "", ""
return process_text(text, model_type)
except Exception as e:
return f"❌ 文件处理错误: {str(e)}", "", "", ""
# ======================== 模型评估与自动标注 ========================
def convert_telegram_json_to_eval_format(path):
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "text" in data:
return [{"text": data["text"], "entities": [
{"text": data["text"][e["start"]:e["end"]]} for e in data.get("entities", [])
]}]
elif isinstance(data, list):
return data
elif isinstance(data, dict) and "messages" in data:
result = []
for m in data.get("messages", []):
if isinstance(m.get("text"), str):
result.append({"text": m["text"], "entities": []})
elif isinstance(m.get("text"), list):
txt = ''.join([x["text"] if isinstance(x, dict) else x for x in m["text"]])
result.append({"text": txt, "entities": []})
return result
return []
def evaluate_ner_model(data, model_type):
tp, fp, fn = 0, 0, 0
POS_TOLERANCE = 1
for item in data:
text = item["text"]
# 处理标注数据
gold_entities = []
for e in item.get("entities", []):
if "text" in e and "type" in e:
norm_type = LABEL_MAPPING.get(e["type"], e["type"])
gold_entities.append({
"text": e["text"],
"type": norm_type,
"start": e.get("start", -1),
"end": e.get("end", -1)
})
# 获取预测结果
pred_entities, _ = ner(text, model_type)
# 初始化匹配状态
matched_gold = [False] * len(gold_entities)
matched_pred = [False] * len(pred_entities)
# 遍历预测实体寻找匹配
for p_idx, p in enumerate(pred_entities):
for g_idx, g in enumerate(gold_entities):
if not matched_gold[g_idx] and \
p["text"] == g["text"] and \
p["type"] == g["type"] and \
abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
abs(p["end"] - g["end"]) <= POS_TOLERANCE:
matched_gold[g_idx] = True
matched_pred[p_idx] = True
break
# 统计指标
tp += sum(matched_pred)
fp += len(pred_entities) - sum(matched_pred)
fn += len(gold_entities) - sum(matched_gold)
# 处理除零情况
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return (f"Precision: {precision:.2f}\n"
f"Recall: {recall:.2f}\n"
f"F1: {f1:.2f}")
def auto_annotate(file, model_type):
data = convert_telegram_json_to_eval_format(file.name)
for item in data:
ents, _ = ner(item["text"], model_type)
item["entities"] = ents
return json.dumps(data, ensure_ascii=False, indent=2)
def save_json(json_text):
fname = f"auto_labeled_{int(time.time())}.json"
with open(fname, "w", encoding="utf-8") as f:
f.write(json_text)
return fname
# ======================== 数据集导入 ========================
def import_dataset(path="D:/云边智算/暗语识别/filtered_results"):
import os
import json
for filename in os.listdir(path):
if filename.endswith('.json'):
filepath = os.path.join(path, filename)
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# 调用现有处理流程
process_text(data['text'])
print(f"已处理文件: {filename}")
# ======================== Gradio 界面 ========================
with gr.Blocks(css="""
.kg-graph {height: 500px; overflow-y: auto;}
.warning {color: #ff6b6b;}
""") as demo:
gr.Markdown("# 🤖 聊天记录实体关系识别系统")
with gr.Tab("📄 文本分析"):
input_text = gr.Textbox(lines=6, label="输入文本")
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
btn = gr.Button("开始分析")
out1 = gr.Textbox(label="识别实体")
out2 = gr.Textbox(label="识别关系")
out3 = gr.Textbox(label="知识图谱")
out4 = gr.Textbox(label="耗时")
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
with gr.Tab("🗂 文件分析"):
file_input = gr.File(file_types=[".txt", ".json"])
file_btn = gr.Button("上传并分析")
fout1, fout2, fout3, fout4 = gr.Textbox(), gr.Textbox(), gr.Textbox(), gr.Textbox()
file_btn.click(fn=process_file, inputs=[file_input, model_type], outputs=[fout1, fout2, fout3, fout4])
with gr.Tab("📊 模型评估"):
eval_file = gr.File(label="上传标注 JSON")
eval_model = gr.Radio(["bert", "chatglm"], value="bert")
eval_btn = gr.Button("开始评估")
eval_output = gr.Textbox(label="评估结果", lines=5)
eval_btn.click(lambda f, m: evaluate_ner_model(convert_telegram_json_to_eval_format(f.name), m),
[eval_file, eval_model], eval_output)
with gr.Tab("✏️ 自动标注"):
raw_file = gr.File(label="上传 Telegram 原始 JSON")
auto_model = gr.Radio(["bert", "chatglm"], value="bert")
auto_btn = gr.Button("自动标注")
marked_texts = gr.Textbox(label="标注结果", lines=20)
download_btn = gr.Button("💾 下载标注文件")
auto_btn.click(fn=auto_annotate, inputs=[raw_file, auto_model], outputs=marked_texts)
download_btn.click(fn=save_json, inputs=marked_texts, outputs=gr.File())
with gr.Tab("📂 数据管理"):
gr.Markdown("### 数据集导入")
dataset_path = gr.Textbox(
value="D:/云边智算/暗语识别/filtered_results",
label="数据集路径"
)
import_btn = gr.Button("导入数据集到数据库")
import_output = gr.Textbox(label="导入日志")
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
demo.launch(server_name="0.0.0.0", server_port=7860)