Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel | |
| from loguru import logger | |
| import gc, torch | |
| import sys | |
| sys.path.append('./') | |
| app = FastAPI() | |
| # 全局变量用于存储当前加载的LLM模型 | |
| from LLM import LLM | |
| llm_class = LLM(mode='offline') | |
| # 默认不使用LLM模型,直接回复问题,同时减少显存占用! | |
| llm = llm_class.init_model('直接回复 Direct Reply') | |
| # 默认系统提示语 | |
| PREFIX_PROMPT = '请用少于25个字回答以下问题\n\n' | |
| PREFIX_PROMPT = '' | |
| DEFAULT_SYSTEM = '你是一个很有帮助的助手' | |
| class LLMRequest(BaseModel): | |
| question: str = '请问什么是FastAPI?' | |
| model_name: str = 'Linly' | |
| gemini_apikey: str = '' # Gemini模型的API密钥 | |
| openai_apikey: str = '' # OpenAI的API密钥 | |
| proxy_url: str = None # 代理URL | |
| async def change_model( | |
| model_name: str = Query(..., description="要加载的LLM模型名称"), | |
| gemini_apikey: str = Query('', description="Gemini API 密钥"), | |
| openai_apikey: str = Query('', description="OpenAI API 密钥"), | |
| proxy_url: str = Query(None, description="代理 URL") | |
| ): | |
| """更换LLM模型并加载相应资源。""" | |
| global llm | |
| # 清理显存(具体实现依赖于模型库) | |
| await clear_memory() | |
| try: | |
| if model_name == 'Linly': | |
| llm = llm_class.init_model('Linly', 'Linly-AI/Chinese-LLaMA-2-7B-hf', prefix_prompt=PREFIX_PROMPT) | |
| logger.info("Linly模型导入成功") | |
| elif model_name in ['Qwen', 'Qwen2']: | |
| model_path = 'Qwen/Qwen-1_8B-Chat' if model_name == 'Qwen' else 'Qwen/Qwen1.5-0.5B-Chat' | |
| llm = llm_class.init_model(model_name, model_path, prefix_prompt=PREFIX_PROMPT) | |
| logger.info(f"{model_name} 模型导入成功") | |
| elif model_name == 'Gemini': | |
| if gemini_apikey: | |
| llm = llm_class.init_model('Gemini', 'gemini-pro', gemini_apikey, proxy_url) | |
| logger.info("Gemini模型导入成功") | |
| else: | |
| raise HTTPException(status_code=400, detail="请填写Gemini的API密钥") | |
| elif model_name == 'ChatGLM': | |
| llm = llm_class.init_model('ChatGLM', 'THUDM/chatglm3-6b', prefix_prompt=PREFIX_PROMPT) | |
| logger.info("ChatGLM模型导入成功") | |
| elif model_name == 'ChatGPT': | |
| if openai_apikey: | |
| llm = llm_class.init_model('ChatGPT', api_key=openai_apikey, proxy_url=proxy_url, prefix_prompt=PREFIX_PROMPT) | |
| logger.info("ChatGPT模型导入成功") | |
| else: | |
| raise HTTPException(status_code=400, detail="请填写OpenAI的API密钥") | |
| elif model_name == 'GPT4Free': | |
| llm = llm_class.init_model('GPT4Free', prefix_prompt=PREFIX_PROMPT) | |
| logger.info("GPT4Free模型导入成功,请注意该模型可能不稳定") | |
| elif model_name == '直接回复 Direct Reply': | |
| llm = llm_class.init_model(model_name) | |
| logger.info("直接回复模式激活,不使用LLM模型") | |
| else: | |
| raise HTTPException(status_code=400, detail=f"未知的LLM模型: {model_name}") | |
| except Exception as e: | |
| logger.error(f"{model_name}模型加载失败: {e}") | |
| raise HTTPException(status_code=500, detail=f"{model_name}模型加载失败: {e}") | |
| return {"message": f"{model_name} 模型加载成功"} | |
| async def llm_response(request: LLMRequest): | |
| """处理LLM模型的问答请求。""" | |
| global llm | |
| if not request.question: | |
| raise HTTPException(status_code=400, detail="问题内容不能为空") | |
| if llm is None: | |
| raise HTTPException(status_code=400, detail="LLM模型未加载,请先加载模型") | |
| try: | |
| answer = llm.generate(request.question, DEFAULT_SYSTEM) | |
| logger.info(f"LLM 回复:{answer}") | |
| return {"answer": answer} | |
| except Exception as e: | |
| logger.error(f"处理LLM请求失败: {e}") | |
| raise HTTPException(status_code=500, detail=f"处理LLM请求失败: {e}") | |
| async def clear_memory(): | |
| """清理显存的异步函数""" | |
| logger.info("清理显存资源") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| logger.info(f"显存使用情况: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8002) |