File size: 4,509 Bytes
e90b704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from loguru import logger
import gc, torch
import sys
sys.path.append('./')
app = FastAPI()

# 全局变量用于存储当前加载的LLM模型
from LLM import LLM
llm_class = LLM(mode='offline')

# 默认不使用LLM模型,直接回复问题,同时减少显存占用!
llm = llm_class.init_model('直接回复 Direct Reply')

# 默认系统提示语
PREFIX_PROMPT = '请用少于25个字回答以下问题\n\n'
PREFIX_PROMPT = ''

DEFAULT_SYSTEM = '你是一个很有帮助的助手'

class LLMRequest(BaseModel):
    question: str = '请问什么是FastAPI?'
    model_name: str = 'Linly'
    gemini_apikey: str = ''  # Gemini模型的API密钥
    openai_apikey: str = ''  # OpenAI的API密钥
    proxy_url: str = None  # 代理URL

@app.post("/llm_change_model/")
async def change_model(
    model_name: str = Query(..., description="要加载的LLM模型名称"),
    gemini_apikey: str = Query('', description="Gemini API 密钥"),
    openai_apikey: str = Query('', description="OpenAI API 密钥"),
    proxy_url: str = Query(None, description="代理 URL")
):
    """更换LLM模型并加载相应资源。"""
    global llm

    # 清理显存(具体实现依赖于模型库)
    await clear_memory()

    try:
        if model_name == 'Linly':
            llm = llm_class.init_model('Linly', 'Linly-AI/Chinese-LLaMA-2-7B-hf', prefix_prompt=PREFIX_PROMPT)
            logger.info("Linly模型导入成功")
        elif model_name in ['Qwen', 'Qwen2']:
            model_path = 'Qwen/Qwen-1_8B-Chat' if model_name == 'Qwen' else 'Qwen/Qwen1.5-0.5B-Chat'
            llm = llm_class.init_model(model_name, model_path, prefix_prompt=PREFIX_PROMPT)
            logger.info(f"{model_name} 模型导入成功")
        elif model_name == 'Gemini':
            if gemini_apikey:
                llm = llm_class.init_model('Gemini', 'gemini-pro', gemini_apikey, proxy_url)
                logger.info("Gemini模型导入成功")
            else:
                raise HTTPException(status_code=400, detail="请填写Gemini的API密钥")
        elif model_name == 'ChatGLM':
            llm = llm_class.init_model('ChatGLM', 'THUDM/chatglm3-6b', prefix_prompt=PREFIX_PROMPT)
            logger.info("ChatGLM模型导入成功")
        elif model_name == 'ChatGPT':
            if openai_apikey:
                llm = llm_class.init_model('ChatGPT', api_key=openai_apikey, proxy_url=proxy_url, prefix_prompt=PREFIX_PROMPT)
                logger.info("ChatGPT模型导入成功")
            else:
                raise HTTPException(status_code=400, detail="请填写OpenAI的API密钥")
        elif model_name == 'GPT4Free':
            llm = llm_class.init_model('GPT4Free', prefix_prompt=PREFIX_PROMPT)
            logger.info("GPT4Free模型导入成功,请注意该模型可能不稳定")
        elif model_name == '直接回复 Direct Reply':
            llm = llm_class.init_model(model_name)
            logger.info("直接回复模式激活,不使用LLM模型")
        else:
            raise HTTPException(status_code=400, detail=f"未知的LLM模型: {model_name}")
    except Exception as e:
        logger.error(f"{model_name}模型加载失败: {e}")
        raise HTTPException(status_code=500, detail=f"{model_name}模型加载失败: {e}")

    return {"message": f"{model_name} 模型加载成功"}

@app.post("/llm_response/")
async def llm_response(request: LLMRequest):
    """处理LLM模型的问答请求。"""
    global llm

    if not request.question:
        raise HTTPException(status_code=400, detail="问题内容不能为空")

    if llm is None:
        raise HTTPException(status_code=400, detail="LLM模型未加载,请先加载模型")

    try:
        answer = llm.generate(request.question, DEFAULT_SYSTEM)
        logger.info(f"LLM 回复:{answer}")
        return {"answer": answer}
    except Exception as e:
        logger.error(f"处理LLM请求失败: {e}")
        raise HTTPException(status_code=500, detail=f"处理LLM请求失败: {e}")

async def clear_memory():
    """清理显存的异步函数"""
    logger.info("清理显存资源")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        logger.info(f"显存使用情况: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8002)