HY-2012 commited on
Commit
1ebaeb9
·
verified ·
1 Parent(s): 1ff1404

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. README.md +118 -3
  3. ax_model/.gitattributes +2 -0
  4. ax_model/auto.npy +3 -0
  5. ax_model/chn_jpn_yue_eng_ko_spectok.bpe.model +3 -0
  6. ax_model/event_emo.npy +3 -0
  7. ax_model/sensevoice.axmodel +3 -0
  8. ax_model/sensevoice/am.mvn +8 -0
  9. ax_model/sensevoice/config.yaml +97 -0
  10. ax_model/vad/am.mvn +8 -0
  11. ax_model/vad/config.yaml +56 -0
  12. ax_model/withitn.npy +3 -0
  13. ax_speech_translate_demo.py +347 -0
  14. libmelotts/install/libonnxruntime.so +3 -0
  15. libmelotts/install/libonnxruntime.so.1.14.0 +3 -0
  16. libmelotts/install/libonnxruntime_providers_shared.so +0 -0
  17. libmelotts/install/melotts +3 -0
  18. libmelotts/models/decoder-en.axmodel +3 -0
  19. libmelotts/models/decoder-zh.axmodel +3 -0
  20. libmelotts/models/encoder-en.onnx +3 -0
  21. libmelotts/models/encoder-zh.onnx +3 -0
  22. libmelotts/models/g-en.bin +3 -0
  23. libmelotts/models/g-jp.bin +3 -0
  24. libmelotts/models/g-zh_mix_en.bin +3 -0
  25. libmelotts/models/lexicon.txt +0 -0
  26. libmelotts/models/tokens.txt +112 -0
  27. libtranslate/libax_translate.so +3 -0
  28. libtranslate/libsentencepiece.so.0 +3 -0
  29. libtranslate/opus-mt-en-zh.axmodel +3 -0
  30. libtranslate/opus-mt-en-zh/.gitattributes +9 -0
  31. libtranslate/opus-mt-en-zh/README.md +96 -0
  32. libtranslate/opus-mt-en-zh/config.json +61 -0
  33. libtranslate/opus-mt-en-zh/generation_config.json +16 -0
  34. libtranslate/opus-mt-en-zh/metadata.json +1 -0
  35. libtranslate/opus-mt-en-zh/source.spm +3 -0
  36. libtranslate/opus-mt-en-zh/target.spm +3 -0
  37. libtranslate/opus-mt-en-zh/tokenizer_config.json +1 -0
  38. libtranslate/opus-mt-en-zh/vocab.json +0 -0
  39. libtranslate/test_translate +0 -0
  40. model.py +942 -0
  41. requirements.txt +5 -0
  42. utils/__init__.py +0 -0
  43. utils/ax_model_bin.py +241 -0
  44. utils/ax_vad_bin.py +156 -0
  45. utils/ctc_alignment.py +76 -0
  46. utils/frontend.py +433 -0
  47. utils/infer_utils.py +312 -0
  48. utils/utils/__init__.py +0 -0
  49. utils/utils/e2e_vad.py +711 -0
  50. utils/utils/frontend.py +448 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ax_model/sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ libmelotts/install/libonnxruntime.so filter=lfs diff=lfs merge=lfs -text
38
+ libmelotts/install/libonnxruntime.so.1.14.0 filter=lfs diff=lfs merge=lfs -text
39
+ libmelotts/install/melotts filter=lfs diff=lfs merge=lfs -text
40
+ libmelotts/models/decoder-en.axmodel filter=lfs diff=lfs merge=lfs -text
41
+ libmelotts/models/decoder-zh.axmodel filter=lfs diff=lfs merge=lfs -text
42
+ libtranslate/libax_translate.so filter=lfs diff=lfs merge=lfs -text
43
+ libtranslate/libsentencepiece.so.0 filter=lfs diff=lfs merge=lfs -text
44
+ libtranslate/opus-mt-en-zh/source.spm filter=lfs diff=lfs merge=lfs -text
45
+ libtranslate/opus-mt-en-zh/target.spm filter=lfs diff=lfs merge=lfs -text
46
+ libtranslate/opus-mt-en-zh.axmodel filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,118 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - zh
6
+ pipeline_tag: Speech-Translation
7
+ base_model:
8
+ - FunAudioLLM/SenseVoiceSmall
9
+ - opus-mt-en-zh
10
+ - MeloTTS
11
+ tags:
12
+ - VAD
13
+ - ASR
14
+ - Translation
15
+ - TTS
16
+ ---
17
+
18
+
19
+ # Speech-Translation.axera
20
+
21
+ speech translation demo on Axera
22
+
23
+ - [x] Python 示例
24
+ - [ ] C++ 示例
25
+
26
+ ## Convert tools links:
27
+
28
+ For those who are interested in model conversion, you can try to export axmodel through the original repo :
29
+ How to Convert from ONNX to axmodel
30
+ - [ASR](https://github.com/AXERA-TECH/3D-Speaker-MT.axera/tree/main/model_convert)
31
+ - [MeloTTS](https://github.com/ml-inory/melotts.axera/tree/main/model_convert)
32
+
33
+ ## 支持平台
34
+
35
+ - AX650N
36
+
37
+ ## 功能
38
+
39
+ 语音翻译(支持语言:英->中)
40
+
41
+ ## Pipeline组件
42
+
43
+ - [ASR](https://github.com/AXERA-TECH/3D-Speaker-MT.axera/tree/main)
44
+ - [Text-Translate](https://github.com/AXERA-TECH/libtranslate.axera/tree/master),参考生成库文件,保存到libtranslate
45
+ - [MeloTTS](https://github.com/ml-inory/melotts.axera/tree/main/cpp),参考生成库文件,保存到libmelotts
46
+
47
+ ## 上板部署
48
+
49
+ - AX650N 的设备已预装 Ubuntu22.04
50
+ - 以 root 权限登陆 AX650N 的板卡设备
51
+ - 链接互联网,确保 AX650N 的设备能正常执行 apt install, pip install 等指令
52
+ - 已验证设备:AX650N DEMO Board
53
+
54
+ ## Python API 运行
55
+
56
+ 在python3.10(验证)
57
+
58
+ 1、添加动态库
59
+ ```
60
+ export LD_LIBRARY_PATH=./libtranslate/:$LD_LIBRARY_PATH
61
+ export LD_LIBRARY_PATH=./libmelotts/install/:$LD_LIBRARY_PATH
62
+ ```
63
+ 2、安装python库
64
+ ```
65
+ pip3 install -r requirements.txt
66
+ ```
67
+
68
+ ## 在开发板运行以下命令
69
+
70
+ ```
71
+ 支持输入音频文件格式:wav,mp3
72
+ ```
73
+
74
+ ```
75
+ python3 pipeline.py --audio_file wav/en.mp3 --output_dir output
76
+ ```
77
+
78
+ 运行参数说明:
79
+
80
+ | 参数名称 | 说明|
81
+ |-------|------|
82
+ | `--audio_file` | 音频路径 |
83
+ | `--output_dir` | 结果保存路径 |
84
+
85
+
86
+ 输出保存为wav文件,具体结果如下:
87
+ ```
88
+ 原始音频: wav/en.mp3
89
+ 原始文本: The tribal chieftain called for the boy and presented him with 50 pieces of gold.
90
+ 翻译文本: 部落酋长召唤了男孩 给他50块黄金
91
+ 生成音频: output/output.wav
92
+ ```
93
+
94
+ ## Latency
95
+
96
+ AX650N
97
+
98
+ RTF: 约为2.0
99
+ ```
100
+ eg:
101
+ Inference time for en.mp3: 13.04 seconds
102
+ - VAD + ASR processing time: 0.89 seconds
103
+ - Translate time: 3.95 seconds
104
+ - TTS time: 8.20 seconds
105
+ Audio duration: 7.18 seconds
106
+ RTF: 1.82
107
+ ```
108
+
109
+ 参考:
110
+ - [sensevoice.axera](https://github.com/ml-inory/sensevoice.axera/tree/main)
111
+ - [3D-Speaker.axera](https://github.com/AXERA-TECH/3D-Speaker.axera/tree/master)
112
+ - [libtranslate.axera](https://github.com/AXERA-TECH/libtranslate.axera/tree/master)
113
+ - [melotts.axera](https://github.com/ml-inory/melotts.axera/tree/main)
114
+
115
+ ## 技术讨论
116
+
117
+ - Github issues
118
+ - QQ 群: 139953715
ax_model/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.axmodel filter=lfs diff=lfs merge=lfs -text
2
+ *.npy filter=lfs diff=lfs merge=lfs -text
ax_model/auto.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0997706b30274f7ff3b157ca90df50b7ed8ced35091a0231700355d5ee1374
3
+ size 2368
ax_model/chn_jpn_yue_eng_ko_spectok.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
3
+ size 377341
ax_model/event_emo.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d22e3df5d192fdc3e73e368a2cb576975a5a43a114a8432a91c036adf8e2263
3
+ size 4608
ax_model/sensevoice.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b64a36fa15e75ab5e3b75f18ae87a058970cff76219407e503b54fb53dd8e38
3
+ size 262170623
ax_model/sensevoice/am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 560 560
3
+ [ 0 ]
4
+ <AddShift> 560 560
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 560 560
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
ax_model/sensevoice/config.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ encoder: SenseVoiceEncoderSmall
2
+ encoder_conf:
3
+ output_size: 512
4
+ attention_heads: 4
5
+ linear_units: 2048
6
+ num_blocks: 50
7
+ tp_blocks: 20
8
+ dropout_rate: 0.1
9
+ positional_dropout_rate: 0.1
10
+ attention_dropout_rate: 0.1
11
+ input_layer: pe
12
+ pos_enc_class: SinusoidalPositionEncoder
13
+ normalize_before: true
14
+ kernel_size: 11
15
+ sanm_shfit: 0
16
+ selfattention_layer_type: sanm
17
+
18
+
19
+ model: SenseVoiceSmall
20
+ model_conf:
21
+ length_normalized_loss: true
22
+ sos: 1
23
+ eos: 2
24
+ ignore_id: -1
25
+
26
+ tokenizer: SentencepiecesTokenizer
27
+ tokenizer_conf:
28
+ bpemodel: null
29
+ unk_symbol: <unk>
30
+ split_with_space: true
31
+
32
+ frontend: WavFrontend
33
+ frontend_conf:
34
+ fs: 16000
35
+ window: hamming
36
+ n_mels: 80
37
+ frame_length: 25
38
+ frame_shift: 10
39
+ lfr_m: 7
40
+ lfr_n: 6
41
+ cmvn_file: null
42
+
43
+
44
+ dataset: SenseVoiceCTCDataset
45
+ dataset_conf:
46
+ index_ds: IndexDSJsonl
47
+ batch_sampler: EspnetStyleBatchSampler
48
+ data_split_num: 32
49
+ batch_type: token
50
+ batch_size: 14000
51
+ max_token_length: 2000
52
+ min_token_length: 60
53
+ max_source_length: 2000
54
+ min_source_length: 60
55
+ max_target_length: 200
56
+ min_target_length: 0
57
+ shuffle: true
58
+ num_workers: 4
59
+ sos: ${model_conf.sos}
60
+ eos: ${model_conf.eos}
61
+ IndexDSJsonl: IndexDSJsonl
62
+ retry: 20
63
+
64
+ train_conf:
65
+ accum_grad: 1
66
+ grad_clip: 5
67
+ max_epoch: 20
68
+ keep_nbest_models: 10
69
+ avg_nbest_model: 10
70
+ log_interval: 100
71
+ resume: true
72
+ validate_interval: 10000
73
+ save_checkpoint_interval: 10000
74
+
75
+ optim: adamw
76
+ optim_conf:
77
+ lr: 0.00002
78
+ scheduler: warmuplr
79
+ scheduler_conf:
80
+ warmup_steps: 25000
81
+
82
+ specaug: SpecAugLFR
83
+ specaug_conf:
84
+ apply_time_warp: false
85
+ time_warp_window: 5
86
+ time_warp_mode: bicubic
87
+ apply_freq_mask: true
88
+ freq_mask_width_range:
89
+ - 0
90
+ - 30
91
+ lfr_rate: 6
92
+ num_freq_mask: 1
93
+ apply_time_mask: true
94
+ time_mask_width_range:
95
+ - 0
96
+ - 12
97
+ num_time_mask: 1
ax_model/vad/am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 400 400
3
+ [ 0 ]
4
+ <AddShift> 400 400
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 400 400
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
ax_model/vad/config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ frontend: WavFrontendOnline
2
+ frontend_conf:
3
+ fs: 16000
4
+ window: hamming
5
+ n_mels: 80
6
+ frame_length: 25
7
+ frame_shift: 10
8
+ dither: 0.0
9
+ lfr_m: 5
10
+ lfr_n: 1
11
+
12
+ model: FsmnVADStreaming
13
+ model_conf:
14
+ sample_rate: 16000
15
+ detect_mode: 1
16
+ snr_mode: 0
17
+ max_end_silence_time: 800
18
+ max_start_silence_time: 3000
19
+ do_start_point_detection: True
20
+ do_end_point_detection: True
21
+ window_size_ms: 200
22
+ sil_to_speech_time_thres: 150
23
+ speech_to_sil_time_thres: 150
24
+ speech_2_noise_ratio: 1.0
25
+ do_extend: 1
26
+ lookback_time_start_point: 200
27
+ lookahead_time_end_point: 100
28
+ max_single_segment_time: 60000
29
+ snr_thres: -100.0
30
+ noise_frame_num_used_for_snr: 100
31
+ decibel_thres: -100.0
32
+ speech_noise_thres: 0.6
33
+ fe_prior_thres: 0.0001
34
+ silence_pdf_num: 1
35
+ sil_pdf_ids: [0]
36
+ speech_noise_thresh_low: -0.1
37
+ speech_noise_thresh_high: 0.3
38
+ output_frame_probs: False
39
+ frame_in_ms: 10
40
+ frame_length_ms: 25
41
+
42
+ encoder: FSMN
43
+ encoder_conf:
44
+ input_dim: 400
45
+ input_affine_dim: 140
46
+ fsmn_layers: 4
47
+ linear_dim: 250
48
+ proj_dim: 128
49
+ lorder: 20
50
+ rorder: 0
51
+ lstride: 1
52
+ rstride: 0
53
+ output_affine_dim: 140
54
+ output_dim: 248
55
+
56
+
ax_model/withitn.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39bf02586f59237894fc2918ab2db4f12ec3c084c41465718832fbd7646ea729
3
+ size 2368
ax_speech_translate_demo.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import tempfile
3
+ import os
4
+ import json
5
+ import shutil
6
+ import time
7
+ import librosa
8
+ import torch
9
+ import argparse
10
+ import soundfile as sf
11
+ from pathlib import Path
12
+ import cn2an
13
+
14
+ # 导入SenseVoice相关模块
15
+ from model import SinusoidalPositionEncoder
16
+ from utils.ax_model_bin import AX_SenseVoiceSmall
17
+ from utils.ax_vad_bin import AX_Fsmn_vad
18
+ from utils.vad_utils import merge_vad
19
+ from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
20
+
21
+ # 配置参数
22
+ # translate 参数
23
+ TRANSLATE_EXECUTABLE = "libtranslate/test_translate"
24
+ TRANSLATE_MODEL = "libtranslate/opus-mt-en-zh.axmodel"
25
+ TRANSLATE_TOKENIZER_DIR = "libtranslate/opus-mt-en-zh/"
26
+
27
+ # tts 参数
28
+ TTS_EXECUTABLE = "libmelotts/install/melotts"
29
+ TTS_MODEL_DIR = "libmelotts/models"
30
+ TTS_MODEL_FILES = {
31
+ "g": "g-zh_mix_en.bin",
32
+ "encoder": "encoder-zh.onnx",
33
+ "lexicon": "lexicon.txt",
34
+ "tokens": "tokens.txt",
35
+ "decoder": "decoder-zh.axmodel"
36
+ }
37
+
38
+ class SpeechTranslationPipeline:
39
+ def __init__(self,
40
+ translate_exec, translate_model, translate_tokenizer,
41
+ tts_exec, tts_model_dir, tts_model_files,
42
+ asr_model_dir="ax_model", seq_len=132):
43
+ self.translate_exec = translate_exec
44
+ self.translate_model = translate_model
45
+ self.translate_tokenizer = translate_tokenizer
46
+ self.tts_exec = tts_exec
47
+ self.tts_model_dir = tts_model_dir
48
+ self.tts_model_files = tts_model_files
49
+ self.asr_model_dir = asr_model_dir
50
+ self.seq_len = seq_len
51
+
52
+ # 初始化ASR模型
53
+ self._init_asr_models()
54
+
55
+ # 验证所有必需文件存在
56
+ self._validate_files()
57
+
58
+ def _init_asr_models(self):
59
+ """初始化语音识别相关模型"""
60
+ print("Initializing SenseVoice models...")
61
+
62
+ # VAD模型
63
+ self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
64
+
65
+ # 位置编码
66
+ self.embed = SinusoidalPositionEncoder()
67
+ self.position_encoding = self.embed.get_position_encoding(
68
+ torch.randn(1, self.seq_len, 560)).numpy()
69
+
70
+ # ASR模型
71
+ self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
72
+
73
+ # Tokenizer
74
+ tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
75
+ self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
76
+
77
+ print("SenseVoice models initialized successfully.")
78
+
79
+ def _validate_files(self):
80
+ """验证所有必需的文件都存在"""
81
+ # 检查翻译相关文件
82
+ if not os.path.exists(self.translate_exec):
83
+ raise FileNotFoundError(f"翻译可执行文件不存在: {self.translate_exec}")
84
+ if not os.path.exists(self.translate_model):
85
+ raise FileNotFoundError(f"翻译模型不存在: {self.translate_model}")
86
+ if not os.path.exists(self.translate_tokenizer):
87
+ raise FileNotFoundError(f"翻译tokenizer目录不存在: {self.translate_tokenizer}")
88
+
89
+ # 检查TTS相关文件
90
+ if not os.path.exists(self.tts_exec):
91
+ raise FileNotFoundError(f"TTS可执行文件不存在: {self.tts_exec}")
92
+
93
+ for key, filename in self.tts_model_files.items():
94
+ filepath = os.path.join(self.tts_model_dir, filename)
95
+ if not os.path.exists(filepath):
96
+ raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
97
+
98
+ def speech_recognition(self, speech, fs):
99
+ """
100
+ 第一步:语音识别(ASR)
101
+ """
102
+ speech_lengths = len(speech)
103
+
104
+ # VAD处理
105
+ print("Running VAD...")
106
+ vad_start_time = time.time()
107
+ res_vad = self.model_vad(speech)[0]
108
+ vad_segments = merge_vad(res_vad, 15 * 1000)
109
+ vad_time_cost = time.time() - vad_start_time
110
+ print(f"VAD processing time: {vad_time_cost:.2f} seconds")
111
+ print(f"VAD segments detected: {len(vad_segments)}")
112
+
113
+ # ASR处理
114
+ print("Running ASR...")
115
+ asr_start_time = time.time()
116
+ all_results = ""
117
+
118
+ # 遍历每个VAD片段并处理
119
+ for i, segment in enumerate(vad_segments):
120
+ segment_start, segment_end = segment
121
+ start_sample = int(segment_start / 1000 * fs)
122
+ end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
123
+ segment_speech = speech[start_sample:end_sample]
124
+
125
+ # 为当前片段创建临时文件
126
+ segment_filename = f"temp_segment_{i}.wav"
127
+ sf.write(segment_filename, segment_speech, fs)
128
+
129
+ # 对当前片段进行识别
130
+ try:
131
+ segment_res = self.model_bin(
132
+ segment_filename,
133
+ "auto", # 语言自动检测
134
+ True, # withitn
135
+ self.position_encoding,
136
+ tokenizer=self.tokenizer,
137
+ )
138
+
139
+ all_results += segment_res
140
+
141
+ # 清理临时文件
142
+ if os.path.exists(segment_filename):
143
+ os.remove(segment_filename)
144
+
145
+ except Exception as e:
146
+ if os.path.exists(segment_filename):
147
+ os.remove(segment_filename)
148
+ print(f"Error processing segment {i}: {e}")
149
+ continue
150
+
151
+ asr_time_cost = time.time() - asr_start_time
152
+ print(f"ASR processing time: {asr_time_cost:.2f} seconds")
153
+ print(f"ASR Result: {all_results}")
154
+
155
+ return all_results.strip()
156
+
157
+ def run_translation(self, english_text):
158
+ """
159
+ 第二步:调用翻译程序将英文翻译成中文
160
+ """
161
+ # 构建命令参数
162
+ cmd = [
163
+ self.translate_exec,
164
+ "--model", self.translate_model,
165
+ "--tokenizer_dir", self.translate_tokenizer,
166
+ "--text", f'"{english_text}"' # 添加引号处理包含空格和特殊字符的文本
167
+ ]
168
+
169
+ try:
170
+ # 执行命令
171
+ result = subprocess.run(
172
+ cmd,
173
+ capture_output=True,
174
+ text=True,
175
+ timeout=30 # 设置超时时间,单位秒
176
+ )
177
+
178
+ # 检查执行结果
179
+ if result.returncode != 0:
180
+ error_msg = f"翻译程序执行失败: {result.stderr}"
181
+ raise RuntimeError(error_msg)
182
+
183
+ # 提取翻译结果
184
+ chinese_text = result.stdout.strip()
185
+
186
+ # 清理可能的额外输出
187
+ # if "翻译结果:" in chinese_text:
188
+ # chinese_text = chinese_text.split("翻译结果:", 1)[-1].strip()
189
+ chinese_text = chinese_text.split("output: ")[-1].split("\nAX_ENGINE_Deinit")[0]
190
+
191
+ print(f"翻译结果: {chinese_text}")
192
+ return chinese_text
193
+
194
+ except subprocess.TimeoutExpired:
195
+ raise RuntimeError("翻译程序执行超时")
196
+ except Exception as e:
197
+ raise e
198
+
199
+ def run_tts(self, chinese_text, output_dir, output_wav=None):
200
+ """
201
+ 第三步:调用TTS程序合成中文语音
202
+ """
203
+ output_path = os.path.join(output_dir, output_wav)
204
+ #chinese_text = chinese_text.split("output: ")[-1].split("\nAX_ENGINE_Deinit")[0]
205
+
206
+ chinese_text = cn2an.transform(chinese_text, "an2cn")
207
+
208
+ # 构建命令参数
209
+ cmd = [
210
+ self.tts_exec,
211
+ "--g", os.path.join(self.tts_model_dir, self.tts_model_files["g"]),
212
+ "-e", os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]),
213
+ "-l", os.path.join(self.tts_model_dir, self.tts_model_files["lexicon"]),
214
+ "-t", os.path.join(self.tts_model_dir, self.tts_model_files["tokens"]),
215
+ "-d", os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]),
216
+ "-w", output_path,
217
+ "-s", f'"{chinese_text}"'
218
+ ]
219
+
220
+ try:
221
+ # 执行命令
222
+ result = subprocess.run(
223
+ cmd,
224
+ capture_output=False,
225
+ text=True,
226
+ timeout=60 # TTS可能需要更长时间
227
+ )
228
+
229
+ # 检查执行结果
230
+ if result.returncode != 0:
231
+ error_msg = f"TTS程序执行失败: {result.stderr}"
232
+ raise RuntimeError(error_msg)
233
+
234
+ # 验证输出文件是否存在
235
+ if not os.path.exists(output_path):
236
+ raise FileNotFoundError(f"输出文件未生成: {output_path}")
237
+
238
+ return output_path
239
+
240
+ except subprocess.TimeoutExpired:
241
+ raise RuntimeError("TTS程序执行超时")
242
+ except Exception as e:
243
+ # 清理临时文件
244
+ if output_path and os.path.exists(os.path.dirname(output_path)):
245
+ shutil.rmtree(os.path.dirname(output_path))
246
+ raise e
247
+
248
+ def full_pipeline(self, speech, fs, output_dir=None,output_tts = None):
249
+ """
250
+ 完整Pipeline:语音识别 -> 翻译 -> TTS合成
251
+ """
252
+
253
+ # 第一步:语音识别
254
+ print("\n----------------------VAD+ASR----------------------------\n")
255
+ start_time = time.time() # 记录开始时间
256
+ english_text = self.speech_recognition(speech, fs)
257
+ asr_time = time.time() - start_time # 计算耗时
258
+ print(f"语音识别耗时: {asr_time:.2f} 秒")
259
+
260
+ # 第二步:翻译
261
+ print("\n---------------------translate---------------------------\n")
262
+ start_time = time.time() # 记录开始时间
263
+ chinese_text = self.run_translation(english_text)
264
+ translate_time = time.time() - start_time # 计算耗时
265
+ print(f"翻译耗时: {translate_time:.2f} 秒")
266
+
267
+ # 第三步:TTS合成
268
+ print("-------------------------TTS-------------------------------\n")
269
+ start_time = time.time() # 记录开始时间
270
+ output_path = self.run_tts(chinese_text, output_dir, output_tts)
271
+ tts_time = time.time() - start_time # 计算耗时
272
+ print(f"TTS合成耗时: {tts_time:.2f} 秒")
273
+
274
+ return {
275
+ "original_text": english_text,
276
+ "translated_text": chinese_text,
277
+ "audio_path": output_path
278
+ }
279
+
280
+ def main():
281
+ parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline")
282
+ parser.add_argument("--audio_file", type=str, required=True, help="Input audio file path")
283
+ parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
284
+ parser.add_argument("--output_tts", type=str, default="output.wav", help="Output directory")
285
+
286
+ args = parser.parse_args()
287
+ print("-------------------START------------------------\n")
288
+ os.makedirs(args.output_dir ,exist_ok=True)
289
+
290
+ print(f"Processing audio file: {args.audio_file}")
291
+ # 加载音频
292
+ speech, fs = librosa.load(args.audio_file, sr=None)
293
+ if fs != 16000:
294
+ print(f"Resampling audio from {fs}Hz to 16000Hz")
295
+ speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
296
+ fs = 16000
297
+ audio_duration = librosa.get_duration(y=speech, sr=fs)
298
+
299
+
300
+ # 初始化
301
+ pipeline = SpeechTranslationPipeline(
302
+ translate_exec=TRANSLATE_EXECUTABLE,
303
+ translate_model=TRANSLATE_MODEL,
304
+ translate_tokenizer=TRANSLATE_TOKENIZER_DIR,
305
+ tts_exec=TTS_EXECUTABLE,
306
+ tts_model_dir=TTS_MODEL_DIR,
307
+ tts_model_files=TTS_MODEL_FILES,
308
+ asr_model_dir="ax_model",
309
+ seq_len=132
310
+ )
311
+
312
+ start_time = time.time()
313
+ try:
314
+ # 运行
315
+ result = pipeline.full_pipeline(speech, fs, args.output_dir, args.output_tts)
316
+
317
+ print("\n" + "="*50)
318
+ print("speech translate 完成!")
319
+ print("="*50 + "\n")
320
+ print(f"原始音频: {args.audio_file}")
321
+ print(f"原始文本: {result['original_text']}")
322
+ print(f"翻译文本: {result['translated_text']}")
323
+ print(f"生成音频: {result['audio_path']}")
324
+
325
+ # 保存结果到文件
326
+ result_file = os.path.join(args.output_dir, "pipeline_result.txt")
327
+ with open(result_file, 'w', encoding='utf-8') as f:
328
+ f.write(f"原始音频: {args.audio_file}\n")
329
+ f.write(f"识别文本: {result['original_text']}\n")
330
+ f.write(f"翻译结果: {result['translated_text']}\n")
331
+ f.write(f"合成音频: {result['audio_path']}\n")
332
+
333
+ # print(f"\n详细结果已保存到: {result_file}")
334
+ time_cost = time.time() - start_time
335
+ rtf = time_cost / audio_duration
336
+ print(f"Inference time for {args.audio_file}: {time_cost:.2f} seconds")
337
+ print(f"Audio duration: {audio_duration:.2f} seconds")
338
+ print(f"RTF: {rtf:.2f}\n")
339
+ except Exception as e:
340
+ print(f"Pipeline执行失败: {e}")
341
+ import traceback
342
+ traceback.print_exc()
343
+
344
+ if __name__ == "__main__":
345
+ main()
346
+
347
+
libmelotts/install/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b062d058b15bbf87bc3232b9c905a2bd49ea7e88e203a308df5c08c4c15129
3
+ size 16014680
libmelotts/install/libonnxruntime.so.1.14.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b062d058b15bbf87bc3232b9c905a2bd49ea7e88e203a308df5c08c4c15129
3
+ size 16014680
libmelotts/install/libonnxruntime_providers_shared.so ADDED
Binary file (9.92 kB). View file
 
libmelotts/install/melotts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88d12e4b441d2398ea7646afb677f2dd5819f347caa680781ce2bcfa26e81d85
3
+ size 187256
libmelotts/models/decoder-en.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90c93c0fa978cc1c68fbac6a78707dd75b8b9069cb01a1ade6846e2435aa1eb1
3
+ size 44093802
libmelotts/models/decoder-zh.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ea2d8401f18dd371eec50b90bd39dcadf9684aaf3543dace8ce1a9499ef253
3
+ size 44092592
libmelotts/models/encoder-en.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cc51185fb81934c7490c5f9ac993fff7efa98ab41c08cd3753c96abcb297582
3
+ size 31488385
libmelotts/models/encoder-zh.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2b0a5bc2789faef16b4bfc56ab4905364f8163a59f2db3d071b4a14792bfee5
3
+ size 31397760
libmelotts/models/g-en.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:094bf0dbe1cd6c9408707209b2b7261b9df2cd5917d310bfac5945a15a31821a
3
+ size 1024
libmelotts/models/g-jp.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c01dd0961bbe1effca4ed378d2969d6fbd9b579133b722f6968db5cf4d22281e
3
+ size 1024
libmelotts/models/g-zh_mix_en.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c70d897674847882bd35e780aee696ddaff8d04d5c57e4f9cf37611b6821879f
3
+ size 1024
libmelotts/models/lexicon.txt ADDED
The diff for this file is too large to render. See raw diff
 
libmelotts/models/tokens.txt ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _ 0
2
+ AA 1
3
+ E 2
4
+ EE 3
5
+ En 4
6
+ N 5
7
+ OO 6
8
+ V 7
9
+ a 8
10
+ a: 9
11
+ aa 10
12
+ ae 11
13
+ ah 12
14
+ ai 13
15
+ an 14
16
+ ang 15
17
+ ao 16
18
+ aw 17
19
+ ay 18
20
+ b 19
21
+ by 20
22
+ c 21
23
+ ch 22
24
+ d 23
25
+ dh 24
26
+ dy 25
27
+ e 26
28
+ e: 27
29
+ eh 28
30
+ ei 29
31
+ en 30
32
+ eng 31
33
+ er 32
34
+ ey 33
35
+ f 34
36
+ g 35
37
+ gy 36
38
+ h 37
39
+ hh 38
40
+ hy 39
41
+ i 40
42
+ i0 41
43
+ i: 42
44
+ ia 43
45
+ ian 44
46
+ iang 45
47
+ iao 46
48
+ ie 47
49
+ ih 48
50
+ in 49
51
+ ing 50
52
+ iong 51
53
+ ir 52
54
+ iu 53
55
+ iy 54
56
+ j 55
57
+ jh 56
58
+ k 57
59
+ ky 58
60
+ l 59
61
+ m 60
62
+ my 61
63
+ n 62
64
+ ng 63
65
+ ny 64
66
+ o 65
67
+ o: 66
68
+ ong 67
69
+ ou 68
70
+ ow 69
71
+ oy 70
72
+ p 71
73
+ py 72
74
+ q 73
75
+ r 74
76
+ ry 75
77
+ s 76
78
+ sh 77
79
+ t 78
80
+ th 79
81
+ ts 80
82
+ ty 81
83
+ u 82
84
+ u: 83
85
+ ua 84
86
+ uai 85
87
+ uan 86
88
+ uang 87
89
+ uh 88
90
+ ui 89
91
+ un 90
92
+ uo 91
93
+ uw 92
94
+ v 93
95
+ van 94
96
+ ve 95
97
+ vn 96
98
+ w 97
99
+ x 98
100
+ y 99
101
+ z 100
102
+ zh 101
103
+ zy 102
104
+ ! 103
105
+ ? 104
106
+ … 105
107
+ , 106
108
+ . 107
109
+ ' 108
110
+ - 109
111
+ SP 110
112
+ UNK 111
libtranslate/libax_translate.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d6c5901387ed8d0a3c0611a5ad8d60281487e34d20afc83a58923ea2462b456
3
+ size 1222544
libtranslate/libsentencepiece.so.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e93872d229a074c73e66eeeff8988197deb35b81b09ee893e022e115f886ed4
3
+ size 1320848
libtranslate/opus-mt-en-zh.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e262aab07f3a5478a0a0df24a92e1f88c138c4999be5b4cb18b618d996bcacff
3
+ size 217562368
libtranslate/opus-mt-en-zh/.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
libtranslate/opus-mt-en-zh/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ tags:
6
+ - translation
7
+ license: apache-2.0
8
+ ---
9
+
10
+ ### eng-zho
11
+
12
+ * source group: English
13
+ * target group: Chinese
14
+ * OPUS readme: [eng-zho](https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-zho/README.md)
15
+
16
+ * model: transformer
17
+ * source language(s): eng
18
+ * target language(s): cjy_Hans cjy_Hant cmn cmn_Hans cmn_Hant gan lzh lzh_Hans nan wuu yue yue_Hans yue_Hant
19
+ * model: transformer
20
+ * pre-processing: normalization + SentencePiece (spm32k,spm32k)
21
+ * a sentence initial language token is required in the form of `>>id<<` (id = valid target language ID)
22
+ * download original weights: [opus-2020-07-17.zip](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.zip)
23
+ * test set translations: [opus-2020-07-17.test.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.test.txt)
24
+ * test set scores: [opus-2020-07-17.eval.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.eval.txt)
25
+
26
+ ## Benchmarks
27
+
28
+ | testset | BLEU | chr-F |
29
+ |-----------------------|-------|-------|
30
+ | Tatoeba-test.eng.zho | 31.4 | 0.268 |
31
+
32
+
33
+ ### System Info:
34
+ - hf_name: eng-zho
35
+
36
+ - source_languages: eng
37
+
38
+ - target_languages: zho
39
+
40
+ - opus_readme_url: https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-zho/README.md
41
+
42
+ - original_repo: Tatoeba-Challenge
43
+
44
+ - tags: ['translation']
45
+
46
+ - languages: ['en', 'zh']
47
+
48
+ - src_constituents: {'eng'}
49
+
50
+ - tgt_constituents: {'cmn_Hans', 'nan', 'nan_Hani', 'gan', 'yue', 'cmn_Kana', 'yue_Hani', 'wuu_Bopo', 'cmn_Latn', 'yue_Hira', 'cmn_Hani', 'cjy_Hans', 'cmn', 'lzh_Hang', 'lzh_Hira', 'cmn_Hant', 'lzh_Bopo', 'zho', 'zho_Hans', 'zho_Hant', 'lzh_Hani', 'yue_Hang', 'wuu', 'yue_Kana', 'wuu_Latn', 'yue_Bopo', 'cjy_Hant', 'yue_Hans', 'lzh', 'cmn_Hira', 'lzh_Yiii', 'lzh_Hans', 'cmn_Bopo', 'cmn_Hang', 'hak_Hani', 'cmn_Yiii', 'yue_Hant', 'lzh_Kana', 'wuu_Hani'}
51
+
52
+ - src_multilingual: False
53
+
54
+ - tgt_multilingual: False
55
+
56
+ - prepro: normalization + SentencePiece (spm32k,spm32k)
57
+
58
+ - url_model: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.zip
59
+
60
+ - url_test_set: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.test.txt
61
+
62
+ - src_alpha3: eng
63
+
64
+ - tgt_alpha3: zho
65
+
66
+ - short_pair: en-zh
67
+
68
+ - chrF2_score: 0.268
69
+
70
+ - bleu: 31.4
71
+
72
+ - brevity_penalty: 0.8959999999999999
73
+
74
+ - ref_len: 110468.0
75
+
76
+ - src_name: English
77
+
78
+ - tgt_name: Chinese
79
+
80
+ - train_date: 2020-07-17
81
+
82
+ - src_alpha2: en
83
+
84
+ - tgt_alpha2: zh
85
+
86
+ - prefer_old: False
87
+
88
+ - long_pair: eng-zho
89
+
90
+ - helsinki_git_sha: 480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535
91
+
92
+ - transformers_git_sha: 2207e5d8cb224e954a7cba69fa4ac2309e9ff30b
93
+
94
+ - port_machine: brutasse
95
+
96
+ - port_time: 2020-08-21-14:41
libtranslate/opus-mt-en-zh/config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "swish",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "MarianMTModel"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "bad_words_ids": [
12
+ [
13
+ 65000
14
+ ]
15
+ ],
16
+ "bos_token_id": 0,
17
+ "classif_dropout": 0.0,
18
+ "classifier_dropout": 0.0,
19
+ "d_model": 512,
20
+ "decoder_attention_heads": 8,
21
+ "decoder_ffn_dim": 2048,
22
+ "decoder_layerdrop": 0.0,
23
+ "decoder_layers": 6,
24
+ "decoder_start_token_id": 65000,
25
+ "do_blenderbot_90_layernorm": false,
26
+ "dropout": 0.1,
27
+ "encoder_attention_heads": 8,
28
+ "encoder_ffn_dim": 2048,
29
+ "encoder_layerdrop": 0.0,
30
+ "encoder_layers": 6,
31
+ "eos_token_id": 0,
32
+ "extra_pos_embeddings": 0,
33
+ "force_bos_token_to_be_generated": false,
34
+ "forced_eos_token_id": 0,
35
+ "gradient_checkpointing": false,
36
+ "id2label": {
37
+ "0": "LABEL_0",
38
+ "1": "LABEL_1",
39
+ "2": "LABEL_2"
40
+ },
41
+ "init_std": 0.02,
42
+ "is_encoder_decoder": true,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1,
46
+ "LABEL_2": 2
47
+ },
48
+ "max_length": 512,
49
+ "max_position_embeddings": 512,
50
+ "model_type": "marian",
51
+ "normalize_before": false,
52
+ "normalize_embedding": false,
53
+ "num_beams": 4,
54
+ "num_hidden_layers": 6,
55
+ "pad_token_id": 65000,
56
+ "scale_embedding": true,
57
+ "static_position_embeddings": true,
58
+ "transformers_version": "4.9.0.dev0",
59
+ "use_cache": true,
60
+ "vocab_size": 65001
61
+ }
libtranslate/opus-mt-en-zh/generation_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bad_words_ids": [
3
+ [
4
+ 65000
5
+ ]
6
+ ],
7
+ "bos_token_id": 0,
8
+ "decoder_start_token_id": 65000,
9
+ "eos_token_id": 0,
10
+ "forced_eos_token_id": 0,
11
+ "max_length": 512,
12
+ "num_beams": 4,
13
+ "pad_token_id": 65000,
14
+ "renormalize_logits": true,
15
+ "transformers_version": "4.32.0.dev0"
16
+ }
libtranslate/opus-mt-en-zh/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"hf_name":"eng-zho","source_languages":"eng","target_languages":"zho","opus_readme_url":"https:\/\/github.com\/Helsinki-NLP\/Tatoeba-Challenge\/tree\/master\/models\/eng-zho\/README.md","original_repo":"Tatoeba-Challenge","tags":["translation"],"languages":["en","zh"],"src_constituents":["eng"],"tgt_constituents":["cmn_Hans","nan","nan_Hani","gan","yue","cmn_Kana","yue_Hani","wuu_Bopo","cmn_Latn","yue_Hira","cmn_Hani","cjy_Hans","cmn","lzh_Hang","lzh_Hira","cmn_Hant","lzh_Bopo","zho","zho_Hans","zho_Hant","lzh_Hani","yue_Hang","wuu","yue_Kana","wuu_Latn","yue_Bopo","cjy_Hant","yue_Hans","lzh","cmn_Hira","lzh_Yiii","lzh_Hans","cmn_Bopo","cmn_Hang","hak_Hani","cmn_Yiii","yue_Hant","lzh_Kana","wuu_Hani"],"src_multilingual":false,"tgt_multilingual":false,"prepro":" normalization + SentencePiece (spm32k,spm32k)","url_model":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/eng-zho\/opus-2020-07-17.zip","url_test_set":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/eng-zho\/opus-2020-07-17.test.txt","src_alpha3":"eng","tgt_alpha3":"zho","short_pair":"en-zh","chrF2_score":0.268,"bleu":31.4,"brevity_penalty":0.896,"ref_len":110468.0,"src_name":"English","tgt_name":"Chinese","train_date":"2020-07-17","src_alpha2":"en","tgt_alpha2":"zh","prefer_old":false,"long_pair":"eng-zho","helsinki_git_sha":"480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535","transformers_git_sha":"2207e5d8cb224e954a7cba69fa4ac2309e9ff30b","port_machine":"brutasse","port_time":"2020-08-21-14:41"}
libtranslate/opus-mt-en-zh/source.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5775ddc9e3ff2fae91554da56468ad35ff56edaba870fea74447bc7234bfdaa8
3
+ size 806435
libtranslate/opus-mt-en-zh/target.spm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81dc94efa84e4025ef38d25d5d07429fe41e3eb29d44003f1db6fe98487b0052
3
+ size 804600
libtranslate/opus-mt-en-zh/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"target_lang": "zho", "source_lang": "eng"}
libtranslate/opus-mt-en-zh/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
libtranslate/test_translate ADDED
Binary file (82.1 kB). View file
 
model.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from typing import Iterable, Optional
7
+
8
+ from funasr.register import tables
9
+ from funasr.models.ctc.ctc import CTC
10
+ from funasr.utils.datadir_writer import DatadirWriter
11
+ from funasr.models.paraformer.search import Hypothesis
12
+ from funasr.train_utils.device_funcs import force_gatherable
13
+ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
14
+ from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
15
+ from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
16
+ from utils.ctc_alignment import ctc_forced_align
17
+
18
+ class SinusoidalPositionEncoder(torch.nn.Module):
19
+ """ """
20
+
21
+ def __int__(self, d_model=80, dropout_rate=0.1):
22
+ pass
23
+
24
+ def encode(
25
+ self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
26
+ ):
27
+ batch_size = positions.size(0)
28
+ positions = positions.type(dtype)
29
+ device = positions.device
30
+ log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
31
+ depth / 2 - 1
32
+ )
33
+ inv_timescales = torch.exp(
34
+ torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
35
+ )
36
+ inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
37
+ scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
38
+ inv_timescales, [1, 1, -1]
39
+ )
40
+ encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
41
+ return encoding.type(dtype)
42
+
43
+ def forward(self, x):
44
+ batch_size, timesteps, input_dim = x.size()
45
+ positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
46
+ position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
47
+
48
+ return x + position_encoding
49
+
50
+ def get_position_encoding(self, x):
51
+ batch_size, timesteps, input_dim = x.size()
52
+ positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
53
+ position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
54
+ return position_encoding
55
+
56
+
57
+ class PositionwiseFeedForward(torch.nn.Module):
58
+ """Positionwise feed forward layer.
59
+
60
+ Args:
61
+ idim (int): Input dimenstion.
62
+ hidden_units (int): The number of hidden units.
63
+ dropout_rate (float): Dropout rate.
64
+
65
+ """
66
+
67
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
68
+ """Construct an PositionwiseFeedForward object."""
69
+ super(PositionwiseFeedForward, self).__init__()
70
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
71
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
72
+ self.dropout = torch.nn.Dropout(dropout_rate)
73
+ self.activation = activation
74
+
75
+ def forward(self, x):
76
+ """Forward function."""
77
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
78
+
79
+
80
+ class MultiHeadedAttentionSANM(nn.Module):
81
+ """Multi-Head Attention layer.
82
+
83
+ Args:
84
+ n_head (int): The number of heads.
85
+ n_feat (int): The number of features.
86
+ dropout_rate (float): Dropout rate.
87
+
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ n_head,
93
+ in_feat,
94
+ n_feat,
95
+ dropout_rate,
96
+ kernel_size,
97
+ sanm_shfit=0,
98
+ lora_list=None,
99
+ lora_rank=8,
100
+ lora_alpha=16,
101
+ lora_dropout=0.1,
102
+ ):
103
+ """Construct an MultiHeadedAttention object."""
104
+ super().__init__()
105
+ assert n_feat % n_head == 0
106
+ # We assume d_v always equals d_k
107
+ self.d_k = n_feat // n_head
108
+ self.h = n_head
109
+ # self.linear_q = nn.Linear(n_feat, n_feat)
110
+ # self.linear_k = nn.Linear(n_feat, n_feat)
111
+ # self.linear_v = nn.Linear(n_feat, n_feat)
112
+
113
+ self.linear_out = nn.Linear(n_feat, n_feat)
114
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
115
+ self.attn = None
116
+ self.dropout = nn.Dropout(p=dropout_rate)
117
+
118
+ self.fsmn_block = nn.Conv1d(
119
+ n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
120
+ )
121
+ # padding
122
+ left_padding = (kernel_size - 1) // 2
123
+ if sanm_shfit > 0:
124
+ left_padding = left_padding + sanm_shfit
125
+ right_padding = kernel_size - 1 - left_padding
126
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
127
+
128
+ def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
129
+ b, t, d = inputs.size()
130
+ if mask is not None:
131
+ mask = torch.reshape(mask, (b, -1, 1))
132
+ if mask_shfit_chunk is not None:
133
+ mask = mask * mask_shfit_chunk
134
+ inputs = inputs * mask
135
+
136
+ x = inputs.transpose(1, 2)
137
+ x = self.pad_fn(x)
138
+ x = self.fsmn_block(x)
139
+ x = x.transpose(1, 2)
140
+ x += inputs
141
+ x = self.dropout(x)
142
+ if mask is not None:
143
+ x = x * mask
144
+ return x
145
+
146
+ def forward_qkv(self, x):
147
+ """Transform query, key and value.
148
+
149
+ Args:
150
+ query (torch.Tensor): Query tensor (#batch, time1, size).
151
+ key (torch.Tensor): Key tensor (#batch, time2, size).
152
+ value (torch.Tensor): Value tensor (#batch, time2, size).
153
+
154
+ Returns:
155
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
156
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
157
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
158
+
159
+ """
160
+ b, t, d = x.size()
161
+ q_k_v = self.linear_q_k_v(x)
162
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
163
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
164
+ 1, 2
165
+ ) # (batch, head, time1, d_k)
166
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
167
+ 1, 2
168
+ ) # (batch, head, time2, d_k)
169
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
170
+ 1, 2
171
+ ) # (batch, head, time2, d_k)
172
+
173
+ return q_h, k_h, v_h, v
174
+
175
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
176
+ """Compute attention context vector.
177
+
178
+ Args:
179
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
180
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
181
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
182
+
183
+ Returns:
184
+ torch.Tensor: Transformed value (#batch, time1, d_model)
185
+ weighted by the attention score (#batch, time1, time2).
186
+
187
+ """
188
+ n_batch = value.size(0)
189
+ if mask is not None:
190
+ if mask_att_chunk_encoder is not None:
191
+ mask = mask * mask_att_chunk_encoder
192
+
193
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
194
+
195
+ min_value = -float(
196
+ "inf"
197
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
198
+ scores = scores.masked_fill(mask, min_value)
199
+ attn = torch.softmax(scores, dim=-1).masked_fill(
200
+ mask, 0.0
201
+ ) # (batch, head, time1, time2)
202
+ else:
203
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
204
+
205
+ p_attn = self.dropout(attn)
206
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
207
+ x = (
208
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
209
+ ) # (batch, time1, d_model)
210
+
211
+ return self.linear_out(x) # (batch, time1, d_model)
212
+
213
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
214
+ """Compute scaled dot product attention.
215
+
216
+ Args:
217
+ query (torch.Tensor): Query tensor (#batch, time1, size).
218
+ key (torch.Tensor): Key tensor (#batch, time2, size).
219
+ value (torch.Tensor): Value tensor (#batch, time2, size).
220
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
221
+ (#batch, time1, time2).
222
+
223
+ Returns:
224
+ torch.Tensor: Output tensor (#batch, time1, d_model).
225
+
226
+ """
227
+ q_h, k_h, v_h, v = self.forward_qkv(x)
228
+ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
229
+ q_h = q_h * self.d_k ** (-0.5)
230
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
231
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
232
+ return att_outs + fsmn_memory
233
+
234
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
235
+ """Compute scaled dot product attention.
236
+
237
+ Args:
238
+ query (torch.Tensor): Query tensor (#batch, time1, size).
239
+ key (torch.Tensor): Key tensor (#batch, time2, size).
240
+ value (torch.Tensor): Value tensor (#batch, time2, size).
241
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
242
+ (#batch, time1, time2).
243
+
244
+ Returns:
245
+ torch.Tensor: Output tensor (#batch, time1, d_model).
246
+
247
+ """
248
+ q_h, k_h, v_h, v = self.forward_qkv(x)
249
+ if chunk_size is not None and look_back > 0 or look_back == -1:
250
+ if cache is not None:
251
+ k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
252
+ v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
253
+ k_h = torch.cat((cache["k"], k_h), dim=2)
254
+ v_h = torch.cat((cache["v"], v_h), dim=2)
255
+
256
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
257
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
258
+ if look_back != -1:
259
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
260
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
261
+ else:
262
+ cache_tmp = {
263
+ "k": k_h[:, :, : -(chunk_size[2]), :],
264
+ "v": v_h[:, :, : -(chunk_size[2]), :],
265
+ }
266
+ cache = cache_tmp
267
+ fsmn_memory = self.forward_fsmn(v, None)
268
+ q_h = q_h * self.d_k ** (-0.5)
269
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
270
+ att_outs = self.forward_attention(v_h, scores, None)
271
+ return att_outs + fsmn_memory, cache
272
+
273
+
274
+ class LayerNorm(nn.LayerNorm):
275
+ def __init__(self, *args, **kwargs):
276
+ super().__init__(*args, **kwargs)
277
+
278
+ def forward(self, input):
279
+ output = F.layer_norm(
280
+ input.float(),
281
+ self.normalized_shape,
282
+ self.weight.float() if self.weight is not None else None,
283
+ self.bias.float() if self.bias is not None else None,
284
+ self.eps,
285
+ )
286
+ return output.type_as(input)
287
+
288
+
289
+ def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
290
+ if maxlen is None:
291
+ maxlen = lengths.max()
292
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
293
+ matrix = torch.unsqueeze(lengths, dim=-1)
294
+ mask = row_vector < matrix
295
+ mask = mask.detach()
296
+
297
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
298
+
299
+
300
+ class EncoderLayerSANM(nn.Module):
301
+ def __init__(
302
+ self,
303
+ in_size,
304
+ size,
305
+ self_attn,
306
+ feed_forward,
307
+ dropout_rate,
308
+ normalize_before=True,
309
+ concat_after=False,
310
+ stochastic_depth_rate=0.0,
311
+ ):
312
+ """Construct an EncoderLayer object."""
313
+ super(EncoderLayerSANM, self).__init__()
314
+ self.self_attn = self_attn
315
+ self.feed_forward = feed_forward
316
+ self.norm1 = LayerNorm(in_size)
317
+ self.norm2 = LayerNorm(size)
318
+ self.dropout = nn.Dropout(dropout_rate)
319
+ self.in_size = in_size
320
+ self.size = size
321
+ self.normalize_before = normalize_before
322
+ self.concat_after = concat_after
323
+ if self.concat_after:
324
+ self.concat_linear = nn.Linear(size + size, size)
325
+ self.stochastic_depth_rate = stochastic_depth_rate
326
+ self.dropout_rate = dropout_rate
327
+
328
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
329
+ """Compute encoded features.
330
+
331
+ Args:
332
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
333
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
334
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
335
+
336
+ Returns:
337
+ torch.Tensor: Output tensor (#batch, time, size).
338
+ torch.Tensor: Mask tensor (#batch, time).
339
+
340
+ """
341
+ skip_layer = False
342
+ # with stochastic depth, residual connection `x + f(x)` becomes
343
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
344
+ stoch_layer_coeff = 1.0
345
+ if self.training and self.stochastic_depth_rate > 0:
346
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
347
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
348
+
349
+ if skip_layer:
350
+ if cache is not None:
351
+ x = torch.cat([cache, x], dim=1)
352
+ return x, mask
353
+
354
+ residual = x
355
+ if self.normalize_before:
356
+ x = self.norm1(x)
357
+
358
+ if self.concat_after:
359
+ x_concat = torch.cat(
360
+ (
361
+ x,
362
+ self.self_attn(
363
+ x,
364
+ mask,
365
+ mask_shfit_chunk=mask_shfit_chunk,
366
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
367
+ ),
368
+ ),
369
+ dim=-1,
370
+ )
371
+ if self.in_size == self.size:
372
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
373
+ else:
374
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
375
+ else:
376
+ if self.in_size == self.size:
377
+ x = residual + stoch_layer_coeff * self.dropout(
378
+ self.self_attn(
379
+ x,
380
+ mask,
381
+ mask_shfit_chunk=mask_shfit_chunk,
382
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
383
+ )
384
+ )
385
+ else:
386
+ x = stoch_layer_coeff * self.dropout(
387
+ self.self_attn(
388
+ x,
389
+ mask,
390
+ mask_shfit_chunk=mask_shfit_chunk,
391
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
392
+ )
393
+ )
394
+ if not self.normalize_before:
395
+ x = self.norm1(x)
396
+
397
+ residual = x
398
+ if self.normalize_before:
399
+ x = self.norm2(x)
400
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
401
+ if not self.normalize_before:
402
+ x = self.norm2(x)
403
+
404
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
405
+
406
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
407
+ """Compute encoded features.
408
+
409
+ Args:
410
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
411
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
412
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
413
+
414
+ Returns:
415
+ torch.Tensor: Output tensor (#batch, time, size).
416
+ torch.Tensor: Mask tensor (#batch, time).
417
+
418
+ """
419
+
420
+ residual = x
421
+ if self.normalize_before:
422
+ x = self.norm1(x)
423
+
424
+ if self.in_size == self.size:
425
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
426
+ x = residual + attn
427
+ else:
428
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
429
+
430
+ if not self.normalize_before:
431
+ x = self.norm1(x)
432
+
433
+ residual = x
434
+ if self.normalize_before:
435
+ x = self.norm2(x)
436
+ x = residual + self.feed_forward(x)
437
+ if not self.normalize_before:
438
+ x = self.norm2(x)
439
+
440
+ return x, cache
441
+
442
+
443
+ @tables.register("encoder_classes", "SenseVoiceEncoderSmall")
444
+ class SenseVoiceEncoderSmall(nn.Module):
445
+ """
446
+ Author: Speech Lab of DAMO Academy, Alibaba Group
447
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
448
+ https://arxiv.org/abs/2006.01713
449
+ """
450
+
451
+ def __init__(
452
+ self,
453
+ input_size: int,
454
+ output_size: int = 256,
455
+ attention_heads: int = 4,
456
+ linear_units: int = 2048,
457
+ num_blocks: int = 6,
458
+ tp_blocks: int = 0,
459
+ dropout_rate: float = 0.1,
460
+ positional_dropout_rate: float = 0.1,
461
+ attention_dropout_rate: float = 0.0,
462
+ stochastic_depth_rate: float = 0.0,
463
+ input_layer: Optional[str] = "conv2d",
464
+ pos_enc_class=SinusoidalPositionEncoder,
465
+ normalize_before: bool = True,
466
+ concat_after: bool = False,
467
+ positionwise_layer_type: str = "linear",
468
+ positionwise_conv_kernel_size: int = 1,
469
+ padding_idx: int = -1,
470
+ kernel_size: int = 11,
471
+ sanm_shfit: int = 0,
472
+ selfattention_layer_type: str = "sanm",
473
+ **kwargs,
474
+ ):
475
+ super().__init__()
476
+ self._output_size = output_size
477
+
478
+ self.embed = SinusoidalPositionEncoder()
479
+
480
+ self.normalize_before = normalize_before
481
+
482
+ positionwise_layer = PositionwiseFeedForward
483
+ positionwise_layer_args = (
484
+ output_size,
485
+ linear_units,
486
+ dropout_rate,
487
+ )
488
+
489
+ encoder_selfattn_layer = MultiHeadedAttentionSANM
490
+ encoder_selfattn_layer_args0 = (
491
+ attention_heads,
492
+ input_size,
493
+ output_size,
494
+ attention_dropout_rate,
495
+ kernel_size,
496
+ sanm_shfit,
497
+ )
498
+ encoder_selfattn_layer_args = (
499
+ attention_heads,
500
+ output_size,
501
+ output_size,
502
+ attention_dropout_rate,
503
+ kernel_size,
504
+ sanm_shfit,
505
+ )
506
+
507
+ self.encoders0 = nn.ModuleList(
508
+ [
509
+ EncoderLayerSANM(
510
+ input_size,
511
+ output_size,
512
+ encoder_selfattn_layer(*encoder_selfattn_layer_args0),
513
+ positionwise_layer(*positionwise_layer_args),
514
+ dropout_rate,
515
+ )
516
+ for i in range(1)
517
+ ]
518
+ )
519
+ self.encoders = nn.ModuleList(
520
+ [
521
+ EncoderLayerSANM(
522
+ output_size,
523
+ output_size,
524
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
525
+ positionwise_layer(*positionwise_layer_args),
526
+ dropout_rate,
527
+ )
528
+ for i in range(num_blocks - 1)
529
+ ]
530
+ )
531
+
532
+ self.tp_encoders = nn.ModuleList(
533
+ [
534
+ EncoderLayerSANM(
535
+ output_size,
536
+ output_size,
537
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
538
+ positionwise_layer(*positionwise_layer_args),
539
+ dropout_rate,
540
+ )
541
+ for i in range(tp_blocks)
542
+ ]
543
+ )
544
+
545
+ self.after_norm = LayerNorm(output_size)
546
+
547
+ self.tp_norm = LayerNorm(output_size)
548
+
549
+ def output_size(self) -> int:
550
+ return self._output_size
551
+
552
+ def forward(
553
+ self,
554
+ xs_pad: torch.Tensor,
555
+ # ilens: torch.Tensor,
556
+ masks: torch.Tensor,
557
+ position_encoding: torch.Tensor
558
+ ):
559
+ """Embed positions in tensor."""
560
+ # masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
561
+
562
+ xs_pad *= self.output_size() ** 0.5
563
+
564
+ # xs_pad = self.embed(xs_pad)
565
+ xs_pad += position_encoding
566
+
567
+ # forward encoder1
568
+ for layer_idx, encoder_layer in enumerate(self.encoders0):
569
+ encoder_outs = encoder_layer(xs_pad, masks)
570
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
571
+
572
+ for layer_idx, encoder_layer in enumerate(self.encoders):
573
+ encoder_outs = encoder_layer(xs_pad, masks)
574
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
575
+
576
+ xs_pad = self.after_norm(xs_pad)
577
+
578
+ # forward encoder2
579
+ olens = masks.squeeze(1).sum(1).int()
580
+
581
+ for layer_idx, encoder_layer in enumerate(self.tp_encoders):
582
+ encoder_outs = encoder_layer(xs_pad, masks)
583
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
584
+
585
+ xs_pad = self.tp_norm(xs_pad)
586
+ return xs_pad, olens
587
+
588
+
589
+ @tables.register("model_classes", "SenseVoiceSmall")
590
+ class SenseVoiceSmall(nn.Module):
591
+ """CTC-attention hybrid Encoder-Decoder model"""
592
+
593
+ def __init__(
594
+ self,
595
+ specaug: str = None,
596
+ specaug_conf: dict = None,
597
+ normalize: str = None,
598
+ normalize_conf: dict = None,
599
+ encoder: str = None,
600
+ encoder_conf: dict = None,
601
+ ctc_conf: dict = None,
602
+ input_size: int = 80,
603
+ vocab_size: int = -1,
604
+ ignore_id: int = -1,
605
+ blank_id: int = 0,
606
+ sos: int = 1,
607
+ eos: int = 2,
608
+ length_normalized_loss: bool = False,
609
+ seq_len = 68,
610
+ **kwargs,
611
+ ):
612
+
613
+ super().__init__()
614
+
615
+ if specaug is not None:
616
+ specaug_class = tables.specaug_classes.get(specaug)
617
+ specaug = specaug_class(**specaug_conf)
618
+ if normalize is not None:
619
+ normalize_class = tables.normalize_classes.get(normalize)
620
+ normalize = normalize_class(**normalize_conf)
621
+ encoder_class = tables.encoder_classes.get(encoder)
622
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
623
+ encoder_output_size = encoder.output_size()
624
+
625
+ if ctc_conf is None:
626
+ ctc_conf = {}
627
+ ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
628
+
629
+ self.blank_id = blank_id
630
+ self.sos = sos if sos is not None else vocab_size - 1
631
+ self.eos = eos if eos is not None else vocab_size - 1
632
+ self.vocab_size = vocab_size
633
+ self.ignore_id = ignore_id
634
+ self.specaug = specaug
635
+ self.normalize = normalize
636
+ self.encoder = encoder
637
+ self.error_calculator = None
638
+
639
+ self.ctc = ctc
640
+
641
+ self.length_normalized_loss = length_normalized_loss
642
+ self.encoder_output_size = encoder_output_size
643
+
644
+ self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
645
+ self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
646
+ self.textnorm_dict = {"withitn": 14, "woitn": 15}
647
+ self.textnorm_int_dict = {25016: 14, 25017: 15}
648
+ self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
649
+ self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
650
+
651
+ self.criterion_att = LabelSmoothingLoss(
652
+ size=self.vocab_size,
653
+ padding_idx=self.ignore_id,
654
+ smoothing=kwargs.get("lsm_weight", 0.0),
655
+ normalize_length=self.length_normalized_loss,
656
+ )
657
+
658
+ self.seq_len = seq_len
659
+
660
+ @staticmethod
661
+ def from_pretrained(model:str=None, **kwargs):
662
+ from funasr import AutoModel
663
+ model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
664
+
665
+ return model, kwargs
666
+
667
+ def forward(
668
+ self,
669
+ speech: torch.Tensor,
670
+ speech_lengths: torch.Tensor,
671
+ text: torch.Tensor,
672
+ text_lengths: torch.Tensor,
673
+ **kwargs,
674
+ ):
675
+ """Encoder + Decoder + Calc loss
676
+ Args:
677
+ speech: (Batch, Length, ...)
678
+ speech_lengths: (Batch, )
679
+ text: (Batch, Length)
680
+ text_lengths: (Batch,)
681
+ """
682
+ # import pdb;
683
+ # pdb.set_trace()
684
+ if len(text_lengths.size()) > 1:
685
+ text_lengths = text_lengths[:, 0]
686
+ if len(speech_lengths.size()) > 1:
687
+ speech_lengths = speech_lengths[:, 0]
688
+
689
+ batch_size = speech.shape[0]
690
+
691
+ # 1. Encoder
692
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
693
+
694
+ loss_ctc, cer_ctc = None, None
695
+ loss_rich, acc_rich = None, None
696
+ stats = dict()
697
+
698
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
699
+ encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
700
+ )
701
+
702
+ loss_rich, acc_rich = self._calc_rich_ce_loss(
703
+ encoder_out[:, :4, :], text[:, :4]
704
+ )
705
+
706
+ loss = loss_ctc + loss_rich
707
+ # Collect total loss stats
708
+ stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
709
+ stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
710
+ stats["loss"] = torch.clone(loss.detach()) if loss is not None else None
711
+ stats["acc_rich"] = acc_rich
712
+
713
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
714
+ if self.length_normalized_loss:
715
+ batch_size = int((text_lengths + 1).sum())
716
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
717
+ return loss, stats, weight
718
+
719
+ def encode(
720
+ self,
721
+ speech: torch.Tensor,
722
+ speech_lengths: torch.Tensor,
723
+ text: torch.Tensor,
724
+ **kwargs,
725
+ ):
726
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
727
+ Args:
728
+ speech: (Batch, Length, ...)
729
+ speech_lengths: (Batch, )
730
+ ind: int
731
+ """
732
+
733
+ # Data augmentation
734
+ if self.specaug is not None and self.training:
735
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
736
+
737
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
738
+ if self.normalize is not None:
739
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
740
+
741
+
742
+ lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
743
+ language_query = self.embed(lids)
744
+
745
+ styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
746
+ style_query = self.embed(styles)
747
+ speech = torch.cat((style_query, speech), dim=1)
748
+ speech_lengths += 1
749
+
750
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
751
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
752
+ speech = torch.cat((input_query, speech), dim=1)
753
+ speech_lengths += 3
754
+
755
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
756
+
757
+ return encoder_out, encoder_out_lens
758
+
759
+ def _calc_ctc_loss(
760
+ self,
761
+ encoder_out: torch.Tensor,
762
+ encoder_out_lens: torch.Tensor,
763
+ ys_pad: torch.Tensor,
764
+ ys_pad_lens: torch.Tensor,
765
+ ):
766
+ # Calc CTC loss
767
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
768
+
769
+ # Calc CER using CTC
770
+ cer_ctc = None
771
+ if not self.training and self.error_calculator is not None:
772
+ ys_hat = self.ctc.argmax(encoder_out).data
773
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
774
+ return loss_ctc, cer_ctc
775
+
776
+ def _calc_rich_ce_loss(
777
+ self,
778
+ encoder_out: torch.Tensor,
779
+ ys_pad: torch.Tensor,
780
+ ):
781
+ decoder_out = self.ctc.ctc_lo(encoder_out)
782
+ # 2. Compute attention loss
783
+ loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
784
+ acc_rich = th_accuracy(
785
+ decoder_out.view(-1, self.vocab_size),
786
+ ys_pad.contiguous(),
787
+ ignore_label=self.ignore_id,
788
+ )
789
+
790
+ return loss_rich, acc_rich
791
+
792
+
793
+ def inference(
794
+ self,
795
+ data_in,
796
+ data_lengths=None,
797
+ key: list = ["wav_file_tmp_name"],
798
+ tokenizer=None,
799
+ frontend=None,
800
+ **kwargs,
801
+ ):
802
+
803
+
804
+ meta_data = {}
805
+ if (
806
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
807
+ ): # fbank
808
+ speech, speech_lengths = data_in, data_lengths
809
+ if len(speech.shape) < 3:
810
+ speech = speech[None, :, :]
811
+ if speech_lengths is None:
812
+ speech_lengths = speech.shape[1]
813
+ else:
814
+ # extract fbank feats
815
+ time1 = time.perf_counter()
816
+ audio_sample_list = load_audio_text_image_video(
817
+ data_in,
818
+ fs=frontend.fs,
819
+ audio_fs=kwargs.get("fs", 16000),
820
+ data_type=kwargs.get("data_type", "sound"),
821
+ tokenizer=tokenizer,
822
+ )
823
+ time2 = time.perf_counter()
824
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
825
+ speech, speech_lengths = extract_fbank(
826
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
827
+ )
828
+ time3 = time.perf_counter()
829
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
830
+ meta_data["batch_data_time"] = (
831
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
832
+ )
833
+
834
+ speech = speech.to(device=kwargs["device"])
835
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
836
+
837
+ language = kwargs.get("language", "auto")
838
+ language_query = self.embed(
839
+ torch.LongTensor(
840
+ [[self.lid_dict[language] if language in self.lid_dict else 0]]
841
+ ).to(speech.device)
842
+ ).repeat(speech.size(0), 1, 1)
843
+
844
+ use_itn = kwargs.get("use_itn", False)
845
+ output_timestamp = kwargs.get("output_timestamp", False)
846
+
847
+ textnorm = kwargs.get("text_norm", None)
848
+ if textnorm is None:
849
+ textnorm = "withitn" if use_itn else "woitn"
850
+ textnorm_query = self.embed(
851
+ torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
852
+ ).repeat(speech.size(0), 1, 1)
853
+ speech = torch.cat((textnorm_query, speech), dim=1)
854
+ speech_lengths += 1
855
+
856
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
857
+ speech.size(0), 1, 1
858
+ )
859
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
860
+ speech = torch.cat((input_query, speech), dim=1)
861
+ speech_lengths += 3
862
+
863
+ # Encoder
864
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
865
+ if isinstance(encoder_out, tuple):
866
+ encoder_out = encoder_out[0]
867
+
868
+ # c. Passed the encoder result and the beam search
869
+ ctc_logits = self.ctc.log_softmax(encoder_out)
870
+ if kwargs.get("ban_emo_unk", False):
871
+ ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
872
+
873
+ results = []
874
+ b, n, d = encoder_out.size()
875
+ if isinstance(key[0], (list, tuple)):
876
+ key = key[0]
877
+ if len(key) < b:
878
+ key = key * b
879
+ for i in range(b):
880
+ x = ctc_logits[i, : encoder_out_lens[i].item(), :]
881
+ yseq = x.argmax(dim=-1)
882
+ yseq = torch.unique_consecutive(yseq, dim=-1)
883
+
884
+ ibest_writer = None
885
+ if kwargs.get("output_dir") is not None:
886
+ if not hasattr(self, "writer"):
887
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
888
+ ibest_writer = self.writer[f"1best_recog"]
889
+
890
+ mask = yseq != self.blank_id
891
+ token_int = yseq[mask].tolist()
892
+
893
+ # Change integer-ids to tokens
894
+ text = tokenizer.decode(token_int)
895
+ if ibest_writer is not None:
896
+ ibest_writer["text"][key[i]] = text
897
+
898
+ if output_timestamp:
899
+ from itertools import groupby
900
+ timestamp = []
901
+ tokens = tokenizer.text2tokens(text)[4:]
902
+
903
+ logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]
904
+
905
+ pred = logits_speech.argmax(-1).cpu()
906
+ logits_speech[pred==self.blank_id, self.blank_id] = 0
907
+
908
+ align = ctc_forced_align(
909
+ logits_speech.unsqueeze(0).float(),
910
+ torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
911
+ (encoder_out_lens-4).long(),
912
+ torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device),
913
+ ignore_id=self.ignore_id,
914
+ )
915
+
916
+ pred = groupby(align[0, :encoder_out_lens[0]])
917
+ _start = 0
918
+ token_id = 0
919
+ ts_max = encoder_out_lens[i] - 4
920
+ for pred_token, pred_frame in pred:
921
+ _end = _start + len(list(pred_frame))
922
+ if pred_token != 0:
923
+ ts_left = max((_start*60-30)/1000, 0)
924
+ ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000)
925
+ timestamp.append([tokens[token_id], ts_left, ts_right])
926
+ token_id += 1
927
+ _start = _end
928
+
929
+ result_i = {"key": key[i], "text": text, "timestamp": timestamp}
930
+ results.append(result_i)
931
+ else:
932
+ result_i = {"key": key[i], "text": text}
933
+ results.append(result_i)
934
+ return results, meta_data
935
+
936
+ def export(self, **kwargs):
937
+ from export_meta import export_rebuild_model
938
+
939
+ if "max_seq_len" not in kwargs:
940
+ kwargs["max_seq_len"] = 512
941
+ models = export_rebuild_model(model=self, **kwargs)
942
+ return models
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ tqdm
3
+ funasr==1.2.7
4
+ torchaudio
5
+ cn2an
utils/__init__.py ADDED
File without changes
utils/ax_model_bin.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ import os.path
7
+ from pathlib import Path
8
+ from typing import List, Union, Tuple
9
+ import torch
10
+ import numpy as np
11
+ import axengine as axe
12
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
13
+ try:
14
+ import librosa
15
+ except ImportError:
16
+ print("Warning: librosa not found. Please install it using 'pip install librosa'.")
17
+ # Provide a fallback implementation if needed
18
+ def load_wav_fallback(path, sr=None):
19
+ import wave
20
+ import numpy as np
21
+ with wave.open(path, 'rb') as wf:
22
+ num_frames = wf.getnframes()
23
+ frames = wf.readframes(num_frames)
24
+ return np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0, wf.getframerate()
25
+
26
+ from utils.infer_utils import (
27
+ CharTokenizer,
28
+ get_logger,
29
+ read_yaml,
30
+ )
31
+ from utils.frontend import WavFrontend
32
+ from utils.ctc_alignment import ctc_forced_align
33
+
34
+ logging = get_logger()
35
+
36
+
37
+ def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
38
+ if maxlen is None:
39
+ maxlen = lengths.max()
40
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
41
+ matrix = torch.unsqueeze(lengths, dim=-1)
42
+ mask = row_vector < matrix
43
+ mask = mask.detach()
44
+
45
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
46
+
47
+
48
+ class AX_SenseVoiceSmall:
49
+ """
50
+ Author: Speech Lab of DAMO Academy, Alibaba Group
51
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
52
+ https://arxiv.org/abs/2206.08317
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model_dir: Union[str, Path] = None,
58
+ batch_size: int = 1,
59
+ seq_len: int = 68
60
+ ):
61
+
62
+ model_file = os.path.join(model_dir, "sensevoice.axmodel")
63
+ config_file = os.path.join(model_dir, "sensevoice/config.yaml")
64
+ cmvn_file = os.path.join(model_dir, "sensevoice/am.mvn")
65
+ config = read_yaml(config_file)
66
+ self.model_dir = model_dir
67
+ # token_list = os.path.join(model_dir, "tokens.json")
68
+ # with open(token_list, "r", encoding="utf-8") as f:
69
+ # token_list = json.load(f)
70
+
71
+ # self.converter = TokenIDConverter(token_list)
72
+ self.tokenizer = CharTokenizer()
73
+ config["frontend_conf"]['cmvn_file'] = cmvn_file
74
+ self.frontend = WavFrontend(**config["frontend_conf"])
75
+ # self.ort_infer = OrtInferSession(
76
+ # model_file, device_id, intra_op_num_threads=intra_op_num_threads
77
+ # )
78
+ self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
79
+ self.batch_size = batch_size
80
+ self.blank_id = 0
81
+ self.seq_len = seq_len
82
+
83
+ self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
84
+ self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
85
+ self.textnorm_dict = {"withitn": 14, "woitn": 15}
86
+ self.textnorm_int_dict = {25016: 14, 25017: 15}
87
+ self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
88
+
89
+ def __call__(self,
90
+ wav_content: Union[str, np.ndarray, List[str]],
91
+ language: str,
92
+ withitn: bool,
93
+ position_encoding: np.ndarray,
94
+ tokenizer=None,
95
+ **kwargs) -> List:
96
+ """Enhanced model inference with additional features from model.py
97
+
98
+ Args:
99
+ wav_content: Audio data or path
100
+ language: Language code for processing
101
+ withitn: Whether to use ITN (inverse text normalization)
102
+ position_encoding: Position encoding tensor
103
+ tokenizer: Tokenizer for text conversion
104
+ **kwargs: Additional arguments
105
+ """
106
+ # Start time tracking for metadata
107
+ import time
108
+ meta_data = {}
109
+ time_start = time.perf_counter()
110
+
111
+ # Load waveform data
112
+ waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
113
+ waveform_nums = len(waveform_list)
114
+ time_load = time.perf_counter()
115
+ meta_data["load_data"] = f"{time_load - time_start:0.3f}"
116
+
117
+ # Load queries from saved numpy files
118
+ language_query = np.load(os.path.join(self.model_dir, f"{language}.npy"))
119
+ textnorm_query = np.load(os.path.join(self.model_dir, "withitn.npy") if withitn
120
+ else os.path.join(self.model_dir, "woitn.npy"))
121
+ event_emo_query = np.load(os.path.join(self.model_dir, "event_emo.npy"))
122
+
123
+ # Concatenate queries to form input_query
124
+ input_query = np.concatenate((language_query, event_emo_query, textnorm_query), axis=1)
125
+
126
+ # Process features
127
+ results = ""
128
+
129
+ # Handle output_dir without using DatadirWriter (which is not available)
130
+ slice_len = self.seq_len - 4
131
+ time_pre = time.perf_counter()
132
+ meta_data["preprocess"] = f"{time_pre - time_load:0.3f}"
133
+ for beg_idx in range(0, waveform_nums, self.batch_size):
134
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
135
+ feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
136
+
137
+ time_feat = time.perf_counter()
138
+ meta_data["extract_feat"] = f"{time_feat - time_pre:0.3f}"
139
+
140
+ for i in range(int(np.ceil(feats.shape[1] / slice_len))):
141
+ sub_feats = np.concatenate([input_query, feats[:, i*slice_len : (i+1)*slice_len, :]], axis=1)
142
+ feats_len[0] = sub_feats.shape[1]
143
+
144
+
145
+ if feats_len[0] < self.seq_len:
146
+ sub_feats = np.concatenate([sub_feats, np.zeros((1, self.seq_len - feats_len[0], 560), dtype=np.float32)], axis=1)
147
+
148
+ masks = sequence_mask(torch.IntTensor([self.seq_len]), maxlen=self.seq_len, dtype=torch.float32)[:, None, :]
149
+ masks = masks.numpy()
150
+
151
+ # Run inference
152
+
153
+ ctc_logits, encoder_out_lens = self.infer(sub_feats, masks, position_encoding)
154
+ # Convert to torch tensor for processing
155
+ ctc_logits = torch.from_numpy(ctc_logits).float()
156
+
157
+ # Process results for each batch
158
+ b, _, _ = ctc_logits.size()
159
+
160
+ for j in range(b):
161
+ x = ctc_logits[j, : encoder_out_lens[j].item(), :]
162
+ yseq = x.argmax(dim=-1)
163
+ yseq = torch.unique_consecutive(yseq, dim=-1)
164
+
165
+ mask = yseq != self.blank_id
166
+ token_int = yseq[mask].tolist()[4:] #前4个略去: <|zh|><|ANGRY|><|Speech|><|withitn|>
167
+
168
+ # Convert tokens to text
169
+ text = tokenizer.decode(token_int) if tokenizer is not None else str(token_int)
170
+
171
+ if tokenizer is not None:
172
+ results+= text
173
+ else:
174
+ results+= token_int
175
+ return results
176
+
177
+ def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
178
+ def load_wav(path: str) -> np.ndarray:
179
+ try:
180
+ # Use librosa if available
181
+ if 'librosa' in globals():
182
+ waveform, _ = librosa.load(path, sr=fs)
183
+ else:
184
+ # Use fallback implementation
185
+ waveform, native_sr = load_wav_fallback(path)
186
+ if fs is not None and native_sr != fs:
187
+ # Implement resampling if needed
188
+ print(f"Warning: Resampling from {native_sr} to {fs} is not implemented in fallback mode")
189
+ return waveform
190
+ except Exception as e:
191
+ print(f"Error loading audio file {path}: {e}")
192
+ # Return empty audio in case of error
193
+ return np.zeros(1600, dtype=np.float32)
194
+
195
+ if isinstance(wav_content, np.ndarray):
196
+ return [wav_content]
197
+
198
+ if isinstance(wav_content, str):
199
+ return [load_wav(wav_content)]
200
+
201
+ if isinstance(wav_content, list):
202
+ return [load_wav(path) for path in wav_content]
203
+
204
+ raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
205
+
206
+ def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
207
+ feats, feats_len = [], []
208
+ for waveform in waveform_list:
209
+ speech, _ = self.frontend.fbank(waveform)
210
+
211
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
212
+
213
+ feats.append(feat)
214
+ feats_len.append(feat_len)
215
+
216
+ feats = self.pad_feats(feats, np.max(feats_len))
217
+ feats_len = np.array(feats_len).astype(np.int32)
218
+ return feats, feats_len
219
+
220
+ @staticmethod
221
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
222
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
223
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
224
+ return np.pad(feat, pad_width, "constant", constant_values=0)
225
+
226
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
227
+ feats = np.array(feat_res).astype(np.float32)
228
+ return feats
229
+
230
+ def infer(self,
231
+ feats: np.ndarray,
232
+ masks: np.ndarray,
233
+ position_encoding: np.ndarray,
234
+ ) -> Tuple[np.ndarray, np.ndarray]:
235
+ #outputs = self.ort_infer([feats, masks, position_encoding])
236
+ outputs =self.session.run(None, {
237
+ 'speech': feats,
238
+ 'masks': masks,
239
+ 'position_encoding': position_encoding
240
+ })
241
+ return outputs
utils/ax_vad_bin.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
3
+ # MIT License (https://opensource.org/licenses/MIT)
4
+
5
+ import os.path
6
+ from typing import List, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from utils.utils.utils import read_yaml
11
+ from utils.utils.frontend import WavFrontend
12
+ from utils.utils.e2e_vad import E2EVadModel
13
+ import axengine as axe
14
+
15
+ class AX_Fsmn_vad:
16
+ def __init__(self, model_dir, batch_size=1, max_end_sil=None):
17
+ """Initialize VAD model for inference"""
18
+
19
+ # Export model if needed
20
+ model_file = os.path.join(model_dir, "vad.axmodel")
21
+
22
+ # Load config and frontend
23
+ config_file = os.path.join(model_dir, "vad/config.yaml")
24
+ cmvn_file = os.path.join(model_dir, "vad/am.mvn")
25
+ self.config = read_yaml(config_file)
26
+ self.frontend = WavFrontend(cmvn_file=cmvn_file, **self.config["frontend_conf"])
27
+ self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
28
+ self.batch_size = batch_size
29
+ self.vad_scorer = E2EVadModel(self.config["model_conf"])
30
+ self.max_end_sil = max_end_sil if max_end_sil is not None else self.config["model_conf"]["max_end_silence_time"]
31
+
32
+ def extract_feat(self, waveform_list):
33
+ """Extract features from waveform"""
34
+ feats, feats_len = [], []
35
+ for waveform in waveform_list:
36
+ speech, _ = self.frontend.fbank(waveform)
37
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
38
+ feats.append(feat)
39
+ feats_len.append(feat_len)
40
+
41
+ max_len = max(feats_len)
42
+ padded_feats = [np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), 'constant') for f in feats]
43
+ feats = np.array(padded_feats).astype(np.float32)
44
+ feats_len = np.array(feats_len).astype(np.int32)
45
+ return feats, feats_len
46
+
47
+ def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
48
+ """Run inference with ONNX Runtime"""
49
+ # Get all input names from the model
50
+ input_names = [input.name for input in self.session.get_inputs()]
51
+ output_names = [x.name for x in self.session.get_outputs()]
52
+
53
+ # Create input dictionary for all inputs
54
+ input_dict = {}
55
+ for i, (name, tensor) in enumerate(zip(input_names, feats)):
56
+ input_dict[name] = tensor
57
+
58
+ # Run inference with all inputs
59
+ outputs = self.session.run(output_names, input_dict)
60
+ scores, out_caches = outputs[0], outputs[1:]
61
+ return scores, out_caches
62
+
63
+ def __call__(self, wav_file, **kwargs):
64
+ """Process audio file with sliding window approach"""
65
+ # Load audio and prepare data
66
+ # waveform = self.load_wav(wav_file)
67
+ # waveform, _ = librosa.load(wav_file, sr=16000)
68
+ waveform_list = [wav_file]
69
+ waveform_nums = len(waveform_list)
70
+ is_final = kwargs.get("kwargs", False)
71
+ segments = [[]] * self.batch_size
72
+
73
+ for beg_idx in range(0, waveform_nums, self.batch_size):
74
+ vad_scorer = E2EVadModel(self.config["model_conf"])
75
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
76
+ waveform = waveform_list[beg_idx:end_idx]
77
+ feats, feats_len = self.extract_feat(waveform)
78
+ waveform = np.array(waveform)
79
+ param_dict = kwargs.get("param_dict", dict())
80
+ in_cache = param_dict.get("in_cache", list())
81
+ in_cache = self.prepare_cache(in_cache)
82
+
83
+ t_offset = 0
84
+ step = int(min(feats_len.max(), 6000))
85
+ for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
86
+ if t_offset + step >= feats_len - 1:
87
+ step = feats_len - t_offset
88
+ is_final = True
89
+ else:
90
+ is_final = False
91
+
92
+ # Extract feature segment
93
+ feats_package = feats[:, t_offset:int(t_offset + step), :]
94
+
95
+ # Pad if it's the final segment
96
+ if is_final:
97
+ pad_length = 6000 - int(step)
98
+ feats_package = np.pad(
99
+ feats_package,
100
+ ((0, 0), (0, pad_length), (0, 0)),
101
+ mode='constant',
102
+ constant_values=0
103
+ )
104
+
105
+ # Extract corresponding waveform segment
106
+ waveform_package = waveform[
107
+ :,
108
+ t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400),
109
+ ]
110
+
111
+ # Pad waveform if it's the final segment
112
+ if is_final:
113
+ expected_wave_length = 6000 * 160 + 240
114
+ current_wave_length = waveform_package.shape[-1]
115
+ pad_wave_length = expected_wave_length - current_wave_length
116
+ if pad_wave_length > 0:
117
+ waveform_package = np.pad(
118
+ waveform_package,
119
+ ((0, 0), (0, pad_wave_length)),
120
+ mode='constant',
121
+ constant_values=0
122
+ )
123
+
124
+ # Run inference
125
+ inputs = [feats_package]
126
+ inputs.extend(in_cache)
127
+ scores, out_caches = self.infer(inputs)
128
+ in_cache = out_caches
129
+
130
+ # Get VAD segments for this chunk
131
+ segments_part = vad_scorer(
132
+ scores,
133
+ waveform_package,
134
+ is_final=is_final,
135
+ max_end_sil=self.max_end_sil,
136
+ online=False,
137
+ )
138
+
139
+ # Accumulate segments
140
+ if segments_part:
141
+ for batch_num in range(0, self.batch_size):
142
+ segments[batch_num] += segments_part[batch_num]
143
+
144
+ return segments
145
+
146
+ def prepare_cache(self, in_cache: list = []):
147
+ if len(in_cache) > 0:
148
+ return in_cache
149
+ fsmn_layers = 4
150
+ proj_dim = 128
151
+ lorder = 20
152
+ for i in range(fsmn_layers):
153
+ cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
154
+ in_cache.append(cache)
155
+ return in_cache
156
+
utils/ctc_alignment.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def ctc_forced_align(
4
+ log_probs: torch.Tensor,
5
+ targets: torch.Tensor,
6
+ input_lengths: torch.Tensor,
7
+ target_lengths: torch.Tensor,
8
+ blank: int = 0,
9
+ ignore_id: int = -1,
10
+ ) -> torch.Tensor:
11
+ """Align a CTC label sequence to an emission.
12
+
13
+ Args:
14
+ log_probs (Tensor): log probability of CTC emission output.
15
+ Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
16
+ `C` is the number of characters in alphabet including blank.
17
+ targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
18
+ where `L` is the target length.
19
+ input_lengths (Tensor):
20
+ Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
21
+ target_lengths (Tensor):
22
+ Lengths of the targets. 1-D Tensor of shape `(B,)`.
23
+ blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
24
+ ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
25
+ """
26
+ targets[targets == ignore_id] = blank
27
+
28
+ batch_size, input_time_size, _ = log_probs.size()
29
+ bsz_indices = torch.arange(batch_size, device=input_lengths.device)
30
+
31
+ _t_a_r_g_e_t_s_ = torch.cat(
32
+ (
33
+ torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
34
+ torch.full_like(targets[:, :1], blank),
35
+ ),
36
+ dim=-1,
37
+ )
38
+ diff_labels = torch.cat(
39
+ (
40
+ torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
41
+ _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
42
+ ),
43
+ dim=1,
44
+ )
45
+
46
+ neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
47
+ padding_num = 2
48
+ padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
49
+ best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
50
+ best_score[:, padding_num + 0] = log_probs[:, 0, blank]
51
+ best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
52
+
53
+ backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
54
+
55
+ for t in range(1, input_time_size):
56
+ prev = torch.stack(
57
+ (best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
58
+ )
59
+ prev_max_value, prev_max_idx = prev.max(dim=0)
60
+ best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
61
+ backpointers[:, t, padding_num:] = prev_max_idx
62
+
63
+ l1l2 = best_score.gather(
64
+ -1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
65
+ )
66
+
67
+ path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
68
+ path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
69
+
70
+ for t in range(input_time_size - 1, 0, -1):
71
+ target_indices = path[:, t]
72
+ prev_max_idx = backpointers[bsz_indices, t, target_indices]
73
+ path[:, t - 1] += target_indices - prev_max_idx
74
+
75
+ alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
76
+ return alignments
utils/frontend.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
+ import copy
5
+
6
+ import numpy as np
7
+ import kaldi_native_fbank as knf
8
+
9
+ root_dir = Path(__file__).resolve().parent
10
+
11
+ logger_initialized = {}
12
+
13
+
14
+ class WavFrontend:
15
+ """Conventional frontend structure for ASR."""
16
+
17
+ def __init__(
18
+ self,
19
+ cmvn_file: str = None,
20
+ fs: int = 16000,
21
+ window: str = "hamming",
22
+ n_mels: int = 80,
23
+ frame_length: int = 25,
24
+ frame_shift: int = 10,
25
+ lfr_m: int = 1,
26
+ lfr_n: int = 1,
27
+ dither: float = 1.0,
28
+ **kwargs,
29
+ ) -> None:
30
+
31
+ opts = knf.FbankOptions()
32
+ opts.frame_opts.samp_freq = fs
33
+ opts.frame_opts.dither = dither
34
+ opts.frame_opts.window_type = window
35
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
36
+ opts.frame_opts.frame_length_ms = float(frame_length)
37
+ opts.mel_opts.num_bins = n_mels
38
+ opts.energy_floor = 0
39
+ opts.frame_opts.snip_edges = True
40
+ opts.mel_opts.debug_mel = False
41
+ self.opts = opts
42
+
43
+ self.lfr_m = lfr_m
44
+ self.lfr_n = lfr_n
45
+ self.cmvn_file = cmvn_file
46
+
47
+ if self.cmvn_file:
48
+ self.cmvn = self.load_cmvn()
49
+ self.fbank_fn = None
50
+ self.fbank_beg_idx = 0
51
+ self.reset_status()
52
+
53
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
54
+ waveform = waveform * (1 << 15)
55
+ self.fbank_fn = knf.OnlineFbank(self.opts)
56
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
57
+ frames = self.fbank_fn.num_frames_ready
58
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
59
+ for i in range(frames):
60
+ mat[i, :] = self.fbank_fn.get_frame(i)
61
+ feat = mat.astype(np.float32)
62
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
63
+ return feat, feat_len
64
+
65
+ def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
66
+ waveform = waveform * (1 << 15)
67
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
68
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
69
+ frames = self.fbank_fn.num_frames_ready
70
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
71
+ for i in range(self.fbank_beg_idx, frames):
72
+ mat[i, :] = self.fbank_fn.get_frame(i)
73
+ # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
74
+ feat = mat.astype(np.float32)
75
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
76
+ return feat, feat_len
77
+
78
+ def reset_status(self):
79
+ self.fbank_fn = knf.OnlineFbank(self.opts)
80
+ self.fbank_beg_idx = 0
81
+
82
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
83
+ if self.lfr_m != 1 or self.lfr_n != 1:
84
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
85
+
86
+ if self.cmvn_file:
87
+ feat = self.apply_cmvn(feat)
88
+
89
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
90
+ return feat, feat_len
91
+
92
+ @staticmethod
93
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
94
+ LFR_inputs = []
95
+
96
+ T = inputs.shape[0]
97
+ T_lfr = int(np.ceil(T / lfr_n))
98
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
99
+ inputs = np.vstack((left_padding, inputs))
100
+ T = T + (lfr_m - 1) // 2
101
+ for i in range(T_lfr):
102
+ if lfr_m <= T - i * lfr_n:
103
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
104
+ else:
105
+ # process last LFR frame
106
+ num_padding = lfr_m - (T - i * lfr_n)
107
+ frame = inputs[i * lfr_n :].reshape(-1)
108
+ for _ in range(num_padding):
109
+ frame = np.hstack((frame, inputs[-1]))
110
+
111
+ LFR_inputs.append(frame)
112
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
113
+ return LFR_outputs
114
+
115
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
116
+ """
117
+ Apply CMVN with mvn data
118
+ """
119
+ frame, dim = inputs.shape
120
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
121
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
122
+ inputs = (inputs + means) * vars
123
+ return inputs
124
+
125
+ def load_cmvn(
126
+ self,
127
+ ) -> np.ndarray:
128
+ with open(self.cmvn_file, "r", encoding="utf-8") as f:
129
+ lines = f.readlines()
130
+
131
+ means_list = []
132
+ vars_list = []
133
+ for i in range(len(lines)):
134
+ line_item = lines[i].split()
135
+ if line_item[0] == "<AddShift>":
136
+ line_item = lines[i + 1].split()
137
+ if line_item[0] == "<LearnRateCoef>":
138
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
139
+ means_list = list(add_shift_line)
140
+ continue
141
+ elif line_item[0] == "<Rescale>":
142
+ line_item = lines[i + 1].split()
143
+ if line_item[0] == "<LearnRateCoef>":
144
+ rescale_line = line_item[3 : (len(line_item) - 1)]
145
+ vars_list = list(rescale_line)
146
+ continue
147
+
148
+ means = np.array(means_list).astype(np.float64)
149
+ vars = np.array(vars_list).astype(np.float64)
150
+ cmvn = np.array([means, vars])
151
+ return cmvn
152
+
153
+
154
+ class WavFrontendOnline(WavFrontend):
155
+ def __init__(self, **kwargs):
156
+ super().__init__(**kwargs)
157
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
158
+ # add variables
159
+ self.frame_sample_length = int(
160
+ self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
161
+ )
162
+ self.frame_shift_sample_length = int(
163
+ self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
164
+ )
165
+ self.waveform = None
166
+ self.reserve_waveforms = None
167
+ self.input_cache = None
168
+ self.lfr_splice_cache = []
169
+
170
+ @staticmethod
171
+ # inputs has catted the cache
172
+ def apply_lfr(
173
+ inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
174
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
175
+ """
176
+ Apply lfr with data
177
+ """
178
+
179
+ LFR_inputs = []
180
+ T = inputs.shape[0] # include the right context
181
+ T_lfr = int(
182
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
183
+ ) # minus the right context: (lfr_m - 1) // 2
184
+ splice_idx = T_lfr
185
+ for i in range(T_lfr):
186
+ if lfr_m <= T - i * lfr_n:
187
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
188
+ else: # process last LFR frame
189
+ if is_final:
190
+ num_padding = lfr_m - (T - i * lfr_n)
191
+ frame = (inputs[i * lfr_n :]).reshape(-1)
192
+ for _ in range(num_padding):
193
+ frame = np.hstack((frame, inputs[-1]))
194
+ LFR_inputs.append(frame)
195
+ else:
196
+ # update splice_idx and break the circle
197
+ splice_idx = i
198
+ break
199
+ splice_idx = min(T - 1, splice_idx * lfr_n)
200
+ lfr_splice_cache = inputs[splice_idx:, :]
201
+ LFR_outputs = np.vstack(LFR_inputs)
202
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
203
+
204
+ @staticmethod
205
+ def compute_frame_num(
206
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
207
+ ) -> int:
208
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
209
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
210
+
211
+ def fbank(
212
+ self, input: np.ndarray, input_lengths: np.ndarray
213
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
214
+ self.fbank_fn = knf.OnlineFbank(self.opts)
215
+ batch_size = input.shape[0]
216
+ if self.input_cache is None:
217
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
218
+ input = np.concatenate((self.input_cache, input), axis=1)
219
+ frame_num = self.compute_frame_num(
220
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
221
+ )
222
+ # update self.in_cache
223
+ self.input_cache = input[
224
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
225
+ ]
226
+ waveforms = np.empty(0, dtype=np.float32)
227
+ feats_pad = np.empty(0, dtype=np.float32)
228
+ feats_lens = np.empty(0, dtype=np.int32)
229
+ if frame_num:
230
+ waveforms = []
231
+ feats = []
232
+ feats_lens = []
233
+ for i in range(batch_size):
234
+ waveform = input[i]
235
+ waveforms.append(
236
+ waveform[
237
+ : (
238
+ (frame_num - 1) * self.frame_shift_sample_length
239
+ + self.frame_sample_length
240
+ )
241
+ ]
242
+ )
243
+ waveform = waveform * (1 << 15)
244
+
245
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
246
+ frames = self.fbank_fn.num_frames_ready
247
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
248
+ for i in range(frames):
249
+ mat[i, :] = self.fbank_fn.get_frame(i)
250
+ feat = mat.astype(np.float32)
251
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
252
+ feats.append(feat)
253
+ feats_lens.append(feat_len)
254
+
255
+ waveforms = np.stack(waveforms)
256
+ feats_lens = np.array(feats_lens)
257
+ feats_pad = np.array(feats)
258
+ self.fbanks = feats_pad
259
+ self.fbanks_lens = copy.deepcopy(feats_lens)
260
+ return waveforms, feats_pad, feats_lens
261
+
262
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
263
+ return self.fbanks, self.fbanks_lens
264
+
265
+ def lfr_cmvn(
266
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
267
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
268
+ batch_size = input.shape[0]
269
+ feats = []
270
+ feats_lens = []
271
+ lfr_splice_frame_idxs = []
272
+ for i in range(batch_size):
273
+ mat = input[i, : input_lengths[i], :]
274
+ lfr_splice_frame_idx = -1
275
+ if self.lfr_m != 1 or self.lfr_n != 1:
276
+ # update self.lfr_splice_cache in self.apply_lfr
277
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
278
+ mat, self.lfr_m, self.lfr_n, is_final
279
+ )
280
+ if self.cmvn_file is not None:
281
+ mat = self.apply_cmvn(mat)
282
+ feat_length = mat.shape[0]
283
+ feats.append(mat)
284
+ feats_lens.append(feat_length)
285
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
286
+
287
+ feats_lens = np.array(feats_lens)
288
+ feats_pad = np.array(feats)
289
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
290
+
291
+ def extract_fbank(
292
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
293
+ ) -> Tuple[np.ndarray, np.ndarray]:
294
+ batch_size = input.shape[0]
295
+ assert (
296
+ batch_size == 1
297
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
298
+ waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
299
+ if feats.shape[0]:
300
+ self.waveforms = (
301
+ waveforms
302
+ if self.reserve_waveforms is None
303
+ else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
304
+ )
305
+ if not self.lfr_splice_cache:
306
+ for i in range(batch_size):
307
+ self.lfr_splice_cache.append(
308
+ np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
309
+ )
310
+
311
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
312
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
313
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
314
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
315
+ frame_from_waveforms = int(
316
+ (self.waveforms.shape[1] - self.frame_sample_length)
317
+ / self.frame_shift_sample_length
318
+ + 1
319
+ )
320
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
321
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
322
+ feats, feats_lengths, is_final
323
+ )
324
+ if self.lfr_m == 1:
325
+ self.reserve_waveforms = None
326
+ else:
327
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
328
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
329
+ # print('frame_frame: ' + str(frame_from_waveforms))
330
+ self.reserve_waveforms = self.waveforms[
331
+ :,
332
+ reserve_frame_idx
333
+ * self.frame_shift_sample_length : frame_from_waveforms
334
+ * self.frame_shift_sample_length,
335
+ ]
336
+ sample_length = (
337
+ frame_from_waveforms - 1
338
+ ) * self.frame_shift_sample_length + self.frame_sample_length
339
+ self.waveforms = self.waveforms[:, :sample_length]
340
+ else:
341
+ # update self.reserve_waveforms and self.lfr_splice_cache
342
+ self.reserve_waveforms = self.waveforms[
343
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
344
+ ]
345
+ for i in range(batch_size):
346
+ self.lfr_splice_cache[i] = np.concatenate(
347
+ (self.lfr_splice_cache[i], feats[i]), axis=0
348
+ )
349
+ return np.empty(0, dtype=np.float32), feats_lengths
350
+ else:
351
+ if is_final:
352
+ self.waveforms = (
353
+ waveforms if self.reserve_waveforms is None else self.reserve_waveforms
354
+ )
355
+ feats = np.stack(self.lfr_splice_cache)
356
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
357
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
358
+ if is_final:
359
+ self.cache_reset()
360
+ return feats, feats_lengths
361
+
362
+ def get_waveforms(self):
363
+ return self.waveforms
364
+
365
+ def cache_reset(self):
366
+ self.fbank_fn = knf.OnlineFbank(self.opts)
367
+ self.reserve_waveforms = None
368
+ self.input_cache = None
369
+ self.lfr_splice_cache = []
370
+
371
+
372
+ def load_bytes(input):
373
+ middle_data = np.frombuffer(input, dtype=np.int16)
374
+ middle_data = np.asarray(middle_data)
375
+ if middle_data.dtype.kind not in "iu":
376
+ raise TypeError("'middle_data' must be an array of integers")
377
+ dtype = np.dtype("float32")
378
+ if dtype.kind != "f":
379
+ raise TypeError("'dtype' must be a floating point type")
380
+
381
+ i = np.iinfo(middle_data.dtype)
382
+ abs_max = 2 ** (i.bits - 1)
383
+ offset = i.min + abs_max
384
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
385
+ return array
386
+
387
+
388
+ class SinusoidalPositionEncoderOnline:
389
+ """Streaming Positional encoding."""
390
+
391
+ def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
392
+ batch_size = positions.shape[0]
393
+ positions = positions.astype(dtype)
394
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
395
+ inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
396
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
397
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
398
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
399
+ return encoding.astype(dtype)
400
+
401
+ def forward(self, x, start_idx=0):
402
+ batch_size, timesteps, input_dim = x.shape
403
+ positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
404
+ position_encoding = self.encode(positions, input_dim, x.dtype)
405
+
406
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
407
+
408
+
409
+ def test():
410
+ path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
411
+ import librosa
412
+
413
+ cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
414
+ config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
415
+ from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
416
+
417
+ config = read_yaml(config_file)
418
+ waveform, _ = librosa.load(path, sr=None)
419
+ frontend = WavFrontend(
420
+ cmvn_file=cmvn_file,
421
+ **config["frontend_conf"],
422
+ )
423
+ speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
424
+ feat, feat_len = frontend.lfr_cmvn(
425
+ speech
426
+ ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
427
+
428
+ frontend.reset_status() # clear cache
429
+ return feat, feat_len
430
+
431
+
432
+ if __name__ == "__main__":
433
+ test()
utils/infer_utils.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ import functools
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
7
+
8
+ import re
9
+ import numpy as np
10
+ import yaml
11
+
12
+ import jieba
13
+ import warnings
14
+
15
+ root_dir = Path(__file__).resolve().parent
16
+
17
+ logger_initialized = {}
18
+
19
+
20
+ def pad_list(xs, pad_value, max_len=None):
21
+ n_batch = len(xs)
22
+ if max_len is None:
23
+ max_len = max(x.size(0) for x in xs)
24
+ # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
25
+ # numpy format
26
+ pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
27
+ for i in range(n_batch):
28
+ pad[i, : xs[i].shape[0]] = xs[i]
29
+
30
+ return pad
31
+
32
+
33
+ """
34
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
35
+ if length_dim == 0:
36
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
37
+
38
+ if not isinstance(lengths, list):
39
+ lengths = lengths.tolist()
40
+ bs = int(len(lengths))
41
+ if maxlen is None:
42
+ if xs is None:
43
+ maxlen = int(max(lengths))
44
+ else:
45
+ maxlen = xs.size(length_dim)
46
+ else:
47
+ assert xs is None
48
+ assert maxlen >= int(max(lengths))
49
+
50
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
51
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
52
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
53
+ mask = seq_range_expand >= seq_length_expand
54
+
55
+ if xs is not None:
56
+ assert xs.size(0) == bs, (xs.size(0), bs)
57
+
58
+ if length_dim < 0:
59
+ length_dim = xs.dim() + length_dim
60
+ # ind = (:, None, ..., None, :, , None, ..., None)
61
+ ind = tuple(
62
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
63
+ )
64
+ mask = mask[ind].expand_as(xs).to(xs.device)
65
+ return mask
66
+ """
67
+
68
+
69
+ class TokenIDConverter:
70
+ def __init__(
71
+ self,
72
+ token_list: Union[List, str],
73
+ ):
74
+
75
+ self.token_list = token_list
76
+ self.unk_symbol = token_list[-1]
77
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
78
+ self.unk_id = self.token2id[self.unk_symbol]
79
+
80
+ def get_num_vocabulary_size(self) -> int:
81
+ return len(self.token_list)
82
+
83
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
84
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
85
+ raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
86
+ return [self.token_list[i] for i in integers]
87
+
88
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
89
+
90
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
91
+
92
+
93
+ class CharTokenizer:
94
+ def __init__(
95
+ self,
96
+ symbol_value: Union[Path, str, Iterable[str]] = None,
97
+ space_symbol: str = "<space>",
98
+ remove_non_linguistic_symbols: bool = False,
99
+ ):
100
+
101
+ self.space_symbol = space_symbol
102
+ self.non_linguistic_symbols = self.load_symbols(symbol_value)
103
+ self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
104
+
105
+ @staticmethod
106
+ def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
107
+ if value is None:
108
+ return set()
109
+
110
+ if isinstance(value, Iterable[str]):
111
+ return set(value)
112
+
113
+ file_path = Path(value)
114
+ if not file_path.exists():
115
+ logging.warning("%s doesn't exist.", file_path)
116
+ return set()
117
+
118
+ with file_path.open("r", encoding="utf-8") as f:
119
+ return set(line.rstrip() for line in f)
120
+
121
+ def text2tokens(self, line: Union[str, list]) -> List[str]:
122
+ tokens = []
123
+ while len(line) != 0:
124
+ for w in self.non_linguistic_symbols:
125
+ if line.startswith(w):
126
+ if not self.remove_non_linguistic_symbols:
127
+ tokens.append(line[: len(w)])
128
+ line = line[len(w) :]
129
+ break
130
+ else:
131
+ t = line[0]
132
+ if t == " ":
133
+ t = "<space>"
134
+ tokens.append(t)
135
+ line = line[1:]
136
+ return tokens
137
+
138
+ def tokens2text(self, tokens: Iterable[str]) -> str:
139
+ tokens = [t if t != self.space_symbol else " " for t in tokens]
140
+ return "".join(tokens)
141
+
142
+ def __repr__(self):
143
+ return (
144
+ f"{self.__class__.__name__}("
145
+ f'space_symbol="{self.space_symbol}"'
146
+ f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
147
+ f")"
148
+ )
149
+
150
+
151
+ class Hypothesis(NamedTuple):
152
+ """Hypothesis data type."""
153
+
154
+ yseq: np.ndarray
155
+ score: Union[float, np.ndarray] = 0
156
+ scores: Dict[str, Union[float, np.ndarray]] = dict()
157
+ states: Dict[str, Any] = dict()
158
+
159
+ def asdict(self) -> dict:
160
+ """Convert data to JSON-friendly dict."""
161
+ return self._replace(
162
+ yseq=self.yseq.tolist(),
163
+ score=float(self.score),
164
+ scores={k: float(v) for k, v in self.scores.items()},
165
+ )._asdict()
166
+
167
+
168
+ class TokenIDConverterError(Exception):
169
+ pass
170
+
171
+
172
+ class ONNXRuntimeError(Exception):
173
+ pass
174
+
175
+
176
+ def split_to_mini_sentence(words: list, word_limit: int = 20):
177
+ assert word_limit > 1
178
+ if len(words) <= word_limit:
179
+ return [words]
180
+ sentences = []
181
+ length = len(words)
182
+ sentence_len = length // word_limit
183
+ for i in range(sentence_len):
184
+ sentences.append(words[i * word_limit : (i + 1) * word_limit])
185
+ if length % word_limit > 0:
186
+ sentences.append(words[sentence_len * word_limit :])
187
+ return sentences
188
+
189
+
190
+ def code_mix_split_words(text: str):
191
+ words = []
192
+ segs = text.split()
193
+ for seg in segs:
194
+ # There is no space in seg.
195
+ current_word = ""
196
+ for c in seg:
197
+ if len(c.encode()) == 1:
198
+ # This is an ASCII char.
199
+ current_word += c
200
+ else:
201
+ # This is a Chinese char.
202
+ if len(current_word) > 0:
203
+ words.append(current_word)
204
+ current_word = ""
205
+ words.append(c)
206
+ if len(current_word) > 0:
207
+ words.append(current_word)
208
+ return words
209
+
210
+
211
+ def isEnglish(text: str):
212
+ if re.search("^[a-zA-Z']+$", text):
213
+ return True
214
+ else:
215
+ return False
216
+
217
+
218
+ def join_chinese_and_english(input_list):
219
+ line = ""
220
+ for token in input_list:
221
+ if isEnglish(token):
222
+ line = line + " " + token
223
+ else:
224
+ line = line + token
225
+
226
+ line = line.strip()
227
+ return line
228
+
229
+
230
+ def code_mix_split_words_jieba(seg_dict_file: str):
231
+ jieba.load_userdict(seg_dict_file)
232
+
233
+ def _fn(text: str):
234
+ input_list = text.split()
235
+ token_list_all = []
236
+ langauge_list = []
237
+ token_list_tmp = []
238
+ language_flag = None
239
+ for token in input_list:
240
+ if isEnglish(token) and language_flag == "Chinese":
241
+ token_list_all.append(token_list_tmp)
242
+ langauge_list.append("Chinese")
243
+ token_list_tmp = []
244
+ elif not isEnglish(token) and language_flag == "English":
245
+ token_list_all.append(token_list_tmp)
246
+ langauge_list.append("English")
247
+ token_list_tmp = []
248
+
249
+ token_list_tmp.append(token)
250
+
251
+ if isEnglish(token):
252
+ language_flag = "English"
253
+ else:
254
+ language_flag = "Chinese"
255
+
256
+ if token_list_tmp:
257
+ token_list_all.append(token_list_tmp)
258
+ langauge_list.append(language_flag)
259
+
260
+ result_list = []
261
+ for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
262
+ if language_flag == "English":
263
+ result_list.extend(token_list_tmp)
264
+ else:
265
+ seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
266
+ result_list.extend(seg_list)
267
+
268
+ return result_list
269
+
270
+ return _fn
271
+
272
+
273
+ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
274
+ if not Path(yaml_path).exists():
275
+ raise FileExistsError(f"The {yaml_path} does not exist.")
276
+
277
+ with open(str(yaml_path), "rb") as f:
278
+ data = yaml.load(f, Loader=yaml.Loader)
279
+ return data
280
+
281
+
282
+ @functools.lru_cache()
283
+ def get_logger(name="funasr_onnx"):
284
+ """Initialize and get a logger by name.
285
+ If the logger has not been initialized, this method will initialize the
286
+ logger by adding one or two handlers, otherwise the initialized logger will
287
+ be directly returned. During initialization, a StreamHandler will always be
288
+ added.
289
+ Args:
290
+ name (str): Logger name.
291
+ Returns:
292
+ logging.Logger: The expected logger.
293
+ """
294
+ logger = logging.getLogger(name)
295
+ if name in logger_initialized:
296
+ return logger
297
+
298
+ for logger_name in logger_initialized:
299
+ if name.startswith(logger_name):
300
+ return logger
301
+
302
+ formatter = logging.Formatter(
303
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
304
+ )
305
+
306
+ sh = logging.StreamHandler()
307
+ sh.setFormatter(formatter)
308
+ logger.addHandler(sh)
309
+ logger_initialized[name] = True
310
+ logger.propagate = False
311
+ logging.basicConfig(level=logging.ERROR)
312
+ return logger
utils/utils/__init__.py ADDED
File without changes
utils/utils/e2e_vad.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
3
+ # MIT License (https://opensource.org/licenses/MIT)
4
+
5
+ from enum import Enum
6
+ from typing import List, Tuple, Dict, Any
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+
12
+ class VadStateMachine(Enum):
13
+ kVadInStateStartPointNotDetected = 1
14
+ kVadInStateInSpeechSegment = 2
15
+ kVadInStateEndPointDetected = 3
16
+
17
+
18
+ class FrameState(Enum):
19
+ kFrameStateInvalid = -1
20
+ kFrameStateSpeech = 1
21
+ kFrameStateSil = 0
22
+
23
+
24
+ # final voice/unvoice state per frame
25
+ class AudioChangeState(Enum):
26
+ kChangeStateSpeech2Speech = 0
27
+ kChangeStateSpeech2Sil = 1
28
+ kChangeStateSil2Sil = 2
29
+ kChangeStateSil2Speech = 3
30
+ kChangeStateNoBegin = 4
31
+ kChangeStateInvalid = 5
32
+
33
+
34
+ class VadDetectMode(Enum):
35
+ kVadSingleUtteranceDetectMode = 0
36
+ kVadMutipleUtteranceDetectMode = 1
37
+
38
+
39
+ class VADXOptions:
40
+ def __init__(
41
+ self,
42
+ sample_rate: int = 16000,
43
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
44
+ snr_mode: int = 0,
45
+ max_end_silence_time: int = 800,
46
+ max_start_silence_time: int = 3000,
47
+ do_start_point_detection: bool = True,
48
+ do_end_point_detection: bool = True,
49
+ window_size_ms: int = 200,
50
+ sil_to_speech_time_thres: int = 150,
51
+ speech_to_sil_time_thres: int = 150,
52
+ speech_2_noise_ratio: float = 1.0,
53
+ do_extend: int = 1,
54
+ lookback_time_start_point: int = 200,
55
+ lookahead_time_end_point: int = 100,
56
+ max_single_segment_time: int = 60000,
57
+ nn_eval_block_size: int = 8,
58
+ dcd_block_size: int = 4,
59
+ snr_thres: int = -100.0,
60
+ noise_frame_num_used_for_snr: int = 100,
61
+ decibel_thres: int = -100.0,
62
+ speech_noise_thres: float = 0.6,
63
+ fe_prior_thres: float = 1e-4,
64
+ silence_pdf_num: int = 1,
65
+ sil_pdf_ids: List[int] = [0],
66
+ speech_noise_thresh_low: float = -0.1,
67
+ speech_noise_thresh_high: float = 0.3,
68
+ output_frame_probs: bool = False,
69
+ frame_in_ms: int = 10,
70
+ frame_length_ms: int = 25,
71
+ ):
72
+ self.sample_rate = sample_rate
73
+ self.detect_mode = detect_mode
74
+ self.snr_mode = snr_mode
75
+ self.max_end_silence_time = max_end_silence_time
76
+ self.max_start_silence_time = max_start_silence_time
77
+ self.do_start_point_detection = do_start_point_detection
78
+ self.do_end_point_detection = do_end_point_detection
79
+ self.window_size_ms = window_size_ms
80
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
81
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
82
+ self.speech_2_noise_ratio = speech_2_noise_ratio
83
+ self.do_extend = do_extend
84
+ self.lookback_time_start_point = lookback_time_start_point
85
+ self.lookahead_time_end_point = lookahead_time_end_point
86
+ self.max_single_segment_time = max_single_segment_time
87
+ self.nn_eval_block_size = nn_eval_block_size
88
+ self.dcd_block_size = dcd_block_size
89
+ self.snr_thres = snr_thres
90
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
91
+ self.decibel_thres = decibel_thres
92
+ self.speech_noise_thres = speech_noise_thres
93
+ self.fe_prior_thres = fe_prior_thres
94
+ self.silence_pdf_num = silence_pdf_num
95
+ self.sil_pdf_ids = sil_pdf_ids
96
+ self.speech_noise_thresh_low = speech_noise_thresh_low
97
+ self.speech_noise_thresh_high = speech_noise_thresh_high
98
+ self.output_frame_probs = output_frame_probs
99
+ self.frame_in_ms = frame_in_ms
100
+ self.frame_length_ms = frame_length_ms
101
+
102
+
103
+ class E2EVadSpeechBufWithDoa(object):
104
+ def __init__(self):
105
+ self.start_ms = 0
106
+ self.end_ms = 0
107
+ self.buffer = []
108
+ self.contain_seg_start_point = False
109
+ self.contain_seg_end_point = False
110
+ self.doa = 0
111
+
112
+ def Reset(self):
113
+ self.start_ms = 0
114
+ self.end_ms = 0
115
+ self.buffer = []
116
+ self.contain_seg_start_point = False
117
+ self.contain_seg_end_point = False
118
+ self.doa = 0
119
+
120
+
121
+ class E2EVadFrameProb(object):
122
+ def __init__(self):
123
+ self.noise_prob = 0.0
124
+ self.speech_prob = 0.0
125
+ self.score = 0.0
126
+ self.frame_id = 0
127
+ self.frm_state = 0
128
+
129
+
130
+ class WindowDetector(object):
131
+ def __init__(
132
+ self,
133
+ window_size_ms: int,
134
+ sil_to_speech_time: int,
135
+ speech_to_sil_time: int,
136
+ frame_size_ms: int,
137
+ ):
138
+ self.window_size_ms = window_size_ms
139
+ self.sil_to_speech_time = sil_to_speech_time
140
+ self.speech_to_sil_time = speech_to_sil_time
141
+ self.frame_size_ms = frame_size_ms
142
+
143
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
144
+ self.win_sum = 0
145
+ self.win_state = [0] * self.win_size_frame # 初始化窗
146
+
147
+ self.cur_win_pos = 0
148
+ self.pre_frame_state = FrameState.kFrameStateSil
149
+ self.cur_frame_state = FrameState.kFrameStateSil
150
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
151
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
152
+
153
+ self.voice_last_frame_count = 0
154
+ self.noise_last_frame_count = 0
155
+ self.hydre_frame_count = 0
156
+
157
+ def Reset(self) -> None:
158
+ self.cur_win_pos = 0
159
+ self.win_sum = 0
160
+ self.win_state = [0] * self.win_size_frame
161
+ self.pre_frame_state = FrameState.kFrameStateSil
162
+ self.cur_frame_state = FrameState.kFrameStateSil
163
+ self.voice_last_frame_count = 0
164
+ self.noise_last_frame_count = 0
165
+ self.hydre_frame_count = 0
166
+
167
+ def GetWinSize(self) -> int:
168
+ return int(self.win_size_frame)
169
+
170
+ def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
171
+ cur_frame_state = FrameState.kFrameStateSil
172
+ if frameState == FrameState.kFrameStateSpeech:
173
+ cur_frame_state = 1
174
+ elif frameState == FrameState.kFrameStateSil:
175
+ cur_frame_state = 0
176
+ else:
177
+ return AudioChangeState.kChangeStateInvalid
178
+ self.win_sum -= self.win_state[self.cur_win_pos]
179
+ self.win_sum += cur_frame_state
180
+ self.win_state[self.cur_win_pos] = cur_frame_state
181
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
182
+
183
+ if (
184
+ self.pre_frame_state == FrameState.kFrameStateSil
185
+ and self.win_sum >= self.sil_to_speech_frmcnt_thres
186
+ ):
187
+ self.pre_frame_state = FrameState.kFrameStateSpeech
188
+ return AudioChangeState.kChangeStateSil2Speech
189
+
190
+ if (
191
+ self.pre_frame_state == FrameState.kFrameStateSpeech
192
+ and self.win_sum <= self.speech_to_sil_frmcnt_thres
193
+ ):
194
+ self.pre_frame_state = FrameState.kFrameStateSil
195
+ return AudioChangeState.kChangeStateSpeech2Sil
196
+
197
+ if self.pre_frame_state == FrameState.kFrameStateSil:
198
+ return AudioChangeState.kChangeStateSil2Sil
199
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
200
+ return AudioChangeState.kChangeStateSpeech2Speech
201
+ return AudioChangeState.kChangeStateInvalid
202
+
203
+ def FrameSizeMs(self) -> int:
204
+ return int(self.frame_size_ms)
205
+
206
+
207
+ class E2EVadModel:
208
+ """
209
+ Author: Speech Lab of DAMO Academy, Alibaba Group
210
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
211
+ https://arxiv.org/abs/1803.05030
212
+ """
213
+
214
+ def __init__(self, vad_post_args: Dict[str, Any]):
215
+ super(E2EVadModel, self).__init__()
216
+ self.vad_opts = VADXOptions(**vad_post_args)
217
+ self.windows_detector = WindowDetector(
218
+ self.vad_opts.window_size_ms,
219
+ self.vad_opts.sil_to_speech_time_thres,
220
+ self.vad_opts.speech_to_sil_time_thres,
221
+ self.vad_opts.frame_in_ms,
222
+ )
223
+ # self.encoder = encoder
224
+ # init variables
225
+ self.is_final = False
226
+ self.data_buf_start_frame = 0
227
+ self.frm_cnt = 0
228
+ self.latest_confirmed_speech_frame = 0
229
+ self.lastest_confirmed_silence_frame = -1
230
+ self.continous_silence_frame_count = 0
231
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
232
+ self.confirmed_start_frame = -1
233
+ self.confirmed_end_frame = -1
234
+ self.number_end_time_detected = 0
235
+ self.sil_frame = 0
236
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
237
+ self.noise_average_decibel = -100.0
238
+ self.pre_end_silence_detected = False
239
+ self.next_seg = True
240
+
241
+ self.output_data_buf = []
242
+ self.output_data_buf_offset = 0
243
+ self.frame_probs = []
244
+ self.max_end_sil_frame_cnt_thresh = (
245
+ self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
246
+ )
247
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
248
+ self.scores = None
249
+ self.idx_pre_chunk = 0
250
+ self.max_time_out = False
251
+ self.decibel = []
252
+ self.data_buf_size = 0
253
+ self.data_buf_all_size = 0
254
+ self.waveform = None
255
+ self.ResetDetection()
256
+
257
+ def AllResetDetection(self):
258
+ self.is_final = False
259
+ self.data_buf_start_frame = 0
260
+ self.frm_cnt = 0
261
+ self.latest_confirmed_speech_frame = 0
262
+ self.lastest_confirmed_silence_frame = -1
263
+ self.continous_silence_frame_count = 0
264
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
265
+ self.confirmed_start_frame = -1
266
+ self.confirmed_end_frame = -1
267
+ self.number_end_time_detected = 0
268
+ self.sil_frame = 0
269
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
270
+ self.noise_average_decibel = -100.0
271
+ self.pre_end_silence_detected = False
272
+ self.next_seg = True
273
+
274
+ self.output_data_buf = []
275
+ self.output_data_buf_offset = 0
276
+ self.frame_probs = []
277
+ self.max_end_sil_frame_cnt_thresh = (
278
+ self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
279
+ )
280
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
281
+ self.scores = None
282
+ self.idx_pre_chunk = 0
283
+ self.max_time_out = False
284
+ self.decibel = []
285
+ self.data_buf_size = 0
286
+ self.data_buf_all_size = 0
287
+ self.waveform = None
288
+ self.ResetDetection()
289
+
290
+ def ResetDetection(self):
291
+ self.continous_silence_frame_count = 0
292
+ self.latest_confirmed_speech_frame = 0
293
+ self.lastest_confirmed_silence_frame = -1
294
+ self.confirmed_start_frame = -1
295
+ self.confirmed_end_frame = -1
296
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
297
+ self.windows_detector.Reset()
298
+ self.sil_frame = 0
299
+ self.frame_probs = []
300
+
301
+ def ComputeDecibel(self) -> None:
302
+ frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
303
+ frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
304
+ if self.data_buf_all_size == 0:
305
+ self.data_buf_all_size = len(self.waveform[0])
306
+ self.data_buf_size = self.data_buf_all_size
307
+ else:
308
+ self.data_buf_all_size += len(self.waveform[0])
309
+ for offset in range(
310
+ 0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length
311
+ ):
312
+ self.decibel.append(
313
+ 10
314
+ * math.log10(
315
+ np.square((self.waveform[0][offset : offset + frame_sample_length])).sum()
316
+ + 0.000001
317
+ )
318
+ )
319
+
320
+ def ComputeScores(self, scores: np.ndarray) -> None:
321
+ # scores = self.encoder(feats, in_cache) # return B * T * D
322
+ self.vad_opts.nn_eval_block_size = scores.shape[1]
323
+ self.frm_cnt += scores.shape[1] # count total frames
324
+ self.scores = scores
325
+
326
+ def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
327
+ while self.data_buf_start_frame < frame_idx:
328
+ if self.data_buf_size >= int(
329
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
330
+ ):
331
+ self.data_buf_start_frame += 1
332
+ self.data_buf_size = self.data_buf_all_size - self.data_buf_start_frame * int(
333
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
334
+ )
335
+
336
+ def PopDataToOutputBuf(
337
+ self,
338
+ start_frm: int,
339
+ frm_cnt: int,
340
+ first_frm_is_start_point: bool,
341
+ last_frm_is_end_point: bool,
342
+ end_point_is_sent_end: bool,
343
+ ) -> None:
344
+ self.PopDataBufTillFrame(start_frm)
345
+ expected_sample_number = int(
346
+ frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
347
+ )
348
+ if last_frm_is_end_point:
349
+ extra_sample = max(
350
+ 0,
351
+ int(
352
+ self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
353
+ - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
354
+ ),
355
+ )
356
+ expected_sample_number += int(extra_sample)
357
+ if end_point_is_sent_end:
358
+ expected_sample_number = max(expected_sample_number, self.data_buf_size)
359
+ if self.data_buf_size < expected_sample_number:
360
+ print("error in calling pop data_buf\n")
361
+
362
+ if len(self.output_data_buf) == 0 or first_frm_is_start_point:
363
+ self.output_data_buf.append(E2EVadSpeechBufWithDoa())
364
+ self.output_data_buf[-1].Reset()
365
+ self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
366
+ self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
367
+ self.output_data_buf[-1].doa = 0
368
+ cur_seg = self.output_data_buf[-1]
369
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
370
+ print("warning\n")
371
+ out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
372
+ data_to_pop = 0
373
+ if end_point_is_sent_end:
374
+ data_to_pop = expected_sample_number
375
+ else:
376
+ data_to_pop = int(
377
+ frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
378
+ )
379
+ if data_to_pop > self.data_buf_size:
380
+ print("VAD data_to_pop is bigger than self.data_buf_size!!!\n")
381
+ data_to_pop = self.data_buf_size
382
+ expected_sample_number = self.data_buf_size
383
+
384
+ cur_seg.doa = 0
385
+ for sample_cpy_out in range(0, data_to_pop):
386
+ # cur_seg.buffer[out_pos ++] = data_buf_.back();
387
+ out_pos += 1
388
+ for sample_cpy_out in range(data_to_pop, expected_sample_number):
389
+ # cur_seg.buffer[out_pos++] = data_buf_.back()
390
+ out_pos += 1
391
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
392
+ print("Something wrong with the VAD algorithm\n")
393
+ self.data_buf_start_frame += frm_cnt
394
+ cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
395
+ if first_frm_is_start_point:
396
+ cur_seg.contain_seg_start_point = True
397
+ if last_frm_is_end_point:
398
+ cur_seg.contain_seg_end_point = True
399
+
400
+ def OnSilenceDetected(self, valid_frame: int):
401
+ self.lastest_confirmed_silence_frame = valid_frame
402
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
403
+ self.PopDataBufTillFrame(valid_frame)
404
+ # silence_detected_callback_
405
+ # pass
406
+
407
+ def OnVoiceDetected(self, valid_frame: int) -> None:
408
+ self.latest_confirmed_speech_frame = valid_frame
409
+ self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
410
+
411
+ def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
412
+ if self.vad_opts.do_start_point_detection:
413
+ pass
414
+ if self.confirmed_start_frame != -1:
415
+ print("not reset vad properly\n")
416
+ else:
417
+ self.confirmed_start_frame = start_frame
418
+
419
+ if (
420
+ not fake_result
421
+ and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
422
+ ):
423
+ self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
424
+
425
+ def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
426
+ for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
427
+ self.OnVoiceDetected(t)
428
+ if self.vad_opts.do_end_point_detection:
429
+ pass
430
+ if self.confirmed_end_frame != -1:
431
+ print("not reset vad properly\n")
432
+ else:
433
+ self.confirmed_end_frame = end_frame
434
+ if not fake_result:
435
+ self.sil_frame = 0
436
+ self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
437
+ self.number_end_time_detected += 1
438
+
439
+ def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
440
+ if is_final_frame:
441
+ self.OnVoiceEnd(cur_frm_idx, False, True)
442
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
443
+
444
+ def GetLatency(self) -> int:
445
+ return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
446
+
447
+ def LatencyFrmNumAtStartPoint(self) -> int:
448
+ vad_latency = self.windows_detector.GetWinSize()
449
+ if self.vad_opts.do_extend:
450
+ vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
451
+ return vad_latency
452
+
453
+ def GetFrameState(self, t: int) -> FrameState:
454
+ frame_state = FrameState.kFrameStateInvalid
455
+ cur_decibel = self.decibel[t]
456
+ cur_snr = cur_decibel - self.noise_average_decibel
457
+ # for each frame, calc log posterior probability of each state
458
+ if cur_decibel < self.vad_opts.decibel_thres:
459
+ frame_state = FrameState.kFrameStateSil
460
+ self.DetectOneFrame(frame_state, t, False)
461
+ return frame_state
462
+
463
+ sum_score = 0.0
464
+ noise_prob = 0.0
465
+ assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
466
+ if len(self.sil_pdf_ids) > 0:
467
+ assert len(self.scores) == 1 # 只支持batch_size = 1的测试
468
+ sil_pdf_scores = [
469
+ self.scores[0][t - self.idx_pre_chunk][sil_pdf_id]
470
+ for sil_pdf_id in self.sil_pdf_ids
471
+ ]
472
+ sum_score = sum(sil_pdf_scores)
473
+ # add by huyuan, avoid sum_score <= 0
474
+ if sum_score <= 0.0:
475
+ sum_score = 1e-10
476
+ # import pdb
477
+ # pdb.set_trace()
478
+ noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
479
+ total_score = 1.0
480
+ sum_score = total_score - sum_score
481
+ speech_prob = math.log(sum_score)
482
+ if self.vad_opts.output_frame_probs:
483
+ frame_prob = E2EVadFrameProb()
484
+ frame_prob.noise_prob = noise_prob
485
+ frame_prob.speech_prob = speech_prob
486
+ frame_prob.score = sum_score
487
+ frame_prob.frame_id = t
488
+ self.frame_probs.append(frame_prob)
489
+ if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
490
+ if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
491
+ frame_state = FrameState.kFrameStateSpeech
492
+ else:
493
+ frame_state = FrameState.kFrameStateSil
494
+ else:
495
+ frame_state = FrameState.kFrameStateSil
496
+ if self.noise_average_decibel < -99.9:
497
+ self.noise_average_decibel = cur_decibel
498
+ else:
499
+ self.noise_average_decibel = (
500
+ cur_decibel
501
+ + self.noise_average_decibel * (self.vad_opts.noise_frame_num_used_for_snr - 1)
502
+ ) / self.vad_opts.noise_frame_num_used_for_snr
503
+
504
+ return frame_state
505
+
506
+ def __call__(
507
+ self,
508
+ score: np.ndarray,
509
+ waveform: np.ndarray,
510
+ is_final: bool = False,
511
+ max_end_sil: int = 800,
512
+ online: bool = False,
513
+ ):
514
+ self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
515
+ self.waveform = waveform # compute decibel for each frame
516
+ self.ComputeDecibel()
517
+ self.ComputeScores(score)
518
+ #import pdb
519
+ #pdb.set_trace()
520
+ if not is_final:
521
+ self.DetectCommonFrames()
522
+ else:
523
+ self.DetectLastFrames()
524
+ segments = []
525
+ for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
526
+ segment_batch = []
527
+ if len(self.output_data_buf) > 0:
528
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
529
+ if online:
530
+ if not self.output_data_buf[i].contain_seg_start_point:
531
+ continue
532
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
533
+ continue
534
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
535
+ if self.output_data_buf[i].contain_seg_end_point:
536
+ end_ms = self.output_data_buf[i].end_ms
537
+ self.next_seg = True
538
+ self.output_data_buf_offset += 1
539
+ else:
540
+ end_ms = -1
541
+ self.next_seg = False
542
+ else:
543
+ if not is_final and (
544
+ not self.output_data_buf[i].contain_seg_start_point
545
+ or not self.output_data_buf[i].contain_seg_end_point
546
+ ):
547
+ continue
548
+ start_ms = self.output_data_buf[i].start_ms
549
+ end_ms = self.output_data_buf[i].end_ms
550
+ self.output_data_buf_offset += 1
551
+ segment = [start_ms, end_ms]
552
+ segment_batch.append(segment)
553
+
554
+ if segment_batch:
555
+ segments.append(segment_batch)
556
+ if is_final:
557
+ # reset class variables and clear the dict for the next query
558
+ self.AllResetDetection()
559
+ return segments
560
+
561
+ def DetectCommonFrames(self) -> int:
562
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
563
+ return 0
564
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
565
+ frame_state = FrameState.kFrameStateInvalid
566
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
567
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
568
+ self.idx_pre_chunk += self.scores.shape[1]
569
+ return 0
570
+
571
+ def DetectLastFrames(self) -> int:
572
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
573
+ return 0
574
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
575
+ frame_state = FrameState.kFrameStateInvalid
576
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
577
+ if i != 0:
578
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
579
+ else:
580
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
581
+
582
+ return 0
583
+
584
+ def DetectOneFrame(
585
+ self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
586
+ ) -> None:
587
+ tmp_cur_frm_state = FrameState.kFrameStateInvalid
588
+ if cur_frm_state == FrameState.kFrameStateSpeech:
589
+ if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
590
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech
591
+ else:
592
+ tmp_cur_frm_state = FrameState.kFrameStateSil
593
+ elif cur_frm_state == FrameState.kFrameStateSil:
594
+ tmp_cur_frm_state = FrameState.kFrameStateSil
595
+ state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
596
+ frm_shift_in_ms = self.vad_opts.frame_in_ms
597
+ if AudioChangeState.kChangeStateSil2Speech == state_change:
598
+ silence_frame_count = self.continous_silence_frame_count
599
+ self.continous_silence_frame_count = 0
600
+ self.pre_end_silence_detected = False
601
+ start_frame = 0
602
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
603
+ start_frame = max(
604
+ self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()
605
+ )
606
+ self.OnVoiceStart(start_frame)
607
+ self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
608
+ for t in range(start_frame + 1, cur_frm_idx + 1):
609
+ self.OnVoiceDetected(t)
610
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
611
+ for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
612
+ self.OnVoiceDetected(t)
613
+ if (
614
+ cur_frm_idx - self.confirmed_start_frame + 1
615
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
616
+ ):
617
+ self.OnVoiceEnd(cur_frm_idx, False, False)
618
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
619
+ elif not is_final_frame:
620
+ self.OnVoiceDetected(cur_frm_idx)
621
+ else:
622
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
623
+ else:
624
+ pass
625
+ elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
626
+ self.continous_silence_frame_count = 0
627
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
628
+ pass
629
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
630
+ if (
631
+ cur_frm_idx - self.confirmed_start_frame + 1
632
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
633
+ ):
634
+ self.OnVoiceEnd(cur_frm_idx, False, False)
635
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
636
+ elif not is_final_frame:
637
+ self.OnVoiceDetected(cur_frm_idx)
638
+ else:
639
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
640
+ else:
641
+ pass
642
+ elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
643
+ self.continous_silence_frame_count = 0
644
+ if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
645
+ if (
646
+ cur_frm_idx - self.confirmed_start_frame + 1
647
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
648
+ ):
649
+ self.max_time_out = True
650
+ self.OnVoiceEnd(cur_frm_idx, False, False)
651
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
652
+ elif not is_final_frame:
653
+ self.OnVoiceDetected(cur_frm_idx)
654
+ else:
655
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
656
+ else:
657
+ pass
658
+ elif AudioChangeState.kChangeStateSil2Sil == state_change:
659
+ self.continous_silence_frame_count += 1
660
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
661
+ # silence timeout, return zero length decision
662
+ if (
663
+ (self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
664
+ and (
665
+ self.continous_silence_frame_count * frm_shift_in_ms
666
+ > self.vad_opts.max_start_silence_time
667
+ )
668
+ ) or (is_final_frame and self.number_end_time_detected == 0):
669
+ for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
670
+ self.OnSilenceDetected(t)
671
+ self.OnVoiceStart(0, True)
672
+ self.OnVoiceEnd(0, True, False)
673
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
674
+ else:
675
+ if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
676
+ self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
677
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
678
+ if (
679
+ self.continous_silence_frame_count * frm_shift_in_ms
680
+ >= self.max_end_sil_frame_cnt_thresh
681
+ ):
682
+ lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
683
+ if self.vad_opts.do_extend:
684
+ lookback_frame -= int(
685
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
686
+ )
687
+ lookback_frame -= 1
688
+ lookback_frame = max(0, lookback_frame)
689
+ self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
690
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
691
+ elif (
692
+ cur_frm_idx - self.confirmed_start_frame + 1
693
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
694
+ ):
695
+ self.OnVoiceEnd(cur_frm_idx, False, False)
696
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
697
+ elif self.vad_opts.do_extend and not is_final_frame:
698
+ if self.continous_silence_frame_count <= int(
699
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
700
+ ):
701
+ self.OnVoiceDetected(cur_frm_idx)
702
+ else:
703
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
704
+ else:
705
+ pass
706
+
707
+ if (
708
+ self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
709
+ and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
710
+ ):
711
+ self.ResetDetection()
utils/utils/frontend.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
+ import copy
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+ import kaldi_native_fbank as knf
9
+
10
+ root_dir = Path(__file__).resolve().parent
11
+
12
+ logger_initialized = {}
13
+
14
+
15
+ class WavFrontend:
16
+ """Conventional frontend structure for ASR."""
17
+
18
+ def __init__(
19
+ self,
20
+ cmvn_file: str = None,
21
+ fs: int = 16000,
22
+ window: str = "hamming",
23
+ n_mels: int = 80,
24
+ frame_length: int = 25,
25
+ frame_shift: int = 10,
26
+ lfr_m: int = 1,
27
+ lfr_n: int = 1,
28
+ dither: float = 1.0,
29
+ **kwargs,
30
+ ) -> None:
31
+
32
+ opts = knf.FbankOptions()
33
+ opts.frame_opts.samp_freq = fs
34
+ opts.frame_opts.dither = dither
35
+ opts.frame_opts.window_type = window
36
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
37
+ opts.frame_opts.frame_length_ms = float(frame_length)
38
+ opts.mel_opts.num_bins = n_mels
39
+ opts.energy_floor = 0
40
+ opts.frame_opts.snip_edges = True
41
+ opts.mel_opts.debug_mel = False
42
+ self.opts = opts
43
+
44
+ self.lfr_m = lfr_m
45
+ self.lfr_n = lfr_n
46
+ self.cmvn_file = cmvn_file
47
+
48
+ if self.cmvn_file:
49
+ self.cmvn = load_cmvn(self.cmvn_file)
50
+ self.fbank_fn = None
51
+ self.fbank_beg_idx = 0
52
+ self.reset_status()
53
+
54
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
55
+ waveform = waveform * (1 << 15)
56
+ fbank_fn = knf.OnlineFbank(self.opts)
57
+ fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
58
+ frames = fbank_fn.num_frames_ready
59
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
60
+ for i in range(frames):
61
+ mat[i, :] = fbank_fn.get_frame(i)
62
+ feat = mat.astype(np.float32)
63
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
64
+ return feat, feat_len
65
+
66
+ def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
67
+ waveform = waveform * (1 << 15)
68
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
69
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
70
+ frames = self.fbank_fn.num_frames_ready
71
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
72
+ for i in range(self.fbank_beg_idx, frames):
73
+ mat[i, :] = self.fbank_fn.get_frame(i)
74
+ # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
75
+ feat = mat.astype(np.float32)
76
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
77
+ return feat, feat_len
78
+
79
+ def reset_status(self):
80
+ self.fbank_fn = knf.OnlineFbank(self.opts)
81
+ self.fbank_beg_idx = 0
82
+
83
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
84
+ if self.lfr_m != 1 or self.lfr_n != 1:
85
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
86
+
87
+ if self.cmvn_file:
88
+ feat = self.apply_cmvn(feat)
89
+
90
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
91
+ return feat, feat_len
92
+
93
+ @staticmethod
94
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
95
+ LFR_inputs = []
96
+
97
+ T = inputs.shape[0]
98
+ T_lfr = int(np.ceil(T / lfr_n))
99
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
100
+ inputs = np.vstack((left_padding, inputs))
101
+ T = T + (lfr_m - 1) // 2
102
+ for i in range(T_lfr):
103
+ if lfr_m <= T - i * lfr_n:
104
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
105
+ else:
106
+ # process last LFR frame
107
+ num_padding = lfr_m - (T - i * lfr_n)
108
+ frame = inputs[i * lfr_n :].reshape(-1)
109
+ for _ in range(num_padding):
110
+ frame = np.hstack((frame, inputs[-1]))
111
+
112
+ LFR_inputs.append(frame)
113
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
114
+ return LFR_outputs
115
+
116
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
117
+ """
118
+ Apply CMVN with mvn data
119
+ """
120
+ frame, dim = inputs.shape
121
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
122
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
123
+ inputs = (inputs + means) * vars
124
+ return inputs
125
+
126
+ @lru_cache()
127
+ def load_cmvn(cmvn_file: Union[str, Path]) -> np.ndarray:
128
+ """load cmvn file to numpy array.
129
+
130
+ Args:
131
+ cmvn_file (Union[str, Path]): cmvn file path.
132
+
133
+ Raises:
134
+ FileNotFoundError: cmvn file not exits.
135
+
136
+ Returns:
137
+ np.ndarray: cmvn array. shape is (2, dim).The first row is means, the second row is vars.
138
+ """
139
+
140
+ cmvn_file = Path(cmvn_file)
141
+ if not cmvn_file.exists():
142
+ raise FileNotFoundError("cmvn file not exits")
143
+
144
+ with open(cmvn_file, "r", encoding="utf-8") as f:
145
+ lines = f.readlines()
146
+ means_list = []
147
+ vars_list = []
148
+ for i in range(len(lines)):
149
+ line_item = lines[i].split()
150
+ if line_item[0] == "<AddShift>":
151
+ line_item = lines[i + 1].split()
152
+ if line_item[0] == "<LearnRateCoef>":
153
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
154
+ means_list = list(add_shift_line)
155
+ continue
156
+ elif line_item[0] == "<Rescale>":
157
+ line_item = lines[i + 1].split()
158
+ if line_item[0] == "<LearnRateCoef>":
159
+ rescale_line = line_item[3 : (len(line_item) - 1)]
160
+ vars_list = list(rescale_line)
161
+ continue
162
+
163
+ means = np.array(means_list).astype(np.float64)
164
+ vars = np.array(vars_list).astype(np.float64)
165
+ cmvn = np.array([means, vars])
166
+ return cmvn
167
+
168
+
169
+ class WavFrontendOnline(WavFrontend):
170
+ def __init__(self, **kwargs):
171
+ super().__init__(**kwargs)
172
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
173
+ # add variables
174
+ self.frame_sample_length = int(
175
+ self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
176
+ )
177
+ self.frame_shift_sample_length = int(
178
+ self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
179
+ )
180
+ self.waveform = None
181
+ self.reserve_waveforms = None
182
+ self.input_cache = None
183
+ self.lfr_splice_cache = []
184
+
185
+ @staticmethod
186
+ # inputs has catted the cache
187
+ def apply_lfr(
188
+ inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
189
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
190
+ """
191
+ Apply lfr with data
192
+ """
193
+
194
+ LFR_inputs = []
195
+ T = inputs.shape[0] # include the right context
196
+ T_lfr = int(
197
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
198
+ ) # minus the right context: (lfr_m - 1) // 2
199
+ splice_idx = T_lfr
200
+ for i in range(T_lfr):
201
+ if lfr_m <= T - i * lfr_n:
202
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
203
+ else: # process last LFR frame
204
+ if is_final:
205
+ num_padding = lfr_m - (T - i * lfr_n)
206
+ frame = (inputs[i * lfr_n :]).reshape(-1)
207
+ for _ in range(num_padding):
208
+ frame = np.hstack((frame, inputs[-1]))
209
+ LFR_inputs.append(frame)
210
+ else:
211
+ # update splice_idx and break the circle
212
+ splice_idx = i
213
+ break
214
+ splice_idx = min(T - 1, splice_idx * lfr_n)
215
+ lfr_splice_cache = inputs[splice_idx:, :]
216
+ LFR_outputs = np.vstack(LFR_inputs)
217
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
218
+
219
+ @staticmethod
220
+ def compute_frame_num(
221
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
222
+ ) -> int:
223
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
224
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
225
+
226
+ def fbank(
227
+ self, input: np.ndarray, input_lengths: np.ndarray
228
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
229
+ self.fbank_fn = knf.OnlineFbank(self.opts)
230
+ batch_size = input.shape[0]
231
+ if self.input_cache is None:
232
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
233
+ input = np.concatenate((self.input_cache, input), axis=1)
234
+ frame_num = self.compute_frame_num(
235
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
236
+ )
237
+ # update self.in_cache
238
+ self.input_cache = input[
239
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
240
+ ]
241
+ waveforms = np.empty(0, dtype=np.float32)
242
+ feats_pad = np.empty(0, dtype=np.float32)
243
+ feats_lens = np.empty(0, dtype=np.int32)
244
+ if frame_num:
245
+ waveforms = []
246
+ feats = []
247
+ feats_lens = []
248
+ for i in range(batch_size):
249
+ waveform = input[i]
250
+ waveforms.append(
251
+ waveform[
252
+ : (
253
+ (frame_num - 1) * self.frame_shift_sample_length
254
+ + self.frame_sample_length
255
+ )
256
+ ]
257
+ )
258
+ waveform = waveform * (1 << 15)
259
+
260
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
261
+ frames = self.fbank_fn.num_frames_ready
262
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
263
+ for i in range(frames):
264
+ mat[i, :] = self.fbank_fn.get_frame(i)
265
+ feat = mat.astype(np.float32)
266
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
267
+ feats.append(feat)
268
+ feats_lens.append(feat_len)
269
+
270
+ waveforms = np.stack(waveforms)
271
+ feats_lens = np.array(feats_lens)
272
+ feats_pad = np.array(feats)
273
+ self.fbanks = feats_pad
274
+ self.fbanks_lens = copy.deepcopy(feats_lens)
275
+ return waveforms, feats_pad, feats_lens
276
+
277
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
278
+ return self.fbanks, self.fbanks_lens
279
+
280
+ def lfr_cmvn(
281
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
282
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
283
+ batch_size = input.shape[0]
284
+ feats = []
285
+ feats_lens = []
286
+ lfr_splice_frame_idxs = []
287
+ for i in range(batch_size):
288
+ mat = input[i, : input_lengths[i], :]
289
+ lfr_splice_frame_idx = -1
290
+ if self.lfr_m != 1 or self.lfr_n != 1:
291
+ # update self.lfr_splice_cache in self.apply_lfr
292
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
293
+ mat, self.lfr_m, self.lfr_n, is_final
294
+ )
295
+ if self.cmvn_file is not None:
296
+ mat = self.apply_cmvn(mat)
297
+ feat_length = mat.shape[0]
298
+ feats.append(mat)
299
+ feats_lens.append(feat_length)
300
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
301
+
302
+ feats_lens = np.array(feats_lens)
303
+ feats_pad = np.array(feats)
304
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
305
+
306
+ def extract_fbank(
307
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
308
+ ) -> Tuple[np.ndarray, np.ndarray]:
309
+ batch_size = input.shape[0]
310
+ assert (
311
+ batch_size == 1
312
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
313
+ waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
314
+ if feats.shape[0]:
315
+ self.waveforms = (
316
+ waveforms
317
+ if self.reserve_waveforms is None
318
+ else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
319
+ )
320
+ if not self.lfr_splice_cache:
321
+ for i in range(batch_size):
322
+ self.lfr_splice_cache.append(
323
+ np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
324
+ )
325
+
326
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
327
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
328
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
329
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
330
+ frame_from_waveforms = int(
331
+ (self.waveforms.shape[1] - self.frame_sample_length)
332
+ / self.frame_shift_sample_length
333
+ + 1
334
+ )
335
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
336
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
337
+ feats, feats_lengths, is_final
338
+ )
339
+ if self.lfr_m == 1:
340
+ self.reserve_waveforms = None
341
+ else:
342
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
343
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
344
+ # print('frame_frame: ' + str(frame_from_waveforms))
345
+ self.reserve_waveforms = self.waveforms[
346
+ :,
347
+ reserve_frame_idx
348
+ * self.frame_shift_sample_length : frame_from_waveforms
349
+ * self.frame_shift_sample_length,
350
+ ]
351
+ sample_length = (
352
+ frame_from_waveforms - 1
353
+ ) * self.frame_shift_sample_length + self.frame_sample_length
354
+ self.waveforms = self.waveforms[:, :sample_length]
355
+ else:
356
+ # update self.reserve_waveforms and self.lfr_splice_cache
357
+ self.reserve_waveforms = self.waveforms[
358
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
359
+ ]
360
+ for i in range(batch_size):
361
+ self.lfr_splice_cache[i] = np.concatenate(
362
+ (self.lfr_splice_cache[i], feats[i]), axis=0
363
+ )
364
+ return np.empty(0, dtype=np.float32), feats_lengths
365
+ else:
366
+ if is_final:
367
+ self.waveforms = (
368
+ waveforms if self.reserve_waveforms is None else self.reserve_waveforms
369
+ )
370
+ feats = np.stack(self.lfr_splice_cache)
371
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
372
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
373
+ if is_final:
374
+ self.cache_reset()
375
+ return feats, feats_lengths
376
+
377
+ def get_waveforms(self):
378
+ return self.waveforms
379
+
380
+ def cache_reset(self):
381
+ self.fbank_fn = knf.OnlineFbank(self.opts)
382
+ self.reserve_waveforms = None
383
+ self.input_cache = None
384
+ self.lfr_splice_cache = []
385
+
386
+
387
+ def load_bytes(input):
388
+ middle_data = np.frombuffer(input, dtype=np.int16)
389
+ middle_data = np.asarray(middle_data)
390
+ if middle_data.dtype.kind not in "iu":
391
+ raise TypeError("'middle_data' must be an array of integers")
392
+ dtype = np.dtype("float32")
393
+ if dtype.kind != "f":
394
+ raise TypeError("'dtype' must be a floating point type")
395
+
396
+ i = np.iinfo(middle_data.dtype)
397
+ abs_max = 2 ** (i.bits - 1)
398
+ offset = i.min + abs_max
399
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
400
+ return array
401
+
402
+
403
+ class SinusoidalPositionEncoderOnline:
404
+ """Streaming Positional encoding."""
405
+
406
+ def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
407
+ batch_size = positions.shape[0]
408
+ positions = positions.astype(dtype)
409
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
410
+ inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
411
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
412
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
413
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
414
+ return encoding.astype(dtype)
415
+
416
+ def forward(self, x, start_idx=0):
417
+ batch_size, timesteps, input_dim = x.shape
418
+ positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
419
+ position_encoding = self.encode(positions, input_dim, x.dtype)
420
+
421
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
422
+
423
+
424
+ def test():
425
+ path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
426
+ import librosa
427
+
428
+ cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
429
+ config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
430
+ from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
431
+
432
+ config = read_yaml(config_file)
433
+ waveform, _ = librosa.load(path, sr=None)
434
+ frontend = WavFrontend(
435
+ cmvn_file=cmvn_file,
436
+ **config["frontend_conf"],
437
+ )
438
+ speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
439
+ feat, feat_len = frontend.lfr_cmvn(
440
+ speech
441
+ ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
442
+
443
+ frontend.reset_status() # clear cache
444
+ return feat, feat_len
445
+
446
+
447
+ if __name__ == "__main__":
448
+ test()