Spaces:
Runtime error
Runtime error
| # Copyright 2025 ByteDance and/or its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class LengthRegulator(torch.nn.Module): | |
| def __init__(self, pad_value=0.0): | |
| super(LengthRegulator, self).__init__() | |
| self.pad_value = pad_value | |
| def forward(self, dur, dur_padding=None, alpha=1.0): | |
| """ | |
| Example (no batch dim version): | |
| 1. dur = [2,2,3] | |
| 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] | |
| 3. token_mask = [[1,1,0,0,0,0,0], | |
| [0,0,1,1,0,0,0], | |
| [0,0,0,0,1,1,1]] | |
| 4. token_idx * token_mask = [[1,1,0,0,0,0,0], | |
| [0,0,2,2,0,0,0], | |
| [0,0,0,0,3,3,3]] | |
| 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] | |
| :param dur: Batch of durations of each frame (B, T_txt) | |
| :param dur_padding: Batch of padding of each frame (B, T_txt) | |
| :param alpha: duration rescale coefficient | |
| :return: | |
| mel2ph (B, T_speech) | |
| assert alpha > 0 | |
| """ | |
| dur = torch.round(dur.float() * alpha).long() | |
| if dur_padding is not None: | |
| dur = dur * (1 - dur_padding.long()) | |
| token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) | |
| dur_cumsum = torch.cumsum(dur, 1) | |
| dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) | |
| pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) | |
| token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) | |
| mel2token = (token_idx * token_mask.long()).sum(1) | |
| return mel2token | |
| class PosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim) * -emb) | |
| self.emb = emb # TODO | |
| def forward(self, x): | |
| emb = x[:, :, None] * self.emb[None, None, :].to(x.device) | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |