""" 路由器和评分器模块 包含查询路由、文档相关性评分、答案质量评分和幻觉检测 """ from typing import List try: from langchain_core.prompts import PromptTemplate except ImportError: try: from langchain_core.prompts import PromptTemplate except ImportError: from langchain.prompts import PromptTemplate from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import JsonOutputParser, StrOutputParser from config import LOCAL_LLM class QueryRouter: """查询路由器,决定使用向量存储还是网络搜索""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。 对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。 你不需要严格匹配问题中与这些主题相关的关键词。 否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。 返回一个只包含'datasource'键的JSON,不要前言或解释。 要路由的问题:{question}""", input_variables=["question"], ) self.router = self.prompt | self.llm | JsonOutputParser() def route(self, question: str) -> str: """路由问题到相应的数据源""" result = self.router.invoke({"question": question}) return result.get("datasource", "web_search") class DocumentGrader: """文档相关性评分器""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估检索到的文档是否与用户问题相关。 如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。 给出二进制分数'yes'或'no',以表明文档是否与问题相关。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 检索到的文档: {document} 用户问题:{question}""", input_variables=["question", "document"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, question: str, document: str) -> str: """评估文档与问题的相关性""" result = self.grader.invoke({"question": question, "document": document}) return result.get("score", "no") class AnswerGrader: """答案质量评分器""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估答案是否有助于解决问题。 这里是答案: \n ------- \n {generation} \n ------- \n 这里是问题:{question} 给出二进制分数'yes'或'no',表示答案是否有助于解决问题。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""", input_variables=["generation", "question"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, question: str, generation: str) -> str: """评估答案质量""" result = self.grader.invoke({"question": question, "generation": generation}) return result.get("score", "no") class HallucinationGrader: """ 幻觉检测器 - 使用专业模型(Vectara + NLI) 相比 LLM-as-a-Judge 方法: - 准确率从 60-75% 提升到 85-95% - 速度提升 5-10 倍 - 成本降低 90% """ def __init__(self, method: str = "hybrid"): """ 初始化幻觉检测器 Args: method: 'vectara', 'nli', 或 'hybrid' (推荐) """ # 尝试加载专业检测模型 try: from hallucination_detector import initialize_hallucination_detector self.detector = initialize_hallucination_detector(method=method) self.use_professional_detector = True print(f"✅ 使用专业幻觉检测器: {method}") except Exception as e: print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}") self.use_professional_detector = False # 回退到原有的 LLM 方法 self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。 给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 检索到的文档: {documents} LLM生成:{generation}""", input_variables=["generation", "documents"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, generation: str, documents) -> str: """ 检测生成内容是否存在幻觉 Args: generation: LLM 生成的内容 documents: 参考文档 Returns: "yes" 表示无幻觉,"no" 表示有幻觉 """ if self.use_professional_detector: # 使用专业检测器 return self.detector.grade(generation, documents) else: # 回退到 LLM 方法 result = self.grader.invoke({"generation": generation, "documents": documents}) return result.get("score", "no") class QueryDecomposer: """查询分解器,将复杂的多跳问题分解为子问题序列""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。 分解规则: 1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。 2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤: - 步骤1: "谁是A的作者?" - 步骤2: "该作者在哪个大学?" 3. 保持子问题简洁明了。 4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。 输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。 不要输出任何前言或解释。 复杂问题: {question}""", input_variables=["question"], ) self.decomposer = self.prompt | self.llm | JsonOutputParser() def decompose(self, question: str) -> List[str]: """分解问题""" print(f"---分解问题: {question}---") try: result = self.decomposer.invoke({"question": question}) sub_queries = result.get("sub_queries", [question]) # 确保至少包含原始问题 if not sub_queries: sub_queries = [question] print(f"---子问题: {sub_queries}---") return sub_queries except Exception as e: print(f"⚠️ 分解失败: {e},使用原始问题") return [question] class AnswerabilityGrader: """答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。 原始问题: {question} 目前检索到的文档集合: {documents} 任务: 判断上述文档是否已经包含了回答原始问题所需的全部关键信息。 - 如果信息充足,可以终止进一步的检索,返回 'yes'。 - 如果信息缺失,需要继续检索更多信息,返回 'no'。 输出格式: 返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。 不要输出任何前言或解释。""", input_variables=["question", "documents"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, question: str, documents: str) -> str: """评估文档是否足以回答问题""" result = self.grader.invoke({"question": question, "documents": documents}) return result.get("score", "no") class QueryRewriter: """查询重写器,优化查询以获得更好的检索结果""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, temperature=0) self.prompt = PromptTemplate( template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。 你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。 如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。 初始问题: {question} 之前的上下文(如果有): {context} 改进的问题(只输出问题,无前言):""", input_variables=["question", "context"], ) self.rewriter = self.prompt | self.llm | StrOutputParser() def rewrite(self, question: str, context: str = "") -> str: """重写查询以获得更好的检索效果""" print(f"---原始查询: {question}---") if context: print(f"---参考上下文长度: {len(context)} 字符---") rewritten_query = self.rewriter.invoke({"question": question, "context": context}) print(f"---重写查询: {rewritten_query}---") return rewritten_query def initialize_graders_and_router(): """初始化所有评分器和路由器""" # Load detection method from config try: from hallucination_config import HALLUCINATION_DETECTION_METHOD detection_method = HALLUCINATION_DETECTION_METHOD except ImportError: detection_method = "hybrid" # Default to hybrid query_router = QueryRouter() document_grader = DocumentGrader() answer_grader = AnswerGrader() hallucination_grader = HallucinationGrader(method=detection_method) query_rewriter = QueryRewriter() query_decomposer = QueryDecomposer() answerability_grader = AnswerabilityGrader() return { "query_router": query_router, "document_grader": document_grader, "answer_grader": answer_grader, "hallucination_grader": hallucination_grader, "query_rewriter": query_rewriter, "query_decomposer": query_decomposer, "answerability_grader": answerability_grader }