File size: 9,945 Bytes
1ebaeb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2362cb
82fcf23
1ebaeb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import os.path
from pathlib import Path
from typing import List, Union, Tuple
import torch
import numpy as np
import axengine as axe
from funasr.utils.postprocess_utils import rich_transcription_postprocess
try:
    import librosa
except ImportError:
    print("Warning: librosa not found. Please install it using 'pip install librosa'.")
    # Provide a fallback implementation if needed
    def load_wav_fallback(path, sr=None):
        import wave
        import numpy as np
        with wave.open(path, 'rb') as wf:
            num_frames = wf.getnframes()
            frames = wf.readframes(num_frames)
            return np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0, wf.getframerate()

from utils.infer_utils import (
    CharTokenizer,
    get_logger,
    read_yaml,
)
from utils.frontend import WavFrontend
from utils.ctc_alignment import ctc_forced_align

logging = get_logger()


def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
    if maxlen is None:
        maxlen = lengths.max()
    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
    matrix = torch.unsqueeze(lengths, dim=-1)
    mask = row_vector < matrix
    mask = mask.detach()

    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)


class AX_SenseVoiceSmall:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
        self,
        model_dir: Union[str, Path] = None,
        batch_size: int = 1,
        seq_len: int = 68
    ):

        model_file = os.path.join(model_dir, "sensevoice.axmodel")
        config_file = os.path.join(model_dir, "sensevoice/config.yaml")
        cmvn_file = os.path.join(model_dir, "sensevoice/am.mvn")
        config = read_yaml(config_file)
        self.model_dir = model_dir
        # token_list = os.path.join(model_dir, "tokens.json")
        # with open(token_list, "r", encoding="utf-8") as f:
        #     token_list = json.load(f)

        # self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        config["frontend_conf"]['cmvn_file'] = cmvn_file
        self.frontend = WavFrontend(**config["frontend_conf"])
        # self.ort_infer = OrtInferSession(
        #     model_file, device_id, intra_op_num_threads=intra_op_num_threads
        # )
        #self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
        self.session = axe.InferenceSession(model_file)
        self.batch_size = batch_size
        self.blank_id = 0
        self.seq_len = seq_len

        self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
        self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
        self.textnorm_dict = {"withitn": 14, "woitn": 15}
        self.textnorm_int_dict = {25016: 14, 25017: 15}
        self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}

    def __call__(self, 
                 wav_content: Union[str, np.ndarray, List[str]], 
                 language: str,
                 withitn: bool,
                 position_encoding: np.ndarray,
                 tokenizer=None,
                 **kwargs) -> List:
        """Enhanced model inference with additional features from model.py
        
        Args:
            wav_content: Audio data or path
            language: Language code for processing
            withitn: Whether to use ITN (inverse text normalization)
            position_encoding: Position encoding tensor
            tokenizer: Tokenizer for text conversion
            **kwargs: Additional arguments
        """
        # Start time tracking for metadata
        import time
        meta_data = {}
        time_start = time.perf_counter()
        
        # Load waveform data
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        time_load = time.perf_counter()
        meta_data["load_data"] = f"{time_load - time_start:0.3f}"
        
        # Load queries from saved numpy files
        language_query = np.load(os.path.join(self.model_dir, f"{language}.npy"))
        textnorm_query = np.load(os.path.join(self.model_dir, "withitn.npy") if withitn 
                                 else os.path.join(self.model_dir, "woitn.npy"))
        event_emo_query = np.load(os.path.join(self.model_dir, "event_emo.npy"))

        # Concatenate queries to form input_query
        input_query = np.concatenate((language_query, event_emo_query, textnorm_query), axis=1)
        
        # Process features
        results = ""

        # Handle output_dir without using DatadirWriter (which is not available)
        slice_len = self.seq_len - 4
        time_pre = time.perf_counter()
        meta_data["preprocess"] = f"{time_pre - time_load:0.3f}"
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            
            time_feat = time.perf_counter()
            meta_data["extract_feat"] = f"{time_feat - time_pre:0.3f}"

            for i in range(int(np.ceil(feats.shape[1] / slice_len))):
                sub_feats = np.concatenate([input_query, feats[:, i*slice_len : (i+1)*slice_len, :]], axis=1)
                feats_len[0] = sub_feats.shape[1]
                

                if feats_len[0] < self.seq_len:
                    sub_feats = np.concatenate([sub_feats, np.zeros((1, self.seq_len - feats_len[0], 560), dtype=np.float32)], axis=1)

                masks = sequence_mask(torch.IntTensor([self.seq_len]), maxlen=self.seq_len, dtype=torch.float32)[:, None, :]
                masks = masks.numpy()
                
                # Run inference

                ctc_logits, encoder_out_lens = self.infer(sub_feats, masks, position_encoding)
                # Convert to torch tensor for processing
                ctc_logits = torch.from_numpy(ctc_logits).float()
                              
                # Process results for each batch
                b, _, _ = ctc_logits.size()
                    
                for j in range(b):
                    x = ctc_logits[j, : encoder_out_lens[j].item(), :]
                    yseq = x.argmax(dim=-1)
                    yseq = torch.unique_consecutive(yseq, dim=-1)

                    mask = yseq != self.blank_id
                    token_int = yseq[mask].tolist()[4:] #前4个略去: <|zh|><|ANGRY|><|Speech|><|withitn|>

                    # Convert tokens to text
                    text = tokenizer.decode(token_int) if tokenizer is not None else str(token_int)

                    if tokenizer is not None:
                        results+= text
                    else:
                        results+= token_int  
        return results

    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            try:
                # Use librosa if available
                if 'librosa' in globals():
                    waveform, _ = librosa.load(path, sr=fs)
                else:
                    # Use fallback implementation
                    waveform, native_sr = load_wav_fallback(path)
                    if fs is not None and native_sr != fs:
                        # Implement resampling if needed
                        print(f"Warning: Resampling from {native_sr} to {fs} is not implemented in fallback mode")
                return waveform
            except Exception as e:
                print(f"Error loading audio file {path}: {e}")
                # Return empty audio in case of error
                return np.zeros(1600, dtype=np.float32)

        if isinstance(wav_content, np.ndarray):
            return [wav_content]

        if isinstance(wav_content, str):
            return [load_wav(wav_content)]

        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]

        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")

    def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)

            feat, feat_len = self.frontend.lfr_cmvn(speech)

            feats.append(feat)
            feats_len.append(feat_len)

        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len

    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, "constant", constant_values=0)

        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats

    def infer(self, 
              feats: np.ndarray, 
              masks: np.ndarray,
              position_encoding: np.ndarray,
              ) -> Tuple[np.ndarray, np.ndarray]:
        #outputs = self.ort_infer([feats, masks, position_encoding])
        outputs =self.session.run(None, {
            'speech': feats,
            'masks': masks,
            'position_encoding': position_encoding
        })
        return outputs