first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- README.md +118 -3
- ax_model/.gitattributes +2 -0
- ax_model/auto.npy +3 -0
- ax_model/chn_jpn_yue_eng_ko_spectok.bpe.model +3 -0
- ax_model/event_emo.npy +3 -0
- ax_model/sensevoice.axmodel +3 -0
- ax_model/sensevoice/am.mvn +8 -0
- ax_model/sensevoice/config.yaml +97 -0
- ax_model/vad/am.mvn +8 -0
- ax_model/vad/config.yaml +56 -0
- ax_model/withitn.npy +3 -0
- ax_speech_translate_demo.py +347 -0
- libmelotts/install/libonnxruntime.so +3 -0
- libmelotts/install/libonnxruntime.so.1.14.0 +3 -0
- libmelotts/install/libonnxruntime_providers_shared.so +0 -0
- libmelotts/install/melotts +3 -0
- libmelotts/models/decoder-en.axmodel +3 -0
- libmelotts/models/decoder-zh.axmodel +3 -0
- libmelotts/models/encoder-en.onnx +3 -0
- libmelotts/models/encoder-zh.onnx +3 -0
- libmelotts/models/g-en.bin +3 -0
- libmelotts/models/g-jp.bin +3 -0
- libmelotts/models/g-zh_mix_en.bin +3 -0
- libmelotts/models/lexicon.txt +0 -0
- libmelotts/models/tokens.txt +112 -0
- libtranslate/libax_translate.so +3 -0
- libtranslate/libsentencepiece.so.0 +3 -0
- libtranslate/opus-mt-en-zh.axmodel +3 -0
- libtranslate/opus-mt-en-zh/.gitattributes +9 -0
- libtranslate/opus-mt-en-zh/README.md +96 -0
- libtranslate/opus-mt-en-zh/config.json +61 -0
- libtranslate/opus-mt-en-zh/generation_config.json +16 -0
- libtranslate/opus-mt-en-zh/metadata.json +1 -0
- libtranslate/opus-mt-en-zh/source.spm +3 -0
- libtranslate/opus-mt-en-zh/target.spm +3 -0
- libtranslate/opus-mt-en-zh/tokenizer_config.json +1 -0
- libtranslate/opus-mt-en-zh/vocab.json +0 -0
- libtranslate/test_translate +0 -0
- model.py +942 -0
- requirements.txt +5 -0
- utils/__init__.py +0 -0
- utils/ax_model_bin.py +241 -0
- utils/ax_vad_bin.py +156 -0
- utils/ctc_alignment.py +76 -0
- utils/frontend.py +433 -0
- utils/infer_utils.py +312 -0
- utils/utils/__init__.py +0 -0
- utils/utils/e2e_vad.py +711 -0
- utils/utils/frontend.py +448 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ax_model/sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
libmelotts/install/libonnxruntime.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
libmelotts/install/libonnxruntime.so.1.14.0 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
libmelotts/install/melotts filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
libmelotts/models/decoder-en.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
libmelotts/models/decoder-zh.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
libtranslate/libax_translate.so filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
libtranslate/libsentencepiece.so.0 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
libtranslate/opus-mt-en-zh/source.spm filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
libtranslate/opus-mt-en-zh/target.spm filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
libtranslate/opus-mt-en-zh.axmodel filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,118 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
pipeline_tag: Speech-Translation
|
| 7 |
+
base_model:
|
| 8 |
+
- FunAudioLLM/SenseVoiceSmall
|
| 9 |
+
- opus-mt-en-zh
|
| 10 |
+
- MeloTTS
|
| 11 |
+
tags:
|
| 12 |
+
- VAD
|
| 13 |
+
- ASR
|
| 14 |
+
- Translation
|
| 15 |
+
- TTS
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Speech-Translation.axera
|
| 20 |
+
|
| 21 |
+
speech translation demo on Axera
|
| 22 |
+
|
| 23 |
+
- [x] Python 示例
|
| 24 |
+
- [ ] C++ 示例
|
| 25 |
+
|
| 26 |
+
## Convert tools links:
|
| 27 |
+
|
| 28 |
+
For those who are interested in model conversion, you can try to export axmodel through the original repo :
|
| 29 |
+
How to Convert from ONNX to axmodel
|
| 30 |
+
- [ASR](https://github.com/AXERA-TECH/3D-Speaker-MT.axera/tree/main/model_convert)
|
| 31 |
+
- [MeloTTS](https://github.com/ml-inory/melotts.axera/tree/main/model_convert)
|
| 32 |
+
|
| 33 |
+
## 支持平台
|
| 34 |
+
|
| 35 |
+
- AX650N
|
| 36 |
+
|
| 37 |
+
## 功能
|
| 38 |
+
|
| 39 |
+
语音翻译(支持语言:英->中)
|
| 40 |
+
|
| 41 |
+
## Pipeline组件
|
| 42 |
+
|
| 43 |
+
- [ASR](https://github.com/AXERA-TECH/3D-Speaker-MT.axera/tree/main)
|
| 44 |
+
- [Text-Translate](https://github.com/AXERA-TECH/libtranslate.axera/tree/master),参考生成库文件,保存到libtranslate
|
| 45 |
+
- [MeloTTS](https://github.com/ml-inory/melotts.axera/tree/main/cpp),参考生成库文件,保存到libmelotts
|
| 46 |
+
|
| 47 |
+
## 上板部署
|
| 48 |
+
|
| 49 |
+
- AX650N 的设备已预装 Ubuntu22.04
|
| 50 |
+
- 以 root 权限登陆 AX650N 的板卡设备
|
| 51 |
+
- 链接互联网,确保 AX650N 的设备能正常执行 apt install, pip install 等指令
|
| 52 |
+
- 已验证设备:AX650N DEMO Board
|
| 53 |
+
|
| 54 |
+
## Python API 运行
|
| 55 |
+
|
| 56 |
+
在python3.10(验证)
|
| 57 |
+
|
| 58 |
+
1、添加动态库
|
| 59 |
+
```
|
| 60 |
+
export LD_LIBRARY_PATH=./libtranslate/:$LD_LIBRARY_PATH
|
| 61 |
+
export LD_LIBRARY_PATH=./libmelotts/install/:$LD_LIBRARY_PATH
|
| 62 |
+
```
|
| 63 |
+
2、安装python库
|
| 64 |
+
```
|
| 65 |
+
pip3 install -r requirements.txt
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## 在开发板运行以下命令
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
支持输入音频文件格式:wav,mp3
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
python3 pipeline.py --audio_file wav/en.mp3 --output_dir output
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
运行参数说明:
|
| 79 |
+
|
| 80 |
+
| 参数名称 | 说明|
|
| 81 |
+
|-------|------|
|
| 82 |
+
| `--audio_file` | 音频路径 |
|
| 83 |
+
| `--output_dir` | 结果保存路径 |
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
输出保存为wav文件,具体结果如下:
|
| 87 |
+
```
|
| 88 |
+
原始音频: wav/en.mp3
|
| 89 |
+
原始文本: The tribal chieftain called for the boy and presented him with 50 pieces of gold.
|
| 90 |
+
翻译文本: 部落酋长召唤了男孩 给他50块黄金
|
| 91 |
+
生成音频: output/output.wav
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Latency
|
| 95 |
+
|
| 96 |
+
AX650N
|
| 97 |
+
|
| 98 |
+
RTF: 约为2.0
|
| 99 |
+
```
|
| 100 |
+
eg:
|
| 101 |
+
Inference time for en.mp3: 13.04 seconds
|
| 102 |
+
- VAD + ASR processing time: 0.89 seconds
|
| 103 |
+
- Translate time: 3.95 seconds
|
| 104 |
+
- TTS time: 8.20 seconds
|
| 105 |
+
Audio duration: 7.18 seconds
|
| 106 |
+
RTF: 1.82
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
参考:
|
| 110 |
+
- [sensevoice.axera](https://github.com/ml-inory/sensevoice.axera/tree/main)
|
| 111 |
+
- [3D-Speaker.axera](https://github.com/AXERA-TECH/3D-Speaker.axera/tree/master)
|
| 112 |
+
- [libtranslate.axera](https://github.com/AXERA-TECH/libtranslate.axera/tree/master)
|
| 113 |
+
- [melotts.axera](https://github.com/ml-inory/melotts.axera/tree/main)
|
| 114 |
+
|
| 115 |
+
## 技术讨论
|
| 116 |
+
|
| 117 |
+
- Github issues
|
| 118 |
+
- QQ 群: 139953715
|
ax_model/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
ax_model/auto.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d0997706b30274f7ff3b157ca90df50b7ed8ced35091a0231700355d5ee1374
|
| 3 |
+
size 2368
|
ax_model/chn_jpn_yue_eng_ko_spectok.bpe.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
|
| 3 |
+
size 377341
|
ax_model/event_emo.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d22e3df5d192fdc3e73e368a2cb576975a5a43a114a8432a91c036adf8e2263
|
| 3 |
+
size 4608
|
ax_model/sensevoice.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b64a36fa15e75ab5e3b75f18ae87a058970cff76219407e503b54fb53dd8e38
|
| 3 |
+
size 262170623
|
ax_model/sensevoice/am.mvn
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<Nnet>
|
| 2 |
+
<Splice> 560 560
|
| 3 |
+
[ 0 ]
|
| 4 |
+
<AddShift> 560 560
|
| 5 |
+
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
|
| 6 |
+
<Rescale> 560 560
|
| 7 |
+
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
|
| 8 |
+
</Nnet>
|
ax_model/sensevoice/config.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
encoder: SenseVoiceEncoderSmall
|
| 2 |
+
encoder_conf:
|
| 3 |
+
output_size: 512
|
| 4 |
+
attention_heads: 4
|
| 5 |
+
linear_units: 2048
|
| 6 |
+
num_blocks: 50
|
| 7 |
+
tp_blocks: 20
|
| 8 |
+
dropout_rate: 0.1
|
| 9 |
+
positional_dropout_rate: 0.1
|
| 10 |
+
attention_dropout_rate: 0.1
|
| 11 |
+
input_layer: pe
|
| 12 |
+
pos_enc_class: SinusoidalPositionEncoder
|
| 13 |
+
normalize_before: true
|
| 14 |
+
kernel_size: 11
|
| 15 |
+
sanm_shfit: 0
|
| 16 |
+
selfattention_layer_type: sanm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
model: SenseVoiceSmall
|
| 20 |
+
model_conf:
|
| 21 |
+
length_normalized_loss: true
|
| 22 |
+
sos: 1
|
| 23 |
+
eos: 2
|
| 24 |
+
ignore_id: -1
|
| 25 |
+
|
| 26 |
+
tokenizer: SentencepiecesTokenizer
|
| 27 |
+
tokenizer_conf:
|
| 28 |
+
bpemodel: null
|
| 29 |
+
unk_symbol: <unk>
|
| 30 |
+
split_with_space: true
|
| 31 |
+
|
| 32 |
+
frontend: WavFrontend
|
| 33 |
+
frontend_conf:
|
| 34 |
+
fs: 16000
|
| 35 |
+
window: hamming
|
| 36 |
+
n_mels: 80
|
| 37 |
+
frame_length: 25
|
| 38 |
+
frame_shift: 10
|
| 39 |
+
lfr_m: 7
|
| 40 |
+
lfr_n: 6
|
| 41 |
+
cmvn_file: null
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
dataset: SenseVoiceCTCDataset
|
| 45 |
+
dataset_conf:
|
| 46 |
+
index_ds: IndexDSJsonl
|
| 47 |
+
batch_sampler: EspnetStyleBatchSampler
|
| 48 |
+
data_split_num: 32
|
| 49 |
+
batch_type: token
|
| 50 |
+
batch_size: 14000
|
| 51 |
+
max_token_length: 2000
|
| 52 |
+
min_token_length: 60
|
| 53 |
+
max_source_length: 2000
|
| 54 |
+
min_source_length: 60
|
| 55 |
+
max_target_length: 200
|
| 56 |
+
min_target_length: 0
|
| 57 |
+
shuffle: true
|
| 58 |
+
num_workers: 4
|
| 59 |
+
sos: ${model_conf.sos}
|
| 60 |
+
eos: ${model_conf.eos}
|
| 61 |
+
IndexDSJsonl: IndexDSJsonl
|
| 62 |
+
retry: 20
|
| 63 |
+
|
| 64 |
+
train_conf:
|
| 65 |
+
accum_grad: 1
|
| 66 |
+
grad_clip: 5
|
| 67 |
+
max_epoch: 20
|
| 68 |
+
keep_nbest_models: 10
|
| 69 |
+
avg_nbest_model: 10
|
| 70 |
+
log_interval: 100
|
| 71 |
+
resume: true
|
| 72 |
+
validate_interval: 10000
|
| 73 |
+
save_checkpoint_interval: 10000
|
| 74 |
+
|
| 75 |
+
optim: adamw
|
| 76 |
+
optim_conf:
|
| 77 |
+
lr: 0.00002
|
| 78 |
+
scheduler: warmuplr
|
| 79 |
+
scheduler_conf:
|
| 80 |
+
warmup_steps: 25000
|
| 81 |
+
|
| 82 |
+
specaug: SpecAugLFR
|
| 83 |
+
specaug_conf:
|
| 84 |
+
apply_time_warp: false
|
| 85 |
+
time_warp_window: 5
|
| 86 |
+
time_warp_mode: bicubic
|
| 87 |
+
apply_freq_mask: true
|
| 88 |
+
freq_mask_width_range:
|
| 89 |
+
- 0
|
| 90 |
+
- 30
|
| 91 |
+
lfr_rate: 6
|
| 92 |
+
num_freq_mask: 1
|
| 93 |
+
apply_time_mask: true
|
| 94 |
+
time_mask_width_range:
|
| 95 |
+
- 0
|
| 96 |
+
- 12
|
| 97 |
+
num_time_mask: 1
|
ax_model/vad/am.mvn
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<Nnet>
|
| 2 |
+
<Splice> 400 400
|
| 3 |
+
[ 0 ]
|
| 4 |
+
<AddShift> 400 400
|
| 5 |
+
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
|
| 6 |
+
<Rescale> 400 400
|
| 7 |
+
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
|
| 8 |
+
</Nnet>
|
ax_model/vad/config.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
frontend: WavFrontendOnline
|
| 2 |
+
frontend_conf:
|
| 3 |
+
fs: 16000
|
| 4 |
+
window: hamming
|
| 5 |
+
n_mels: 80
|
| 6 |
+
frame_length: 25
|
| 7 |
+
frame_shift: 10
|
| 8 |
+
dither: 0.0
|
| 9 |
+
lfr_m: 5
|
| 10 |
+
lfr_n: 1
|
| 11 |
+
|
| 12 |
+
model: FsmnVADStreaming
|
| 13 |
+
model_conf:
|
| 14 |
+
sample_rate: 16000
|
| 15 |
+
detect_mode: 1
|
| 16 |
+
snr_mode: 0
|
| 17 |
+
max_end_silence_time: 800
|
| 18 |
+
max_start_silence_time: 3000
|
| 19 |
+
do_start_point_detection: True
|
| 20 |
+
do_end_point_detection: True
|
| 21 |
+
window_size_ms: 200
|
| 22 |
+
sil_to_speech_time_thres: 150
|
| 23 |
+
speech_to_sil_time_thres: 150
|
| 24 |
+
speech_2_noise_ratio: 1.0
|
| 25 |
+
do_extend: 1
|
| 26 |
+
lookback_time_start_point: 200
|
| 27 |
+
lookahead_time_end_point: 100
|
| 28 |
+
max_single_segment_time: 60000
|
| 29 |
+
snr_thres: -100.0
|
| 30 |
+
noise_frame_num_used_for_snr: 100
|
| 31 |
+
decibel_thres: -100.0
|
| 32 |
+
speech_noise_thres: 0.6
|
| 33 |
+
fe_prior_thres: 0.0001
|
| 34 |
+
silence_pdf_num: 1
|
| 35 |
+
sil_pdf_ids: [0]
|
| 36 |
+
speech_noise_thresh_low: -0.1
|
| 37 |
+
speech_noise_thresh_high: 0.3
|
| 38 |
+
output_frame_probs: False
|
| 39 |
+
frame_in_ms: 10
|
| 40 |
+
frame_length_ms: 25
|
| 41 |
+
|
| 42 |
+
encoder: FSMN
|
| 43 |
+
encoder_conf:
|
| 44 |
+
input_dim: 400
|
| 45 |
+
input_affine_dim: 140
|
| 46 |
+
fsmn_layers: 4
|
| 47 |
+
linear_dim: 250
|
| 48 |
+
proj_dim: 128
|
| 49 |
+
lorder: 20
|
| 50 |
+
rorder: 0
|
| 51 |
+
lstride: 1
|
| 52 |
+
rstride: 0
|
| 53 |
+
output_affine_dim: 140
|
| 54 |
+
output_dim: 248
|
| 55 |
+
|
| 56 |
+
|
ax_model/withitn.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39bf02586f59237894fc2918ab2db4f12ec3c084c41465718832fbd7646ea729
|
| 3 |
+
size 2368
|
ax_speech_translate_demo.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import shutil
|
| 6 |
+
import time
|
| 7 |
+
import librosa
|
| 8 |
+
import torch
|
| 9 |
+
import argparse
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import cn2an
|
| 13 |
+
|
| 14 |
+
# 导入SenseVoice相关模块
|
| 15 |
+
from model import SinusoidalPositionEncoder
|
| 16 |
+
from utils.ax_model_bin import AX_SenseVoiceSmall
|
| 17 |
+
from utils.ax_vad_bin import AX_Fsmn_vad
|
| 18 |
+
from utils.vad_utils import merge_vad
|
| 19 |
+
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
| 20 |
+
|
| 21 |
+
# 配置参数
|
| 22 |
+
# translate 参数
|
| 23 |
+
TRANSLATE_EXECUTABLE = "libtranslate/test_translate"
|
| 24 |
+
TRANSLATE_MODEL = "libtranslate/opus-mt-en-zh.axmodel"
|
| 25 |
+
TRANSLATE_TOKENIZER_DIR = "libtranslate/opus-mt-en-zh/"
|
| 26 |
+
|
| 27 |
+
# tts 参数
|
| 28 |
+
TTS_EXECUTABLE = "libmelotts/install/melotts"
|
| 29 |
+
TTS_MODEL_DIR = "libmelotts/models"
|
| 30 |
+
TTS_MODEL_FILES = {
|
| 31 |
+
"g": "g-zh_mix_en.bin",
|
| 32 |
+
"encoder": "encoder-zh.onnx",
|
| 33 |
+
"lexicon": "lexicon.txt",
|
| 34 |
+
"tokens": "tokens.txt",
|
| 35 |
+
"decoder": "decoder-zh.axmodel"
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
class SpeechTranslationPipeline:
|
| 39 |
+
def __init__(self,
|
| 40 |
+
translate_exec, translate_model, translate_tokenizer,
|
| 41 |
+
tts_exec, tts_model_dir, tts_model_files,
|
| 42 |
+
asr_model_dir="ax_model", seq_len=132):
|
| 43 |
+
self.translate_exec = translate_exec
|
| 44 |
+
self.translate_model = translate_model
|
| 45 |
+
self.translate_tokenizer = translate_tokenizer
|
| 46 |
+
self.tts_exec = tts_exec
|
| 47 |
+
self.tts_model_dir = tts_model_dir
|
| 48 |
+
self.tts_model_files = tts_model_files
|
| 49 |
+
self.asr_model_dir = asr_model_dir
|
| 50 |
+
self.seq_len = seq_len
|
| 51 |
+
|
| 52 |
+
# 初始化ASR模型
|
| 53 |
+
self._init_asr_models()
|
| 54 |
+
|
| 55 |
+
# 验证所有必需文件存在
|
| 56 |
+
self._validate_files()
|
| 57 |
+
|
| 58 |
+
def _init_asr_models(self):
|
| 59 |
+
"""初始化语音识别相关模型"""
|
| 60 |
+
print("Initializing SenseVoice models...")
|
| 61 |
+
|
| 62 |
+
# VAD模型
|
| 63 |
+
self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
|
| 64 |
+
|
| 65 |
+
# 位置编码
|
| 66 |
+
self.embed = SinusoidalPositionEncoder()
|
| 67 |
+
self.position_encoding = self.embed.get_position_encoding(
|
| 68 |
+
torch.randn(1, self.seq_len, 560)).numpy()
|
| 69 |
+
|
| 70 |
+
# ASR模型
|
| 71 |
+
self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
|
| 72 |
+
|
| 73 |
+
# Tokenizer
|
| 74 |
+
tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
|
| 75 |
+
self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
|
| 76 |
+
|
| 77 |
+
print("SenseVoice models initialized successfully.")
|
| 78 |
+
|
| 79 |
+
def _validate_files(self):
|
| 80 |
+
"""验证所有必需的文件都存在"""
|
| 81 |
+
# 检查翻译相关文件
|
| 82 |
+
if not os.path.exists(self.translate_exec):
|
| 83 |
+
raise FileNotFoundError(f"翻译可执行文件不存在: {self.translate_exec}")
|
| 84 |
+
if not os.path.exists(self.translate_model):
|
| 85 |
+
raise FileNotFoundError(f"翻译模型不存在: {self.translate_model}")
|
| 86 |
+
if not os.path.exists(self.translate_tokenizer):
|
| 87 |
+
raise FileNotFoundError(f"翻译tokenizer目录不存在: {self.translate_tokenizer}")
|
| 88 |
+
|
| 89 |
+
# 检查TTS相关文件
|
| 90 |
+
if not os.path.exists(self.tts_exec):
|
| 91 |
+
raise FileNotFoundError(f"TTS可执行文件不存在: {self.tts_exec}")
|
| 92 |
+
|
| 93 |
+
for key, filename in self.tts_model_files.items():
|
| 94 |
+
filepath = os.path.join(self.tts_model_dir, filename)
|
| 95 |
+
if not os.path.exists(filepath):
|
| 96 |
+
raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
|
| 97 |
+
|
| 98 |
+
def speech_recognition(self, speech, fs):
|
| 99 |
+
"""
|
| 100 |
+
第一步:语音识别(ASR)
|
| 101 |
+
"""
|
| 102 |
+
speech_lengths = len(speech)
|
| 103 |
+
|
| 104 |
+
# VAD处理
|
| 105 |
+
print("Running VAD...")
|
| 106 |
+
vad_start_time = time.time()
|
| 107 |
+
res_vad = self.model_vad(speech)[0]
|
| 108 |
+
vad_segments = merge_vad(res_vad, 15 * 1000)
|
| 109 |
+
vad_time_cost = time.time() - vad_start_time
|
| 110 |
+
print(f"VAD processing time: {vad_time_cost:.2f} seconds")
|
| 111 |
+
print(f"VAD segments detected: {len(vad_segments)}")
|
| 112 |
+
|
| 113 |
+
# ASR处理
|
| 114 |
+
print("Running ASR...")
|
| 115 |
+
asr_start_time = time.time()
|
| 116 |
+
all_results = ""
|
| 117 |
+
|
| 118 |
+
# 遍历每个VAD片段并处理
|
| 119 |
+
for i, segment in enumerate(vad_segments):
|
| 120 |
+
segment_start, segment_end = segment
|
| 121 |
+
start_sample = int(segment_start / 1000 * fs)
|
| 122 |
+
end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
|
| 123 |
+
segment_speech = speech[start_sample:end_sample]
|
| 124 |
+
|
| 125 |
+
# 为当前片段创建临时文件
|
| 126 |
+
segment_filename = f"temp_segment_{i}.wav"
|
| 127 |
+
sf.write(segment_filename, segment_speech, fs)
|
| 128 |
+
|
| 129 |
+
# 对当前片段进行识别
|
| 130 |
+
try:
|
| 131 |
+
segment_res = self.model_bin(
|
| 132 |
+
segment_filename,
|
| 133 |
+
"auto", # 语言自动检测
|
| 134 |
+
True, # withitn
|
| 135 |
+
self.position_encoding,
|
| 136 |
+
tokenizer=self.tokenizer,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
all_results += segment_res
|
| 140 |
+
|
| 141 |
+
# 清理临时文件
|
| 142 |
+
if os.path.exists(segment_filename):
|
| 143 |
+
os.remove(segment_filename)
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
if os.path.exists(segment_filename):
|
| 147 |
+
os.remove(segment_filename)
|
| 148 |
+
print(f"Error processing segment {i}: {e}")
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
asr_time_cost = time.time() - asr_start_time
|
| 152 |
+
print(f"ASR processing time: {asr_time_cost:.2f} seconds")
|
| 153 |
+
print(f"ASR Result: {all_results}")
|
| 154 |
+
|
| 155 |
+
return all_results.strip()
|
| 156 |
+
|
| 157 |
+
def run_translation(self, english_text):
|
| 158 |
+
"""
|
| 159 |
+
第二步:调用翻译程序将英文翻译成中文
|
| 160 |
+
"""
|
| 161 |
+
# 构建命令参数
|
| 162 |
+
cmd = [
|
| 163 |
+
self.translate_exec,
|
| 164 |
+
"--model", self.translate_model,
|
| 165 |
+
"--tokenizer_dir", self.translate_tokenizer,
|
| 166 |
+
"--text", f'"{english_text}"' # 添加引号处理包含空格和特殊字符的文本
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# 执行命令
|
| 171 |
+
result = subprocess.run(
|
| 172 |
+
cmd,
|
| 173 |
+
capture_output=True,
|
| 174 |
+
text=True,
|
| 175 |
+
timeout=30 # 设置超时时间,单位秒
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 检查执行结果
|
| 179 |
+
if result.returncode != 0:
|
| 180 |
+
error_msg = f"翻译程序执行失败: {result.stderr}"
|
| 181 |
+
raise RuntimeError(error_msg)
|
| 182 |
+
|
| 183 |
+
# 提取翻译结果
|
| 184 |
+
chinese_text = result.stdout.strip()
|
| 185 |
+
|
| 186 |
+
# 清理可能的额外输出
|
| 187 |
+
# if "翻译结果:" in chinese_text:
|
| 188 |
+
# chinese_text = chinese_text.split("翻译结果:", 1)[-1].strip()
|
| 189 |
+
chinese_text = chinese_text.split("output: ")[-1].split("\nAX_ENGINE_Deinit")[0]
|
| 190 |
+
|
| 191 |
+
print(f"翻译结果: {chinese_text}")
|
| 192 |
+
return chinese_text
|
| 193 |
+
|
| 194 |
+
except subprocess.TimeoutExpired:
|
| 195 |
+
raise RuntimeError("翻译程序执行超时")
|
| 196 |
+
except Exception as e:
|
| 197 |
+
raise e
|
| 198 |
+
|
| 199 |
+
def run_tts(self, chinese_text, output_dir, output_wav=None):
|
| 200 |
+
"""
|
| 201 |
+
第三步:调用TTS程序合成中文语音
|
| 202 |
+
"""
|
| 203 |
+
output_path = os.path.join(output_dir, output_wav)
|
| 204 |
+
#chinese_text = chinese_text.split("output: ")[-1].split("\nAX_ENGINE_Deinit")[0]
|
| 205 |
+
|
| 206 |
+
chinese_text = cn2an.transform(chinese_text, "an2cn")
|
| 207 |
+
|
| 208 |
+
# 构建命令参数
|
| 209 |
+
cmd = [
|
| 210 |
+
self.tts_exec,
|
| 211 |
+
"--g", os.path.join(self.tts_model_dir, self.tts_model_files["g"]),
|
| 212 |
+
"-e", os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]),
|
| 213 |
+
"-l", os.path.join(self.tts_model_dir, self.tts_model_files["lexicon"]),
|
| 214 |
+
"-t", os.path.join(self.tts_model_dir, self.tts_model_files["tokens"]),
|
| 215 |
+
"-d", os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]),
|
| 216 |
+
"-w", output_path,
|
| 217 |
+
"-s", f'"{chinese_text}"'
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
# 执行命令
|
| 222 |
+
result = subprocess.run(
|
| 223 |
+
cmd,
|
| 224 |
+
capture_output=False,
|
| 225 |
+
text=True,
|
| 226 |
+
timeout=60 # TTS可能需要更长时间
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# 检查执行结果
|
| 230 |
+
if result.returncode != 0:
|
| 231 |
+
error_msg = f"TTS程序执行失败: {result.stderr}"
|
| 232 |
+
raise RuntimeError(error_msg)
|
| 233 |
+
|
| 234 |
+
# 验证输出文件是否存在
|
| 235 |
+
if not os.path.exists(output_path):
|
| 236 |
+
raise FileNotFoundError(f"输出文件未生成: {output_path}")
|
| 237 |
+
|
| 238 |
+
return output_path
|
| 239 |
+
|
| 240 |
+
except subprocess.TimeoutExpired:
|
| 241 |
+
raise RuntimeError("TTS程序执行超时")
|
| 242 |
+
except Exception as e:
|
| 243 |
+
# 清理临时文件
|
| 244 |
+
if output_path and os.path.exists(os.path.dirname(output_path)):
|
| 245 |
+
shutil.rmtree(os.path.dirname(output_path))
|
| 246 |
+
raise e
|
| 247 |
+
|
| 248 |
+
def full_pipeline(self, speech, fs, output_dir=None,output_tts = None):
|
| 249 |
+
"""
|
| 250 |
+
完整Pipeline:语音识别 -> 翻译 -> TTS合成
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# 第一步:语音识别
|
| 254 |
+
print("\n----------------------VAD+ASR----------------------------\n")
|
| 255 |
+
start_time = time.time() # 记录开始时间
|
| 256 |
+
english_text = self.speech_recognition(speech, fs)
|
| 257 |
+
asr_time = time.time() - start_time # 计算耗时
|
| 258 |
+
print(f"语音识别耗时: {asr_time:.2f} 秒")
|
| 259 |
+
|
| 260 |
+
# 第二步:翻译
|
| 261 |
+
print("\n---------------------translate---------------------------\n")
|
| 262 |
+
start_time = time.time() # 记录开始时间
|
| 263 |
+
chinese_text = self.run_translation(english_text)
|
| 264 |
+
translate_time = time.time() - start_time # 计算耗时
|
| 265 |
+
print(f"翻译耗时: {translate_time:.2f} 秒")
|
| 266 |
+
|
| 267 |
+
# 第三步:TTS合成
|
| 268 |
+
print("-------------------------TTS-------------------------------\n")
|
| 269 |
+
start_time = time.time() # 记录开始时间
|
| 270 |
+
output_path = self.run_tts(chinese_text, output_dir, output_tts)
|
| 271 |
+
tts_time = time.time() - start_time # 计算耗时
|
| 272 |
+
print(f"TTS合成耗时: {tts_time:.2f} 秒")
|
| 273 |
+
|
| 274 |
+
return {
|
| 275 |
+
"original_text": english_text,
|
| 276 |
+
"translated_text": chinese_text,
|
| 277 |
+
"audio_path": output_path
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def main():
|
| 281 |
+
parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline")
|
| 282 |
+
parser.add_argument("--audio_file", type=str, required=True, help="Input audio file path")
|
| 283 |
+
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
|
| 284 |
+
parser.add_argument("--output_tts", type=str, default="output.wav", help="Output directory")
|
| 285 |
+
|
| 286 |
+
args = parser.parse_args()
|
| 287 |
+
print("-------------------START------------------------\n")
|
| 288 |
+
os.makedirs(args.output_dir ,exist_ok=True)
|
| 289 |
+
|
| 290 |
+
print(f"Processing audio file: {args.audio_file}")
|
| 291 |
+
# 加载音频
|
| 292 |
+
speech, fs = librosa.load(args.audio_file, sr=None)
|
| 293 |
+
if fs != 16000:
|
| 294 |
+
print(f"Resampling audio from {fs}Hz to 16000Hz")
|
| 295 |
+
speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
|
| 296 |
+
fs = 16000
|
| 297 |
+
audio_duration = librosa.get_duration(y=speech, sr=fs)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# 初始化
|
| 301 |
+
pipeline = SpeechTranslationPipeline(
|
| 302 |
+
translate_exec=TRANSLATE_EXECUTABLE,
|
| 303 |
+
translate_model=TRANSLATE_MODEL,
|
| 304 |
+
translate_tokenizer=TRANSLATE_TOKENIZER_DIR,
|
| 305 |
+
tts_exec=TTS_EXECUTABLE,
|
| 306 |
+
tts_model_dir=TTS_MODEL_DIR,
|
| 307 |
+
tts_model_files=TTS_MODEL_FILES,
|
| 308 |
+
asr_model_dir="ax_model",
|
| 309 |
+
seq_len=132
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
start_time = time.time()
|
| 313 |
+
try:
|
| 314 |
+
# 运行
|
| 315 |
+
result = pipeline.full_pipeline(speech, fs, args.output_dir, args.output_tts)
|
| 316 |
+
|
| 317 |
+
print("\n" + "="*50)
|
| 318 |
+
print("speech translate 完成!")
|
| 319 |
+
print("="*50 + "\n")
|
| 320 |
+
print(f"原始音频: {args.audio_file}")
|
| 321 |
+
print(f"原始文本: {result['original_text']}")
|
| 322 |
+
print(f"翻译文本: {result['translated_text']}")
|
| 323 |
+
print(f"生成音频: {result['audio_path']}")
|
| 324 |
+
|
| 325 |
+
# 保存结果到文件
|
| 326 |
+
result_file = os.path.join(args.output_dir, "pipeline_result.txt")
|
| 327 |
+
with open(result_file, 'w', encoding='utf-8') as f:
|
| 328 |
+
f.write(f"原始音频: {args.audio_file}\n")
|
| 329 |
+
f.write(f"识别文本: {result['original_text']}\n")
|
| 330 |
+
f.write(f"翻译结果: {result['translated_text']}\n")
|
| 331 |
+
f.write(f"合成音频: {result['audio_path']}\n")
|
| 332 |
+
|
| 333 |
+
# print(f"\n详细结果已保存到: {result_file}")
|
| 334 |
+
time_cost = time.time() - start_time
|
| 335 |
+
rtf = time_cost / audio_duration
|
| 336 |
+
print(f"Inference time for {args.audio_file}: {time_cost:.2f} seconds")
|
| 337 |
+
print(f"Audio duration: {audio_duration:.2f} seconds")
|
| 338 |
+
print(f"RTF: {rtf:.2f}\n")
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f"Pipeline执行失败: {e}")
|
| 341 |
+
import traceback
|
| 342 |
+
traceback.print_exc()
|
| 343 |
+
|
| 344 |
+
if __name__ == "__main__":
|
| 345 |
+
main()
|
| 346 |
+
|
| 347 |
+
|
libmelotts/install/libonnxruntime.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0b062d058b15bbf87bc3232b9c905a2bd49ea7e88e203a308df5c08c4c15129
|
| 3 |
+
size 16014680
|
libmelotts/install/libonnxruntime.so.1.14.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0b062d058b15bbf87bc3232b9c905a2bd49ea7e88e203a308df5c08c4c15129
|
| 3 |
+
size 16014680
|
libmelotts/install/libonnxruntime_providers_shared.so
ADDED
|
Binary file (9.92 kB). View file
|
|
|
libmelotts/install/melotts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88d12e4b441d2398ea7646afb677f2dd5819f347caa680781ce2bcfa26e81d85
|
| 3 |
+
size 187256
|
libmelotts/models/decoder-en.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90c93c0fa978cc1c68fbac6a78707dd75b8b9069cb01a1ade6846e2435aa1eb1
|
| 3 |
+
size 44093802
|
libmelotts/models/decoder-zh.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37ea2d8401f18dd371eec50b90bd39dcadf9684aaf3543dace8ce1a9499ef253
|
| 3 |
+
size 44092592
|
libmelotts/models/encoder-en.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6cc51185fb81934c7490c5f9ac993fff7efa98ab41c08cd3753c96abcb297582
|
| 3 |
+
size 31488385
|
libmelotts/models/encoder-zh.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2b0a5bc2789faef16b4bfc56ab4905364f8163a59f2db3d071b4a14792bfee5
|
| 3 |
+
size 31397760
|
libmelotts/models/g-en.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:094bf0dbe1cd6c9408707209b2b7261b9df2cd5917d310bfac5945a15a31821a
|
| 3 |
+
size 1024
|
libmelotts/models/g-jp.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c01dd0961bbe1effca4ed378d2969d6fbd9b579133b722f6968db5cf4d22281e
|
| 3 |
+
size 1024
|
libmelotts/models/g-zh_mix_en.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c70d897674847882bd35e780aee696ddaff8d04d5c57e4f9cf37611b6821879f
|
| 3 |
+
size 1024
|
libmelotts/models/lexicon.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
libmelotts/models/tokens.txt
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_ 0
|
| 2 |
+
AA 1
|
| 3 |
+
E 2
|
| 4 |
+
EE 3
|
| 5 |
+
En 4
|
| 6 |
+
N 5
|
| 7 |
+
OO 6
|
| 8 |
+
V 7
|
| 9 |
+
a 8
|
| 10 |
+
a: 9
|
| 11 |
+
aa 10
|
| 12 |
+
ae 11
|
| 13 |
+
ah 12
|
| 14 |
+
ai 13
|
| 15 |
+
an 14
|
| 16 |
+
ang 15
|
| 17 |
+
ao 16
|
| 18 |
+
aw 17
|
| 19 |
+
ay 18
|
| 20 |
+
b 19
|
| 21 |
+
by 20
|
| 22 |
+
c 21
|
| 23 |
+
ch 22
|
| 24 |
+
d 23
|
| 25 |
+
dh 24
|
| 26 |
+
dy 25
|
| 27 |
+
e 26
|
| 28 |
+
e: 27
|
| 29 |
+
eh 28
|
| 30 |
+
ei 29
|
| 31 |
+
en 30
|
| 32 |
+
eng 31
|
| 33 |
+
er 32
|
| 34 |
+
ey 33
|
| 35 |
+
f 34
|
| 36 |
+
g 35
|
| 37 |
+
gy 36
|
| 38 |
+
h 37
|
| 39 |
+
hh 38
|
| 40 |
+
hy 39
|
| 41 |
+
i 40
|
| 42 |
+
i0 41
|
| 43 |
+
i: 42
|
| 44 |
+
ia 43
|
| 45 |
+
ian 44
|
| 46 |
+
iang 45
|
| 47 |
+
iao 46
|
| 48 |
+
ie 47
|
| 49 |
+
ih 48
|
| 50 |
+
in 49
|
| 51 |
+
ing 50
|
| 52 |
+
iong 51
|
| 53 |
+
ir 52
|
| 54 |
+
iu 53
|
| 55 |
+
iy 54
|
| 56 |
+
j 55
|
| 57 |
+
jh 56
|
| 58 |
+
k 57
|
| 59 |
+
ky 58
|
| 60 |
+
l 59
|
| 61 |
+
m 60
|
| 62 |
+
my 61
|
| 63 |
+
n 62
|
| 64 |
+
ng 63
|
| 65 |
+
ny 64
|
| 66 |
+
o 65
|
| 67 |
+
o: 66
|
| 68 |
+
ong 67
|
| 69 |
+
ou 68
|
| 70 |
+
ow 69
|
| 71 |
+
oy 70
|
| 72 |
+
p 71
|
| 73 |
+
py 72
|
| 74 |
+
q 73
|
| 75 |
+
r 74
|
| 76 |
+
ry 75
|
| 77 |
+
s 76
|
| 78 |
+
sh 77
|
| 79 |
+
t 78
|
| 80 |
+
th 79
|
| 81 |
+
ts 80
|
| 82 |
+
ty 81
|
| 83 |
+
u 82
|
| 84 |
+
u: 83
|
| 85 |
+
ua 84
|
| 86 |
+
uai 85
|
| 87 |
+
uan 86
|
| 88 |
+
uang 87
|
| 89 |
+
uh 88
|
| 90 |
+
ui 89
|
| 91 |
+
un 90
|
| 92 |
+
uo 91
|
| 93 |
+
uw 92
|
| 94 |
+
v 93
|
| 95 |
+
van 94
|
| 96 |
+
ve 95
|
| 97 |
+
vn 96
|
| 98 |
+
w 97
|
| 99 |
+
x 98
|
| 100 |
+
y 99
|
| 101 |
+
z 100
|
| 102 |
+
zh 101
|
| 103 |
+
zy 102
|
| 104 |
+
! 103
|
| 105 |
+
? 104
|
| 106 |
+
… 105
|
| 107 |
+
, 106
|
| 108 |
+
. 107
|
| 109 |
+
' 108
|
| 110 |
+
- 109
|
| 111 |
+
SP 110
|
| 112 |
+
UNK 111
|
libtranslate/libax_translate.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d6c5901387ed8d0a3c0611a5ad8d60281487e34d20afc83a58923ea2462b456
|
| 3 |
+
size 1222544
|
libtranslate/libsentencepiece.so.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e93872d229a074c73e66eeeff8988197deb35b81b09ee893e022e115f886ed4
|
| 3 |
+
size 1320848
|
libtranslate/opus-mt-en-zh.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e262aab07f3a5478a0a0df24a92e1f88c138c4999be5b4cb18b618d996bcacff
|
| 3 |
+
size 217562368
|
libtranslate/opus-mt-en-zh/.gitattributes
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
libtranslate/opus-mt-en-zh/README.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- zh
|
| 5 |
+
tags:
|
| 6 |
+
- translation
|
| 7 |
+
license: apache-2.0
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
### eng-zho
|
| 11 |
+
|
| 12 |
+
* source group: English
|
| 13 |
+
* target group: Chinese
|
| 14 |
+
* OPUS readme: [eng-zho](https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-zho/README.md)
|
| 15 |
+
|
| 16 |
+
* model: transformer
|
| 17 |
+
* source language(s): eng
|
| 18 |
+
* target language(s): cjy_Hans cjy_Hant cmn cmn_Hans cmn_Hant gan lzh lzh_Hans nan wuu yue yue_Hans yue_Hant
|
| 19 |
+
* model: transformer
|
| 20 |
+
* pre-processing: normalization + SentencePiece (spm32k,spm32k)
|
| 21 |
+
* a sentence initial language token is required in the form of `>>id<<` (id = valid target language ID)
|
| 22 |
+
* download original weights: [opus-2020-07-17.zip](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.zip)
|
| 23 |
+
* test set translations: [opus-2020-07-17.test.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.test.txt)
|
| 24 |
+
* test set scores: [opus-2020-07-17.eval.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.eval.txt)
|
| 25 |
+
|
| 26 |
+
## Benchmarks
|
| 27 |
+
|
| 28 |
+
| testset | BLEU | chr-F |
|
| 29 |
+
|-----------------------|-------|-------|
|
| 30 |
+
| Tatoeba-test.eng.zho | 31.4 | 0.268 |
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### System Info:
|
| 34 |
+
- hf_name: eng-zho
|
| 35 |
+
|
| 36 |
+
- source_languages: eng
|
| 37 |
+
|
| 38 |
+
- target_languages: zho
|
| 39 |
+
|
| 40 |
+
- opus_readme_url: https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-zho/README.md
|
| 41 |
+
|
| 42 |
+
- original_repo: Tatoeba-Challenge
|
| 43 |
+
|
| 44 |
+
- tags: ['translation']
|
| 45 |
+
|
| 46 |
+
- languages: ['en', 'zh']
|
| 47 |
+
|
| 48 |
+
- src_constituents: {'eng'}
|
| 49 |
+
|
| 50 |
+
- tgt_constituents: {'cmn_Hans', 'nan', 'nan_Hani', 'gan', 'yue', 'cmn_Kana', 'yue_Hani', 'wuu_Bopo', 'cmn_Latn', 'yue_Hira', 'cmn_Hani', 'cjy_Hans', 'cmn', 'lzh_Hang', 'lzh_Hira', 'cmn_Hant', 'lzh_Bopo', 'zho', 'zho_Hans', 'zho_Hant', 'lzh_Hani', 'yue_Hang', 'wuu', 'yue_Kana', 'wuu_Latn', 'yue_Bopo', 'cjy_Hant', 'yue_Hans', 'lzh', 'cmn_Hira', 'lzh_Yiii', 'lzh_Hans', 'cmn_Bopo', 'cmn_Hang', 'hak_Hani', 'cmn_Yiii', 'yue_Hant', 'lzh_Kana', 'wuu_Hani'}
|
| 51 |
+
|
| 52 |
+
- src_multilingual: False
|
| 53 |
+
|
| 54 |
+
- tgt_multilingual: False
|
| 55 |
+
|
| 56 |
+
- prepro: normalization + SentencePiece (spm32k,spm32k)
|
| 57 |
+
|
| 58 |
+
- url_model: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.zip
|
| 59 |
+
|
| 60 |
+
- url_test_set: https://object.pouta.csc.fi/Tatoeba-MT-models/eng-zho/opus-2020-07-17.test.txt
|
| 61 |
+
|
| 62 |
+
- src_alpha3: eng
|
| 63 |
+
|
| 64 |
+
- tgt_alpha3: zho
|
| 65 |
+
|
| 66 |
+
- short_pair: en-zh
|
| 67 |
+
|
| 68 |
+
- chrF2_score: 0.268
|
| 69 |
+
|
| 70 |
+
- bleu: 31.4
|
| 71 |
+
|
| 72 |
+
- brevity_penalty: 0.8959999999999999
|
| 73 |
+
|
| 74 |
+
- ref_len: 110468.0
|
| 75 |
+
|
| 76 |
+
- src_name: English
|
| 77 |
+
|
| 78 |
+
- tgt_name: Chinese
|
| 79 |
+
|
| 80 |
+
- train_date: 2020-07-17
|
| 81 |
+
|
| 82 |
+
- src_alpha2: en
|
| 83 |
+
|
| 84 |
+
- tgt_alpha2: zh
|
| 85 |
+
|
| 86 |
+
- prefer_old: False
|
| 87 |
+
|
| 88 |
+
- long_pair: eng-zho
|
| 89 |
+
|
| 90 |
+
- helsinki_git_sha: 480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535
|
| 91 |
+
|
| 92 |
+
- transformers_git_sha: 2207e5d8cb224e954a7cba69fa4ac2309e9ff30b
|
| 93 |
+
|
| 94 |
+
- port_machine: brutasse
|
| 95 |
+
|
| 96 |
+
- port_time: 2020-08-21-14:41
|
libtranslate/opus-mt-en-zh/config.json
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "./",
|
| 3 |
+
"activation_dropout": 0.0,
|
| 4 |
+
"activation_function": "swish",
|
| 5 |
+
"add_bias_logits": false,
|
| 6 |
+
"add_final_layer_norm": false,
|
| 7 |
+
"architectures": [
|
| 8 |
+
"MarianMTModel"
|
| 9 |
+
],
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"bad_words_ids": [
|
| 12 |
+
[
|
| 13 |
+
65000
|
| 14 |
+
]
|
| 15 |
+
],
|
| 16 |
+
"bos_token_id": 0,
|
| 17 |
+
"classif_dropout": 0.0,
|
| 18 |
+
"classifier_dropout": 0.0,
|
| 19 |
+
"d_model": 512,
|
| 20 |
+
"decoder_attention_heads": 8,
|
| 21 |
+
"decoder_ffn_dim": 2048,
|
| 22 |
+
"decoder_layerdrop": 0.0,
|
| 23 |
+
"decoder_layers": 6,
|
| 24 |
+
"decoder_start_token_id": 65000,
|
| 25 |
+
"do_blenderbot_90_layernorm": false,
|
| 26 |
+
"dropout": 0.1,
|
| 27 |
+
"encoder_attention_heads": 8,
|
| 28 |
+
"encoder_ffn_dim": 2048,
|
| 29 |
+
"encoder_layerdrop": 0.0,
|
| 30 |
+
"encoder_layers": 6,
|
| 31 |
+
"eos_token_id": 0,
|
| 32 |
+
"extra_pos_embeddings": 0,
|
| 33 |
+
"force_bos_token_to_be_generated": false,
|
| 34 |
+
"forced_eos_token_id": 0,
|
| 35 |
+
"gradient_checkpointing": false,
|
| 36 |
+
"id2label": {
|
| 37 |
+
"0": "LABEL_0",
|
| 38 |
+
"1": "LABEL_1",
|
| 39 |
+
"2": "LABEL_2"
|
| 40 |
+
},
|
| 41 |
+
"init_std": 0.02,
|
| 42 |
+
"is_encoder_decoder": true,
|
| 43 |
+
"label2id": {
|
| 44 |
+
"LABEL_0": 0,
|
| 45 |
+
"LABEL_1": 1,
|
| 46 |
+
"LABEL_2": 2
|
| 47 |
+
},
|
| 48 |
+
"max_length": 512,
|
| 49 |
+
"max_position_embeddings": 512,
|
| 50 |
+
"model_type": "marian",
|
| 51 |
+
"normalize_before": false,
|
| 52 |
+
"normalize_embedding": false,
|
| 53 |
+
"num_beams": 4,
|
| 54 |
+
"num_hidden_layers": 6,
|
| 55 |
+
"pad_token_id": 65000,
|
| 56 |
+
"scale_embedding": true,
|
| 57 |
+
"static_position_embeddings": true,
|
| 58 |
+
"transformers_version": "4.9.0.dev0",
|
| 59 |
+
"use_cache": true,
|
| 60 |
+
"vocab_size": 65001
|
| 61 |
+
}
|
libtranslate/opus-mt-en-zh/generation_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bad_words_ids": [
|
| 3 |
+
[
|
| 4 |
+
65000
|
| 5 |
+
]
|
| 6 |
+
],
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"decoder_start_token_id": 65000,
|
| 9 |
+
"eos_token_id": 0,
|
| 10 |
+
"forced_eos_token_id": 0,
|
| 11 |
+
"max_length": 512,
|
| 12 |
+
"num_beams": 4,
|
| 13 |
+
"pad_token_id": 65000,
|
| 14 |
+
"renormalize_logits": true,
|
| 15 |
+
"transformers_version": "4.32.0.dev0"
|
| 16 |
+
}
|
libtranslate/opus-mt-en-zh/metadata.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"hf_name":"eng-zho","source_languages":"eng","target_languages":"zho","opus_readme_url":"https:\/\/github.com\/Helsinki-NLP\/Tatoeba-Challenge\/tree\/master\/models\/eng-zho\/README.md","original_repo":"Tatoeba-Challenge","tags":["translation"],"languages":["en","zh"],"src_constituents":["eng"],"tgt_constituents":["cmn_Hans","nan","nan_Hani","gan","yue","cmn_Kana","yue_Hani","wuu_Bopo","cmn_Latn","yue_Hira","cmn_Hani","cjy_Hans","cmn","lzh_Hang","lzh_Hira","cmn_Hant","lzh_Bopo","zho","zho_Hans","zho_Hant","lzh_Hani","yue_Hang","wuu","yue_Kana","wuu_Latn","yue_Bopo","cjy_Hant","yue_Hans","lzh","cmn_Hira","lzh_Yiii","lzh_Hans","cmn_Bopo","cmn_Hang","hak_Hani","cmn_Yiii","yue_Hant","lzh_Kana","wuu_Hani"],"src_multilingual":false,"tgt_multilingual":false,"prepro":" normalization + SentencePiece (spm32k,spm32k)","url_model":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/eng-zho\/opus-2020-07-17.zip","url_test_set":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/eng-zho\/opus-2020-07-17.test.txt","src_alpha3":"eng","tgt_alpha3":"zho","short_pair":"en-zh","chrF2_score":0.268,"bleu":31.4,"brevity_penalty":0.896,"ref_len":110468.0,"src_name":"English","tgt_name":"Chinese","train_date":"2020-07-17","src_alpha2":"en","tgt_alpha2":"zh","prefer_old":false,"long_pair":"eng-zho","helsinki_git_sha":"480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535","transformers_git_sha":"2207e5d8cb224e954a7cba69fa4ac2309e9ff30b","port_machine":"brutasse","port_time":"2020-08-21-14:41"}
|
libtranslate/opus-mt-en-zh/source.spm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5775ddc9e3ff2fae91554da56468ad35ff56edaba870fea74447bc7234bfdaa8
|
| 3 |
+
size 806435
|
libtranslate/opus-mt-en-zh/target.spm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81dc94efa84e4025ef38d25d5d07429fe41e3eb29d44003f1db6fe98487b0052
|
| 3 |
+
size 804600
|
libtranslate/opus-mt-en-zh/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"target_lang": "zho", "source_lang": "eng"}
|
libtranslate/opus-mt-en-zh/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
libtranslate/test_translate
ADDED
|
Binary file (82.1 kB). View file
|
|
|
model.py
ADDED
|
@@ -0,0 +1,942 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Iterable, Optional
|
| 7 |
+
|
| 8 |
+
from funasr.register import tables
|
| 9 |
+
from funasr.models.ctc.ctc import CTC
|
| 10 |
+
from funasr.utils.datadir_writer import DatadirWriter
|
| 11 |
+
from funasr.models.paraformer.search import Hypothesis
|
| 12 |
+
from funasr.train_utils.device_funcs import force_gatherable
|
| 13 |
+
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
| 14 |
+
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
| 15 |
+
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
| 16 |
+
from utils.ctc_alignment import ctc_forced_align
|
| 17 |
+
|
| 18 |
+
class SinusoidalPositionEncoder(torch.nn.Module):
|
| 19 |
+
""" """
|
| 20 |
+
|
| 21 |
+
def __int__(self, d_model=80, dropout_rate=0.1):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def encode(
|
| 25 |
+
self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
|
| 26 |
+
):
|
| 27 |
+
batch_size = positions.size(0)
|
| 28 |
+
positions = positions.type(dtype)
|
| 29 |
+
device = positions.device
|
| 30 |
+
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
|
| 31 |
+
depth / 2 - 1
|
| 32 |
+
)
|
| 33 |
+
inv_timescales = torch.exp(
|
| 34 |
+
torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
|
| 35 |
+
)
|
| 36 |
+
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
| 37 |
+
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
|
| 38 |
+
inv_timescales, [1, 1, -1]
|
| 39 |
+
)
|
| 40 |
+
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
| 41 |
+
return encoding.type(dtype)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
batch_size, timesteps, input_dim = x.size()
|
| 45 |
+
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
|
| 46 |
+
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
| 47 |
+
|
| 48 |
+
return x + position_encoding
|
| 49 |
+
|
| 50 |
+
def get_position_encoding(self, x):
|
| 51 |
+
batch_size, timesteps, input_dim = x.size()
|
| 52 |
+
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
|
| 53 |
+
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
| 54 |
+
return position_encoding
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 58 |
+
"""Positionwise feed forward layer.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
idim (int): Input dimenstion.
|
| 62 |
+
hidden_units (int): The number of hidden units.
|
| 63 |
+
dropout_rate (float): Dropout rate.
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
| 68 |
+
"""Construct an PositionwiseFeedForward object."""
|
| 69 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 70 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 71 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 72 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 73 |
+
self.activation = activation
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
"""Forward function."""
|
| 77 |
+
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MultiHeadedAttentionSANM(nn.Module):
|
| 81 |
+
"""Multi-Head Attention layer.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
n_head (int): The number of heads.
|
| 85 |
+
n_feat (int): The number of features.
|
| 86 |
+
dropout_rate (float): Dropout rate.
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
n_head,
|
| 93 |
+
in_feat,
|
| 94 |
+
n_feat,
|
| 95 |
+
dropout_rate,
|
| 96 |
+
kernel_size,
|
| 97 |
+
sanm_shfit=0,
|
| 98 |
+
lora_list=None,
|
| 99 |
+
lora_rank=8,
|
| 100 |
+
lora_alpha=16,
|
| 101 |
+
lora_dropout=0.1,
|
| 102 |
+
):
|
| 103 |
+
"""Construct an MultiHeadedAttention object."""
|
| 104 |
+
super().__init__()
|
| 105 |
+
assert n_feat % n_head == 0
|
| 106 |
+
# We assume d_v always equals d_k
|
| 107 |
+
self.d_k = n_feat // n_head
|
| 108 |
+
self.h = n_head
|
| 109 |
+
# self.linear_q = nn.Linear(n_feat, n_feat)
|
| 110 |
+
# self.linear_k = nn.Linear(n_feat, n_feat)
|
| 111 |
+
# self.linear_v = nn.Linear(n_feat, n_feat)
|
| 112 |
+
|
| 113 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 114 |
+
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
| 115 |
+
self.attn = None
|
| 116 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 117 |
+
|
| 118 |
+
self.fsmn_block = nn.Conv1d(
|
| 119 |
+
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
| 120 |
+
)
|
| 121 |
+
# padding
|
| 122 |
+
left_padding = (kernel_size - 1) // 2
|
| 123 |
+
if sanm_shfit > 0:
|
| 124 |
+
left_padding = left_padding + sanm_shfit
|
| 125 |
+
right_padding = kernel_size - 1 - left_padding
|
| 126 |
+
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
| 127 |
+
|
| 128 |
+
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
| 129 |
+
b, t, d = inputs.size()
|
| 130 |
+
if mask is not None:
|
| 131 |
+
mask = torch.reshape(mask, (b, -1, 1))
|
| 132 |
+
if mask_shfit_chunk is not None:
|
| 133 |
+
mask = mask * mask_shfit_chunk
|
| 134 |
+
inputs = inputs * mask
|
| 135 |
+
|
| 136 |
+
x = inputs.transpose(1, 2)
|
| 137 |
+
x = self.pad_fn(x)
|
| 138 |
+
x = self.fsmn_block(x)
|
| 139 |
+
x = x.transpose(1, 2)
|
| 140 |
+
x += inputs
|
| 141 |
+
x = self.dropout(x)
|
| 142 |
+
if mask is not None:
|
| 143 |
+
x = x * mask
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
def forward_qkv(self, x):
|
| 147 |
+
"""Transform query, key and value.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 151 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 152 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
| 156 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
| 157 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
b, t, d = x.size()
|
| 161 |
+
q_k_v = self.linear_q_k_v(x)
|
| 162 |
+
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
| 163 |
+
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
|
| 164 |
+
1, 2
|
| 165 |
+
) # (batch, head, time1, d_k)
|
| 166 |
+
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
|
| 167 |
+
1, 2
|
| 168 |
+
) # (batch, head, time2, d_k)
|
| 169 |
+
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
|
| 170 |
+
1, 2
|
| 171 |
+
) # (batch, head, time2, d_k)
|
| 172 |
+
|
| 173 |
+
return q_h, k_h, v_h, v
|
| 174 |
+
|
| 175 |
+
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
| 176 |
+
"""Compute attention context vector.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
| 180 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
| 181 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 185 |
+
weighted by the attention score (#batch, time1, time2).
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
n_batch = value.size(0)
|
| 189 |
+
if mask is not None:
|
| 190 |
+
if mask_att_chunk_encoder is not None:
|
| 191 |
+
mask = mask * mask_att_chunk_encoder
|
| 192 |
+
|
| 193 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 194 |
+
|
| 195 |
+
min_value = -float(
|
| 196 |
+
"inf"
|
| 197 |
+
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 198 |
+
scores = scores.masked_fill(mask, min_value)
|
| 199 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 200 |
+
mask, 0.0
|
| 201 |
+
) # (batch, head, time1, time2)
|
| 202 |
+
else:
|
| 203 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 204 |
+
|
| 205 |
+
p_attn = self.dropout(attn)
|
| 206 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 207 |
+
x = (
|
| 208 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 209 |
+
) # (batch, time1, d_model)
|
| 210 |
+
|
| 211 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 212 |
+
|
| 213 |
+
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
| 214 |
+
"""Compute scaled dot product attention.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 218 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 219 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 220 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 221 |
+
(#batch, time1, time2).
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
q_h, k_h, v_h, v = self.forward_qkv(x)
|
| 228 |
+
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
|
| 229 |
+
q_h = q_h * self.d_k ** (-0.5)
|
| 230 |
+
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
| 231 |
+
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
| 232 |
+
return att_outs + fsmn_memory
|
| 233 |
+
|
| 234 |
+
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
| 235 |
+
"""Compute scaled dot product attention.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 239 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 240 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 241 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 242 |
+
(#batch, time1, time2).
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 246 |
+
|
| 247 |
+
"""
|
| 248 |
+
q_h, k_h, v_h, v = self.forward_qkv(x)
|
| 249 |
+
if chunk_size is not None and look_back > 0 or look_back == -1:
|
| 250 |
+
if cache is not None:
|
| 251 |
+
k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
|
| 252 |
+
v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
|
| 253 |
+
k_h = torch.cat((cache["k"], k_h), dim=2)
|
| 254 |
+
v_h = torch.cat((cache["v"], v_h), dim=2)
|
| 255 |
+
|
| 256 |
+
cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
|
| 257 |
+
cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
|
| 258 |
+
if look_back != -1:
|
| 259 |
+
cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
|
| 260 |
+
cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
|
| 261 |
+
else:
|
| 262 |
+
cache_tmp = {
|
| 263 |
+
"k": k_h[:, :, : -(chunk_size[2]), :],
|
| 264 |
+
"v": v_h[:, :, : -(chunk_size[2]), :],
|
| 265 |
+
}
|
| 266 |
+
cache = cache_tmp
|
| 267 |
+
fsmn_memory = self.forward_fsmn(v, None)
|
| 268 |
+
q_h = q_h * self.d_k ** (-0.5)
|
| 269 |
+
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
| 270 |
+
att_outs = self.forward_attention(v_h, scores, None)
|
| 271 |
+
return att_outs + fsmn_memory, cache
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class LayerNorm(nn.LayerNorm):
|
| 275 |
+
def __init__(self, *args, **kwargs):
|
| 276 |
+
super().__init__(*args, **kwargs)
|
| 277 |
+
|
| 278 |
+
def forward(self, input):
|
| 279 |
+
output = F.layer_norm(
|
| 280 |
+
input.float(),
|
| 281 |
+
self.normalized_shape,
|
| 282 |
+
self.weight.float() if self.weight is not None else None,
|
| 283 |
+
self.bias.float() if self.bias is not None else None,
|
| 284 |
+
self.eps,
|
| 285 |
+
)
|
| 286 |
+
return output.type_as(input)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
| 290 |
+
if maxlen is None:
|
| 291 |
+
maxlen = lengths.max()
|
| 292 |
+
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
| 293 |
+
matrix = torch.unsqueeze(lengths, dim=-1)
|
| 294 |
+
mask = row_vector < matrix
|
| 295 |
+
mask = mask.detach()
|
| 296 |
+
|
| 297 |
+
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class EncoderLayerSANM(nn.Module):
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
in_size,
|
| 304 |
+
size,
|
| 305 |
+
self_attn,
|
| 306 |
+
feed_forward,
|
| 307 |
+
dropout_rate,
|
| 308 |
+
normalize_before=True,
|
| 309 |
+
concat_after=False,
|
| 310 |
+
stochastic_depth_rate=0.0,
|
| 311 |
+
):
|
| 312 |
+
"""Construct an EncoderLayer object."""
|
| 313 |
+
super(EncoderLayerSANM, self).__init__()
|
| 314 |
+
self.self_attn = self_attn
|
| 315 |
+
self.feed_forward = feed_forward
|
| 316 |
+
self.norm1 = LayerNorm(in_size)
|
| 317 |
+
self.norm2 = LayerNorm(size)
|
| 318 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 319 |
+
self.in_size = in_size
|
| 320 |
+
self.size = size
|
| 321 |
+
self.normalize_before = normalize_before
|
| 322 |
+
self.concat_after = concat_after
|
| 323 |
+
if self.concat_after:
|
| 324 |
+
self.concat_linear = nn.Linear(size + size, size)
|
| 325 |
+
self.stochastic_depth_rate = stochastic_depth_rate
|
| 326 |
+
self.dropout_rate = dropout_rate
|
| 327 |
+
|
| 328 |
+
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
| 329 |
+
"""Compute encoded features.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
| 333 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
| 334 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 338 |
+
torch.Tensor: Mask tensor (#batch, time).
|
| 339 |
+
|
| 340 |
+
"""
|
| 341 |
+
skip_layer = False
|
| 342 |
+
# with stochastic depth, residual connection `x + f(x)` becomes
|
| 343 |
+
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
| 344 |
+
stoch_layer_coeff = 1.0
|
| 345 |
+
if self.training and self.stochastic_depth_rate > 0:
|
| 346 |
+
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
| 347 |
+
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
| 348 |
+
|
| 349 |
+
if skip_layer:
|
| 350 |
+
if cache is not None:
|
| 351 |
+
x = torch.cat([cache, x], dim=1)
|
| 352 |
+
return x, mask
|
| 353 |
+
|
| 354 |
+
residual = x
|
| 355 |
+
if self.normalize_before:
|
| 356 |
+
x = self.norm1(x)
|
| 357 |
+
|
| 358 |
+
if self.concat_after:
|
| 359 |
+
x_concat = torch.cat(
|
| 360 |
+
(
|
| 361 |
+
x,
|
| 362 |
+
self.self_attn(
|
| 363 |
+
x,
|
| 364 |
+
mask,
|
| 365 |
+
mask_shfit_chunk=mask_shfit_chunk,
|
| 366 |
+
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
| 367 |
+
),
|
| 368 |
+
),
|
| 369 |
+
dim=-1,
|
| 370 |
+
)
|
| 371 |
+
if self.in_size == self.size:
|
| 372 |
+
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
| 373 |
+
else:
|
| 374 |
+
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
| 375 |
+
else:
|
| 376 |
+
if self.in_size == self.size:
|
| 377 |
+
x = residual + stoch_layer_coeff * self.dropout(
|
| 378 |
+
self.self_attn(
|
| 379 |
+
x,
|
| 380 |
+
mask,
|
| 381 |
+
mask_shfit_chunk=mask_shfit_chunk,
|
| 382 |
+
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
| 383 |
+
)
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
x = stoch_layer_coeff * self.dropout(
|
| 387 |
+
self.self_attn(
|
| 388 |
+
x,
|
| 389 |
+
mask,
|
| 390 |
+
mask_shfit_chunk=mask_shfit_chunk,
|
| 391 |
+
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
| 392 |
+
)
|
| 393 |
+
)
|
| 394 |
+
if not self.normalize_before:
|
| 395 |
+
x = self.norm1(x)
|
| 396 |
+
|
| 397 |
+
residual = x
|
| 398 |
+
if self.normalize_before:
|
| 399 |
+
x = self.norm2(x)
|
| 400 |
+
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
| 401 |
+
if not self.normalize_before:
|
| 402 |
+
x = self.norm2(x)
|
| 403 |
+
|
| 404 |
+
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
| 405 |
+
|
| 406 |
+
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
| 407 |
+
"""Compute encoded features.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
| 411 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
| 412 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 416 |
+
torch.Tensor: Mask tensor (#batch, time).
|
| 417 |
+
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
residual = x
|
| 421 |
+
if self.normalize_before:
|
| 422 |
+
x = self.norm1(x)
|
| 423 |
+
|
| 424 |
+
if self.in_size == self.size:
|
| 425 |
+
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
| 426 |
+
x = residual + attn
|
| 427 |
+
else:
|
| 428 |
+
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
| 429 |
+
|
| 430 |
+
if not self.normalize_before:
|
| 431 |
+
x = self.norm1(x)
|
| 432 |
+
|
| 433 |
+
residual = x
|
| 434 |
+
if self.normalize_before:
|
| 435 |
+
x = self.norm2(x)
|
| 436 |
+
x = residual + self.feed_forward(x)
|
| 437 |
+
if not self.normalize_before:
|
| 438 |
+
x = self.norm2(x)
|
| 439 |
+
|
| 440 |
+
return x, cache
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
@tables.register("encoder_classes", "SenseVoiceEncoderSmall")
|
| 444 |
+
class SenseVoiceEncoderSmall(nn.Module):
|
| 445 |
+
"""
|
| 446 |
+
Author: Speech Lab of DAMO Academy, Alibaba Group
|
| 447 |
+
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
| 448 |
+
https://arxiv.org/abs/2006.01713
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
def __init__(
|
| 452 |
+
self,
|
| 453 |
+
input_size: int,
|
| 454 |
+
output_size: int = 256,
|
| 455 |
+
attention_heads: int = 4,
|
| 456 |
+
linear_units: int = 2048,
|
| 457 |
+
num_blocks: int = 6,
|
| 458 |
+
tp_blocks: int = 0,
|
| 459 |
+
dropout_rate: float = 0.1,
|
| 460 |
+
positional_dropout_rate: float = 0.1,
|
| 461 |
+
attention_dropout_rate: float = 0.0,
|
| 462 |
+
stochastic_depth_rate: float = 0.0,
|
| 463 |
+
input_layer: Optional[str] = "conv2d",
|
| 464 |
+
pos_enc_class=SinusoidalPositionEncoder,
|
| 465 |
+
normalize_before: bool = True,
|
| 466 |
+
concat_after: bool = False,
|
| 467 |
+
positionwise_layer_type: str = "linear",
|
| 468 |
+
positionwise_conv_kernel_size: int = 1,
|
| 469 |
+
padding_idx: int = -1,
|
| 470 |
+
kernel_size: int = 11,
|
| 471 |
+
sanm_shfit: int = 0,
|
| 472 |
+
selfattention_layer_type: str = "sanm",
|
| 473 |
+
**kwargs,
|
| 474 |
+
):
|
| 475 |
+
super().__init__()
|
| 476 |
+
self._output_size = output_size
|
| 477 |
+
|
| 478 |
+
self.embed = SinusoidalPositionEncoder()
|
| 479 |
+
|
| 480 |
+
self.normalize_before = normalize_before
|
| 481 |
+
|
| 482 |
+
positionwise_layer = PositionwiseFeedForward
|
| 483 |
+
positionwise_layer_args = (
|
| 484 |
+
output_size,
|
| 485 |
+
linear_units,
|
| 486 |
+
dropout_rate,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
| 490 |
+
encoder_selfattn_layer_args0 = (
|
| 491 |
+
attention_heads,
|
| 492 |
+
input_size,
|
| 493 |
+
output_size,
|
| 494 |
+
attention_dropout_rate,
|
| 495 |
+
kernel_size,
|
| 496 |
+
sanm_shfit,
|
| 497 |
+
)
|
| 498 |
+
encoder_selfattn_layer_args = (
|
| 499 |
+
attention_heads,
|
| 500 |
+
output_size,
|
| 501 |
+
output_size,
|
| 502 |
+
attention_dropout_rate,
|
| 503 |
+
kernel_size,
|
| 504 |
+
sanm_shfit,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
self.encoders0 = nn.ModuleList(
|
| 508 |
+
[
|
| 509 |
+
EncoderLayerSANM(
|
| 510 |
+
input_size,
|
| 511 |
+
output_size,
|
| 512 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
| 513 |
+
positionwise_layer(*positionwise_layer_args),
|
| 514 |
+
dropout_rate,
|
| 515 |
+
)
|
| 516 |
+
for i in range(1)
|
| 517 |
+
]
|
| 518 |
+
)
|
| 519 |
+
self.encoders = nn.ModuleList(
|
| 520 |
+
[
|
| 521 |
+
EncoderLayerSANM(
|
| 522 |
+
output_size,
|
| 523 |
+
output_size,
|
| 524 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
| 525 |
+
positionwise_layer(*positionwise_layer_args),
|
| 526 |
+
dropout_rate,
|
| 527 |
+
)
|
| 528 |
+
for i in range(num_blocks - 1)
|
| 529 |
+
]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
self.tp_encoders = nn.ModuleList(
|
| 533 |
+
[
|
| 534 |
+
EncoderLayerSANM(
|
| 535 |
+
output_size,
|
| 536 |
+
output_size,
|
| 537 |
+
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
| 538 |
+
positionwise_layer(*positionwise_layer_args),
|
| 539 |
+
dropout_rate,
|
| 540 |
+
)
|
| 541 |
+
for i in range(tp_blocks)
|
| 542 |
+
]
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self.after_norm = LayerNorm(output_size)
|
| 546 |
+
|
| 547 |
+
self.tp_norm = LayerNorm(output_size)
|
| 548 |
+
|
| 549 |
+
def output_size(self) -> int:
|
| 550 |
+
return self._output_size
|
| 551 |
+
|
| 552 |
+
def forward(
|
| 553 |
+
self,
|
| 554 |
+
xs_pad: torch.Tensor,
|
| 555 |
+
# ilens: torch.Tensor,
|
| 556 |
+
masks: torch.Tensor,
|
| 557 |
+
position_encoding: torch.Tensor
|
| 558 |
+
):
|
| 559 |
+
"""Embed positions in tensor."""
|
| 560 |
+
# masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
|
| 561 |
+
|
| 562 |
+
xs_pad *= self.output_size() ** 0.5
|
| 563 |
+
|
| 564 |
+
# xs_pad = self.embed(xs_pad)
|
| 565 |
+
xs_pad += position_encoding
|
| 566 |
+
|
| 567 |
+
# forward encoder1
|
| 568 |
+
for layer_idx, encoder_layer in enumerate(self.encoders0):
|
| 569 |
+
encoder_outs = encoder_layer(xs_pad, masks)
|
| 570 |
+
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
| 571 |
+
|
| 572 |
+
for layer_idx, encoder_layer in enumerate(self.encoders):
|
| 573 |
+
encoder_outs = encoder_layer(xs_pad, masks)
|
| 574 |
+
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
| 575 |
+
|
| 576 |
+
xs_pad = self.after_norm(xs_pad)
|
| 577 |
+
|
| 578 |
+
# forward encoder2
|
| 579 |
+
olens = masks.squeeze(1).sum(1).int()
|
| 580 |
+
|
| 581 |
+
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
|
| 582 |
+
encoder_outs = encoder_layer(xs_pad, masks)
|
| 583 |
+
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
| 584 |
+
|
| 585 |
+
xs_pad = self.tp_norm(xs_pad)
|
| 586 |
+
return xs_pad, olens
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@tables.register("model_classes", "SenseVoiceSmall")
|
| 590 |
+
class SenseVoiceSmall(nn.Module):
|
| 591 |
+
"""CTC-attention hybrid Encoder-Decoder model"""
|
| 592 |
+
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
specaug: str = None,
|
| 596 |
+
specaug_conf: dict = None,
|
| 597 |
+
normalize: str = None,
|
| 598 |
+
normalize_conf: dict = None,
|
| 599 |
+
encoder: str = None,
|
| 600 |
+
encoder_conf: dict = None,
|
| 601 |
+
ctc_conf: dict = None,
|
| 602 |
+
input_size: int = 80,
|
| 603 |
+
vocab_size: int = -1,
|
| 604 |
+
ignore_id: int = -1,
|
| 605 |
+
blank_id: int = 0,
|
| 606 |
+
sos: int = 1,
|
| 607 |
+
eos: int = 2,
|
| 608 |
+
length_normalized_loss: bool = False,
|
| 609 |
+
seq_len = 68,
|
| 610 |
+
**kwargs,
|
| 611 |
+
):
|
| 612 |
+
|
| 613 |
+
super().__init__()
|
| 614 |
+
|
| 615 |
+
if specaug is not None:
|
| 616 |
+
specaug_class = tables.specaug_classes.get(specaug)
|
| 617 |
+
specaug = specaug_class(**specaug_conf)
|
| 618 |
+
if normalize is not None:
|
| 619 |
+
normalize_class = tables.normalize_classes.get(normalize)
|
| 620 |
+
normalize = normalize_class(**normalize_conf)
|
| 621 |
+
encoder_class = tables.encoder_classes.get(encoder)
|
| 622 |
+
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
| 623 |
+
encoder_output_size = encoder.output_size()
|
| 624 |
+
|
| 625 |
+
if ctc_conf is None:
|
| 626 |
+
ctc_conf = {}
|
| 627 |
+
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
|
| 628 |
+
|
| 629 |
+
self.blank_id = blank_id
|
| 630 |
+
self.sos = sos if sos is not None else vocab_size - 1
|
| 631 |
+
self.eos = eos if eos is not None else vocab_size - 1
|
| 632 |
+
self.vocab_size = vocab_size
|
| 633 |
+
self.ignore_id = ignore_id
|
| 634 |
+
self.specaug = specaug
|
| 635 |
+
self.normalize = normalize
|
| 636 |
+
self.encoder = encoder
|
| 637 |
+
self.error_calculator = None
|
| 638 |
+
|
| 639 |
+
self.ctc = ctc
|
| 640 |
+
|
| 641 |
+
self.length_normalized_loss = length_normalized_loss
|
| 642 |
+
self.encoder_output_size = encoder_output_size
|
| 643 |
+
|
| 644 |
+
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
| 645 |
+
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
|
| 646 |
+
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
| 647 |
+
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
| 648 |
+
self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
|
| 649 |
+
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
|
| 650 |
+
|
| 651 |
+
self.criterion_att = LabelSmoothingLoss(
|
| 652 |
+
size=self.vocab_size,
|
| 653 |
+
padding_idx=self.ignore_id,
|
| 654 |
+
smoothing=kwargs.get("lsm_weight", 0.0),
|
| 655 |
+
normalize_length=self.length_normalized_loss,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
self.seq_len = seq_len
|
| 659 |
+
|
| 660 |
+
@staticmethod
|
| 661 |
+
def from_pretrained(model:str=None, **kwargs):
|
| 662 |
+
from funasr import AutoModel
|
| 663 |
+
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
|
| 664 |
+
|
| 665 |
+
return model, kwargs
|
| 666 |
+
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
speech: torch.Tensor,
|
| 670 |
+
speech_lengths: torch.Tensor,
|
| 671 |
+
text: torch.Tensor,
|
| 672 |
+
text_lengths: torch.Tensor,
|
| 673 |
+
**kwargs,
|
| 674 |
+
):
|
| 675 |
+
"""Encoder + Decoder + Calc loss
|
| 676 |
+
Args:
|
| 677 |
+
speech: (Batch, Length, ...)
|
| 678 |
+
speech_lengths: (Batch, )
|
| 679 |
+
text: (Batch, Length)
|
| 680 |
+
text_lengths: (Batch,)
|
| 681 |
+
"""
|
| 682 |
+
# import pdb;
|
| 683 |
+
# pdb.set_trace()
|
| 684 |
+
if len(text_lengths.size()) > 1:
|
| 685 |
+
text_lengths = text_lengths[:, 0]
|
| 686 |
+
if len(speech_lengths.size()) > 1:
|
| 687 |
+
speech_lengths = speech_lengths[:, 0]
|
| 688 |
+
|
| 689 |
+
batch_size = speech.shape[0]
|
| 690 |
+
|
| 691 |
+
# 1. Encoder
|
| 692 |
+
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
|
| 693 |
+
|
| 694 |
+
loss_ctc, cer_ctc = None, None
|
| 695 |
+
loss_rich, acc_rich = None, None
|
| 696 |
+
stats = dict()
|
| 697 |
+
|
| 698 |
+
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
| 699 |
+
encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
loss_rich, acc_rich = self._calc_rich_ce_loss(
|
| 703 |
+
encoder_out[:, :4, :], text[:, :4]
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
loss = loss_ctc + loss_rich
|
| 707 |
+
# Collect total loss stats
|
| 708 |
+
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
|
| 709 |
+
stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
|
| 710 |
+
stats["loss"] = torch.clone(loss.detach()) if loss is not None else None
|
| 711 |
+
stats["acc_rich"] = acc_rich
|
| 712 |
+
|
| 713 |
+
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
| 714 |
+
if self.length_normalized_loss:
|
| 715 |
+
batch_size = int((text_lengths + 1).sum())
|
| 716 |
+
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
| 717 |
+
return loss, stats, weight
|
| 718 |
+
|
| 719 |
+
def encode(
|
| 720 |
+
self,
|
| 721 |
+
speech: torch.Tensor,
|
| 722 |
+
speech_lengths: torch.Tensor,
|
| 723 |
+
text: torch.Tensor,
|
| 724 |
+
**kwargs,
|
| 725 |
+
):
|
| 726 |
+
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
| 727 |
+
Args:
|
| 728 |
+
speech: (Batch, Length, ...)
|
| 729 |
+
speech_lengths: (Batch, )
|
| 730 |
+
ind: int
|
| 731 |
+
"""
|
| 732 |
+
|
| 733 |
+
# Data augmentation
|
| 734 |
+
if self.specaug is not None and self.training:
|
| 735 |
+
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
| 736 |
+
|
| 737 |
+
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
| 738 |
+
if self.normalize is not None:
|
| 739 |
+
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
|
| 743 |
+
language_query = self.embed(lids)
|
| 744 |
+
|
| 745 |
+
styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
|
| 746 |
+
style_query = self.embed(styles)
|
| 747 |
+
speech = torch.cat((style_query, speech), dim=1)
|
| 748 |
+
speech_lengths += 1
|
| 749 |
+
|
| 750 |
+
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
| 751 |
+
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
| 752 |
+
speech = torch.cat((input_query, speech), dim=1)
|
| 753 |
+
speech_lengths += 3
|
| 754 |
+
|
| 755 |
+
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
| 756 |
+
|
| 757 |
+
return encoder_out, encoder_out_lens
|
| 758 |
+
|
| 759 |
+
def _calc_ctc_loss(
|
| 760 |
+
self,
|
| 761 |
+
encoder_out: torch.Tensor,
|
| 762 |
+
encoder_out_lens: torch.Tensor,
|
| 763 |
+
ys_pad: torch.Tensor,
|
| 764 |
+
ys_pad_lens: torch.Tensor,
|
| 765 |
+
):
|
| 766 |
+
# Calc CTC loss
|
| 767 |
+
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
| 768 |
+
|
| 769 |
+
# Calc CER using CTC
|
| 770 |
+
cer_ctc = None
|
| 771 |
+
if not self.training and self.error_calculator is not None:
|
| 772 |
+
ys_hat = self.ctc.argmax(encoder_out).data
|
| 773 |
+
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
| 774 |
+
return loss_ctc, cer_ctc
|
| 775 |
+
|
| 776 |
+
def _calc_rich_ce_loss(
|
| 777 |
+
self,
|
| 778 |
+
encoder_out: torch.Tensor,
|
| 779 |
+
ys_pad: torch.Tensor,
|
| 780 |
+
):
|
| 781 |
+
decoder_out = self.ctc.ctc_lo(encoder_out)
|
| 782 |
+
# 2. Compute attention loss
|
| 783 |
+
loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
|
| 784 |
+
acc_rich = th_accuracy(
|
| 785 |
+
decoder_out.view(-1, self.vocab_size),
|
| 786 |
+
ys_pad.contiguous(),
|
| 787 |
+
ignore_label=self.ignore_id,
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
return loss_rich, acc_rich
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def inference(
|
| 794 |
+
self,
|
| 795 |
+
data_in,
|
| 796 |
+
data_lengths=None,
|
| 797 |
+
key: list = ["wav_file_tmp_name"],
|
| 798 |
+
tokenizer=None,
|
| 799 |
+
frontend=None,
|
| 800 |
+
**kwargs,
|
| 801 |
+
):
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
meta_data = {}
|
| 805 |
+
if (
|
| 806 |
+
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
| 807 |
+
): # fbank
|
| 808 |
+
speech, speech_lengths = data_in, data_lengths
|
| 809 |
+
if len(speech.shape) < 3:
|
| 810 |
+
speech = speech[None, :, :]
|
| 811 |
+
if speech_lengths is None:
|
| 812 |
+
speech_lengths = speech.shape[1]
|
| 813 |
+
else:
|
| 814 |
+
# extract fbank feats
|
| 815 |
+
time1 = time.perf_counter()
|
| 816 |
+
audio_sample_list = load_audio_text_image_video(
|
| 817 |
+
data_in,
|
| 818 |
+
fs=frontend.fs,
|
| 819 |
+
audio_fs=kwargs.get("fs", 16000),
|
| 820 |
+
data_type=kwargs.get("data_type", "sound"),
|
| 821 |
+
tokenizer=tokenizer,
|
| 822 |
+
)
|
| 823 |
+
time2 = time.perf_counter()
|
| 824 |
+
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
| 825 |
+
speech, speech_lengths = extract_fbank(
|
| 826 |
+
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
|
| 827 |
+
)
|
| 828 |
+
time3 = time.perf_counter()
|
| 829 |
+
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
| 830 |
+
meta_data["batch_data_time"] = (
|
| 831 |
+
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
speech = speech.to(device=kwargs["device"])
|
| 835 |
+
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
| 836 |
+
|
| 837 |
+
language = kwargs.get("language", "auto")
|
| 838 |
+
language_query = self.embed(
|
| 839 |
+
torch.LongTensor(
|
| 840 |
+
[[self.lid_dict[language] if language in self.lid_dict else 0]]
|
| 841 |
+
).to(speech.device)
|
| 842 |
+
).repeat(speech.size(0), 1, 1)
|
| 843 |
+
|
| 844 |
+
use_itn = kwargs.get("use_itn", False)
|
| 845 |
+
output_timestamp = kwargs.get("output_timestamp", False)
|
| 846 |
+
|
| 847 |
+
textnorm = kwargs.get("text_norm", None)
|
| 848 |
+
if textnorm is None:
|
| 849 |
+
textnorm = "withitn" if use_itn else "woitn"
|
| 850 |
+
textnorm_query = self.embed(
|
| 851 |
+
torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
|
| 852 |
+
).repeat(speech.size(0), 1, 1)
|
| 853 |
+
speech = torch.cat((textnorm_query, speech), dim=1)
|
| 854 |
+
speech_lengths += 1
|
| 855 |
+
|
| 856 |
+
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
|
| 857 |
+
speech.size(0), 1, 1
|
| 858 |
+
)
|
| 859 |
+
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
| 860 |
+
speech = torch.cat((input_query, speech), dim=1)
|
| 861 |
+
speech_lengths += 3
|
| 862 |
+
|
| 863 |
+
# Encoder
|
| 864 |
+
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
| 865 |
+
if isinstance(encoder_out, tuple):
|
| 866 |
+
encoder_out = encoder_out[0]
|
| 867 |
+
|
| 868 |
+
# c. Passed the encoder result and the beam search
|
| 869 |
+
ctc_logits = self.ctc.log_softmax(encoder_out)
|
| 870 |
+
if kwargs.get("ban_emo_unk", False):
|
| 871 |
+
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
|
| 872 |
+
|
| 873 |
+
results = []
|
| 874 |
+
b, n, d = encoder_out.size()
|
| 875 |
+
if isinstance(key[0], (list, tuple)):
|
| 876 |
+
key = key[0]
|
| 877 |
+
if len(key) < b:
|
| 878 |
+
key = key * b
|
| 879 |
+
for i in range(b):
|
| 880 |
+
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
|
| 881 |
+
yseq = x.argmax(dim=-1)
|
| 882 |
+
yseq = torch.unique_consecutive(yseq, dim=-1)
|
| 883 |
+
|
| 884 |
+
ibest_writer = None
|
| 885 |
+
if kwargs.get("output_dir") is not None:
|
| 886 |
+
if not hasattr(self, "writer"):
|
| 887 |
+
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
| 888 |
+
ibest_writer = self.writer[f"1best_recog"]
|
| 889 |
+
|
| 890 |
+
mask = yseq != self.blank_id
|
| 891 |
+
token_int = yseq[mask].tolist()
|
| 892 |
+
|
| 893 |
+
# Change integer-ids to tokens
|
| 894 |
+
text = tokenizer.decode(token_int)
|
| 895 |
+
if ibest_writer is not None:
|
| 896 |
+
ibest_writer["text"][key[i]] = text
|
| 897 |
+
|
| 898 |
+
if output_timestamp:
|
| 899 |
+
from itertools import groupby
|
| 900 |
+
timestamp = []
|
| 901 |
+
tokens = tokenizer.text2tokens(text)[4:]
|
| 902 |
+
|
| 903 |
+
logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]
|
| 904 |
+
|
| 905 |
+
pred = logits_speech.argmax(-1).cpu()
|
| 906 |
+
logits_speech[pred==self.blank_id, self.blank_id] = 0
|
| 907 |
+
|
| 908 |
+
align = ctc_forced_align(
|
| 909 |
+
logits_speech.unsqueeze(0).float(),
|
| 910 |
+
torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
|
| 911 |
+
(encoder_out_lens-4).long(),
|
| 912 |
+
torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device),
|
| 913 |
+
ignore_id=self.ignore_id,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
pred = groupby(align[0, :encoder_out_lens[0]])
|
| 917 |
+
_start = 0
|
| 918 |
+
token_id = 0
|
| 919 |
+
ts_max = encoder_out_lens[i] - 4
|
| 920 |
+
for pred_token, pred_frame in pred:
|
| 921 |
+
_end = _start + len(list(pred_frame))
|
| 922 |
+
if pred_token != 0:
|
| 923 |
+
ts_left = max((_start*60-30)/1000, 0)
|
| 924 |
+
ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000)
|
| 925 |
+
timestamp.append([tokens[token_id], ts_left, ts_right])
|
| 926 |
+
token_id += 1
|
| 927 |
+
_start = _end
|
| 928 |
+
|
| 929 |
+
result_i = {"key": key[i], "text": text, "timestamp": timestamp}
|
| 930 |
+
results.append(result_i)
|
| 931 |
+
else:
|
| 932 |
+
result_i = {"key": key[i], "text": text}
|
| 933 |
+
results.append(result_i)
|
| 934 |
+
return results, meta_data
|
| 935 |
+
|
| 936 |
+
def export(self, **kwargs):
|
| 937 |
+
from export_meta import export_rebuild_model
|
| 938 |
+
|
| 939 |
+
if "max_seq_len" not in kwargs:
|
| 940 |
+
kwargs["max_seq_len"] = 512
|
| 941 |
+
models = export_rebuild_model(model=self, **kwargs)
|
| 942 |
+
return models
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
tqdm
|
| 3 |
+
funasr==1.2.7
|
| 4 |
+
torchaudio
|
| 5 |
+
cn2an
|
utils/__init__.py
ADDED
|
File without changes
|
utils/ax_model_bin.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
import os.path
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Union, Tuple
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import axengine as axe
|
| 12 |
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
| 13 |
+
try:
|
| 14 |
+
import librosa
|
| 15 |
+
except ImportError:
|
| 16 |
+
print("Warning: librosa not found. Please install it using 'pip install librosa'.")
|
| 17 |
+
# Provide a fallback implementation if needed
|
| 18 |
+
def load_wav_fallback(path, sr=None):
|
| 19 |
+
import wave
|
| 20 |
+
import numpy as np
|
| 21 |
+
with wave.open(path, 'rb') as wf:
|
| 22 |
+
num_frames = wf.getnframes()
|
| 23 |
+
frames = wf.readframes(num_frames)
|
| 24 |
+
return np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0, wf.getframerate()
|
| 25 |
+
|
| 26 |
+
from utils.infer_utils import (
|
| 27 |
+
CharTokenizer,
|
| 28 |
+
get_logger,
|
| 29 |
+
read_yaml,
|
| 30 |
+
)
|
| 31 |
+
from utils.frontend import WavFrontend
|
| 32 |
+
from utils.ctc_alignment import ctc_forced_align
|
| 33 |
+
|
| 34 |
+
logging = get_logger()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
| 38 |
+
if maxlen is None:
|
| 39 |
+
maxlen = lengths.max()
|
| 40 |
+
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
| 41 |
+
matrix = torch.unsqueeze(lengths, dim=-1)
|
| 42 |
+
mask = row_vector < matrix
|
| 43 |
+
mask = mask.detach()
|
| 44 |
+
|
| 45 |
+
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class AX_SenseVoiceSmall:
|
| 49 |
+
"""
|
| 50 |
+
Author: Speech Lab of DAMO Academy, Alibaba Group
|
| 51 |
+
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
| 52 |
+
https://arxiv.org/abs/2206.08317
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
model_dir: Union[str, Path] = None,
|
| 58 |
+
batch_size: int = 1,
|
| 59 |
+
seq_len: int = 68
|
| 60 |
+
):
|
| 61 |
+
|
| 62 |
+
model_file = os.path.join(model_dir, "sensevoice.axmodel")
|
| 63 |
+
config_file = os.path.join(model_dir, "sensevoice/config.yaml")
|
| 64 |
+
cmvn_file = os.path.join(model_dir, "sensevoice/am.mvn")
|
| 65 |
+
config = read_yaml(config_file)
|
| 66 |
+
self.model_dir = model_dir
|
| 67 |
+
# token_list = os.path.join(model_dir, "tokens.json")
|
| 68 |
+
# with open(token_list, "r", encoding="utf-8") as f:
|
| 69 |
+
# token_list = json.load(f)
|
| 70 |
+
|
| 71 |
+
# self.converter = TokenIDConverter(token_list)
|
| 72 |
+
self.tokenizer = CharTokenizer()
|
| 73 |
+
config["frontend_conf"]['cmvn_file'] = cmvn_file
|
| 74 |
+
self.frontend = WavFrontend(**config["frontend_conf"])
|
| 75 |
+
# self.ort_infer = OrtInferSession(
|
| 76 |
+
# model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
| 77 |
+
# )
|
| 78 |
+
self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.blank_id = 0
|
| 81 |
+
self.seq_len = seq_len
|
| 82 |
+
|
| 83 |
+
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
| 84 |
+
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
|
| 85 |
+
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
| 86 |
+
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
| 87 |
+
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
|
| 88 |
+
|
| 89 |
+
def __call__(self,
|
| 90 |
+
wav_content: Union[str, np.ndarray, List[str]],
|
| 91 |
+
language: str,
|
| 92 |
+
withitn: bool,
|
| 93 |
+
position_encoding: np.ndarray,
|
| 94 |
+
tokenizer=None,
|
| 95 |
+
**kwargs) -> List:
|
| 96 |
+
"""Enhanced model inference with additional features from model.py
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
wav_content: Audio data or path
|
| 100 |
+
language: Language code for processing
|
| 101 |
+
withitn: Whether to use ITN (inverse text normalization)
|
| 102 |
+
position_encoding: Position encoding tensor
|
| 103 |
+
tokenizer: Tokenizer for text conversion
|
| 104 |
+
**kwargs: Additional arguments
|
| 105 |
+
"""
|
| 106 |
+
# Start time tracking for metadata
|
| 107 |
+
import time
|
| 108 |
+
meta_data = {}
|
| 109 |
+
time_start = time.perf_counter()
|
| 110 |
+
|
| 111 |
+
# Load waveform data
|
| 112 |
+
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
| 113 |
+
waveform_nums = len(waveform_list)
|
| 114 |
+
time_load = time.perf_counter()
|
| 115 |
+
meta_data["load_data"] = f"{time_load - time_start:0.3f}"
|
| 116 |
+
|
| 117 |
+
# Load queries from saved numpy files
|
| 118 |
+
language_query = np.load(os.path.join(self.model_dir, f"{language}.npy"))
|
| 119 |
+
textnorm_query = np.load(os.path.join(self.model_dir, "withitn.npy") if withitn
|
| 120 |
+
else os.path.join(self.model_dir, "woitn.npy"))
|
| 121 |
+
event_emo_query = np.load(os.path.join(self.model_dir, "event_emo.npy"))
|
| 122 |
+
|
| 123 |
+
# Concatenate queries to form input_query
|
| 124 |
+
input_query = np.concatenate((language_query, event_emo_query, textnorm_query), axis=1)
|
| 125 |
+
|
| 126 |
+
# Process features
|
| 127 |
+
results = ""
|
| 128 |
+
|
| 129 |
+
# Handle output_dir without using DatadirWriter (which is not available)
|
| 130 |
+
slice_len = self.seq_len - 4
|
| 131 |
+
time_pre = time.perf_counter()
|
| 132 |
+
meta_data["preprocess"] = f"{time_pre - time_load:0.3f}"
|
| 133 |
+
for beg_idx in range(0, waveform_nums, self.batch_size):
|
| 134 |
+
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
| 135 |
+
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
| 136 |
+
|
| 137 |
+
time_feat = time.perf_counter()
|
| 138 |
+
meta_data["extract_feat"] = f"{time_feat - time_pre:0.3f}"
|
| 139 |
+
|
| 140 |
+
for i in range(int(np.ceil(feats.shape[1] / slice_len))):
|
| 141 |
+
sub_feats = np.concatenate([input_query, feats[:, i*slice_len : (i+1)*slice_len, :]], axis=1)
|
| 142 |
+
feats_len[0] = sub_feats.shape[1]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if feats_len[0] < self.seq_len:
|
| 146 |
+
sub_feats = np.concatenate([sub_feats, np.zeros((1, self.seq_len - feats_len[0], 560), dtype=np.float32)], axis=1)
|
| 147 |
+
|
| 148 |
+
masks = sequence_mask(torch.IntTensor([self.seq_len]), maxlen=self.seq_len, dtype=torch.float32)[:, None, :]
|
| 149 |
+
masks = masks.numpy()
|
| 150 |
+
|
| 151 |
+
# Run inference
|
| 152 |
+
|
| 153 |
+
ctc_logits, encoder_out_lens = self.infer(sub_feats, masks, position_encoding)
|
| 154 |
+
# Convert to torch tensor for processing
|
| 155 |
+
ctc_logits = torch.from_numpy(ctc_logits).float()
|
| 156 |
+
|
| 157 |
+
# Process results for each batch
|
| 158 |
+
b, _, _ = ctc_logits.size()
|
| 159 |
+
|
| 160 |
+
for j in range(b):
|
| 161 |
+
x = ctc_logits[j, : encoder_out_lens[j].item(), :]
|
| 162 |
+
yseq = x.argmax(dim=-1)
|
| 163 |
+
yseq = torch.unique_consecutive(yseq, dim=-1)
|
| 164 |
+
|
| 165 |
+
mask = yseq != self.blank_id
|
| 166 |
+
token_int = yseq[mask].tolist()[4:] #前4个略去: <|zh|><|ANGRY|><|Speech|><|withitn|>
|
| 167 |
+
|
| 168 |
+
# Convert tokens to text
|
| 169 |
+
text = tokenizer.decode(token_int) if tokenizer is not None else str(token_int)
|
| 170 |
+
|
| 171 |
+
if tokenizer is not None:
|
| 172 |
+
results+= text
|
| 173 |
+
else:
|
| 174 |
+
results+= token_int
|
| 175 |
+
return results
|
| 176 |
+
|
| 177 |
+
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
| 178 |
+
def load_wav(path: str) -> np.ndarray:
|
| 179 |
+
try:
|
| 180 |
+
# Use librosa if available
|
| 181 |
+
if 'librosa' in globals():
|
| 182 |
+
waveform, _ = librosa.load(path, sr=fs)
|
| 183 |
+
else:
|
| 184 |
+
# Use fallback implementation
|
| 185 |
+
waveform, native_sr = load_wav_fallback(path)
|
| 186 |
+
if fs is not None and native_sr != fs:
|
| 187 |
+
# Implement resampling if needed
|
| 188 |
+
print(f"Warning: Resampling from {native_sr} to {fs} is not implemented in fallback mode")
|
| 189 |
+
return waveform
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error loading audio file {path}: {e}")
|
| 192 |
+
# Return empty audio in case of error
|
| 193 |
+
return np.zeros(1600, dtype=np.float32)
|
| 194 |
+
|
| 195 |
+
if isinstance(wav_content, np.ndarray):
|
| 196 |
+
return [wav_content]
|
| 197 |
+
|
| 198 |
+
if isinstance(wav_content, str):
|
| 199 |
+
return [load_wav(wav_content)]
|
| 200 |
+
|
| 201 |
+
if isinstance(wav_content, list):
|
| 202 |
+
return [load_wav(path) for path in wav_content]
|
| 203 |
+
|
| 204 |
+
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
| 205 |
+
|
| 206 |
+
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 207 |
+
feats, feats_len = [], []
|
| 208 |
+
for waveform in waveform_list:
|
| 209 |
+
speech, _ = self.frontend.fbank(waveform)
|
| 210 |
+
|
| 211 |
+
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
| 212 |
+
|
| 213 |
+
feats.append(feat)
|
| 214 |
+
feats_len.append(feat_len)
|
| 215 |
+
|
| 216 |
+
feats = self.pad_feats(feats, np.max(feats_len))
|
| 217 |
+
feats_len = np.array(feats_len).astype(np.int32)
|
| 218 |
+
return feats, feats_len
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
| 222 |
+
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
| 223 |
+
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
| 224 |
+
return np.pad(feat, pad_width, "constant", constant_values=0)
|
| 225 |
+
|
| 226 |
+
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
| 227 |
+
feats = np.array(feat_res).astype(np.float32)
|
| 228 |
+
return feats
|
| 229 |
+
|
| 230 |
+
def infer(self,
|
| 231 |
+
feats: np.ndarray,
|
| 232 |
+
masks: np.ndarray,
|
| 233 |
+
position_encoding: np.ndarray,
|
| 234 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 235 |
+
#outputs = self.ort_infer([feats, masks, position_encoding])
|
| 236 |
+
outputs =self.session.run(None, {
|
| 237 |
+
'speech': feats,
|
| 238 |
+
'masks': masks,
|
| 239 |
+
'position_encoding': position_encoding
|
| 240 |
+
})
|
| 241 |
+
return outputs
|
utils/ax_vad_bin.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
| 3 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 4 |
+
|
| 5 |
+
import os.path
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from utils.utils.utils import read_yaml
|
| 11 |
+
from utils.utils.frontend import WavFrontend
|
| 12 |
+
from utils.utils.e2e_vad import E2EVadModel
|
| 13 |
+
import axengine as axe
|
| 14 |
+
|
| 15 |
+
class AX_Fsmn_vad:
|
| 16 |
+
def __init__(self, model_dir, batch_size=1, max_end_sil=None):
|
| 17 |
+
"""Initialize VAD model for inference"""
|
| 18 |
+
|
| 19 |
+
# Export model if needed
|
| 20 |
+
model_file = os.path.join(model_dir, "vad.axmodel")
|
| 21 |
+
|
| 22 |
+
# Load config and frontend
|
| 23 |
+
config_file = os.path.join(model_dir, "vad/config.yaml")
|
| 24 |
+
cmvn_file = os.path.join(model_dir, "vad/am.mvn")
|
| 25 |
+
self.config = read_yaml(config_file)
|
| 26 |
+
self.frontend = WavFrontend(cmvn_file=cmvn_file, **self.config["frontend_conf"])
|
| 27 |
+
self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
|
| 28 |
+
self.batch_size = batch_size
|
| 29 |
+
self.vad_scorer = E2EVadModel(self.config["model_conf"])
|
| 30 |
+
self.max_end_sil = max_end_sil if max_end_sil is not None else self.config["model_conf"]["max_end_silence_time"]
|
| 31 |
+
|
| 32 |
+
def extract_feat(self, waveform_list):
|
| 33 |
+
"""Extract features from waveform"""
|
| 34 |
+
feats, feats_len = [], []
|
| 35 |
+
for waveform in waveform_list:
|
| 36 |
+
speech, _ = self.frontend.fbank(waveform)
|
| 37 |
+
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
| 38 |
+
feats.append(feat)
|
| 39 |
+
feats_len.append(feat_len)
|
| 40 |
+
|
| 41 |
+
max_len = max(feats_len)
|
| 42 |
+
padded_feats = [np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), 'constant') for f in feats]
|
| 43 |
+
feats = np.array(padded_feats).astype(np.float32)
|
| 44 |
+
feats_len = np.array(feats_len).astype(np.int32)
|
| 45 |
+
return feats, feats_len
|
| 46 |
+
|
| 47 |
+
def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
|
| 48 |
+
"""Run inference with ONNX Runtime"""
|
| 49 |
+
# Get all input names from the model
|
| 50 |
+
input_names = [input.name for input in self.session.get_inputs()]
|
| 51 |
+
output_names = [x.name for x in self.session.get_outputs()]
|
| 52 |
+
|
| 53 |
+
# Create input dictionary for all inputs
|
| 54 |
+
input_dict = {}
|
| 55 |
+
for i, (name, tensor) in enumerate(zip(input_names, feats)):
|
| 56 |
+
input_dict[name] = tensor
|
| 57 |
+
|
| 58 |
+
# Run inference with all inputs
|
| 59 |
+
outputs = self.session.run(output_names, input_dict)
|
| 60 |
+
scores, out_caches = outputs[0], outputs[1:]
|
| 61 |
+
return scores, out_caches
|
| 62 |
+
|
| 63 |
+
def __call__(self, wav_file, **kwargs):
|
| 64 |
+
"""Process audio file with sliding window approach"""
|
| 65 |
+
# Load audio and prepare data
|
| 66 |
+
# waveform = self.load_wav(wav_file)
|
| 67 |
+
# waveform, _ = librosa.load(wav_file, sr=16000)
|
| 68 |
+
waveform_list = [wav_file]
|
| 69 |
+
waveform_nums = len(waveform_list)
|
| 70 |
+
is_final = kwargs.get("kwargs", False)
|
| 71 |
+
segments = [[]] * self.batch_size
|
| 72 |
+
|
| 73 |
+
for beg_idx in range(0, waveform_nums, self.batch_size):
|
| 74 |
+
vad_scorer = E2EVadModel(self.config["model_conf"])
|
| 75 |
+
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
| 76 |
+
waveform = waveform_list[beg_idx:end_idx]
|
| 77 |
+
feats, feats_len = self.extract_feat(waveform)
|
| 78 |
+
waveform = np.array(waveform)
|
| 79 |
+
param_dict = kwargs.get("param_dict", dict())
|
| 80 |
+
in_cache = param_dict.get("in_cache", list())
|
| 81 |
+
in_cache = self.prepare_cache(in_cache)
|
| 82 |
+
|
| 83 |
+
t_offset = 0
|
| 84 |
+
step = int(min(feats_len.max(), 6000))
|
| 85 |
+
for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
|
| 86 |
+
if t_offset + step >= feats_len - 1:
|
| 87 |
+
step = feats_len - t_offset
|
| 88 |
+
is_final = True
|
| 89 |
+
else:
|
| 90 |
+
is_final = False
|
| 91 |
+
|
| 92 |
+
# Extract feature segment
|
| 93 |
+
feats_package = feats[:, t_offset:int(t_offset + step), :]
|
| 94 |
+
|
| 95 |
+
# Pad if it's the final segment
|
| 96 |
+
if is_final:
|
| 97 |
+
pad_length = 6000 - int(step)
|
| 98 |
+
feats_package = np.pad(
|
| 99 |
+
feats_package,
|
| 100 |
+
((0, 0), (0, pad_length), (0, 0)),
|
| 101 |
+
mode='constant',
|
| 102 |
+
constant_values=0
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Extract corresponding waveform segment
|
| 106 |
+
waveform_package = waveform[
|
| 107 |
+
:,
|
| 108 |
+
t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400),
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# Pad waveform if it's the final segment
|
| 112 |
+
if is_final:
|
| 113 |
+
expected_wave_length = 6000 * 160 + 240
|
| 114 |
+
current_wave_length = waveform_package.shape[-1]
|
| 115 |
+
pad_wave_length = expected_wave_length - current_wave_length
|
| 116 |
+
if pad_wave_length > 0:
|
| 117 |
+
waveform_package = np.pad(
|
| 118 |
+
waveform_package,
|
| 119 |
+
((0, 0), (0, pad_wave_length)),
|
| 120 |
+
mode='constant',
|
| 121 |
+
constant_values=0
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Run inference
|
| 125 |
+
inputs = [feats_package]
|
| 126 |
+
inputs.extend(in_cache)
|
| 127 |
+
scores, out_caches = self.infer(inputs)
|
| 128 |
+
in_cache = out_caches
|
| 129 |
+
|
| 130 |
+
# Get VAD segments for this chunk
|
| 131 |
+
segments_part = vad_scorer(
|
| 132 |
+
scores,
|
| 133 |
+
waveform_package,
|
| 134 |
+
is_final=is_final,
|
| 135 |
+
max_end_sil=self.max_end_sil,
|
| 136 |
+
online=False,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Accumulate segments
|
| 140 |
+
if segments_part:
|
| 141 |
+
for batch_num in range(0, self.batch_size):
|
| 142 |
+
segments[batch_num] += segments_part[batch_num]
|
| 143 |
+
|
| 144 |
+
return segments
|
| 145 |
+
|
| 146 |
+
def prepare_cache(self, in_cache: list = []):
|
| 147 |
+
if len(in_cache) > 0:
|
| 148 |
+
return in_cache
|
| 149 |
+
fsmn_layers = 4
|
| 150 |
+
proj_dim = 128
|
| 151 |
+
lorder = 20
|
| 152 |
+
for i in range(fsmn_layers):
|
| 153 |
+
cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
|
| 154 |
+
in_cache.append(cache)
|
| 155 |
+
return in_cache
|
| 156 |
+
|
utils/ctc_alignment.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def ctc_forced_align(
|
| 4 |
+
log_probs: torch.Tensor,
|
| 5 |
+
targets: torch.Tensor,
|
| 6 |
+
input_lengths: torch.Tensor,
|
| 7 |
+
target_lengths: torch.Tensor,
|
| 8 |
+
blank: int = 0,
|
| 9 |
+
ignore_id: int = -1,
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
"""Align a CTC label sequence to an emission.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
log_probs (Tensor): log probability of CTC emission output.
|
| 15 |
+
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
|
| 16 |
+
`C` is the number of characters in alphabet including blank.
|
| 17 |
+
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
|
| 18 |
+
where `L` is the target length.
|
| 19 |
+
input_lengths (Tensor):
|
| 20 |
+
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
|
| 21 |
+
target_lengths (Tensor):
|
| 22 |
+
Lengths of the targets. 1-D Tensor of shape `(B,)`.
|
| 23 |
+
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
|
| 24 |
+
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
|
| 25 |
+
"""
|
| 26 |
+
targets[targets == ignore_id] = blank
|
| 27 |
+
|
| 28 |
+
batch_size, input_time_size, _ = log_probs.size()
|
| 29 |
+
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
|
| 30 |
+
|
| 31 |
+
_t_a_r_g_e_t_s_ = torch.cat(
|
| 32 |
+
(
|
| 33 |
+
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
|
| 34 |
+
torch.full_like(targets[:, :1], blank),
|
| 35 |
+
),
|
| 36 |
+
dim=-1,
|
| 37 |
+
)
|
| 38 |
+
diff_labels = torch.cat(
|
| 39 |
+
(
|
| 40 |
+
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
|
| 41 |
+
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
|
| 42 |
+
),
|
| 43 |
+
dim=1,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
|
| 47 |
+
padding_num = 2
|
| 48 |
+
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
|
| 49 |
+
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
|
| 50 |
+
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
|
| 51 |
+
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
|
| 52 |
+
|
| 53 |
+
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
|
| 54 |
+
|
| 55 |
+
for t in range(1, input_time_size):
|
| 56 |
+
prev = torch.stack(
|
| 57 |
+
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
|
| 58 |
+
)
|
| 59 |
+
prev_max_value, prev_max_idx = prev.max(dim=0)
|
| 60 |
+
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
|
| 61 |
+
backpointers[:, t, padding_num:] = prev_max_idx
|
| 62 |
+
|
| 63 |
+
l1l2 = best_score.gather(
|
| 64 |
+
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
|
| 68 |
+
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
|
| 69 |
+
|
| 70 |
+
for t in range(input_time_size - 1, 0, -1):
|
| 71 |
+
target_indices = path[:, t]
|
| 72 |
+
prev_max_idx = backpointers[bsz_indices, t, target_indices]
|
| 73 |
+
path[:, t - 1] += target_indices - prev_max_idx
|
| 74 |
+
|
| 75 |
+
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
|
| 76 |
+
return alignments
|
utils/frontend.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import kaldi_native_fbank as knf
|
| 8 |
+
|
| 9 |
+
root_dir = Path(__file__).resolve().parent
|
| 10 |
+
|
| 11 |
+
logger_initialized = {}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WavFrontend:
|
| 15 |
+
"""Conventional frontend structure for ASR."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
cmvn_file: str = None,
|
| 20 |
+
fs: int = 16000,
|
| 21 |
+
window: str = "hamming",
|
| 22 |
+
n_mels: int = 80,
|
| 23 |
+
frame_length: int = 25,
|
| 24 |
+
frame_shift: int = 10,
|
| 25 |
+
lfr_m: int = 1,
|
| 26 |
+
lfr_n: int = 1,
|
| 27 |
+
dither: float = 1.0,
|
| 28 |
+
**kwargs,
|
| 29 |
+
) -> None:
|
| 30 |
+
|
| 31 |
+
opts = knf.FbankOptions()
|
| 32 |
+
opts.frame_opts.samp_freq = fs
|
| 33 |
+
opts.frame_opts.dither = dither
|
| 34 |
+
opts.frame_opts.window_type = window
|
| 35 |
+
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
| 36 |
+
opts.frame_opts.frame_length_ms = float(frame_length)
|
| 37 |
+
opts.mel_opts.num_bins = n_mels
|
| 38 |
+
opts.energy_floor = 0
|
| 39 |
+
opts.frame_opts.snip_edges = True
|
| 40 |
+
opts.mel_opts.debug_mel = False
|
| 41 |
+
self.opts = opts
|
| 42 |
+
|
| 43 |
+
self.lfr_m = lfr_m
|
| 44 |
+
self.lfr_n = lfr_n
|
| 45 |
+
self.cmvn_file = cmvn_file
|
| 46 |
+
|
| 47 |
+
if self.cmvn_file:
|
| 48 |
+
self.cmvn = self.load_cmvn()
|
| 49 |
+
self.fbank_fn = None
|
| 50 |
+
self.fbank_beg_idx = 0
|
| 51 |
+
self.reset_status()
|
| 52 |
+
|
| 53 |
+
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 54 |
+
waveform = waveform * (1 << 15)
|
| 55 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 56 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 57 |
+
frames = self.fbank_fn.num_frames_ready
|
| 58 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 59 |
+
for i in range(frames):
|
| 60 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 61 |
+
feat = mat.astype(np.float32)
|
| 62 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 63 |
+
return feat, feat_len
|
| 64 |
+
|
| 65 |
+
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 66 |
+
waveform = waveform * (1 << 15)
|
| 67 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 68 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 69 |
+
frames = self.fbank_fn.num_frames_ready
|
| 70 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 71 |
+
for i in range(self.fbank_beg_idx, frames):
|
| 72 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 73 |
+
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
| 74 |
+
feat = mat.astype(np.float32)
|
| 75 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 76 |
+
return feat, feat_len
|
| 77 |
+
|
| 78 |
+
def reset_status(self):
|
| 79 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 80 |
+
self.fbank_beg_idx = 0
|
| 81 |
+
|
| 82 |
+
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 83 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 84 |
+
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
| 85 |
+
|
| 86 |
+
if self.cmvn_file:
|
| 87 |
+
feat = self.apply_cmvn(feat)
|
| 88 |
+
|
| 89 |
+
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 90 |
+
return feat, feat_len
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
| 94 |
+
LFR_inputs = []
|
| 95 |
+
|
| 96 |
+
T = inputs.shape[0]
|
| 97 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
| 98 |
+
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
| 99 |
+
inputs = np.vstack((left_padding, inputs))
|
| 100 |
+
T = T + (lfr_m - 1) // 2
|
| 101 |
+
for i in range(T_lfr):
|
| 102 |
+
if lfr_m <= T - i * lfr_n:
|
| 103 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 104 |
+
else:
|
| 105 |
+
# process last LFR frame
|
| 106 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 107 |
+
frame = inputs[i * lfr_n :].reshape(-1)
|
| 108 |
+
for _ in range(num_padding):
|
| 109 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 110 |
+
|
| 111 |
+
LFR_inputs.append(frame)
|
| 112 |
+
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
| 113 |
+
return LFR_outputs
|
| 114 |
+
|
| 115 |
+
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
| 116 |
+
"""
|
| 117 |
+
Apply CMVN with mvn data
|
| 118 |
+
"""
|
| 119 |
+
frame, dim = inputs.shape
|
| 120 |
+
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
| 121 |
+
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
| 122 |
+
inputs = (inputs + means) * vars
|
| 123 |
+
return inputs
|
| 124 |
+
|
| 125 |
+
def load_cmvn(
|
| 126 |
+
self,
|
| 127 |
+
) -> np.ndarray:
|
| 128 |
+
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
| 129 |
+
lines = f.readlines()
|
| 130 |
+
|
| 131 |
+
means_list = []
|
| 132 |
+
vars_list = []
|
| 133 |
+
for i in range(len(lines)):
|
| 134 |
+
line_item = lines[i].split()
|
| 135 |
+
if line_item[0] == "<AddShift>":
|
| 136 |
+
line_item = lines[i + 1].split()
|
| 137 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 138 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 139 |
+
means_list = list(add_shift_line)
|
| 140 |
+
continue
|
| 141 |
+
elif line_item[0] == "<Rescale>":
|
| 142 |
+
line_item = lines[i + 1].split()
|
| 143 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 144 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 145 |
+
vars_list = list(rescale_line)
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
means = np.array(means_list).astype(np.float64)
|
| 149 |
+
vars = np.array(vars_list).astype(np.float64)
|
| 150 |
+
cmvn = np.array([means, vars])
|
| 151 |
+
return cmvn
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class WavFrontendOnline(WavFrontend):
|
| 155 |
+
def __init__(self, **kwargs):
|
| 156 |
+
super().__init__(**kwargs)
|
| 157 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 158 |
+
# add variables
|
| 159 |
+
self.frame_sample_length = int(
|
| 160 |
+
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
| 161 |
+
)
|
| 162 |
+
self.frame_shift_sample_length = int(
|
| 163 |
+
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
| 164 |
+
)
|
| 165 |
+
self.waveform = None
|
| 166 |
+
self.reserve_waveforms = None
|
| 167 |
+
self.input_cache = None
|
| 168 |
+
self.lfr_splice_cache = []
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
# inputs has catted the cache
|
| 172 |
+
def apply_lfr(
|
| 173 |
+
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
| 174 |
+
) -> Tuple[np.ndarray, np.ndarray, int]:
|
| 175 |
+
"""
|
| 176 |
+
Apply lfr with data
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
LFR_inputs = []
|
| 180 |
+
T = inputs.shape[0] # include the right context
|
| 181 |
+
T_lfr = int(
|
| 182 |
+
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
| 183 |
+
) # minus the right context: (lfr_m - 1) // 2
|
| 184 |
+
splice_idx = T_lfr
|
| 185 |
+
for i in range(T_lfr):
|
| 186 |
+
if lfr_m <= T - i * lfr_n:
|
| 187 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 188 |
+
else: # process last LFR frame
|
| 189 |
+
if is_final:
|
| 190 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 191 |
+
frame = (inputs[i * lfr_n :]).reshape(-1)
|
| 192 |
+
for _ in range(num_padding):
|
| 193 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 194 |
+
LFR_inputs.append(frame)
|
| 195 |
+
else:
|
| 196 |
+
# update splice_idx and break the circle
|
| 197 |
+
splice_idx = i
|
| 198 |
+
break
|
| 199 |
+
splice_idx = min(T - 1, splice_idx * lfr_n)
|
| 200 |
+
lfr_splice_cache = inputs[splice_idx:, :]
|
| 201 |
+
LFR_outputs = np.vstack(LFR_inputs)
|
| 202 |
+
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def compute_frame_num(
|
| 206 |
+
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
| 207 |
+
) -> int:
|
| 208 |
+
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
| 209 |
+
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
| 210 |
+
|
| 211 |
+
def fbank(
|
| 212 |
+
self, input: np.ndarray, input_lengths: np.ndarray
|
| 213 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 214 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 215 |
+
batch_size = input.shape[0]
|
| 216 |
+
if self.input_cache is None:
|
| 217 |
+
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
| 218 |
+
input = np.concatenate((self.input_cache, input), axis=1)
|
| 219 |
+
frame_num = self.compute_frame_num(
|
| 220 |
+
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
| 221 |
+
)
|
| 222 |
+
# update self.in_cache
|
| 223 |
+
self.input_cache = input[
|
| 224 |
+
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
| 225 |
+
]
|
| 226 |
+
waveforms = np.empty(0, dtype=np.float32)
|
| 227 |
+
feats_pad = np.empty(0, dtype=np.float32)
|
| 228 |
+
feats_lens = np.empty(0, dtype=np.int32)
|
| 229 |
+
if frame_num:
|
| 230 |
+
waveforms = []
|
| 231 |
+
feats = []
|
| 232 |
+
feats_lens = []
|
| 233 |
+
for i in range(batch_size):
|
| 234 |
+
waveform = input[i]
|
| 235 |
+
waveforms.append(
|
| 236 |
+
waveform[
|
| 237 |
+
: (
|
| 238 |
+
(frame_num - 1) * self.frame_shift_sample_length
|
| 239 |
+
+ self.frame_sample_length
|
| 240 |
+
)
|
| 241 |
+
]
|
| 242 |
+
)
|
| 243 |
+
waveform = waveform * (1 << 15)
|
| 244 |
+
|
| 245 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 246 |
+
frames = self.fbank_fn.num_frames_ready
|
| 247 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 248 |
+
for i in range(frames):
|
| 249 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 250 |
+
feat = mat.astype(np.float32)
|
| 251 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 252 |
+
feats.append(feat)
|
| 253 |
+
feats_lens.append(feat_len)
|
| 254 |
+
|
| 255 |
+
waveforms = np.stack(waveforms)
|
| 256 |
+
feats_lens = np.array(feats_lens)
|
| 257 |
+
feats_pad = np.array(feats)
|
| 258 |
+
self.fbanks = feats_pad
|
| 259 |
+
self.fbanks_lens = copy.deepcopy(feats_lens)
|
| 260 |
+
return waveforms, feats_pad, feats_lens
|
| 261 |
+
|
| 262 |
+
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 263 |
+
return self.fbanks, self.fbanks_lens
|
| 264 |
+
|
| 265 |
+
def lfr_cmvn(
|
| 266 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 267 |
+
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
| 268 |
+
batch_size = input.shape[0]
|
| 269 |
+
feats = []
|
| 270 |
+
feats_lens = []
|
| 271 |
+
lfr_splice_frame_idxs = []
|
| 272 |
+
for i in range(batch_size):
|
| 273 |
+
mat = input[i, : input_lengths[i], :]
|
| 274 |
+
lfr_splice_frame_idx = -1
|
| 275 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 276 |
+
# update self.lfr_splice_cache in self.apply_lfr
|
| 277 |
+
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
| 278 |
+
mat, self.lfr_m, self.lfr_n, is_final
|
| 279 |
+
)
|
| 280 |
+
if self.cmvn_file is not None:
|
| 281 |
+
mat = self.apply_cmvn(mat)
|
| 282 |
+
feat_length = mat.shape[0]
|
| 283 |
+
feats.append(mat)
|
| 284 |
+
feats_lens.append(feat_length)
|
| 285 |
+
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
| 286 |
+
|
| 287 |
+
feats_lens = np.array(feats_lens)
|
| 288 |
+
feats_pad = np.array(feats)
|
| 289 |
+
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
| 290 |
+
|
| 291 |
+
def extract_fbank(
|
| 292 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 293 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 294 |
+
batch_size = input.shape[0]
|
| 295 |
+
assert (
|
| 296 |
+
batch_size == 1
|
| 297 |
+
), "we support to extract feature online only when the batch size is equal to 1 now"
|
| 298 |
+
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
|
| 299 |
+
if feats.shape[0]:
|
| 300 |
+
self.waveforms = (
|
| 301 |
+
waveforms
|
| 302 |
+
if self.reserve_waveforms is None
|
| 303 |
+
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
| 304 |
+
)
|
| 305 |
+
if not self.lfr_splice_cache:
|
| 306 |
+
for i in range(batch_size):
|
| 307 |
+
self.lfr_splice_cache.append(
|
| 308 |
+
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
| 312 |
+
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
| 313 |
+
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
| 314 |
+
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
| 315 |
+
frame_from_waveforms = int(
|
| 316 |
+
(self.waveforms.shape[1] - self.frame_sample_length)
|
| 317 |
+
/ self.frame_shift_sample_length
|
| 318 |
+
+ 1
|
| 319 |
+
)
|
| 320 |
+
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
| 321 |
+
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
| 322 |
+
feats, feats_lengths, is_final
|
| 323 |
+
)
|
| 324 |
+
if self.lfr_m == 1:
|
| 325 |
+
self.reserve_waveforms = None
|
| 326 |
+
else:
|
| 327 |
+
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
| 328 |
+
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
| 329 |
+
# print('frame_frame: ' + str(frame_from_waveforms))
|
| 330 |
+
self.reserve_waveforms = self.waveforms[
|
| 331 |
+
:,
|
| 332 |
+
reserve_frame_idx
|
| 333 |
+
* self.frame_shift_sample_length : frame_from_waveforms
|
| 334 |
+
* self.frame_shift_sample_length,
|
| 335 |
+
]
|
| 336 |
+
sample_length = (
|
| 337 |
+
frame_from_waveforms - 1
|
| 338 |
+
) * self.frame_shift_sample_length + self.frame_sample_length
|
| 339 |
+
self.waveforms = self.waveforms[:, :sample_length]
|
| 340 |
+
else:
|
| 341 |
+
# update self.reserve_waveforms and self.lfr_splice_cache
|
| 342 |
+
self.reserve_waveforms = self.waveforms[
|
| 343 |
+
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
| 344 |
+
]
|
| 345 |
+
for i in range(batch_size):
|
| 346 |
+
self.lfr_splice_cache[i] = np.concatenate(
|
| 347 |
+
(self.lfr_splice_cache[i], feats[i]), axis=0
|
| 348 |
+
)
|
| 349 |
+
return np.empty(0, dtype=np.float32), feats_lengths
|
| 350 |
+
else:
|
| 351 |
+
if is_final:
|
| 352 |
+
self.waveforms = (
|
| 353 |
+
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
| 354 |
+
)
|
| 355 |
+
feats = np.stack(self.lfr_splice_cache)
|
| 356 |
+
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
| 357 |
+
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
| 358 |
+
if is_final:
|
| 359 |
+
self.cache_reset()
|
| 360 |
+
return feats, feats_lengths
|
| 361 |
+
|
| 362 |
+
def get_waveforms(self):
|
| 363 |
+
return self.waveforms
|
| 364 |
+
|
| 365 |
+
def cache_reset(self):
|
| 366 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 367 |
+
self.reserve_waveforms = None
|
| 368 |
+
self.input_cache = None
|
| 369 |
+
self.lfr_splice_cache = []
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def load_bytes(input):
|
| 373 |
+
middle_data = np.frombuffer(input, dtype=np.int16)
|
| 374 |
+
middle_data = np.asarray(middle_data)
|
| 375 |
+
if middle_data.dtype.kind not in "iu":
|
| 376 |
+
raise TypeError("'middle_data' must be an array of integers")
|
| 377 |
+
dtype = np.dtype("float32")
|
| 378 |
+
if dtype.kind != "f":
|
| 379 |
+
raise TypeError("'dtype' must be a floating point type")
|
| 380 |
+
|
| 381 |
+
i = np.iinfo(middle_data.dtype)
|
| 382 |
+
abs_max = 2 ** (i.bits - 1)
|
| 383 |
+
offset = i.min + abs_max
|
| 384 |
+
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
| 385 |
+
return array
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class SinusoidalPositionEncoderOnline:
|
| 389 |
+
"""Streaming Positional encoding."""
|
| 390 |
+
|
| 391 |
+
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
| 392 |
+
batch_size = positions.shape[0]
|
| 393 |
+
positions = positions.astype(dtype)
|
| 394 |
+
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
| 395 |
+
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
| 396 |
+
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
| 397 |
+
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
| 398 |
+
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
| 399 |
+
return encoding.astype(dtype)
|
| 400 |
+
|
| 401 |
+
def forward(self, x, start_idx=0):
|
| 402 |
+
batch_size, timesteps, input_dim = x.shape
|
| 403 |
+
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
| 404 |
+
position_encoding = self.encode(positions, input_dim, x.dtype)
|
| 405 |
+
|
| 406 |
+
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def test():
|
| 410 |
+
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
| 411 |
+
import librosa
|
| 412 |
+
|
| 413 |
+
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
| 414 |
+
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
| 415 |
+
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
| 416 |
+
|
| 417 |
+
config = read_yaml(config_file)
|
| 418 |
+
waveform, _ = librosa.load(path, sr=None)
|
| 419 |
+
frontend = WavFrontend(
|
| 420 |
+
cmvn_file=cmvn_file,
|
| 421 |
+
**config["frontend_conf"],
|
| 422 |
+
)
|
| 423 |
+
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
| 424 |
+
feat, feat_len = frontend.lfr_cmvn(
|
| 425 |
+
speech
|
| 426 |
+
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
| 427 |
+
|
| 428 |
+
frontend.reset_status() # clear cache
|
| 429 |
+
return feat, feat_len
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
test()
|
utils/infer_utils.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
import jieba
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
root_dir = Path(__file__).resolve().parent
|
| 16 |
+
|
| 17 |
+
logger_initialized = {}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pad_list(xs, pad_value, max_len=None):
|
| 21 |
+
n_batch = len(xs)
|
| 22 |
+
if max_len is None:
|
| 23 |
+
max_len = max(x.size(0) for x in xs)
|
| 24 |
+
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
| 25 |
+
# numpy format
|
| 26 |
+
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
|
| 27 |
+
for i in range(n_batch):
|
| 28 |
+
pad[i, : xs[i].shape[0]] = xs[i]
|
| 29 |
+
|
| 30 |
+
return pad
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
| 35 |
+
if length_dim == 0:
|
| 36 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
| 37 |
+
|
| 38 |
+
if not isinstance(lengths, list):
|
| 39 |
+
lengths = lengths.tolist()
|
| 40 |
+
bs = int(len(lengths))
|
| 41 |
+
if maxlen is None:
|
| 42 |
+
if xs is None:
|
| 43 |
+
maxlen = int(max(lengths))
|
| 44 |
+
else:
|
| 45 |
+
maxlen = xs.size(length_dim)
|
| 46 |
+
else:
|
| 47 |
+
assert xs is None
|
| 48 |
+
assert maxlen >= int(max(lengths))
|
| 49 |
+
|
| 50 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
| 51 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
| 52 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
| 53 |
+
mask = seq_range_expand >= seq_length_expand
|
| 54 |
+
|
| 55 |
+
if xs is not None:
|
| 56 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
| 57 |
+
|
| 58 |
+
if length_dim < 0:
|
| 59 |
+
length_dim = xs.dim() + length_dim
|
| 60 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
| 61 |
+
ind = tuple(
|
| 62 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
| 63 |
+
)
|
| 64 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
| 65 |
+
return mask
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TokenIDConverter:
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
token_list: Union[List, str],
|
| 73 |
+
):
|
| 74 |
+
|
| 75 |
+
self.token_list = token_list
|
| 76 |
+
self.unk_symbol = token_list[-1]
|
| 77 |
+
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
| 78 |
+
self.unk_id = self.token2id[self.unk_symbol]
|
| 79 |
+
|
| 80 |
+
def get_num_vocabulary_size(self) -> int:
|
| 81 |
+
return len(self.token_list)
|
| 82 |
+
|
| 83 |
+
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
| 84 |
+
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
| 85 |
+
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
| 86 |
+
return [self.token_list[i] for i in integers]
|
| 87 |
+
|
| 88 |
+
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
| 89 |
+
|
| 90 |
+
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class CharTokenizer:
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
symbol_value: Union[Path, str, Iterable[str]] = None,
|
| 97 |
+
space_symbol: str = "<space>",
|
| 98 |
+
remove_non_linguistic_symbols: bool = False,
|
| 99 |
+
):
|
| 100 |
+
|
| 101 |
+
self.space_symbol = space_symbol
|
| 102 |
+
self.non_linguistic_symbols = self.load_symbols(symbol_value)
|
| 103 |
+
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
|
| 107 |
+
if value is None:
|
| 108 |
+
return set()
|
| 109 |
+
|
| 110 |
+
if isinstance(value, Iterable[str]):
|
| 111 |
+
return set(value)
|
| 112 |
+
|
| 113 |
+
file_path = Path(value)
|
| 114 |
+
if not file_path.exists():
|
| 115 |
+
logging.warning("%s doesn't exist.", file_path)
|
| 116 |
+
return set()
|
| 117 |
+
|
| 118 |
+
with file_path.open("r", encoding="utf-8") as f:
|
| 119 |
+
return set(line.rstrip() for line in f)
|
| 120 |
+
|
| 121 |
+
def text2tokens(self, line: Union[str, list]) -> List[str]:
|
| 122 |
+
tokens = []
|
| 123 |
+
while len(line) != 0:
|
| 124 |
+
for w in self.non_linguistic_symbols:
|
| 125 |
+
if line.startswith(w):
|
| 126 |
+
if not self.remove_non_linguistic_symbols:
|
| 127 |
+
tokens.append(line[: len(w)])
|
| 128 |
+
line = line[len(w) :]
|
| 129 |
+
break
|
| 130 |
+
else:
|
| 131 |
+
t = line[0]
|
| 132 |
+
if t == " ":
|
| 133 |
+
t = "<space>"
|
| 134 |
+
tokens.append(t)
|
| 135 |
+
line = line[1:]
|
| 136 |
+
return tokens
|
| 137 |
+
|
| 138 |
+
def tokens2text(self, tokens: Iterable[str]) -> str:
|
| 139 |
+
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
| 140 |
+
return "".join(tokens)
|
| 141 |
+
|
| 142 |
+
def __repr__(self):
|
| 143 |
+
return (
|
| 144 |
+
f"{self.__class__.__name__}("
|
| 145 |
+
f'space_symbol="{self.space_symbol}"'
|
| 146 |
+
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
|
| 147 |
+
f")"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Hypothesis(NamedTuple):
|
| 152 |
+
"""Hypothesis data type."""
|
| 153 |
+
|
| 154 |
+
yseq: np.ndarray
|
| 155 |
+
score: Union[float, np.ndarray] = 0
|
| 156 |
+
scores: Dict[str, Union[float, np.ndarray]] = dict()
|
| 157 |
+
states: Dict[str, Any] = dict()
|
| 158 |
+
|
| 159 |
+
def asdict(self) -> dict:
|
| 160 |
+
"""Convert data to JSON-friendly dict."""
|
| 161 |
+
return self._replace(
|
| 162 |
+
yseq=self.yseq.tolist(),
|
| 163 |
+
score=float(self.score),
|
| 164 |
+
scores={k: float(v) for k, v in self.scores.items()},
|
| 165 |
+
)._asdict()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TokenIDConverterError(Exception):
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ONNXRuntimeError(Exception):
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
| 177 |
+
assert word_limit > 1
|
| 178 |
+
if len(words) <= word_limit:
|
| 179 |
+
return [words]
|
| 180 |
+
sentences = []
|
| 181 |
+
length = len(words)
|
| 182 |
+
sentence_len = length // word_limit
|
| 183 |
+
for i in range(sentence_len):
|
| 184 |
+
sentences.append(words[i * word_limit : (i + 1) * word_limit])
|
| 185 |
+
if length % word_limit > 0:
|
| 186 |
+
sentences.append(words[sentence_len * word_limit :])
|
| 187 |
+
return sentences
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def code_mix_split_words(text: str):
|
| 191 |
+
words = []
|
| 192 |
+
segs = text.split()
|
| 193 |
+
for seg in segs:
|
| 194 |
+
# There is no space in seg.
|
| 195 |
+
current_word = ""
|
| 196 |
+
for c in seg:
|
| 197 |
+
if len(c.encode()) == 1:
|
| 198 |
+
# This is an ASCII char.
|
| 199 |
+
current_word += c
|
| 200 |
+
else:
|
| 201 |
+
# This is a Chinese char.
|
| 202 |
+
if len(current_word) > 0:
|
| 203 |
+
words.append(current_word)
|
| 204 |
+
current_word = ""
|
| 205 |
+
words.append(c)
|
| 206 |
+
if len(current_word) > 0:
|
| 207 |
+
words.append(current_word)
|
| 208 |
+
return words
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def isEnglish(text: str):
|
| 212 |
+
if re.search("^[a-zA-Z']+$", text):
|
| 213 |
+
return True
|
| 214 |
+
else:
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def join_chinese_and_english(input_list):
|
| 219 |
+
line = ""
|
| 220 |
+
for token in input_list:
|
| 221 |
+
if isEnglish(token):
|
| 222 |
+
line = line + " " + token
|
| 223 |
+
else:
|
| 224 |
+
line = line + token
|
| 225 |
+
|
| 226 |
+
line = line.strip()
|
| 227 |
+
return line
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def code_mix_split_words_jieba(seg_dict_file: str):
|
| 231 |
+
jieba.load_userdict(seg_dict_file)
|
| 232 |
+
|
| 233 |
+
def _fn(text: str):
|
| 234 |
+
input_list = text.split()
|
| 235 |
+
token_list_all = []
|
| 236 |
+
langauge_list = []
|
| 237 |
+
token_list_tmp = []
|
| 238 |
+
language_flag = None
|
| 239 |
+
for token in input_list:
|
| 240 |
+
if isEnglish(token) and language_flag == "Chinese":
|
| 241 |
+
token_list_all.append(token_list_tmp)
|
| 242 |
+
langauge_list.append("Chinese")
|
| 243 |
+
token_list_tmp = []
|
| 244 |
+
elif not isEnglish(token) and language_flag == "English":
|
| 245 |
+
token_list_all.append(token_list_tmp)
|
| 246 |
+
langauge_list.append("English")
|
| 247 |
+
token_list_tmp = []
|
| 248 |
+
|
| 249 |
+
token_list_tmp.append(token)
|
| 250 |
+
|
| 251 |
+
if isEnglish(token):
|
| 252 |
+
language_flag = "English"
|
| 253 |
+
else:
|
| 254 |
+
language_flag = "Chinese"
|
| 255 |
+
|
| 256 |
+
if token_list_tmp:
|
| 257 |
+
token_list_all.append(token_list_tmp)
|
| 258 |
+
langauge_list.append(language_flag)
|
| 259 |
+
|
| 260 |
+
result_list = []
|
| 261 |
+
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
|
| 262 |
+
if language_flag == "English":
|
| 263 |
+
result_list.extend(token_list_tmp)
|
| 264 |
+
else:
|
| 265 |
+
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
|
| 266 |
+
result_list.extend(seg_list)
|
| 267 |
+
|
| 268 |
+
return result_list
|
| 269 |
+
|
| 270 |
+
return _fn
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
| 274 |
+
if not Path(yaml_path).exists():
|
| 275 |
+
raise FileExistsError(f"The {yaml_path} does not exist.")
|
| 276 |
+
|
| 277 |
+
with open(str(yaml_path), "rb") as f:
|
| 278 |
+
data = yaml.load(f, Loader=yaml.Loader)
|
| 279 |
+
return data
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@functools.lru_cache()
|
| 283 |
+
def get_logger(name="funasr_onnx"):
|
| 284 |
+
"""Initialize and get a logger by name.
|
| 285 |
+
If the logger has not been initialized, this method will initialize the
|
| 286 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
| 287 |
+
be directly returned. During initialization, a StreamHandler will always be
|
| 288 |
+
added.
|
| 289 |
+
Args:
|
| 290 |
+
name (str): Logger name.
|
| 291 |
+
Returns:
|
| 292 |
+
logging.Logger: The expected logger.
|
| 293 |
+
"""
|
| 294 |
+
logger = logging.getLogger(name)
|
| 295 |
+
if name in logger_initialized:
|
| 296 |
+
return logger
|
| 297 |
+
|
| 298 |
+
for logger_name in logger_initialized:
|
| 299 |
+
if name.startswith(logger_name):
|
| 300 |
+
return logger
|
| 301 |
+
|
| 302 |
+
formatter = logging.Formatter(
|
| 303 |
+
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
sh = logging.StreamHandler()
|
| 307 |
+
sh.setFormatter(formatter)
|
| 308 |
+
logger.addHandler(sh)
|
| 309 |
+
logger_initialized[name] = True
|
| 310 |
+
logger.propagate = False
|
| 311 |
+
logging.basicConfig(level=logging.ERROR)
|
| 312 |
+
return logger
|
utils/utils/__init__.py
ADDED
|
File without changes
|
utils/utils/e2e_vad.py
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
| 3 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import List, Tuple, Dict, Any
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VadStateMachine(Enum):
|
| 13 |
+
kVadInStateStartPointNotDetected = 1
|
| 14 |
+
kVadInStateInSpeechSegment = 2
|
| 15 |
+
kVadInStateEndPointDetected = 3
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FrameState(Enum):
|
| 19 |
+
kFrameStateInvalid = -1
|
| 20 |
+
kFrameStateSpeech = 1
|
| 21 |
+
kFrameStateSil = 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# final voice/unvoice state per frame
|
| 25 |
+
class AudioChangeState(Enum):
|
| 26 |
+
kChangeStateSpeech2Speech = 0
|
| 27 |
+
kChangeStateSpeech2Sil = 1
|
| 28 |
+
kChangeStateSil2Sil = 2
|
| 29 |
+
kChangeStateSil2Speech = 3
|
| 30 |
+
kChangeStateNoBegin = 4
|
| 31 |
+
kChangeStateInvalid = 5
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class VadDetectMode(Enum):
|
| 35 |
+
kVadSingleUtteranceDetectMode = 0
|
| 36 |
+
kVadMutipleUtteranceDetectMode = 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VADXOptions:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
sample_rate: int = 16000,
|
| 43 |
+
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
| 44 |
+
snr_mode: int = 0,
|
| 45 |
+
max_end_silence_time: int = 800,
|
| 46 |
+
max_start_silence_time: int = 3000,
|
| 47 |
+
do_start_point_detection: bool = True,
|
| 48 |
+
do_end_point_detection: bool = True,
|
| 49 |
+
window_size_ms: int = 200,
|
| 50 |
+
sil_to_speech_time_thres: int = 150,
|
| 51 |
+
speech_to_sil_time_thres: int = 150,
|
| 52 |
+
speech_2_noise_ratio: float = 1.0,
|
| 53 |
+
do_extend: int = 1,
|
| 54 |
+
lookback_time_start_point: int = 200,
|
| 55 |
+
lookahead_time_end_point: int = 100,
|
| 56 |
+
max_single_segment_time: int = 60000,
|
| 57 |
+
nn_eval_block_size: int = 8,
|
| 58 |
+
dcd_block_size: int = 4,
|
| 59 |
+
snr_thres: int = -100.0,
|
| 60 |
+
noise_frame_num_used_for_snr: int = 100,
|
| 61 |
+
decibel_thres: int = -100.0,
|
| 62 |
+
speech_noise_thres: float = 0.6,
|
| 63 |
+
fe_prior_thres: float = 1e-4,
|
| 64 |
+
silence_pdf_num: int = 1,
|
| 65 |
+
sil_pdf_ids: List[int] = [0],
|
| 66 |
+
speech_noise_thresh_low: float = -0.1,
|
| 67 |
+
speech_noise_thresh_high: float = 0.3,
|
| 68 |
+
output_frame_probs: bool = False,
|
| 69 |
+
frame_in_ms: int = 10,
|
| 70 |
+
frame_length_ms: int = 25,
|
| 71 |
+
):
|
| 72 |
+
self.sample_rate = sample_rate
|
| 73 |
+
self.detect_mode = detect_mode
|
| 74 |
+
self.snr_mode = snr_mode
|
| 75 |
+
self.max_end_silence_time = max_end_silence_time
|
| 76 |
+
self.max_start_silence_time = max_start_silence_time
|
| 77 |
+
self.do_start_point_detection = do_start_point_detection
|
| 78 |
+
self.do_end_point_detection = do_end_point_detection
|
| 79 |
+
self.window_size_ms = window_size_ms
|
| 80 |
+
self.sil_to_speech_time_thres = sil_to_speech_time_thres
|
| 81 |
+
self.speech_to_sil_time_thres = speech_to_sil_time_thres
|
| 82 |
+
self.speech_2_noise_ratio = speech_2_noise_ratio
|
| 83 |
+
self.do_extend = do_extend
|
| 84 |
+
self.lookback_time_start_point = lookback_time_start_point
|
| 85 |
+
self.lookahead_time_end_point = lookahead_time_end_point
|
| 86 |
+
self.max_single_segment_time = max_single_segment_time
|
| 87 |
+
self.nn_eval_block_size = nn_eval_block_size
|
| 88 |
+
self.dcd_block_size = dcd_block_size
|
| 89 |
+
self.snr_thres = snr_thres
|
| 90 |
+
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
|
| 91 |
+
self.decibel_thres = decibel_thres
|
| 92 |
+
self.speech_noise_thres = speech_noise_thres
|
| 93 |
+
self.fe_prior_thres = fe_prior_thres
|
| 94 |
+
self.silence_pdf_num = silence_pdf_num
|
| 95 |
+
self.sil_pdf_ids = sil_pdf_ids
|
| 96 |
+
self.speech_noise_thresh_low = speech_noise_thresh_low
|
| 97 |
+
self.speech_noise_thresh_high = speech_noise_thresh_high
|
| 98 |
+
self.output_frame_probs = output_frame_probs
|
| 99 |
+
self.frame_in_ms = frame_in_ms
|
| 100 |
+
self.frame_length_ms = frame_length_ms
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class E2EVadSpeechBufWithDoa(object):
|
| 104 |
+
def __init__(self):
|
| 105 |
+
self.start_ms = 0
|
| 106 |
+
self.end_ms = 0
|
| 107 |
+
self.buffer = []
|
| 108 |
+
self.contain_seg_start_point = False
|
| 109 |
+
self.contain_seg_end_point = False
|
| 110 |
+
self.doa = 0
|
| 111 |
+
|
| 112 |
+
def Reset(self):
|
| 113 |
+
self.start_ms = 0
|
| 114 |
+
self.end_ms = 0
|
| 115 |
+
self.buffer = []
|
| 116 |
+
self.contain_seg_start_point = False
|
| 117 |
+
self.contain_seg_end_point = False
|
| 118 |
+
self.doa = 0
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class E2EVadFrameProb(object):
|
| 122 |
+
def __init__(self):
|
| 123 |
+
self.noise_prob = 0.0
|
| 124 |
+
self.speech_prob = 0.0
|
| 125 |
+
self.score = 0.0
|
| 126 |
+
self.frame_id = 0
|
| 127 |
+
self.frm_state = 0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class WindowDetector(object):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
window_size_ms: int,
|
| 134 |
+
sil_to_speech_time: int,
|
| 135 |
+
speech_to_sil_time: int,
|
| 136 |
+
frame_size_ms: int,
|
| 137 |
+
):
|
| 138 |
+
self.window_size_ms = window_size_ms
|
| 139 |
+
self.sil_to_speech_time = sil_to_speech_time
|
| 140 |
+
self.speech_to_sil_time = speech_to_sil_time
|
| 141 |
+
self.frame_size_ms = frame_size_ms
|
| 142 |
+
|
| 143 |
+
self.win_size_frame = int(window_size_ms / frame_size_ms)
|
| 144 |
+
self.win_sum = 0
|
| 145 |
+
self.win_state = [0] * self.win_size_frame # 初始化窗
|
| 146 |
+
|
| 147 |
+
self.cur_win_pos = 0
|
| 148 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
| 149 |
+
self.cur_frame_state = FrameState.kFrameStateSil
|
| 150 |
+
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
|
| 151 |
+
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
|
| 152 |
+
|
| 153 |
+
self.voice_last_frame_count = 0
|
| 154 |
+
self.noise_last_frame_count = 0
|
| 155 |
+
self.hydre_frame_count = 0
|
| 156 |
+
|
| 157 |
+
def Reset(self) -> None:
|
| 158 |
+
self.cur_win_pos = 0
|
| 159 |
+
self.win_sum = 0
|
| 160 |
+
self.win_state = [0] * self.win_size_frame
|
| 161 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
| 162 |
+
self.cur_frame_state = FrameState.kFrameStateSil
|
| 163 |
+
self.voice_last_frame_count = 0
|
| 164 |
+
self.noise_last_frame_count = 0
|
| 165 |
+
self.hydre_frame_count = 0
|
| 166 |
+
|
| 167 |
+
def GetWinSize(self) -> int:
|
| 168 |
+
return int(self.win_size_frame)
|
| 169 |
+
|
| 170 |
+
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
|
| 171 |
+
cur_frame_state = FrameState.kFrameStateSil
|
| 172 |
+
if frameState == FrameState.kFrameStateSpeech:
|
| 173 |
+
cur_frame_state = 1
|
| 174 |
+
elif frameState == FrameState.kFrameStateSil:
|
| 175 |
+
cur_frame_state = 0
|
| 176 |
+
else:
|
| 177 |
+
return AudioChangeState.kChangeStateInvalid
|
| 178 |
+
self.win_sum -= self.win_state[self.cur_win_pos]
|
| 179 |
+
self.win_sum += cur_frame_state
|
| 180 |
+
self.win_state[self.cur_win_pos] = cur_frame_state
|
| 181 |
+
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
|
| 182 |
+
|
| 183 |
+
if (
|
| 184 |
+
self.pre_frame_state == FrameState.kFrameStateSil
|
| 185 |
+
and self.win_sum >= self.sil_to_speech_frmcnt_thres
|
| 186 |
+
):
|
| 187 |
+
self.pre_frame_state = FrameState.kFrameStateSpeech
|
| 188 |
+
return AudioChangeState.kChangeStateSil2Speech
|
| 189 |
+
|
| 190 |
+
if (
|
| 191 |
+
self.pre_frame_state == FrameState.kFrameStateSpeech
|
| 192 |
+
and self.win_sum <= self.speech_to_sil_frmcnt_thres
|
| 193 |
+
):
|
| 194 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
| 195 |
+
return AudioChangeState.kChangeStateSpeech2Sil
|
| 196 |
+
|
| 197 |
+
if self.pre_frame_state == FrameState.kFrameStateSil:
|
| 198 |
+
return AudioChangeState.kChangeStateSil2Sil
|
| 199 |
+
if self.pre_frame_state == FrameState.kFrameStateSpeech:
|
| 200 |
+
return AudioChangeState.kChangeStateSpeech2Speech
|
| 201 |
+
return AudioChangeState.kChangeStateInvalid
|
| 202 |
+
|
| 203 |
+
def FrameSizeMs(self) -> int:
|
| 204 |
+
return int(self.frame_size_ms)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class E2EVadModel:
|
| 208 |
+
"""
|
| 209 |
+
Author: Speech Lab of DAMO Academy, Alibaba Group
|
| 210 |
+
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
| 211 |
+
https://arxiv.org/abs/1803.05030
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, vad_post_args: Dict[str, Any]):
|
| 215 |
+
super(E2EVadModel, self).__init__()
|
| 216 |
+
self.vad_opts = VADXOptions(**vad_post_args)
|
| 217 |
+
self.windows_detector = WindowDetector(
|
| 218 |
+
self.vad_opts.window_size_ms,
|
| 219 |
+
self.vad_opts.sil_to_speech_time_thres,
|
| 220 |
+
self.vad_opts.speech_to_sil_time_thres,
|
| 221 |
+
self.vad_opts.frame_in_ms,
|
| 222 |
+
)
|
| 223 |
+
# self.encoder = encoder
|
| 224 |
+
# init variables
|
| 225 |
+
self.is_final = False
|
| 226 |
+
self.data_buf_start_frame = 0
|
| 227 |
+
self.frm_cnt = 0
|
| 228 |
+
self.latest_confirmed_speech_frame = 0
|
| 229 |
+
self.lastest_confirmed_silence_frame = -1
|
| 230 |
+
self.continous_silence_frame_count = 0
|
| 231 |
+
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
| 232 |
+
self.confirmed_start_frame = -1
|
| 233 |
+
self.confirmed_end_frame = -1
|
| 234 |
+
self.number_end_time_detected = 0
|
| 235 |
+
self.sil_frame = 0
|
| 236 |
+
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
| 237 |
+
self.noise_average_decibel = -100.0
|
| 238 |
+
self.pre_end_silence_detected = False
|
| 239 |
+
self.next_seg = True
|
| 240 |
+
|
| 241 |
+
self.output_data_buf = []
|
| 242 |
+
self.output_data_buf_offset = 0
|
| 243 |
+
self.frame_probs = []
|
| 244 |
+
self.max_end_sil_frame_cnt_thresh = (
|
| 245 |
+
self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
|
| 246 |
+
)
|
| 247 |
+
self.speech_noise_thres = self.vad_opts.speech_noise_thres
|
| 248 |
+
self.scores = None
|
| 249 |
+
self.idx_pre_chunk = 0
|
| 250 |
+
self.max_time_out = False
|
| 251 |
+
self.decibel = []
|
| 252 |
+
self.data_buf_size = 0
|
| 253 |
+
self.data_buf_all_size = 0
|
| 254 |
+
self.waveform = None
|
| 255 |
+
self.ResetDetection()
|
| 256 |
+
|
| 257 |
+
def AllResetDetection(self):
|
| 258 |
+
self.is_final = False
|
| 259 |
+
self.data_buf_start_frame = 0
|
| 260 |
+
self.frm_cnt = 0
|
| 261 |
+
self.latest_confirmed_speech_frame = 0
|
| 262 |
+
self.lastest_confirmed_silence_frame = -1
|
| 263 |
+
self.continous_silence_frame_count = 0
|
| 264 |
+
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
| 265 |
+
self.confirmed_start_frame = -1
|
| 266 |
+
self.confirmed_end_frame = -1
|
| 267 |
+
self.number_end_time_detected = 0
|
| 268 |
+
self.sil_frame = 0
|
| 269 |
+
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
| 270 |
+
self.noise_average_decibel = -100.0
|
| 271 |
+
self.pre_end_silence_detected = False
|
| 272 |
+
self.next_seg = True
|
| 273 |
+
|
| 274 |
+
self.output_data_buf = []
|
| 275 |
+
self.output_data_buf_offset = 0
|
| 276 |
+
self.frame_probs = []
|
| 277 |
+
self.max_end_sil_frame_cnt_thresh = (
|
| 278 |
+
self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
|
| 279 |
+
)
|
| 280 |
+
self.speech_noise_thres = self.vad_opts.speech_noise_thres
|
| 281 |
+
self.scores = None
|
| 282 |
+
self.idx_pre_chunk = 0
|
| 283 |
+
self.max_time_out = False
|
| 284 |
+
self.decibel = []
|
| 285 |
+
self.data_buf_size = 0
|
| 286 |
+
self.data_buf_all_size = 0
|
| 287 |
+
self.waveform = None
|
| 288 |
+
self.ResetDetection()
|
| 289 |
+
|
| 290 |
+
def ResetDetection(self):
|
| 291 |
+
self.continous_silence_frame_count = 0
|
| 292 |
+
self.latest_confirmed_speech_frame = 0
|
| 293 |
+
self.lastest_confirmed_silence_frame = -1
|
| 294 |
+
self.confirmed_start_frame = -1
|
| 295 |
+
self.confirmed_end_frame = -1
|
| 296 |
+
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
| 297 |
+
self.windows_detector.Reset()
|
| 298 |
+
self.sil_frame = 0
|
| 299 |
+
self.frame_probs = []
|
| 300 |
+
|
| 301 |
+
def ComputeDecibel(self) -> None:
|
| 302 |
+
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
|
| 303 |
+
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
| 304 |
+
if self.data_buf_all_size == 0:
|
| 305 |
+
self.data_buf_all_size = len(self.waveform[0])
|
| 306 |
+
self.data_buf_size = self.data_buf_all_size
|
| 307 |
+
else:
|
| 308 |
+
self.data_buf_all_size += len(self.waveform[0])
|
| 309 |
+
for offset in range(
|
| 310 |
+
0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length
|
| 311 |
+
):
|
| 312 |
+
self.decibel.append(
|
| 313 |
+
10
|
| 314 |
+
* math.log10(
|
| 315 |
+
np.square((self.waveform[0][offset : offset + frame_sample_length])).sum()
|
| 316 |
+
+ 0.000001
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def ComputeScores(self, scores: np.ndarray) -> None:
|
| 321 |
+
# scores = self.encoder(feats, in_cache) # return B * T * D
|
| 322 |
+
self.vad_opts.nn_eval_block_size = scores.shape[1]
|
| 323 |
+
self.frm_cnt += scores.shape[1] # count total frames
|
| 324 |
+
self.scores = scores
|
| 325 |
+
|
| 326 |
+
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
|
| 327 |
+
while self.data_buf_start_frame < frame_idx:
|
| 328 |
+
if self.data_buf_size >= int(
|
| 329 |
+
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
| 330 |
+
):
|
| 331 |
+
self.data_buf_start_frame += 1
|
| 332 |
+
self.data_buf_size = self.data_buf_all_size - self.data_buf_start_frame * int(
|
| 333 |
+
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def PopDataToOutputBuf(
|
| 337 |
+
self,
|
| 338 |
+
start_frm: int,
|
| 339 |
+
frm_cnt: int,
|
| 340 |
+
first_frm_is_start_point: bool,
|
| 341 |
+
last_frm_is_end_point: bool,
|
| 342 |
+
end_point_is_sent_end: bool,
|
| 343 |
+
) -> None:
|
| 344 |
+
self.PopDataBufTillFrame(start_frm)
|
| 345 |
+
expected_sample_number = int(
|
| 346 |
+
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
|
| 347 |
+
)
|
| 348 |
+
if last_frm_is_end_point:
|
| 349 |
+
extra_sample = max(
|
| 350 |
+
0,
|
| 351 |
+
int(
|
| 352 |
+
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
|
| 353 |
+
- self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
|
| 354 |
+
),
|
| 355 |
+
)
|
| 356 |
+
expected_sample_number += int(extra_sample)
|
| 357 |
+
if end_point_is_sent_end:
|
| 358 |
+
expected_sample_number = max(expected_sample_number, self.data_buf_size)
|
| 359 |
+
if self.data_buf_size < expected_sample_number:
|
| 360 |
+
print("error in calling pop data_buf\n")
|
| 361 |
+
|
| 362 |
+
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
|
| 363 |
+
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
|
| 364 |
+
self.output_data_buf[-1].Reset()
|
| 365 |
+
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
|
| 366 |
+
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
|
| 367 |
+
self.output_data_buf[-1].doa = 0
|
| 368 |
+
cur_seg = self.output_data_buf[-1]
|
| 369 |
+
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
| 370 |
+
print("warning\n")
|
| 371 |
+
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
|
| 372 |
+
data_to_pop = 0
|
| 373 |
+
if end_point_is_sent_end:
|
| 374 |
+
data_to_pop = expected_sample_number
|
| 375 |
+
else:
|
| 376 |
+
data_to_pop = int(
|
| 377 |
+
frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
| 378 |
+
)
|
| 379 |
+
if data_to_pop > self.data_buf_size:
|
| 380 |
+
print("VAD data_to_pop is bigger than self.data_buf_size!!!\n")
|
| 381 |
+
data_to_pop = self.data_buf_size
|
| 382 |
+
expected_sample_number = self.data_buf_size
|
| 383 |
+
|
| 384 |
+
cur_seg.doa = 0
|
| 385 |
+
for sample_cpy_out in range(0, data_to_pop):
|
| 386 |
+
# cur_seg.buffer[out_pos ++] = data_buf_.back();
|
| 387 |
+
out_pos += 1
|
| 388 |
+
for sample_cpy_out in range(data_to_pop, expected_sample_number):
|
| 389 |
+
# cur_seg.buffer[out_pos++] = data_buf_.back()
|
| 390 |
+
out_pos += 1
|
| 391 |
+
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
| 392 |
+
print("Something wrong with the VAD algorithm\n")
|
| 393 |
+
self.data_buf_start_frame += frm_cnt
|
| 394 |
+
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
|
| 395 |
+
if first_frm_is_start_point:
|
| 396 |
+
cur_seg.contain_seg_start_point = True
|
| 397 |
+
if last_frm_is_end_point:
|
| 398 |
+
cur_seg.contain_seg_end_point = True
|
| 399 |
+
|
| 400 |
+
def OnSilenceDetected(self, valid_frame: int):
|
| 401 |
+
self.lastest_confirmed_silence_frame = valid_frame
|
| 402 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 403 |
+
self.PopDataBufTillFrame(valid_frame)
|
| 404 |
+
# silence_detected_callback_
|
| 405 |
+
# pass
|
| 406 |
+
|
| 407 |
+
def OnVoiceDetected(self, valid_frame: int) -> None:
|
| 408 |
+
self.latest_confirmed_speech_frame = valid_frame
|
| 409 |
+
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
|
| 410 |
+
|
| 411 |
+
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
|
| 412 |
+
if self.vad_opts.do_start_point_detection:
|
| 413 |
+
pass
|
| 414 |
+
if self.confirmed_start_frame != -1:
|
| 415 |
+
print("not reset vad properly\n")
|
| 416 |
+
else:
|
| 417 |
+
self.confirmed_start_frame = start_frame
|
| 418 |
+
|
| 419 |
+
if (
|
| 420 |
+
not fake_result
|
| 421 |
+
and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
|
| 422 |
+
):
|
| 423 |
+
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
|
| 424 |
+
|
| 425 |
+
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
|
| 426 |
+
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
|
| 427 |
+
self.OnVoiceDetected(t)
|
| 428 |
+
if self.vad_opts.do_end_point_detection:
|
| 429 |
+
pass
|
| 430 |
+
if self.confirmed_end_frame != -1:
|
| 431 |
+
print("not reset vad properly\n")
|
| 432 |
+
else:
|
| 433 |
+
self.confirmed_end_frame = end_frame
|
| 434 |
+
if not fake_result:
|
| 435 |
+
self.sil_frame = 0
|
| 436 |
+
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
|
| 437 |
+
self.number_end_time_detected += 1
|
| 438 |
+
|
| 439 |
+
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
|
| 440 |
+
if is_final_frame:
|
| 441 |
+
self.OnVoiceEnd(cur_frm_idx, False, True)
|
| 442 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 443 |
+
|
| 444 |
+
def GetLatency(self) -> int:
|
| 445 |
+
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
|
| 446 |
+
|
| 447 |
+
def LatencyFrmNumAtStartPoint(self) -> int:
|
| 448 |
+
vad_latency = self.windows_detector.GetWinSize()
|
| 449 |
+
if self.vad_opts.do_extend:
|
| 450 |
+
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
|
| 451 |
+
return vad_latency
|
| 452 |
+
|
| 453 |
+
def GetFrameState(self, t: int) -> FrameState:
|
| 454 |
+
frame_state = FrameState.kFrameStateInvalid
|
| 455 |
+
cur_decibel = self.decibel[t]
|
| 456 |
+
cur_snr = cur_decibel - self.noise_average_decibel
|
| 457 |
+
# for each frame, calc log posterior probability of each state
|
| 458 |
+
if cur_decibel < self.vad_opts.decibel_thres:
|
| 459 |
+
frame_state = FrameState.kFrameStateSil
|
| 460 |
+
self.DetectOneFrame(frame_state, t, False)
|
| 461 |
+
return frame_state
|
| 462 |
+
|
| 463 |
+
sum_score = 0.0
|
| 464 |
+
noise_prob = 0.0
|
| 465 |
+
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
|
| 466 |
+
if len(self.sil_pdf_ids) > 0:
|
| 467 |
+
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
|
| 468 |
+
sil_pdf_scores = [
|
| 469 |
+
self.scores[0][t - self.idx_pre_chunk][sil_pdf_id]
|
| 470 |
+
for sil_pdf_id in self.sil_pdf_ids
|
| 471 |
+
]
|
| 472 |
+
sum_score = sum(sil_pdf_scores)
|
| 473 |
+
# add by huyuan, avoid sum_score <= 0
|
| 474 |
+
if sum_score <= 0.0:
|
| 475 |
+
sum_score = 1e-10
|
| 476 |
+
# import pdb
|
| 477 |
+
# pdb.set_trace()
|
| 478 |
+
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
|
| 479 |
+
total_score = 1.0
|
| 480 |
+
sum_score = total_score - sum_score
|
| 481 |
+
speech_prob = math.log(sum_score)
|
| 482 |
+
if self.vad_opts.output_frame_probs:
|
| 483 |
+
frame_prob = E2EVadFrameProb()
|
| 484 |
+
frame_prob.noise_prob = noise_prob
|
| 485 |
+
frame_prob.speech_prob = speech_prob
|
| 486 |
+
frame_prob.score = sum_score
|
| 487 |
+
frame_prob.frame_id = t
|
| 488 |
+
self.frame_probs.append(frame_prob)
|
| 489 |
+
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
|
| 490 |
+
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
|
| 491 |
+
frame_state = FrameState.kFrameStateSpeech
|
| 492 |
+
else:
|
| 493 |
+
frame_state = FrameState.kFrameStateSil
|
| 494 |
+
else:
|
| 495 |
+
frame_state = FrameState.kFrameStateSil
|
| 496 |
+
if self.noise_average_decibel < -99.9:
|
| 497 |
+
self.noise_average_decibel = cur_decibel
|
| 498 |
+
else:
|
| 499 |
+
self.noise_average_decibel = (
|
| 500 |
+
cur_decibel
|
| 501 |
+
+ self.noise_average_decibel * (self.vad_opts.noise_frame_num_used_for_snr - 1)
|
| 502 |
+
) / self.vad_opts.noise_frame_num_used_for_snr
|
| 503 |
+
|
| 504 |
+
return frame_state
|
| 505 |
+
|
| 506 |
+
def __call__(
|
| 507 |
+
self,
|
| 508 |
+
score: np.ndarray,
|
| 509 |
+
waveform: np.ndarray,
|
| 510 |
+
is_final: bool = False,
|
| 511 |
+
max_end_sil: int = 800,
|
| 512 |
+
online: bool = False,
|
| 513 |
+
):
|
| 514 |
+
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
| 515 |
+
self.waveform = waveform # compute decibel for each frame
|
| 516 |
+
self.ComputeDecibel()
|
| 517 |
+
self.ComputeScores(score)
|
| 518 |
+
#import pdb
|
| 519 |
+
#pdb.set_trace()
|
| 520 |
+
if not is_final:
|
| 521 |
+
self.DetectCommonFrames()
|
| 522 |
+
else:
|
| 523 |
+
self.DetectLastFrames()
|
| 524 |
+
segments = []
|
| 525 |
+
for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
|
| 526 |
+
segment_batch = []
|
| 527 |
+
if len(self.output_data_buf) > 0:
|
| 528 |
+
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
| 529 |
+
if online:
|
| 530 |
+
if not self.output_data_buf[i].contain_seg_start_point:
|
| 531 |
+
continue
|
| 532 |
+
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
|
| 533 |
+
continue
|
| 534 |
+
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
|
| 535 |
+
if self.output_data_buf[i].contain_seg_end_point:
|
| 536 |
+
end_ms = self.output_data_buf[i].end_ms
|
| 537 |
+
self.next_seg = True
|
| 538 |
+
self.output_data_buf_offset += 1
|
| 539 |
+
else:
|
| 540 |
+
end_ms = -1
|
| 541 |
+
self.next_seg = False
|
| 542 |
+
else:
|
| 543 |
+
if not is_final and (
|
| 544 |
+
not self.output_data_buf[i].contain_seg_start_point
|
| 545 |
+
or not self.output_data_buf[i].contain_seg_end_point
|
| 546 |
+
):
|
| 547 |
+
continue
|
| 548 |
+
start_ms = self.output_data_buf[i].start_ms
|
| 549 |
+
end_ms = self.output_data_buf[i].end_ms
|
| 550 |
+
self.output_data_buf_offset += 1
|
| 551 |
+
segment = [start_ms, end_ms]
|
| 552 |
+
segment_batch.append(segment)
|
| 553 |
+
|
| 554 |
+
if segment_batch:
|
| 555 |
+
segments.append(segment_batch)
|
| 556 |
+
if is_final:
|
| 557 |
+
# reset class variables and clear the dict for the next query
|
| 558 |
+
self.AllResetDetection()
|
| 559 |
+
return segments
|
| 560 |
+
|
| 561 |
+
def DetectCommonFrames(self) -> int:
|
| 562 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
| 563 |
+
return 0
|
| 564 |
+
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
| 565 |
+
frame_state = FrameState.kFrameStateInvalid
|
| 566 |
+
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
|
| 567 |
+
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
|
| 568 |
+
self.idx_pre_chunk += self.scores.shape[1]
|
| 569 |
+
return 0
|
| 570 |
+
|
| 571 |
+
def DetectLastFrames(self) -> int:
|
| 572 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
| 573 |
+
return 0
|
| 574 |
+
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
| 575 |
+
frame_state = FrameState.kFrameStateInvalid
|
| 576 |
+
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
|
| 577 |
+
if i != 0:
|
| 578 |
+
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
|
| 579 |
+
else:
|
| 580 |
+
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
|
| 581 |
+
|
| 582 |
+
return 0
|
| 583 |
+
|
| 584 |
+
def DetectOneFrame(
|
| 585 |
+
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
|
| 586 |
+
) -> None:
|
| 587 |
+
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
| 588 |
+
if cur_frm_state == FrameState.kFrameStateSpeech:
|
| 589 |
+
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
|
| 590 |
+
tmp_cur_frm_state = FrameState.kFrameStateSpeech
|
| 591 |
+
else:
|
| 592 |
+
tmp_cur_frm_state = FrameState.kFrameStateSil
|
| 593 |
+
elif cur_frm_state == FrameState.kFrameStateSil:
|
| 594 |
+
tmp_cur_frm_state = FrameState.kFrameStateSil
|
| 595 |
+
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
|
| 596 |
+
frm_shift_in_ms = self.vad_opts.frame_in_ms
|
| 597 |
+
if AudioChangeState.kChangeStateSil2Speech == state_change:
|
| 598 |
+
silence_frame_count = self.continous_silence_frame_count
|
| 599 |
+
self.continous_silence_frame_count = 0
|
| 600 |
+
self.pre_end_silence_detected = False
|
| 601 |
+
start_frame = 0
|
| 602 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 603 |
+
start_frame = max(
|
| 604 |
+
self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()
|
| 605 |
+
)
|
| 606 |
+
self.OnVoiceStart(start_frame)
|
| 607 |
+
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
|
| 608 |
+
for t in range(start_frame + 1, cur_frm_idx + 1):
|
| 609 |
+
self.OnVoiceDetected(t)
|
| 610 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
| 611 |
+
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
|
| 612 |
+
self.OnVoiceDetected(t)
|
| 613 |
+
if (
|
| 614 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
| 615 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
| 616 |
+
):
|
| 617 |
+
self.OnVoiceEnd(cur_frm_idx, False, False)
|
| 618 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 619 |
+
elif not is_final_frame:
|
| 620 |
+
self.OnVoiceDetected(cur_frm_idx)
|
| 621 |
+
else:
|
| 622 |
+
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
| 623 |
+
else:
|
| 624 |
+
pass
|
| 625 |
+
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
|
| 626 |
+
self.continous_silence_frame_count = 0
|
| 627 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 628 |
+
pass
|
| 629 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
| 630 |
+
if (
|
| 631 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
| 632 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
| 633 |
+
):
|
| 634 |
+
self.OnVoiceEnd(cur_frm_idx, False, False)
|
| 635 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 636 |
+
elif not is_final_frame:
|
| 637 |
+
self.OnVoiceDetected(cur_frm_idx)
|
| 638 |
+
else:
|
| 639 |
+
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
| 640 |
+
else:
|
| 641 |
+
pass
|
| 642 |
+
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
|
| 643 |
+
self.continous_silence_frame_count = 0
|
| 644 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
| 645 |
+
if (
|
| 646 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
| 647 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
| 648 |
+
):
|
| 649 |
+
self.max_time_out = True
|
| 650 |
+
self.OnVoiceEnd(cur_frm_idx, False, False)
|
| 651 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 652 |
+
elif not is_final_frame:
|
| 653 |
+
self.OnVoiceDetected(cur_frm_idx)
|
| 654 |
+
else:
|
| 655 |
+
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
| 656 |
+
else:
|
| 657 |
+
pass
|
| 658 |
+
elif AudioChangeState.kChangeStateSil2Sil == state_change:
|
| 659 |
+
self.continous_silence_frame_count += 1
|
| 660 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
| 661 |
+
# silence timeout, return zero length decision
|
| 662 |
+
if (
|
| 663 |
+
(self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
|
| 664 |
+
and (
|
| 665 |
+
self.continous_silence_frame_count * frm_shift_in_ms
|
| 666 |
+
> self.vad_opts.max_start_silence_time
|
| 667 |
+
)
|
| 668 |
+
) or (is_final_frame and self.number_end_time_detected == 0):
|
| 669 |
+
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
|
| 670 |
+
self.OnSilenceDetected(t)
|
| 671 |
+
self.OnVoiceStart(0, True)
|
| 672 |
+
self.OnVoiceEnd(0, True, False)
|
| 673 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 674 |
+
else:
|
| 675 |
+
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
|
| 676 |
+
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
|
| 677 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
| 678 |
+
if (
|
| 679 |
+
self.continous_silence_frame_count * frm_shift_in_ms
|
| 680 |
+
>= self.max_end_sil_frame_cnt_thresh
|
| 681 |
+
):
|
| 682 |
+
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
|
| 683 |
+
if self.vad_opts.do_extend:
|
| 684 |
+
lookback_frame -= int(
|
| 685 |
+
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
|
| 686 |
+
)
|
| 687 |
+
lookback_frame -= 1
|
| 688 |
+
lookback_frame = max(0, lookback_frame)
|
| 689 |
+
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
|
| 690 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 691 |
+
elif (
|
| 692 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
| 693 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
| 694 |
+
):
|
| 695 |
+
self.OnVoiceEnd(cur_frm_idx, False, False)
|
| 696 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
| 697 |
+
elif self.vad_opts.do_extend and not is_final_frame:
|
| 698 |
+
if self.continous_silence_frame_count <= int(
|
| 699 |
+
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
|
| 700 |
+
):
|
| 701 |
+
self.OnVoiceDetected(cur_frm_idx)
|
| 702 |
+
else:
|
| 703 |
+
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
| 704 |
+
else:
|
| 705 |
+
pass
|
| 706 |
+
|
| 707 |
+
if (
|
| 708 |
+
self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
|
| 709 |
+
and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
|
| 710 |
+
):
|
| 711 |
+
self.ResetDetection()
|
utils/utils/frontend.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 4 |
+
import copy
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import kaldi_native_fbank as knf
|
| 9 |
+
|
| 10 |
+
root_dir = Path(__file__).resolve().parent
|
| 11 |
+
|
| 12 |
+
logger_initialized = {}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class WavFrontend:
|
| 16 |
+
"""Conventional frontend structure for ASR."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
cmvn_file: str = None,
|
| 21 |
+
fs: int = 16000,
|
| 22 |
+
window: str = "hamming",
|
| 23 |
+
n_mels: int = 80,
|
| 24 |
+
frame_length: int = 25,
|
| 25 |
+
frame_shift: int = 10,
|
| 26 |
+
lfr_m: int = 1,
|
| 27 |
+
lfr_n: int = 1,
|
| 28 |
+
dither: float = 1.0,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> None:
|
| 31 |
+
|
| 32 |
+
opts = knf.FbankOptions()
|
| 33 |
+
opts.frame_opts.samp_freq = fs
|
| 34 |
+
opts.frame_opts.dither = dither
|
| 35 |
+
opts.frame_opts.window_type = window
|
| 36 |
+
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
| 37 |
+
opts.frame_opts.frame_length_ms = float(frame_length)
|
| 38 |
+
opts.mel_opts.num_bins = n_mels
|
| 39 |
+
opts.energy_floor = 0
|
| 40 |
+
opts.frame_opts.snip_edges = True
|
| 41 |
+
opts.mel_opts.debug_mel = False
|
| 42 |
+
self.opts = opts
|
| 43 |
+
|
| 44 |
+
self.lfr_m = lfr_m
|
| 45 |
+
self.lfr_n = lfr_n
|
| 46 |
+
self.cmvn_file = cmvn_file
|
| 47 |
+
|
| 48 |
+
if self.cmvn_file:
|
| 49 |
+
self.cmvn = load_cmvn(self.cmvn_file)
|
| 50 |
+
self.fbank_fn = None
|
| 51 |
+
self.fbank_beg_idx = 0
|
| 52 |
+
self.reset_status()
|
| 53 |
+
|
| 54 |
+
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 55 |
+
waveform = waveform * (1 << 15)
|
| 56 |
+
fbank_fn = knf.OnlineFbank(self.opts)
|
| 57 |
+
fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 58 |
+
frames = fbank_fn.num_frames_ready
|
| 59 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 60 |
+
for i in range(frames):
|
| 61 |
+
mat[i, :] = fbank_fn.get_frame(i)
|
| 62 |
+
feat = mat.astype(np.float32)
|
| 63 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 64 |
+
return feat, feat_len
|
| 65 |
+
|
| 66 |
+
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 67 |
+
waveform = waveform * (1 << 15)
|
| 68 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 69 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 70 |
+
frames = self.fbank_fn.num_frames_ready
|
| 71 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 72 |
+
for i in range(self.fbank_beg_idx, frames):
|
| 73 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 74 |
+
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
| 75 |
+
feat = mat.astype(np.float32)
|
| 76 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 77 |
+
return feat, feat_len
|
| 78 |
+
|
| 79 |
+
def reset_status(self):
|
| 80 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 81 |
+
self.fbank_beg_idx = 0
|
| 82 |
+
|
| 83 |
+
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 84 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 85 |
+
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
| 86 |
+
|
| 87 |
+
if self.cmvn_file:
|
| 88 |
+
feat = self.apply_cmvn(feat)
|
| 89 |
+
|
| 90 |
+
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 91 |
+
return feat, feat_len
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
| 95 |
+
LFR_inputs = []
|
| 96 |
+
|
| 97 |
+
T = inputs.shape[0]
|
| 98 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
| 99 |
+
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
| 100 |
+
inputs = np.vstack((left_padding, inputs))
|
| 101 |
+
T = T + (lfr_m - 1) // 2
|
| 102 |
+
for i in range(T_lfr):
|
| 103 |
+
if lfr_m <= T - i * lfr_n:
|
| 104 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 105 |
+
else:
|
| 106 |
+
# process last LFR frame
|
| 107 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 108 |
+
frame = inputs[i * lfr_n :].reshape(-1)
|
| 109 |
+
for _ in range(num_padding):
|
| 110 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 111 |
+
|
| 112 |
+
LFR_inputs.append(frame)
|
| 113 |
+
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
| 114 |
+
return LFR_outputs
|
| 115 |
+
|
| 116 |
+
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
| 117 |
+
"""
|
| 118 |
+
Apply CMVN with mvn data
|
| 119 |
+
"""
|
| 120 |
+
frame, dim = inputs.shape
|
| 121 |
+
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
| 122 |
+
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
| 123 |
+
inputs = (inputs + means) * vars
|
| 124 |
+
return inputs
|
| 125 |
+
|
| 126 |
+
@lru_cache()
|
| 127 |
+
def load_cmvn(cmvn_file: Union[str, Path]) -> np.ndarray:
|
| 128 |
+
"""load cmvn file to numpy array.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
cmvn_file (Union[str, Path]): cmvn file path.
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
FileNotFoundError: cmvn file not exits.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
np.ndarray: cmvn array. shape is (2, dim).The first row is means, the second row is vars.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
cmvn_file = Path(cmvn_file)
|
| 141 |
+
if not cmvn_file.exists():
|
| 142 |
+
raise FileNotFoundError("cmvn file not exits")
|
| 143 |
+
|
| 144 |
+
with open(cmvn_file, "r", encoding="utf-8") as f:
|
| 145 |
+
lines = f.readlines()
|
| 146 |
+
means_list = []
|
| 147 |
+
vars_list = []
|
| 148 |
+
for i in range(len(lines)):
|
| 149 |
+
line_item = lines[i].split()
|
| 150 |
+
if line_item[0] == "<AddShift>":
|
| 151 |
+
line_item = lines[i + 1].split()
|
| 152 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 153 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 154 |
+
means_list = list(add_shift_line)
|
| 155 |
+
continue
|
| 156 |
+
elif line_item[0] == "<Rescale>":
|
| 157 |
+
line_item = lines[i + 1].split()
|
| 158 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 159 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 160 |
+
vars_list = list(rescale_line)
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
means = np.array(means_list).astype(np.float64)
|
| 164 |
+
vars = np.array(vars_list).astype(np.float64)
|
| 165 |
+
cmvn = np.array([means, vars])
|
| 166 |
+
return cmvn
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class WavFrontendOnline(WavFrontend):
|
| 170 |
+
def __init__(self, **kwargs):
|
| 171 |
+
super().__init__(**kwargs)
|
| 172 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 173 |
+
# add variables
|
| 174 |
+
self.frame_sample_length = int(
|
| 175 |
+
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
| 176 |
+
)
|
| 177 |
+
self.frame_shift_sample_length = int(
|
| 178 |
+
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
| 179 |
+
)
|
| 180 |
+
self.waveform = None
|
| 181 |
+
self.reserve_waveforms = None
|
| 182 |
+
self.input_cache = None
|
| 183 |
+
self.lfr_splice_cache = []
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
# inputs has catted the cache
|
| 187 |
+
def apply_lfr(
|
| 188 |
+
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
| 189 |
+
) -> Tuple[np.ndarray, np.ndarray, int]:
|
| 190 |
+
"""
|
| 191 |
+
Apply lfr with data
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
LFR_inputs = []
|
| 195 |
+
T = inputs.shape[0] # include the right context
|
| 196 |
+
T_lfr = int(
|
| 197 |
+
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
| 198 |
+
) # minus the right context: (lfr_m - 1) // 2
|
| 199 |
+
splice_idx = T_lfr
|
| 200 |
+
for i in range(T_lfr):
|
| 201 |
+
if lfr_m <= T - i * lfr_n:
|
| 202 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 203 |
+
else: # process last LFR frame
|
| 204 |
+
if is_final:
|
| 205 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 206 |
+
frame = (inputs[i * lfr_n :]).reshape(-1)
|
| 207 |
+
for _ in range(num_padding):
|
| 208 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 209 |
+
LFR_inputs.append(frame)
|
| 210 |
+
else:
|
| 211 |
+
# update splice_idx and break the circle
|
| 212 |
+
splice_idx = i
|
| 213 |
+
break
|
| 214 |
+
splice_idx = min(T - 1, splice_idx * lfr_n)
|
| 215 |
+
lfr_splice_cache = inputs[splice_idx:, :]
|
| 216 |
+
LFR_outputs = np.vstack(LFR_inputs)
|
| 217 |
+
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def compute_frame_num(
|
| 221 |
+
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
| 222 |
+
) -> int:
|
| 223 |
+
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
| 224 |
+
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
| 225 |
+
|
| 226 |
+
def fbank(
|
| 227 |
+
self, input: np.ndarray, input_lengths: np.ndarray
|
| 228 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 229 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 230 |
+
batch_size = input.shape[0]
|
| 231 |
+
if self.input_cache is None:
|
| 232 |
+
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
| 233 |
+
input = np.concatenate((self.input_cache, input), axis=1)
|
| 234 |
+
frame_num = self.compute_frame_num(
|
| 235 |
+
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
| 236 |
+
)
|
| 237 |
+
# update self.in_cache
|
| 238 |
+
self.input_cache = input[
|
| 239 |
+
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
| 240 |
+
]
|
| 241 |
+
waveforms = np.empty(0, dtype=np.float32)
|
| 242 |
+
feats_pad = np.empty(0, dtype=np.float32)
|
| 243 |
+
feats_lens = np.empty(0, dtype=np.int32)
|
| 244 |
+
if frame_num:
|
| 245 |
+
waveforms = []
|
| 246 |
+
feats = []
|
| 247 |
+
feats_lens = []
|
| 248 |
+
for i in range(batch_size):
|
| 249 |
+
waveform = input[i]
|
| 250 |
+
waveforms.append(
|
| 251 |
+
waveform[
|
| 252 |
+
: (
|
| 253 |
+
(frame_num - 1) * self.frame_shift_sample_length
|
| 254 |
+
+ self.frame_sample_length
|
| 255 |
+
)
|
| 256 |
+
]
|
| 257 |
+
)
|
| 258 |
+
waveform = waveform * (1 << 15)
|
| 259 |
+
|
| 260 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 261 |
+
frames = self.fbank_fn.num_frames_ready
|
| 262 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 263 |
+
for i in range(frames):
|
| 264 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 265 |
+
feat = mat.astype(np.float32)
|
| 266 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 267 |
+
feats.append(feat)
|
| 268 |
+
feats_lens.append(feat_len)
|
| 269 |
+
|
| 270 |
+
waveforms = np.stack(waveforms)
|
| 271 |
+
feats_lens = np.array(feats_lens)
|
| 272 |
+
feats_pad = np.array(feats)
|
| 273 |
+
self.fbanks = feats_pad
|
| 274 |
+
self.fbanks_lens = copy.deepcopy(feats_lens)
|
| 275 |
+
return waveforms, feats_pad, feats_lens
|
| 276 |
+
|
| 277 |
+
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 278 |
+
return self.fbanks, self.fbanks_lens
|
| 279 |
+
|
| 280 |
+
def lfr_cmvn(
|
| 281 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 282 |
+
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
| 283 |
+
batch_size = input.shape[0]
|
| 284 |
+
feats = []
|
| 285 |
+
feats_lens = []
|
| 286 |
+
lfr_splice_frame_idxs = []
|
| 287 |
+
for i in range(batch_size):
|
| 288 |
+
mat = input[i, : input_lengths[i], :]
|
| 289 |
+
lfr_splice_frame_idx = -1
|
| 290 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 291 |
+
# update self.lfr_splice_cache in self.apply_lfr
|
| 292 |
+
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
| 293 |
+
mat, self.lfr_m, self.lfr_n, is_final
|
| 294 |
+
)
|
| 295 |
+
if self.cmvn_file is not None:
|
| 296 |
+
mat = self.apply_cmvn(mat)
|
| 297 |
+
feat_length = mat.shape[0]
|
| 298 |
+
feats.append(mat)
|
| 299 |
+
feats_lens.append(feat_length)
|
| 300 |
+
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
| 301 |
+
|
| 302 |
+
feats_lens = np.array(feats_lens)
|
| 303 |
+
feats_pad = np.array(feats)
|
| 304 |
+
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
| 305 |
+
|
| 306 |
+
def extract_fbank(
|
| 307 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 308 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 309 |
+
batch_size = input.shape[0]
|
| 310 |
+
assert (
|
| 311 |
+
batch_size == 1
|
| 312 |
+
), "we support to extract feature online only when the batch size is equal to 1 now"
|
| 313 |
+
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
|
| 314 |
+
if feats.shape[0]:
|
| 315 |
+
self.waveforms = (
|
| 316 |
+
waveforms
|
| 317 |
+
if self.reserve_waveforms is None
|
| 318 |
+
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
| 319 |
+
)
|
| 320 |
+
if not self.lfr_splice_cache:
|
| 321 |
+
for i in range(batch_size):
|
| 322 |
+
self.lfr_splice_cache.append(
|
| 323 |
+
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
| 327 |
+
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
| 328 |
+
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
| 329 |
+
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
| 330 |
+
frame_from_waveforms = int(
|
| 331 |
+
(self.waveforms.shape[1] - self.frame_sample_length)
|
| 332 |
+
/ self.frame_shift_sample_length
|
| 333 |
+
+ 1
|
| 334 |
+
)
|
| 335 |
+
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
| 336 |
+
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
| 337 |
+
feats, feats_lengths, is_final
|
| 338 |
+
)
|
| 339 |
+
if self.lfr_m == 1:
|
| 340 |
+
self.reserve_waveforms = None
|
| 341 |
+
else:
|
| 342 |
+
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
| 343 |
+
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
| 344 |
+
# print('frame_frame: ' + str(frame_from_waveforms))
|
| 345 |
+
self.reserve_waveforms = self.waveforms[
|
| 346 |
+
:,
|
| 347 |
+
reserve_frame_idx
|
| 348 |
+
* self.frame_shift_sample_length : frame_from_waveforms
|
| 349 |
+
* self.frame_shift_sample_length,
|
| 350 |
+
]
|
| 351 |
+
sample_length = (
|
| 352 |
+
frame_from_waveforms - 1
|
| 353 |
+
) * self.frame_shift_sample_length + self.frame_sample_length
|
| 354 |
+
self.waveforms = self.waveforms[:, :sample_length]
|
| 355 |
+
else:
|
| 356 |
+
# update self.reserve_waveforms and self.lfr_splice_cache
|
| 357 |
+
self.reserve_waveforms = self.waveforms[
|
| 358 |
+
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
| 359 |
+
]
|
| 360 |
+
for i in range(batch_size):
|
| 361 |
+
self.lfr_splice_cache[i] = np.concatenate(
|
| 362 |
+
(self.lfr_splice_cache[i], feats[i]), axis=0
|
| 363 |
+
)
|
| 364 |
+
return np.empty(0, dtype=np.float32), feats_lengths
|
| 365 |
+
else:
|
| 366 |
+
if is_final:
|
| 367 |
+
self.waveforms = (
|
| 368 |
+
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
| 369 |
+
)
|
| 370 |
+
feats = np.stack(self.lfr_splice_cache)
|
| 371 |
+
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
| 372 |
+
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
| 373 |
+
if is_final:
|
| 374 |
+
self.cache_reset()
|
| 375 |
+
return feats, feats_lengths
|
| 376 |
+
|
| 377 |
+
def get_waveforms(self):
|
| 378 |
+
return self.waveforms
|
| 379 |
+
|
| 380 |
+
def cache_reset(self):
|
| 381 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 382 |
+
self.reserve_waveforms = None
|
| 383 |
+
self.input_cache = None
|
| 384 |
+
self.lfr_splice_cache = []
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def load_bytes(input):
|
| 388 |
+
middle_data = np.frombuffer(input, dtype=np.int16)
|
| 389 |
+
middle_data = np.asarray(middle_data)
|
| 390 |
+
if middle_data.dtype.kind not in "iu":
|
| 391 |
+
raise TypeError("'middle_data' must be an array of integers")
|
| 392 |
+
dtype = np.dtype("float32")
|
| 393 |
+
if dtype.kind != "f":
|
| 394 |
+
raise TypeError("'dtype' must be a floating point type")
|
| 395 |
+
|
| 396 |
+
i = np.iinfo(middle_data.dtype)
|
| 397 |
+
abs_max = 2 ** (i.bits - 1)
|
| 398 |
+
offset = i.min + abs_max
|
| 399 |
+
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
| 400 |
+
return array
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class SinusoidalPositionEncoderOnline:
|
| 404 |
+
"""Streaming Positional encoding."""
|
| 405 |
+
|
| 406 |
+
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
| 407 |
+
batch_size = positions.shape[0]
|
| 408 |
+
positions = positions.astype(dtype)
|
| 409 |
+
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
| 410 |
+
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
| 411 |
+
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
| 412 |
+
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
| 413 |
+
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
| 414 |
+
return encoding.astype(dtype)
|
| 415 |
+
|
| 416 |
+
def forward(self, x, start_idx=0):
|
| 417 |
+
batch_size, timesteps, input_dim = x.shape
|
| 418 |
+
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
| 419 |
+
position_encoding = self.encode(positions, input_dim, x.dtype)
|
| 420 |
+
|
| 421 |
+
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def test():
|
| 425 |
+
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
| 426 |
+
import librosa
|
| 427 |
+
|
| 428 |
+
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
| 429 |
+
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
| 430 |
+
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
| 431 |
+
|
| 432 |
+
config = read_yaml(config_file)
|
| 433 |
+
waveform, _ = librosa.load(path, sr=None)
|
| 434 |
+
frontend = WavFrontend(
|
| 435 |
+
cmvn_file=cmvn_file,
|
| 436 |
+
**config["frontend_conf"],
|
| 437 |
+
)
|
| 438 |
+
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
| 439 |
+
feat, feat_len = frontend.lfr_cmvn(
|
| 440 |
+
speech
|
| 441 |
+
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
| 442 |
+
|
| 443 |
+
frontend.reset_status() # clear cache
|
| 444 |
+
return feat, feat_len
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
if __name__ == "__main__":
|
| 448 |
+
test()
|