Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import os, sys | |
| import gc, torch | |
| from loguru import logger | |
| import asyncio | |
| from typing import Optional | |
| sys.path.append("./") | |
| app = FastAPI() | |
| tts = None | |
| cosyvoice = None | |
| edgetts = None | |
| vits = None | |
| class TTSRequest(BaseModel): | |
| text: str = '你好,我是Linly-Talker。' | |
| voice: str = 'zh-CN-XiaoxiaoNeural' | |
| rate: float = 1.0 | |
| volume: float = 1.0 | |
| pitch: float = 1.0 | |
| speed_factor: float = 1.0 | |
| am: str = 'FastSpeech2' | |
| voc: str = 'PWGan' | |
| lang: str = 'zh' | |
| male: bool = False | |
| prompt_text: str = '' | |
| prompt_language: str = '中文' | |
| ref_audio: str = '' | |
| ref_text: str = '' | |
| ref_language: str = '中文' | |
| cut_method: str = '凑四句一切' | |
| cosyvoice_mode: str = '预训练音色' | |
| sft_dropdown: str = '中文男' | |
| seed: int = 42 | |
| tts_method: str = 'EdgeTTS' | |
| save_path: str = 'answer.wav' | |
| async def change_model(model_name: str = Query(..., description="要加载的TTS模型名称")): | |
| global tts, cosyvoice, edgetts, vits | |
| await clear_memory() | |
| try: | |
| if model_name == 'EdgeTTS': | |
| from TTS import EdgeTTS | |
| if edgetts is None: | |
| edgetts = EdgeTTS() | |
| if edgetts.network: | |
| logger.info("EdgeTTS模型加载成功") | |
| else: | |
| logger.warning("EdgeTTS模型加载失败,请检查网络连接") | |
| raise HTTPException(status_code=503, detail="EdgeTTS模型加载失败,请检查网络连接") | |
| elif model_name == 'PaddleTTS': | |
| from TTS import PaddleTTS | |
| if tts is None: | |
| tts = PaddleTTS() | |
| logger.info("PaddleTTS模型加载成功") | |
| elif model_name == 'GPT-SoVITS克隆声音': | |
| gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" | |
| sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" | |
| if vits is None: | |
| from VITS import GPT_SoVITS | |
| vits = GPT_SoVITS() | |
| vits.load_model(gpt_path, sovits_path) | |
| logger.info("GPT-SoVITS模型加载成功") | |
| elif model_name == 'CosyVoice-SFT模式': | |
| from VITS import CosyVoiceTTS | |
| model_path = 'checkpoints/CosyVoice_ckpt/CosyVoice-300M-SFT' | |
| if cosyvoice is None: | |
| cosyvoice = CosyVoiceTTS(model_path) | |
| logger.info("CosyVoice-SFT模式模型加载成功") | |
| elif model_name == 'CosyVoice-克隆翻译模式': | |
| from VITS import CosyVoiceTTS | |
| model_path = 'checkpoints/CosyVoice_ckpt/CosyVoice-300M' | |
| if cosyvoice is None: | |
| cosyvoice = CosyVoiceTTS(model_path) | |
| logger.info("CosyVoice-克隆翻译模式模型加载成功") | |
| else: | |
| logger.warning(f"未知的TTS模型: {model_name}") | |
| raise HTTPException(status_code=400, detail=f"未知的TTS模型: {model_name}") | |
| except ImportError as e: | |
| logger.error(f"导入模型 {model_name} 失败: {e}") | |
| raise HTTPException(status_code=500, detail=f"导入模型 {model_name} 失败: {e}") | |
| except Exception as e: | |
| logger.error(f"{model_name} 模型加载失败: {e}") | |
| raise HTTPException(status_code=500, detail=f"{model_name} 模型加载失败: {e}") | |
| return {"message": f"{model_name} 模型加载成功"} | |
| async def clear_memory(): | |
| logger.info("清理显存资源") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| logger.info(f"显存使用情况: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB") | |
| def save_upload_file(upload_file: UploadFile, destination: str) -> str: | |
| """保存上传的文件到指定路径""" | |
| with open(destination, "wb") as buffer: | |
| buffer.write(upload_file.file.read()) | |
| return destination | |
| def predict_edge_tts(request: TTSRequest): | |
| global edgetts | |
| if edgetts is None: | |
| raise HTTPException(status_code=400, detail="EdgeTTS 模型未加载") | |
| if not edgetts.network: | |
| raise HTTPException(status_code=503, detail="EdgeTTS 模型网络问题") | |
| try: | |
| edgetts.predict(request.text, request.voice, request.rate, request.volume, request.pitch, request.save_path, 'answer.vtt') | |
| except Exception as e: | |
| os.system(f'edge-tts --text "{request.text}" --voice {request.voice} --write-media {request.save_path} --write-subtitles answer.vtt') | |
| return request.save_path | |
| def predict_paddle_tts(request: TTSRequest): | |
| global tts | |
| if tts is None: | |
| raise HTTPException(status_code=400, detail="PaddleTTS 模型未加载") | |
| try: | |
| tts.predict(request.text, request.am, request.voc, lang=request.lang, male=request.male, save_path=request.save_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"PaddleTTS 预测失败: {e}") | |
| return request.save_path | |
| def predict_gpt_sovits(request: TTSRequest): | |
| global vits | |
| if vits is None: | |
| raise HTTPException(status_code=400, detail="GPT-SoVITS 模型未加载") | |
| try: | |
| vits.predict(ref_wav_path=request.ref_audio, prompt_text=request.prompt_text, | |
| prompt_language=request.prompt_language, text=request.text, | |
| text_language=request.ref_language, how_to_cut=request.cut_method, | |
| save_path=request.save_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"GPT-SoVITS 预测失败: {e}") | |
| return request.save_path | |
| def predict_cosyvoice(request: TTSRequest): | |
| global cosyvoice | |
| if cosyvoice is None: | |
| raise HTTPException(status_code=400, detail="CosyVoice 模型未加载") | |
| prompt_wav = None | |
| if request.ref_audio: | |
| prompt_wav = request.ref_audio | |
| if request.cosyvoice_mode in ['3s极速复刻', '跨语种复刻'] and not prompt_wav: | |
| raise HTTPException(status_code=400, detail="选择的模式需要提供 prompt 音频") | |
| try: | |
| if request.cosyvoice_mode == '预训练音色': | |
| output = cosyvoice.predict_sft(request.text, request.sft_dropdown, speed_factor=request.speed_factor, save_path=request.save_path) | |
| elif request.cosyvoice_mode == '3s极速复刻': | |
| output = cosyvoice.predict_zero_shot(request.text, request.ref_text, prompt_wav, speed_factor=request.speed_factor, save_path=request.save_path) | |
| elif request.cosyvoice_mode == '跨语种复刻': | |
| output = cosyvoice.predict_cross_lingual(request.text, prompt_wav, speed_factor=request.speed_factor, save_path=request.save_path) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"CosyVoice 预测失败: {e}") | |
| return output | |
| async def tts_response( | |
| text: str = Form('你好,我是Linly-Talker。'), | |
| voice: str = Form('zh-CN-XiaoxiaoNeural'), | |
| rate: float = Form(1.0), | |
| volume: float = Form(1.0), | |
| pitch: float = Form(1.0), | |
| speed_factor: float = Form(1.0), | |
| am: str = Form('FastSpeech2'), | |
| voc: str = Form('PWGan'), | |
| lang: str = Form('zh'), | |
| male: bool = Form(False), | |
| prompt_text: str = Form(''), | |
| prompt_language: str = Form('中文'), | |
| ref_text: str = Form(''), | |
| ref_language: str = Form('中文'), | |
| cut_method: str = Form('凑四句一切'), | |
| cosyvoice_mode: str = Form('预训练音色'), | |
| sft_dropdown: str = Form('中文男'), | |
| seed: int = Form(42), | |
| tts_method: str = Form('EdgeTTS'), | |
| save_path: str = Form('answer.wav'), | |
| ref_audio: Optional[UploadFile] = File(None) | |
| ): | |
| ref_audio_path = None | |
| if ref_audio: | |
| # 保存上传的音频文件 | |
| ref_audio_path = save_upload_file(ref_audio, "uploaded_ref_audio.wav") | |
| request = TTSRequest( | |
| text=text, | |
| voice=voice, | |
| rate=rate, | |
| volume=volume, | |
| pitch=pitch, | |
| speed_factor=speed_factor, | |
| am=am, | |
| voc=voc, | |
| lang=lang, | |
| male=male, | |
| prompt_text=prompt_text, | |
| prompt_language=prompt_language, | |
| ref_audio=ref_audio_path if ref_audio else '', | |
| ref_text=ref_text, | |
| ref_language=ref_language, | |
| cut_method=cut_method, | |
| cosyvoice_mode=cosyvoice_mode, | |
| sft_dropdown=sft_dropdown, | |
| seed=seed, | |
| tts_method=tts_method, | |
| save_path=save_path | |
| ) | |
| # print(request) | |
| if not request.text: | |
| raise HTTPException(status_code=400, detail="文本内容为空") | |
| try: | |
| if request.tts_method == 'EdgeTTS': | |
| file_path = predict_edge_tts(request) | |
| elif request.tts_method == 'PaddleTTS': | |
| file_path = predict_paddle_tts(request) | |
| elif request.tts_method == 'GPT-SoVITS克隆声音': | |
| file_path = predict_gpt_sovits(request) | |
| elif "CosyVoice" in request.tts_method: | |
| file_path = predict_cosyvoice(request) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"未知的TTS方法: {request.tts_method}") | |
| if os.path.exists(request.save_path): | |
| return FileResponse(file_path, media_type='audio/wav', filename=request.save_path) | |
| else: | |
| logger.error(f"处理TTS请求失败: {e}") | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"处理TTS请求失败: {e}") | |
| # finally: | |
| # if ref_audio: | |
| # os.remove(ref_audio_path) | |
| # os.remove(request.save_path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8001) | |