|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from utils.utils.utils import read_yaml |
|
|
from utils.utils.frontend import WavFrontend |
|
|
from utils.utils.e2e_vad import E2EVadModel |
|
|
import axengine as axe |
|
|
|
|
|
class AX_Fsmn_vad: |
|
|
def __init__(self, model_dir, batch_size=1, max_end_sil=None): |
|
|
"""Initialize VAD model for inference""" |
|
|
|
|
|
|
|
|
model_file = os.path.join(model_dir, "vad.axmodel") |
|
|
|
|
|
|
|
|
config_file = os.path.join(model_dir, "vad/config.yaml") |
|
|
cmvn_file = os.path.join(model_dir, "vad/am.mvn") |
|
|
self.config = read_yaml(config_file) |
|
|
self.frontend = WavFrontend(cmvn_file=cmvn_file, **self.config["frontend_conf"]) |
|
|
|
|
|
self.session = axe.InferenceSession(model_file) |
|
|
self.batch_size = batch_size |
|
|
self.vad_scorer = E2EVadModel(self.config["model_conf"]) |
|
|
self.max_end_sil = max_end_sil if max_end_sil is not None else self.config["model_conf"]["max_end_silence_time"] |
|
|
|
|
|
def extract_feat(self, waveform_list): |
|
|
"""Extract features from waveform""" |
|
|
feats, feats_len = [], [] |
|
|
for waveform in waveform_list: |
|
|
speech, _ = self.frontend.fbank(waveform) |
|
|
feat, feat_len = self.frontend.lfr_cmvn(speech) |
|
|
feats.append(feat) |
|
|
feats_len.append(feat_len) |
|
|
|
|
|
max_len = max(feats_len) |
|
|
padded_feats = [np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), 'constant') for f in feats] |
|
|
feats = np.array(padded_feats).astype(np.float32) |
|
|
feats_len = np.array(feats_len).astype(np.int32) |
|
|
return feats, feats_len |
|
|
|
|
|
def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Run inference with ONNX Runtime""" |
|
|
|
|
|
input_names = [input.name for input in self.session.get_inputs()] |
|
|
output_names = [x.name for x in self.session.get_outputs()] |
|
|
|
|
|
|
|
|
input_dict = {} |
|
|
for i, (name, tensor) in enumerate(zip(input_names, feats)): |
|
|
input_dict[name] = tensor |
|
|
|
|
|
|
|
|
outputs = self.session.run(output_names, input_dict) |
|
|
scores, out_caches = outputs[0], outputs[1:] |
|
|
return scores, out_caches |
|
|
|
|
|
def __call__(self, wav_file, **kwargs): |
|
|
"""Process audio file with sliding window approach""" |
|
|
|
|
|
|
|
|
|
|
|
waveform_list = [wav_file] |
|
|
waveform_nums = len(waveform_list) |
|
|
is_final = kwargs.get("kwargs", False) |
|
|
segments = [[]] * self.batch_size |
|
|
|
|
|
for beg_idx in range(0, waveform_nums, self.batch_size): |
|
|
vad_scorer = E2EVadModel(self.config["model_conf"]) |
|
|
end_idx = min(waveform_nums, beg_idx + self.batch_size) |
|
|
waveform = waveform_list[beg_idx:end_idx] |
|
|
feats, feats_len = self.extract_feat(waveform) |
|
|
waveform = np.array(waveform) |
|
|
param_dict = kwargs.get("param_dict", dict()) |
|
|
in_cache = param_dict.get("in_cache", list()) |
|
|
in_cache = self.prepare_cache(in_cache) |
|
|
|
|
|
t_offset = 0 |
|
|
step = int(min(feats_len.max(), 6000)) |
|
|
for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)): |
|
|
if t_offset + step >= feats_len - 1: |
|
|
step = feats_len - t_offset |
|
|
is_final = True |
|
|
else: |
|
|
is_final = False |
|
|
|
|
|
|
|
|
feats_package = feats[:, t_offset:int(t_offset + step), :] |
|
|
|
|
|
|
|
|
if is_final: |
|
|
pad_length = 6000 - int(step) |
|
|
feats_package = np.pad( |
|
|
feats_package, |
|
|
((0, 0), (0, pad_length), (0, 0)), |
|
|
mode='constant', |
|
|
constant_values=0 |
|
|
) |
|
|
|
|
|
|
|
|
waveform_package = waveform[ |
|
|
:, |
|
|
t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400), |
|
|
] |
|
|
|
|
|
|
|
|
if is_final: |
|
|
expected_wave_length = 6000 * 160 + 240 |
|
|
current_wave_length = waveform_package.shape[-1] |
|
|
pad_wave_length = expected_wave_length - current_wave_length |
|
|
if pad_wave_length > 0: |
|
|
waveform_package = np.pad( |
|
|
waveform_package, |
|
|
((0, 0), (0, pad_wave_length)), |
|
|
mode='constant', |
|
|
constant_values=0 |
|
|
) |
|
|
|
|
|
|
|
|
inputs = [feats_package] |
|
|
inputs.extend(in_cache) |
|
|
scores, out_caches = self.infer(inputs) |
|
|
in_cache = out_caches |
|
|
|
|
|
|
|
|
segments_part = vad_scorer( |
|
|
scores, |
|
|
waveform_package, |
|
|
is_final=is_final, |
|
|
max_end_sil=self.max_end_sil, |
|
|
online=False, |
|
|
) |
|
|
|
|
|
|
|
|
if segments_part: |
|
|
for batch_num in range(0, self.batch_size): |
|
|
segments[batch_num] += segments_part[batch_num] |
|
|
|
|
|
return segments |
|
|
|
|
|
def prepare_cache(self, in_cache: list = []): |
|
|
if len(in_cache) > 0: |
|
|
return in_cache |
|
|
fsmn_layers = 4 |
|
|
proj_dim = 128 |
|
|
lorder = 20 |
|
|
for i in range(fsmn_layers): |
|
|
cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32) |
|
|
in_cache.append(cache) |
|
|
return in_cache |
|
|
|
|
|
|