|
|
from .wav2vec2 import Wav2Vec2Model |
|
|
from .whisper import WhisperLargeV3 |
|
|
import comfy.model_management |
|
|
import comfy.ops |
|
|
import comfy.utils |
|
|
import logging |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
class AudioEncoderModel(): |
|
|
def __init__(self, config): |
|
|
self.load_device = comfy.model_management.text_encoder_device() |
|
|
offload_device = comfy.model_management.text_encoder_offload_device() |
|
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) |
|
|
model_type = config.pop("model_type") |
|
|
model_config = dict(config) |
|
|
model_config.update({ |
|
|
"dtype": self.dtype, |
|
|
"device": offload_device, |
|
|
"operations": comfy.ops.manual_cast |
|
|
}) |
|
|
|
|
|
if model_type == "wav2vec2": |
|
|
self.model = Wav2Vec2Model(**model_config) |
|
|
elif model_type == "whisper3": |
|
|
self.model = WhisperLargeV3(**model_config) |
|
|
self.model.eval() |
|
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) |
|
|
self.model_sample_rate = 16000 |
|
|
|
|
|
def load_sd(self, sd): |
|
|
return self.model.load_state_dict(sd, strict=False) |
|
|
|
|
|
def get_sd(self): |
|
|
return self.model.state_dict() |
|
|
|
|
|
def encode_audio(self, audio, sample_rate): |
|
|
comfy.model_management.load_model_gpu(self.patcher) |
|
|
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) |
|
|
out, all_layers = self.model(audio.to(self.load_device)) |
|
|
outputs = {} |
|
|
outputs["encoded_audio"] = out |
|
|
outputs["encoded_audio_all_layers"] = all_layers |
|
|
return outputs |
|
|
|
|
|
|
|
|
def load_audio_encoder_from_sd(sd, prefix=""): |
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) |
|
|
if "encoder.layer_norm.bias" in sd: |
|
|
embed_dim = sd["encoder.layer_norm.bias"].shape[0] |
|
|
if embed_dim == 1024: |
|
|
config = { |
|
|
"model_type": "wav2vec2", |
|
|
"embed_dim": 1024, |
|
|
"num_heads": 16, |
|
|
"num_layers": 24, |
|
|
"conv_norm": True, |
|
|
"conv_bias": True, |
|
|
"do_normalize": True, |
|
|
"do_stable_layer_norm": True |
|
|
} |
|
|
elif embed_dim == 768: |
|
|
config = { |
|
|
"model_type": "wav2vec2", |
|
|
"embed_dim": 768, |
|
|
"num_heads": 12, |
|
|
"num_layers": 12, |
|
|
"conv_norm": False, |
|
|
"conv_bias": False, |
|
|
"do_normalize": False, |
|
|
"do_stable_layer_norm": False |
|
|
} |
|
|
else: |
|
|
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) |
|
|
elif "model.encoder.embed_positions.weight" in sd: |
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) |
|
|
config = { |
|
|
"model_type": "whisper3", |
|
|
} |
|
|
else: |
|
|
raise RuntimeError("ERROR: audio encoder not supported.") |
|
|
|
|
|
audio_encoder = AudioEncoderModel(config) |
|
|
m, u = audio_encoder.load_sd(sd) |
|
|
if len(m) > 0: |
|
|
logging.warning("missing audio encoder: {}".format(m)) |
|
|
if len(u) > 0: |
|
|
logging.warning("unexpected audio encoder: {}".format(u)) |
|
|
|
|
|
return audio_encoder |
|
|
|