File size: 43,172 Bytes
3463341 e76df49 3463341 e76df49 3463341 e76df49 3463341 e76df49 3463341 3bd6fb0 3463341 e76df49 3463341 3bd6fb0 3463341 3bd6fb0 3463341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 |
# import subprocess
import tempfile
import os
# import json
# import shutil
import time
import librosa
import torch
import argparse
import soundfile as sf
# from pathlib import Path
import cn2an
import requests
import re
import numpy as np
import onnxruntime as ort
import axengine as axe
import threading
import queue
from collections import deque
# 导入SenseVoice相关模块
from model import SinusoidalPositionEncoder
from utils.ax_model_bin import AX_SenseVoiceSmall
from utils.ax_vad_bin import AX_Fsmn_vad
from utils.vad_utils import merge_vad
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
# 导入MeloTTS相关模块
from libmelotts.python.split_utils import split_sentence
from libmelotts.python.text import cleaned_text_to_sequence
from libmelotts.python.text.cleaner import clean_text
from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP
# 配置参数
# tts 参数
TTS_MODEL_DIR = "libmelotts/models"
TTS_MODEL_FILES = {
"g": "g-zh_mix_en.bin",
"encoder": "encoder-zh.onnx",
"decoder": "decoder-zh.axmodel"
}
# Qwen大模型翻译API参数
QWEN_API_URL = "" # API服务地址 http://10.126.29.158:8000
# TTS辅助函数
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
# def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
# norm_text, phone, tone, word2ph = clean_text(text, language_str)
# phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
# phone = intersperse(phone, 0)
# tone = intersperse(tone, 0)
# language = intersperse(language, 0)
# phone = np.array(phone, dtype=np.int32)
# tone = np.array(tone, dtype=np.int32)
# language = np.array(language, dtype=np.int32)
# word2ph = np.array(word2ph, dtype=np.int32) * 2
# word2ph[0] += 1
# return phone, tone, language, norm_text, word2ph
# 处理字符无法不识别
def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
"""修复版音素处理:确保所有数组长度一致"""
try:
norm_text, phone, tone, word2ph = clean_text(text, language_str)
# 特殊音素直接映射为空字符串
phone_mapping = {
'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '',
'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '',
'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '',
}
# 同步处理 phone 和 tone,确保它们长度一致
processed_phone = []
processed_tone = []
removed_symbols = set()
for p, t in zip(phone, tone):
if p in phone_mapping:
# 特殊音素直接删除,同时删除对应的 tone
removed_symbols.add(p)
elif p in symbol_to_id:
# 正常音素保留,同时保留对应的 tone
processed_phone.append(p)
processed_tone.append(t)
else:
# 其他未知音素也删除
removed_symbols.add(p)
# 记录被删除的音素
if removed_symbols:
print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}")
print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}")
print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}")
# 如果没有有效音素,使用默认音素,
if not processed_phone:
print("[警告] 没有有效音素,使用默认中文音素")
processed_phone = ['ni', 'hao']
processed_tone = ['1', '3']
word2ph = [1, 1]
# 确保 word2ph 的长度与处理后的音素序列匹配
if len(processed_phone) != len(phone):
print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}")
# 简单处理:重新计算 word2ph
word2ph = [1] * len(processed_phone)
phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id)
phone = intersperse(phone, 0)
tone = intersperse(tone, 0)
language = intersperse(language, 0)
phone = np.array(phone, dtype=np.int32)
tone = np.array(tone, dtype=np.int32)
language = np.array(language, dtype=np.int32)
word2ph = np.array(word2ph, dtype=np.int32) * 2
word2ph[0] += 1
return phone, tone, language, norm_text, word2ph
except Exception as e:
print(f"[错误] 文本处理失败: {e}")
import traceback
traceback.print_exc()
raise e
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05) / speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
def merge_sub_audio(sub_audio_list, pad_size, audio_len):
if pad_size > 0:
for i in range(len(sub_audio_list) - 1):
sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size]
sub_audio_list[i][-pad_size:] /= 2
if i > 0:
sub_audio_list[i] = sub_audio_list[i][pad_size:]
sub_audio = np.concatenate(sub_audio_list, axis=-1)
return sub_audio[:audio_len]
def calc_word2pronoun(word2ph, pronoun_lens):
indice = [0]
for ph in word2ph[:-1]:
indice.append(indice[-1] + ph)
word2pronoun = []
for i, ph in zip(indice, word2ph):
word2pronoun.append(np.sum(pronoun_lens[i : i + ph]))
return word2pronoun
def generate_slices(word2pronoun, dec_len):
pn_start, pn_end = 0, 0
zp_start, zp_end = 0, 0
zp_len = 0
pn_slices = []
zp_slices = []
while pn_end < len(word2pronoun):
if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len:
zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end])
zp_start = zp_end - zp_len
pn_start = pn_end - 2
else:
zp_len = 0
zp_start = zp_end
pn_start = pn_end
while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len:
zp_len += word2pronoun[pn_end]
pn_end += 1
zp_end = zp_start + zp_len
pn_slices.append(slice(pn_start, pn_end))
zp_slices.append(slice(zp_start, zp_end))
return pn_slices, zp_slices
# 确认中英文
def lang_detect_with_regex(text):
text_without_digits = re.sub(r'\d+', '', text)
if not text_without_digits:
return 'unknown'
if re.search(r'[\u4e00-\u9fff]', text_without_digits):
return 'chinese'
else:
if re.search(r'[a-zA-Z]', text_without_digits):
return 'english'
else:
return 'unknown'
class QwenTranslationAPI:
def __init__(self, api_url=QWEN_API_URL):
self.api_url = api_url
self.session_id = f"speech_translate_{int(time.time())}"
self.last_reset_time = time.time()
self.request_count = 0
self.max_requests_before_reset = 10
def reset_context(self):
"""重置API上下文"""
try:
reset_url = f"{self.api_url}/api/reset"
response = requests.post(reset_url, json={}, timeout=5)
if response.status_code == 200:
print("[翻译API] ✓ 上下文重置成功")
self.last_reset_time = time.time()
self.request_count = 0
return True
else:
print(f"[翻译API] ✗ 重置失败,状态码: {response.status_code}, 响应: {response.text}")
except Exception as e:
print(f"[翻译API] ✗ 重置上下文失败: {e}")
return False
def check_and_reset_if_needed(self):
"""检查是否需要重置上下文"""
current_time = time.time()
if (self.request_count >= 10 or
current_time - self.last_reset_time > 120): # 2分钟
print(f"[翻译API] 重置 (请求数: {self.request_count}, 时间: {current_time - self.last_reset_time:.1f}秒)")
return self.reset_context()
return True
def translate(self, text_content, max_retries=3, timeout=120):
if not text_content or text_content.strip() == "":
return "输入文本为空"
# 过滤太短的文本
if len(text_content.strip()) < 3:
return text_content
if lang_detect_with_regex(text_content)=='chinese':
prompt_f = "翻译成英文"
else:
prompt_f= "翻译成中文"
prompt = f"{prompt_f}:{text_content}"
print(f"[翻译API] 发送请求: {prompt}")
# 检查是否需要重置
self.check_and_reset_if_needed()
for attempt in range(max_retries):
try:
generate_url = f"{self.api_url}/api/generate"
payload = {
"prompt": prompt,
"temperature": 0.1,
"repetition_penalty": 1.0,
"top-p": 0.9,
"top-k": 40,
"max_new_tokens": 512
}
print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})")
response = requests.post(generate_url, json=payload, timeout=30)
response.raise_for_status()
print("[翻译API] 生成请求成功")
result_url = f"{self.api_url}/api/generate_provider"
start_time = time.time()
full_translation = ""
error_detected = False
while time.time() - start_time < timeout:
try:
result_response = requests.get(result_url, timeout=10)
result_data = result_response.json()
current_chunk = result_data.get("response", "")
# 检查是否有错误
if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower():
print(f"[翻译API] ✗ 检测到错误: {current_chunk}")
error_detected = True
print("[翻译API] 立即重置上下文...")
self.reset_context()
break
full_translation += current_chunk
if result_data.get("done", False):
if full_translation and len(full_translation.strip()) > 0:
self.request_count += 1
print(f"[翻译API] ✓ 翻译完成: {full_translation}")
return full_translation
else:
print(f"[翻译API] ✗ 翻译结果为空")
break
time.sleep(0.05)
except requests.exceptions.RequestException as e:
print(f"[翻译API] 轮询请求失败: {e}")
if time.time() - start_time > timeout:
break
time.sleep(0.05)
continue
if error_detected:
if attempt < max_retries - 1:
wait_time = 1
print(f"[翻译API] 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
continue
else:
print("[翻译API] 达到最大重试次数,返回原文")
return text_content
print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试")
except requests.exceptions.RequestException as e:
print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f"[翻译API] 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
return text_content
except Exception as e:
print(f"[翻译API] 翻译过程出错: {e}")
if attempt < max_retries - 1:
time.sleep(1)
continue
return text_content
print("[翻译API] 翻译超时,返回原文")
return text_content
class AudioResampler:
"""音频重采样器"""
def __init__(self, target_sr=16000):
self.target_sr = target_sr
def resample_audio(self, audio_data, original_sr):
"""重采样音频到目标采样率,asr统一输入16000Hz"""
if original_sr == self.target_sr:
return audio_data
print(f"[重采样] {original_sr}Hz -> {self.target_sr}Hz")
return librosa.resample(y=audio_data, orig_sr=original_sr, target_sr=self.target_sr)
def resample_chunk(self, audio_chunk, original_sr):
"""重采样音频块:长音频进行过冲采样后,音频块可以不做重采样"""
if original_sr == self.target_sr:
return audio_chunk
if len(audio_chunk) < 1000:
return self._linear_resample(audio_chunk, original_sr, self.target_sr)
else:
return librosa.resample(y=audio_chunk, orig_sr=original_sr, target_sr=self.target_sr)
def _linear_resample(self, audio_chunk, original_sr, target_sr):
"""线性插值重采样"""
ratio = target_sr / original_sr
old_length = len(audio_chunk)
new_length = int(old_length * ratio)
old_indices = np.arange(old_length)
new_indices = np.linspace(0, old_length - 1, new_length)
resampled = np.interp(new_indices, old_indices, audio_chunk)
return resampled
class StreamProcessor:
"""流式处理"""
def __init__(self, pipeline, chunk_duration=7.0, overlap_duration=0.01, target_sr=16000):
self.pipeline = pipeline
self.chunk_duration = chunk_duration # 增加4->7秒
self.overlap_duration = overlap_duration # 减少到0.1->0.01秒
self.target_sr = target_sr
self.chunk_samples = int(chunk_duration * target_sr)
self.overlap_samples = int(overlap_duration * target_sr)
self.audio_buffer = deque()
self.result_queue = queue.Queue()
self.is_running = False
self.processing_thread = None
self.resampler = AudioResampler(target_sr=target_sr)
self.segment_counter = 0 # 音频段计数器
self.processed_texts = set() # 记录已处理的文本,避免重复
def start_processing(self):
"""开始流式处理"""
self.is_running = True
self.processing_thread = threading.Thread(target=self._process_loop)
self.processing_thread.daemon = True
self.processing_thread.start()
def stop_processing(self):
"""停止流式处理"""
self.is_running = False
if self.processing_thread:
self.processing_thread.join(timeout=5)
def add_audio_chunk(self, audio_chunk, original_sr=None):
"""添加音频块到缓冲区"""
if original_sr and original_sr != self.target_sr:
audio_chunk = self.resampler.resample_chunk(audio_chunk, original_sr)
self.audio_buffer.append(audio_chunk)
def get_next_result(self, timeout=1.0):
"""获取下一个处理结果"""
try:
return self.result_queue.get(timeout=timeout)
except queue.Empty:
return None
def _process_loop(self):
"""处理循环"""
accumulated_audio = np.array([], dtype=np.float32)
last_asr_result = "" # 记录上一次的ASR结果,防止重复处理
while self.is_running:
if len(self.audio_buffer) > 0:
audio_chunk = self.audio_buffer.popleft()
accumulated_audio = np.concatenate([accumulated_audio, audio_chunk])
# 当积累的音频足够处理时
if len(accumulated_audio) >= self.chunk_samples:
# 提取处理块(减少重叠)
process_chunk = accumulated_audio[:self.chunk_samples]
accumulated_audio = accumulated_audio[self.chunk_samples - self.overlap_samples:]
try:
# 实时ASR识别
asr_result = self._stream_asr(process_chunk)
# 过滤条件:
# # 1. 文本有效且足够长
# 2. 与上次结果不同(避免重复)
# 3. 不是已处理过的文本
if (asr_result and asr_result.strip() and
# len(asr_result.strip()) >= 5 and
asr_result != last_asr_result and
asr_result not in self.processed_texts):
print(f"[实时ASR] {asr_result}")
last_asr_result = asr_result
self.processed_texts.add(asr_result)
# 实时翻译
try:
translation_result = self.pipeline.run_translation(asr_result)
# 检查翻译结果是否有效
if (translation_result and
translation_result != asr_result and
"翻译失败" not in translation_result and
"error:" not in translation_result.lower() and
"输入文本为空" not in translation_result):
print(f"[实时翻译] {translation_result}")
# TTS合成
try:
self.segment_counter += 1
tts_filename = f"stream_segment_{self.segment_counter:04d}.wav"
tts_start_time = time.time()
tts_path = self.pipeline.run_tts(
translation_result,
self.pipeline.output_dir,
tts_filename
)
tts_time = time.time() - tts_start_time
print(f"[实时TTS] 音频已保存: {tts_path} (耗时: {tts_time:.2f}秒)")
# 将完整结果放入队列
self.result_queue.put({
'type': 'complete',
'original': asr_result,
'translated': translation_result,
'audio_path': tts_path,
'timestamp': time.time(),
'segment_id': self.segment_counter
})
except Exception as tts_error:
print(f"[实时TTS错误] {tts_error}")
import traceback
traceback.print_exc()
else:
print(f"[实时翻译] 翻译结果无效,已跳过")
except Exception as translation_error:
print(f"[实时翻译错误] {translation_error}")
else:
if asr_result == last_asr_result:
print(f"[实时ASR] 重复内容已跳过: {asr_result}")
except Exception as e:
print(f"[流式处理错误] {e}")
import traceback
traceback.print_exc()
time.sleep(0.01)
def _stream_asr(self, audio_chunk):
"""流式ASR识别(带VAD)"""
try:
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 步骤1: VAD检测 - 过滤静音段
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
res_vad = self.pipeline.model_vad(audio_chunk)[0]
vad_segments = merge_vad(res_vad, 15 * 1000)
# 如果没有检测到语音段,直接返回空
if not vad_segments or len(vad_segments) == 0:
print(f"[VAD] 未检测到语音活动,跳过此音频块")
return ""
print(f"[VAD] 检测到 {len(vad_segments)} 个语音段")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 步骤2: 对检测到的语音段进行ASR识别
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
all_results = ""
for i, segment in enumerate(vad_segments):
segment_start, segment_end = segment
start_sample = int(segment_start / 1000 * self.target_sr)
end_sample = min(int(segment_end / 1000 * self.target_sr), len(audio_chunk))
segment_audio = audio_chunk[start_sample:end_sample]
# 跳过太短的片段,减少误识别(小于0.3秒)
if len(segment_audio) < int(0.3 * self.target_sr):
continue
# 写入临时文件
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
sf.write(temp_file.name, segment_audio, self.target_sr)
temp_filename = temp_file.name
try:
# ASR识别
segment_result = self.pipeline.model_bin(
temp_filename,
"auto",
True,
self.pipeline.position_encoding,
tokenizer=self.pipeline.tokenizer,
)
if segment_result and segment_result.strip():
all_results += segment_result + " "
# 清理临时文件
os.unlink(temp_filename)
except Exception as e:
print(f"[ASR错误] 处理VAD段 {i} 时出错: {e}")
if os.path.exists(temp_filename):
os.unlink(temp_filename)
continue
return all_results.strip()
except Exception as e:
print(f"[ASR错误] {e}")
return ""
class SpeechTranslationPipeline:
def __init__(self,
tts_model_dir, tts_model_files,
asr_model_dir="ax_model", seq_len=132,
tts_dec_len=128, sample_rate=44100, tts_speed=0.8,
qwen_api_url=QWEN_API_URL, target_sr=16000,
output_dir="./output"):
self.tts_model_dir = tts_model_dir
self.tts_model_files = tts_model_files
self.asr_model_dir = asr_model_dir
self.seq_len = seq_len
self.tts_dec_len = tts_dec_len
self.sample_rate = sample_rate
self.tts_speed = tts_speed
self.qwen_api_url = qwen_api_url
self.target_sr = target_sr
self.output_dir = output_dir
# 输出目录
os.makedirs(self.output_dir, exist_ok=True)
# 初始化音频重采样器
self.resampler = AudioResampler(target_sr=target_sr)
# 初始化ASR模型
self._init_asr_models()
# 初始化TTS模型
self._init_tts_models()
# 初始化翻译API
self.translator = QwenTranslationAPI(api_url=qwen_api_url)
# 初始化流式处理器
self.stream_processor = StreamProcessor(self, target_sr=target_sr)
# 验证所有必需文件存在
self._validate_files()
# 初始化时重置API上下文
print("[初始化] 重置API上下文...")
self.translator.reset_context()
def _init_asr_models(self):
"""初始化语音识别相关模型"""
print("Initializing SenseVoice models...")
self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
self.embed = SinusoidalPositionEncoder()
self.position_encoding = self.embed.get_position_encoding(
torch.randn(1, self.seq_len, 560)).numpy()
self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
print("SenseVoice models initialized successfully.")
def _init_tts_models(self):
"""初始化TTS相关模型"""
print("Initializing MeloTTS models...")
init_start = time.time()
enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"])
dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"])
model_load_start = time.time()
self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions())
self.sess_dec = axe.InferenceSession(dec_model)
print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms")
g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"])
self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1)
self.tts_language = "ZH_MIX_EN"
self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])}
print(" Warming up TTS modules...")
warmup_start = time.time()
try:
warmup_text_mix = "这是一个test测试。"
_, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id)
print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start)*1000:.2f}ms")
except Exception as e:
print(f" Warning: Mixed warm-up failed: {e}")
total_init_time = (time.time() - init_start) * 1000
print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms")
def _validate_files(self):
"""验证所有必需的文件都存在"""
for key, filename in self.tts_model_files.items():
filepath = os.path.join(self.tts_model_dir, filename)
if not os.path.exists(filepath):
raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
try:
response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5)
print("[API检查] 千问API服务连接正常")
except:
print("[API警告] 无法连接到千问API服务,请确保已启动API服务")
def start_stream_processing(self):
"""开始流式处理"""
self.stream_processor.start_processing()
print("[流式处理] 已启动")
def stop_stream_processing(self):
"""停止流式处理"""
self.stream_processor.stop_processing()
print("[流式处理] 已停止")
def process_audio_stream(self, audio_chunk, original_sr=None):
"""处理音频流数据"""
self.stream_processor.add_audio_chunk(audio_chunk, original_sr)
def get_stream_results(self):
"""获取流式处理结果"""
return self.stream_processor.get_next_result()
def load_and_resample_audio(self, audio_file):
"""加载音频并重采样到目标采样率"""
print(f"加载音频文件: {audio_file}")
speech, original_sr = librosa.load(audio_file, sr=None)
audio_duration = len(speech) / original_sr
print(f"原始音频: {original_sr}Hz, 时长: {audio_duration:.2f}秒")
if original_sr != self.target_sr:
speech = self.resampler.resample_audio(speech, original_sr)
print(f"重采样后: {self.target_sr}Hz, 时长: {len(speech)/self.target_sr:.2f}秒")
return speech, self.target_sr
def run_translation(self, text_content):
"""调用Qwen大模型API中英互译"""
print("Starting translation via API...")
translation_start_time = time.time()
translate_content = self.translator.translate(text_content)
translation_time_cost = time.time() - translation_start_time
print(f"Translation processing time: {translation_time_cost:.2f} seconds")
print(f"Translation Result: {translate_content}")
return translate_content
def run_tts(self, translate_content, output_dir, output_wav=None):
"""使用TTS模型合成语音"""
output_path = os.path.join(output_dir, output_wav)
try:
if lang_detect_with_regex(translate_content) == "chinese":
translate_content = cn2an.transform(translate_content, "an2cn")
print(f"TTS synthesis for text: {translate_content}")
sens = split_sentence(translate_content, language_str=self.tts_language)
print(f"Text split into {len(sens)} sentences")
audio_list = []
for n, se in enumerate(sens):
if self.tts_language in ['EN', 'ZH_MIX_EN']:
se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se)
print(f"Processing sentence[{n}]: {se}")
phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer(
se, self.tts_language, symbol_to_id=self.symbol_to_id)
encoder_start = time.time()
z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={
'phone': phones, 'g': self.tts_g,
'tone': tones, 'language': lang_ids,
'noise_scale': np.array([0], dtype=np.float32),
'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32),
'noise_scale_w': np.array([0], dtype=np.float32),
'sdp_ratio': np.array([0], dtype=np.float32)})
print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms")
word2pronoun = calc_word2pronoun(word2ph, pronoun_lens)
pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len)
audio_len = audio_len[0]
sub_audio_list = []
for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)):
zp_slice = z_p[..., zs]
sub_dec_len = zp_slice.shape[-1]
sub_audio_len = 512 * sub_dec_len
if zp_slice.shape[-1] < self.tts_dec_len:
zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1)
decoder_start = time.time()
audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten()
audio_start = 0
if len(sub_audio_list) > 0:
if pn_slices[i - 1].stop > ps.start:
audio_start = 512 * word2pronoun[ps.start]
audio_end = sub_audio_len
if i < len(pn_slices) - 1:
if ps.stop > pn_slices[i + 1].start:
audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1]
audio = audio[audio_start:audio_end]
print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms")
sub_audio_list.append(audio)
sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len)
audio_list.append(sub_audio)
audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed)
sf.write(output_path, audio, self.sample_rate)
print(f"TTS audio saved to {output_path}")
return output_path
except Exception as e:
print(f"TTS synthesis failed: {e}")
import traceback
traceback.print_exc()
raise e
def process_long_audio_stream(self, audio_file, chunk_size=64000):
"""
处理长音频文件的流式模拟
chunk_size增加到64000(4秒 * 16000Hz),与StreamProcessor的chunk_duration匹配
4秒有点短,改到7秒感觉更好点
"""
print(f"[流式处理] 开始处理长音频: {audio_file}")
# 加载并重采样音频
speech, fs = self.load_and_resample_audio(audio_file)
# 启动流式处理
self.start_stream_processing()
total_chunks = (len(speech) + chunk_size - 1) // chunk_size
print(f"[流式处理] 音频总长度: {len(speech)/fs:.2f}秒, 分块数: {total_chunks}")
# 收集所有结果
all_results = []
# 模拟流式输入
chunk_count = 0
for i in range(0, len(speech), chunk_size):
chunk = speech[i:i+chunk_size]
chunk_count += 1
# 处理最后一块:如果不足chunk_size,填零补齐
if len(chunk) < chunk_size:
padding_size = chunk_size - len(chunk)
chunk = np.concatenate([chunk, np.zeros(padding_size, dtype=np.float32)])
print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks} (最后一块,已填零 {padding_size} 样本)")
else:
print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks}")
self.process_audio_stream(chunk, fs)
# 获取并显示实时结果
result = self.get_stream_results()
while result:
print(f"\n{'='*70}")
print(f"[实时结果 #{len(all_results) + 1}]")
print(f"段落ID: {result['segment_id']}")
print(f"原文: {result['original']}")
print(f"翻译: {result['translated']}")
print(f"音频: {result['audio_path']}")
print(f"{'='*70}")
all_results.append(result)
result = self.get_stream_results()
time.sleep(0.01)
# 输出结果
# print(f"\n[流式处理] 等待处理剩余音频块...")
max_wait_time = 20 # 增加等待时间到20秒
wait_start = time.time()
while time.time() - wait_start < max_wait_time:
result = self.get_stream_results()
if result:
print(f"\n{'='*70}")
print(f"[实时结果 #{len(all_results) + 1}]")
print(f"段落ID: {result['segment_id']}")
print(f"原文: {result['original']}")
print(f"翻译: {result['translated']}")
print(f"音频: {result['audio_path']}")
print(f"{'='*70}")
all_results.append(result)
wait_start = time.time() # 重置等待时间
else:
time.sleep(0.02)
# 停止流式处理
self.stop_stream_processing()
print(f"\n[流式处理] 完成!共处理 {len(all_results)} 个有效结果")
return all_results
def main():
parser = argparse.ArgumentParser(description="实时语音翻译pipeline")
parser.add_argument("--audio_file", type=str, default="./wav/en_6mins.wav", help="输入音频文件路径")
parser.add_argument("--output_dir", type=str, default="./output", help="输出目录")
parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API服务器URL")
parser.add_argument("--target_sr", type=int, default=16000, help="ASR目标采样率 (默认: 16000)")
parser.add_argument("--chunk_duration", type=float, default=7.0, help="音频块时长(秒) (默认: 7.0)")
parser.add_argument("--overlap_duration", type=float, default=0.01, help="重叠时长(秒) (默认: 0.1)")
args = parser.parse_args()
print("-------------------实时语音翻译pipeline-------------------\n")
os.makedirs(args.output_dir, exist_ok=True)
print(f"处理音频文件: {args.audio_file}")
print(f"输出目录: {args.output_dir}")
print(f"音频块时长: {args.chunk_duration}秒")
print(f"重叠时长: {args.overlap_duration}秒\n")
# 初始化Pipeline
pipeline = SpeechTranslationPipeline(
tts_model_dir=TTS_MODEL_DIR,
tts_model_files=TTS_MODEL_FILES,
asr_model_dir="ax_model",
seq_len=132,
tts_dec_len=128,
sample_rate=44100,
tts_speed=0.8,
qwen_api_url=args.api_url,
target_sr=args.target_sr,
output_dir=args.output_dir
)
# # 可选:调整流式处理参数
# if args.chunk_duration != 7.0 or args.overlap_duration != 0.01:
# pipeline.stream_processor.chunk_duration = args.chunk_duration
# pipeline.stream_processor.overlap_duration = args.overlap_duration
# pipeline.stream_processor.chunk_samples = int(args.chunk_duration * args.target_sr)
# pipeline.stream_processor.overlap_samples = int(args.overlap_duration * args.target_sr)
# print(f"[配置] 已更新流式处理参数: chunk_duration={args.chunk_duration}s, overlap_duration={args.overlap_duration}s\n")
start_time = time.time()
try:
# 流式处理模式
print("="*70 + "\n")
# 计算chunk_size以匹配chunk_duration
chunk_size = int(args.chunk_duration * args.target_sr)
results = pipeline.process_long_audio_stream(args.audio_file, chunk_size=chunk_size)
print("\n" + "="*70)
print(" 处理完成")
print("="*70)
print(f"\n 成功处理 {len(results)} 个有效翻译段落\n")
# 显示所有结果
if results:
print("所有翻译结果:")
print("-" * 70)
for idx, result in enumerate(results, 1):
print(f"\n【段落 {idx}】(ID: {result['segment_id']})")
print(f" 原文: {result['original']}")
print(f" 译文: {result['translated']}")
print(f" 音频: {result['audio_path']}")
print(f" 时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}")
print("-" * 70)
# 保存结果到文件
result_file = os.path.join(args.output_dir, "stream_results.txt")
with open(result_file, 'w', encoding='utf-8') as f:
f.write(f"流式翻译+TTS结果 - {args.audio_file}\n")
f.write(f"处理时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"音频块时长: {args.chunk_duration}秒, 重叠时长: {args.overlap_duration}秒\n")
f.write("="*70 + "\n\n")
for idx, result in enumerate(results, 1):
f.write(f"【段落 {idx}】(ID: {result['segment_id']})\n")
f.write(f"原文: {result['original']}\n")
f.write(f"译文: {result['translated']}\n")
f.write(f"音频: {result['audio_path']}\n")
f.write(f"时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}\n")
f.write("\n" + "-"*70 + "\n\n")
print(f"\n✓ 结果已保存到: {result_file}")
# 统计音频文件
audio_files = [r['audio_path'] for r in results]
print(f"\n 生成 {len(audio_files)} 个TTS音频文件:")
for audio_file in audio_files:
print(f" - {audio_file}")
else:
print("\n 未获取到有效的翻译结果")
print("="*70)
# 总耗时
total_time = time.time() - start_time
print(f"\n总处理时间: {total_time:.2f} 秒")
except Exception as e:
print(f"Pipeline执行失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
|