|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path |
|
|
from pathlib import Path |
|
|
from typing import List, Union, Tuple |
|
|
import torch |
|
|
import numpy as np |
|
|
import axengine as axe |
|
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess |
|
|
try: |
|
|
import librosa |
|
|
except ImportError: |
|
|
print("Warning: librosa not found. Please install it using 'pip install librosa'.") |
|
|
|
|
|
def load_wav_fallback(path, sr=None): |
|
|
import wave |
|
|
import numpy as np |
|
|
with wave.open(path, 'rb') as wf: |
|
|
num_frames = wf.getnframes() |
|
|
frames = wf.readframes(num_frames) |
|
|
return np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0, wf.getframerate() |
|
|
|
|
|
from utils.infer_utils import ( |
|
|
CharTokenizer, |
|
|
get_logger, |
|
|
read_yaml, |
|
|
) |
|
|
from utils.frontend import WavFrontend |
|
|
from utils.ctc_alignment import ctc_forced_align |
|
|
|
|
|
logging = get_logger() |
|
|
|
|
|
|
|
|
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
|
|
if maxlen is None: |
|
|
maxlen = lengths.max() |
|
|
row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
|
|
matrix = torch.unsqueeze(lengths, dim=-1) |
|
|
mask = row_vector < matrix |
|
|
mask = mask.detach() |
|
|
|
|
|
return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
|
|
|
|
|
|
|
|
class AX_SenseVoiceSmall: |
|
|
""" |
|
|
Author: Speech Lab of DAMO Academy, Alibaba Group |
|
|
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
|
|
https://arxiv.org/abs/2206.08317 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_dir: Union[str, Path] = None, |
|
|
batch_size: int = 1, |
|
|
seq_len: int = 68 |
|
|
): |
|
|
|
|
|
model_file = os.path.join(model_dir, "sensevoice.axmodel") |
|
|
config_file = os.path.join(model_dir, "sensevoice/config.yaml") |
|
|
cmvn_file = os.path.join(model_dir, "sensevoice/am.mvn") |
|
|
config = read_yaml(config_file) |
|
|
self.model_dir = model_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = CharTokenizer() |
|
|
config["frontend_conf"]['cmvn_file'] = cmvn_file |
|
|
self.frontend = WavFrontend(**config["frontend_conf"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.session = axe.InferenceSession(model_file) |
|
|
self.batch_size = batch_size |
|
|
self.blank_id = 0 |
|
|
self.seq_len = seq_len |
|
|
|
|
|
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} |
|
|
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} |
|
|
self.textnorm_dict = {"withitn": 14, "woitn": 15} |
|
|
self.textnorm_int_dict = {25016: 14, 25017: 15} |
|
|
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} |
|
|
|
|
|
def __call__(self, |
|
|
wav_content: Union[str, np.ndarray, List[str]], |
|
|
language: str, |
|
|
withitn: bool, |
|
|
position_encoding: np.ndarray, |
|
|
tokenizer=None, |
|
|
**kwargs) -> List: |
|
|
"""Enhanced model inference with additional features from model.py |
|
|
|
|
|
Args: |
|
|
wav_content: Audio data or path |
|
|
language: Language code for processing |
|
|
withitn: Whether to use ITN (inverse text normalization) |
|
|
position_encoding: Position encoding tensor |
|
|
tokenizer: Tokenizer for text conversion |
|
|
**kwargs: Additional arguments |
|
|
""" |
|
|
|
|
|
import time |
|
|
meta_data = {} |
|
|
time_start = time.perf_counter() |
|
|
|
|
|
|
|
|
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
|
|
waveform_nums = len(waveform_list) |
|
|
time_load = time.perf_counter() |
|
|
meta_data["load_data"] = f"{time_load - time_start:0.3f}" |
|
|
|
|
|
|
|
|
language_query = np.load(os.path.join(self.model_dir, f"{language}.npy")) |
|
|
textnorm_query = np.load(os.path.join(self.model_dir, "withitn.npy") if withitn |
|
|
else os.path.join(self.model_dir, "woitn.npy")) |
|
|
event_emo_query = np.load(os.path.join(self.model_dir, "event_emo.npy")) |
|
|
|
|
|
|
|
|
input_query = np.concatenate((language_query, event_emo_query, textnorm_query), axis=1) |
|
|
|
|
|
|
|
|
results = "" |
|
|
|
|
|
|
|
|
slice_len = self.seq_len - 4 |
|
|
time_pre = time.perf_counter() |
|
|
meta_data["preprocess"] = f"{time_pre - time_load:0.3f}" |
|
|
for beg_idx in range(0, waveform_nums, self.batch_size): |
|
|
end_idx = min(waveform_nums, beg_idx + self.batch_size) |
|
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) |
|
|
|
|
|
time_feat = time.perf_counter() |
|
|
meta_data["extract_feat"] = f"{time_feat - time_pre:0.3f}" |
|
|
|
|
|
for i in range(int(np.ceil(feats.shape[1] / slice_len))): |
|
|
sub_feats = np.concatenate([input_query, feats[:, i*slice_len : (i+1)*slice_len, :]], axis=1) |
|
|
feats_len[0] = sub_feats.shape[1] |
|
|
|
|
|
|
|
|
if feats_len[0] < self.seq_len: |
|
|
sub_feats = np.concatenate([sub_feats, np.zeros((1, self.seq_len - feats_len[0], 560), dtype=np.float32)], axis=1) |
|
|
|
|
|
masks = sequence_mask(torch.IntTensor([self.seq_len]), maxlen=self.seq_len, dtype=torch.float32)[:, None, :] |
|
|
masks = masks.numpy() |
|
|
|
|
|
|
|
|
|
|
|
ctc_logits, encoder_out_lens = self.infer(sub_feats, masks, position_encoding) |
|
|
|
|
|
ctc_logits = torch.from_numpy(ctc_logits).float() |
|
|
|
|
|
|
|
|
b, _, _ = ctc_logits.size() |
|
|
|
|
|
for j in range(b): |
|
|
x = ctc_logits[j, : encoder_out_lens[j].item(), :] |
|
|
yseq = x.argmax(dim=-1) |
|
|
yseq = torch.unique_consecutive(yseq, dim=-1) |
|
|
|
|
|
mask = yseq != self.blank_id |
|
|
token_int = yseq[mask].tolist()[4:] |
|
|
|
|
|
|
|
|
text = tokenizer.decode(token_int) if tokenizer is not None else str(token_int) |
|
|
|
|
|
if tokenizer is not None: |
|
|
results+= text |
|
|
else: |
|
|
results+= token_int |
|
|
return results |
|
|
|
|
|
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |
|
|
def load_wav(path: str) -> np.ndarray: |
|
|
try: |
|
|
|
|
|
if 'librosa' in globals(): |
|
|
waveform, _ = librosa.load(path, sr=fs) |
|
|
else: |
|
|
|
|
|
waveform, native_sr = load_wav_fallback(path) |
|
|
if fs is not None and native_sr != fs: |
|
|
|
|
|
print(f"Warning: Resampling from {native_sr} to {fs} is not implemented in fallback mode") |
|
|
return waveform |
|
|
except Exception as e: |
|
|
print(f"Error loading audio file {path}: {e}") |
|
|
|
|
|
return np.zeros(1600, dtype=np.float32) |
|
|
|
|
|
if isinstance(wav_content, np.ndarray): |
|
|
return [wav_content] |
|
|
|
|
|
if isinstance(wav_content, str): |
|
|
return [load_wav(wav_content)] |
|
|
|
|
|
if isinstance(wav_content, list): |
|
|
return [load_wav(path) for path in wav_content] |
|
|
|
|
|
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") |
|
|
|
|
|
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: |
|
|
feats, feats_len = [], [] |
|
|
for waveform in waveform_list: |
|
|
speech, _ = self.frontend.fbank(waveform) |
|
|
|
|
|
feat, feat_len = self.frontend.lfr_cmvn(speech) |
|
|
|
|
|
feats.append(feat) |
|
|
feats_len.append(feat_len) |
|
|
|
|
|
feats = self.pad_feats(feats, np.max(feats_len)) |
|
|
feats_len = np.array(feats_len).astype(np.int32) |
|
|
return feats, feats_len |
|
|
|
|
|
@staticmethod |
|
|
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
|
|
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
|
|
pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
|
|
return np.pad(feat, pad_width, "constant", constant_values=0) |
|
|
|
|
|
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
|
|
feats = np.array(feat_res).astype(np.float32) |
|
|
return feats |
|
|
|
|
|
def infer(self, |
|
|
feats: np.ndarray, |
|
|
masks: np.ndarray, |
|
|
position_encoding: np.ndarray, |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
|
|
|
outputs =self.session.run(None, { |
|
|
'speech': feats, |
|
|
'masks': masks, |
|
|
'position_encoding': position_encoding |
|
|
}) |
|
|
return outputs |
|
|
|