Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +33 -44
- db_config.ini +6 -0
app.py
CHANGED
|
@@ -10,13 +10,13 @@ import time
|
|
| 10 |
# ======================== 数据库模块 ========================
|
| 11 |
import pymysql
|
| 12 |
from configparser import ConfigParser
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
def get_db_connection():
|
| 16 |
config = ConfigParser()
|
| 17 |
config.read('db_config.ini')
|
| 18 |
-
|
| 19 |
-
return pymysql.connect(
|
| 20 |
host=config.get('mysql', 'host'),
|
| 21 |
user=config.get('mysql', 'user'),
|
| 22 |
password=config.get('mysql', 'password'),
|
|
@@ -25,7 +25,10 @@ def get_db_connection():
|
|
| 25 |
charset=config.get('mysql', 'charset', fallback='utf8mb4'),
|
| 26 |
cursorclass=pymysql.cursors.DictCursor
|
| 27 |
)
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def save_to_db(table, data):
|
| 31 |
conn = None
|
|
@@ -122,30 +125,27 @@ def visualize_kg_text():
|
|
| 122 |
|
| 123 |
# ======================== 实体识别(NER) ========================
|
| 124 |
def merge_adjacent_entities(entities):
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
if not merged:
|
| 128 |
-
merged.append(entity)
|
| 129 |
-
continue
|
| 130 |
|
|
|
|
|
|
|
| 131 |
last = merged[-1]
|
| 132 |
# 合并相邻的同类型实体
|
| 133 |
if (entity["type"] == last["type"] and
|
| 134 |
-
entity["start"] == last["end"]
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
"text": last["text"] + entity["text"],
|
| 138 |
-
"type": last["type"],
|
| 139 |
-
"start": last["start"],
|
| 140 |
-
"end": entity["end"]
|
| 141 |
-
}
|
| 142 |
else:
|
| 143 |
merged.append(entity)
|
|
|
|
| 144 |
return merged
|
| 145 |
|
| 146 |
|
| 147 |
def ner(text, model_type="bert"):
|
| 148 |
start_time = time.time()
|
|
|
|
|
|
|
| 149 |
if model_type == "chatglm" and use_chatglm:
|
| 150 |
try:
|
| 151 |
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
|
@@ -155,47 +155,35 @@ def ner(text, model_type="bert"):
|
|
| 155 |
if isinstance(response, tuple):
|
| 156 |
response = response[0]
|
| 157 |
|
| 158 |
-
# 增强 JSON 解析
|
| 159 |
try:
|
| 160 |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
| 161 |
entities = json.loads(json_str)
|
| 162 |
-
|
| 163 |
-
valid_entities = []
|
| 164 |
-
for ent in entities:
|
| 165 |
-
if all(k in ent for k in ("text", "type", "start", "end")):
|
| 166 |
-
valid_entities.append(ent)
|
| 167 |
return valid_entities, time.time() - start_time
|
| 168 |
except Exception as e:
|
| 169 |
-
print(f"JSON
|
| 170 |
return [], time.time() - start_time
|
| 171 |
except Exception as e:
|
| 172 |
-
print(f"ChatGLM
|
| 173 |
return [], time.time() - start_time
|
| 174 |
|
| 175 |
-
#
|
|
|
|
| 176 |
raw_results = []
|
| 177 |
-
max_len = 510 # 安全一点,留一点空余
|
| 178 |
-
text_chunks = [text[i:i + max_len] for i in range(0, len(text), max_len)]
|
| 179 |
-
|
| 180 |
for idx, chunk in enumerate(text_chunks):
|
| 181 |
chunk_results = bert_ner_pipeline(chunk)
|
| 182 |
-
# 修正每个 chunk 里识别的实体在整体文本中的位置
|
| 183 |
for r in chunk_results:
|
| 184 |
-
r["start"] += idx *
|
| 185 |
-
r["end"] += idx *
|
| 186 |
raw_results.extend(chunk_results)
|
| 187 |
|
| 188 |
-
entities = [
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
"type": mapped_type
|
| 196 |
-
})
|
| 197 |
-
|
| 198 |
-
# 执行合并处理
|
| 199 |
entities = merge_adjacent_entities(entities)
|
| 200 |
return entities, time.time() - start_time
|
| 201 |
|
|
@@ -349,7 +337,7 @@ def process_file(file, model_type="bert"):
|
|
| 349 |
text = content.decode(encoding)
|
| 350 |
except UnicodeDecodeError:
|
| 351 |
# 尝试常见中文编码
|
| 352 |
-
for enc in ['gb18030', 'utf-16', 'big5']:
|
| 353 |
try:
|
| 354 |
text = content.decode(enc)
|
| 355 |
break
|
|
@@ -363,6 +351,7 @@ def process_file(file, model_type="bert"):
|
|
| 363 |
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
| 364 |
|
| 365 |
|
|
|
|
| 366 |
# ======================== 模型评估与自动标注 ========================
|
| 367 |
def convert_telegram_json_to_eval_format(path):
|
| 368 |
with open(path, encoding="utf-8") as f:
|
|
|
|
| 10 |
# ======================== 数据库模块 ========================
|
| 11 |
import pymysql
|
| 12 |
from configparser import ConfigParser
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
|
| 15 |
+
@contextmanager
|
| 16 |
def get_db_connection():
|
| 17 |
config = ConfigParser()
|
| 18 |
config.read('db_config.ini')
|
| 19 |
+
conn = pymysql.connect(
|
|
|
|
| 20 |
host=config.get('mysql', 'host'),
|
| 21 |
user=config.get('mysql', 'user'),
|
| 22 |
password=config.get('mysql', 'password'),
|
|
|
|
| 25 |
charset=config.get('mysql', 'charset', fallback='utf8mb4'),
|
| 26 |
cursorclass=pymysql.cursors.DictCursor
|
| 27 |
)
|
| 28 |
+
try:
|
| 29 |
+
yield conn
|
| 30 |
+
finally:
|
| 31 |
+
conn.close()
|
| 32 |
|
| 33 |
def save_to_db(table, data):
|
| 34 |
conn = None
|
|
|
|
| 125 |
|
| 126 |
# ======================== 实体识别(NER) ========================
|
| 127 |
def merge_adjacent_entities(entities):
|
| 128 |
+
if not entities:
|
| 129 |
+
return entities
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
merged = [entities[0]]
|
| 132 |
+
for entity in entities[1:]:
|
| 133 |
last = merged[-1]
|
| 134 |
# 合并相邻的同类型实体
|
| 135 |
if (entity["type"] == last["type"] and
|
| 136 |
+
entity["start"] == last["end"]):
|
| 137 |
+
last["text"] += entity["text"]
|
| 138 |
+
last["end"] = entity["end"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
else:
|
| 140 |
merged.append(entity)
|
| 141 |
+
|
| 142 |
return merged
|
| 143 |
|
| 144 |
|
| 145 |
def ner(text, model_type="bert"):
|
| 146 |
start_time = time.time()
|
| 147 |
+
|
| 148 |
+
# 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
|
| 149 |
if model_type == "chatglm" and use_chatglm:
|
| 150 |
try:
|
| 151 |
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
|
|
|
| 155 |
if isinstance(response, tuple):
|
| 156 |
response = response[0]
|
| 157 |
|
|
|
|
| 158 |
try:
|
| 159 |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
| 160 |
entities = json.loads(json_str)
|
| 161 |
+
valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
return valid_entities, time.time() - start_time
|
| 163 |
except Exception as e:
|
| 164 |
+
print(f"JSON解析失败: {e}")
|
| 165 |
return [], time.time() - start_time
|
| 166 |
except Exception as e:
|
| 167 |
+
print(f"ChatGLM调用失败: {e}")
|
| 168 |
return [], time.time() - start_time
|
| 169 |
|
| 170 |
+
# 使用BERT NER
|
| 171 |
+
text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
|
| 172 |
raw_results = []
|
|
|
|
|
|
|
|
|
|
| 173 |
for idx, chunk in enumerate(text_chunks):
|
| 174 |
chunk_results = bert_ner_pipeline(chunk)
|
|
|
|
| 175 |
for r in chunk_results:
|
| 176 |
+
r["start"] += idx * 510
|
| 177 |
+
r["end"] += idx * 510
|
| 178 |
raw_results.extend(chunk_results)
|
| 179 |
|
| 180 |
+
entities = [{
|
| 181 |
+
"text": r['word'].replace(' ', ''),
|
| 182 |
+
"start": r['start'],
|
| 183 |
+
"end": r['end'],
|
| 184 |
+
"type": LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
|
| 185 |
+
} for r in raw_results]
|
| 186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
entities = merge_adjacent_entities(entities)
|
| 188 |
return entities, time.time() - start_time
|
| 189 |
|
|
|
|
| 337 |
text = content.decode(encoding)
|
| 338 |
except UnicodeDecodeError:
|
| 339 |
# 尝试常见中文编码
|
| 340 |
+
for enc in ['gb18030', 'utf-16', 'big5'] :
|
| 341 |
try:
|
| 342 |
text = content.decode(enc)
|
| 343 |
break
|
|
|
|
| 351 |
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
| 352 |
|
| 353 |
|
| 354 |
+
|
| 355 |
# ======================== 模型评估与自动标注 ========================
|
| 356 |
def convert_telegram_json_to_eval_format(path):
|
| 357 |
with open(path, encoding="utf-8") as f:
|
db_config.ini
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[mysql]
|
| 2 |
+
host = localhost
|
| 3 |
+
user = root
|
| 4 |
+
password = 123456
|
| 5 |
+
database = entity_kg
|
| 6 |
+
charset = utf8mb4
|