| | import numpy |
| | import numpy as np |
| | import queue |
| | import torch |
| | import random |
| | from torch.utils.data import Dataset |
| | from tqdm import tqdm |
| |
|
| |
|
| | class MyDataset(Dataset): |
| | def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, |
| | storage_size=1, sample_from_storage_p=0.5, additive_noise=0, |
| | num_utter_per_speaker=10, skip_speakers=False, verbose=False): |
| | """ |
| | Args: |
| | ap (TTS.tts.utils.AudioProcessor): audio processor object. |
| | meta_data (list): list of dataset instances. |
| | seq_len (int): voice segment length in seconds. |
| | verbose (bool): print diagnostic information. |
| | """ |
| | self.items = meta_data |
| | self.sample_rate = ap.sample_rate |
| | self.voice_len = voice_len |
| | self.seq_len = int(voice_len * self.sample_rate) |
| | self.num_speakers_in_batch = num_speakers_in_batch |
| | self.num_utter_per_speaker = num_utter_per_speaker |
| | self.skip_speakers = skip_speakers |
| | self.ap = ap |
| | self.verbose = verbose |
| | self.__parse_items() |
| | self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch) |
| | self.sample_from_storage_p = float(sample_from_storage_p) |
| | self.additive_noise = float(additive_noise) |
| | if self.verbose: |
| | print("\n > DataLoader initialization") |
| | print(f" | > Speakers per Batch: {num_speakers_in_batch}") |
| | print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters") |
| | print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") |
| | print(f" | > Noise added : {self.additive_noise}") |
| | print(f" | > Number of instances : {len(self.items)}") |
| | print(f" | > Sequence length: {self.seq_len}") |
| | print(f" | > Num speakers: {len(self.speakers)}") |
| |
|
| | def load_wav(self, filename): |
| | audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) |
| | return audio |
| |
|
| | def load_data(self, idx): |
| | text, wav_file, speaker_name = self.items[idx] |
| | wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) |
| | mel = self.ap.melspectrogram(wav).astype("float32") |
| | |
| |
|
| | assert text.size > 0, self.items[idx][1] |
| | assert wav.size > 0, self.items[idx][1] |
| |
|
| | sample = { |
| | "mel": mel, |
| | "item_idx": self.items[idx][1], |
| | "speaker_name": speaker_name, |
| | } |
| | return sample |
| |
|
| | def __parse_items(self): |
| | self.speaker_to_utters = {} |
| | for i in self.items: |
| | path_ = i[1] |
| | speaker_ = i[2] |
| | if speaker_ in self.speaker_to_utters.keys(): |
| | self.speaker_to_utters[speaker_].append(path_) |
| | else: |
| | self.speaker_to_utters[speaker_] = [path_, ] |
| |
|
| | if self.skip_speakers: |
| | self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if |
| | len(v) >= self.num_utter_per_speaker} |
| |
|
| | self.speakers = [k for (k, v) in self.speaker_to_utters.items()] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def __len__(self): |
| | return int(1e10) |
| |
|
| | def __sample_speaker(self): |
| | speaker = random.sample(self.speakers, 1)[0] |
| | if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): |
| | utters = random.choices( |
| | self.speaker_to_utters[speaker], k=self.num_utter_per_speaker |
| | ) |
| | else: |
| | utters = random.sample( |
| | self.speaker_to_utters[speaker], self.num_utter_per_speaker |
| | ) |
| | return speaker, utters |
| |
|
| | def __sample_speaker_utterances(self, speaker): |
| | """ |
| | Sample all M utterances for the given speaker. |
| | """ |
| | wavs = [] |
| | labels = [] |
| | for _ in range(self.num_utter_per_speaker): |
| | |
| | while True: |
| | if len(self.speaker_to_utters[speaker]) > 0: |
| | utter = random.sample(self.speaker_to_utters[speaker], 1)[0] |
| | else: |
| | self.speakers.remove(speaker) |
| | speaker, _ = self.__sample_speaker() |
| | continue |
| | wav = self.load_wav(utter) |
| | if wav.shape[0] - self.seq_len > 0: |
| | break |
| | self.speaker_to_utters[speaker].remove(utter) |
| |
|
| | wavs.append(wav) |
| | labels.append(speaker) |
| | return wavs, labels |
| |
|
| | def __getitem__(self, idx): |
| | speaker, _ = self.__sample_speaker() |
| | return speaker |
| |
|
| | def collate_fn(self, batch): |
| | labels = [] |
| | feats = [] |
| | for speaker in batch: |
| | if random.random() < self.sample_from_storage_p and self.storage.full(): |
| | |
| | wavs_, labels_ = random.choice(self.storage.queue) |
| | else: |
| | |
| | wavs_, labels_ = self.__sample_speaker_utterances(speaker) |
| | |
| | if self.storage.full(): |
| | _ = self.storage.get_nowait() |
| | |
| | self.storage.put_nowait((wavs_, labels_)) |
| |
|
| | |
| | if self.additive_noise > 0: |
| | noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_] |
| | wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))] |
| |
|
| | |
| | offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] |
| | mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] |
| | feats_ = [torch.FloatTensor(mel) for mel in mels_] |
| |
|
| | labels.append(labels_) |
| | feats.extend(feats_) |
| | feats = torch.stack(feats) |
| | return feats.transpose(1, 2), labels |
| |
|