File size: 3,337 Bytes
b68ef84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
import torchaudio


class AudioEncoderModel():
    def __init__(self, config):
        self.load_device = comfy.model_management.text_encoder_device()
        offload_device = comfy.model_management.text_encoder_offload_device()
        self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
        model_type = config.pop("model_type")
        model_config = dict(config)
        model_config.update({
            "dtype": self.dtype,
            "device": offload_device,
            "operations": comfy.ops.manual_cast
        })

        if model_type == "wav2vec2":
            self.model = Wav2Vec2Model(**model_config)
        elif model_type == "whisper3":
            self.model = WhisperLargeV3(**model_config)
        self.model.eval()
        self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
        self.model_sample_rate = 16000

    def load_sd(self, sd):
        return self.model.load_state_dict(sd, strict=False)

    def get_sd(self):
        return self.model.state_dict()

    def encode_audio(self, audio, sample_rate):
        comfy.model_management.load_model_gpu(self.patcher)
        audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
        out, all_layers = self.model(audio.to(self.load_device))
        outputs = {}
        outputs["encoded_audio"] = out
        outputs["encoded_audio_all_layers"] = all_layers
        return outputs


def load_audio_encoder_from_sd(sd, prefix=""):
    sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
    if "encoder.layer_norm.bias" in sd: #wav2vec2
        embed_dim = sd["encoder.layer_norm.bias"].shape[0]
        if embed_dim == 1024:# large
            config = {
                "model_type": "wav2vec2",
                "embed_dim": 1024,
                "num_heads": 16,
                "num_layers": 24,
                "conv_norm": True,
                "conv_bias": True,
                "do_normalize": True,
                "do_stable_layer_norm": True
                }
        elif embed_dim == 768: # base
            config = {
                "model_type": "wav2vec2",
                "embed_dim": 768,
                "num_heads": 12,
                "num_layers": 12,
                "conv_norm": False,
                "conv_bias": False,
                "do_normalize": False, # chinese-wav2vec2-base has this False
                "do_stable_layer_norm": False
            }
        else:
            raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
    elif "model.encoder.embed_positions.weight" in sd:
        sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
        config = {
            "model_type": "whisper3",
        }
    else:
        raise RuntimeError("ERROR: audio encoder not supported.")

    audio_encoder = AudioEncoderModel(config)
    m, u = audio_encoder.load_sd(sd)
    if len(m) > 0:
        logging.warning("missing audio encoder: {}".format(m))
    if len(u) > 0:
        logging.warning("unexpected audio encoder: {}".format(u))

    return audio_encoder