pcunwa commited on
Commit
3cfa53e
·
verified ·
1 Parent(s): dffeed4

Upload 3 files

Browse files
Files changed (3) hide show
  1. bs_hyperace.ckpt +3 -0
  2. bs_roformer.py +1029 -0
  3. config.yaml +129 -0
bs_hyperace.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:907ac3c899a1ecdfcfa7491bb31936d26a4d472bdb5efb258237fdca666c0d6a
3
+ size 274890855
bs_roformer.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+ try:
10
+ from models.bs_roformer.attend_sage import Attend as AttendSage
11
+ except:
12
+ pass
13
+ from torch.utils.checkpoint import checkpoint
14
+
15
+ from beartype.typing import Tuple, Optional, List, Callable
16
+ from beartype import beartype
17
+
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange
22
+ import torchaudio
23
+ # helper functions
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(v, d):
30
+ return v if exists(v) else d
31
+
32
+
33
+ def pack_one(t, pattern):
34
+ return pack([t], pattern)
35
+
36
+
37
+ def unpack_one(t, ps, pattern):
38
+ return unpack(t, ps, pattern)[0]
39
+
40
+
41
+ # norm
42
+
43
+ def l2norm(t):
44
+ return F.normalize(t, dim = -1, p = 2)
45
+
46
+
47
+ class RMSNorm(Module):
48
+ def __init__(self, dim):
49
+ super().__init__()
50
+ self.scale = dim ** 0.5
51
+ self.gamma = nn.Parameter(torch.ones(dim))
52
+
53
+ def forward(self, x):
54
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
55
+
56
+
57
+ # attention
58
+
59
+ class FeedForward(Module):
60
+ def __init__(
61
+ self,
62
+ dim,
63
+ mult=4,
64
+ dropout=0.
65
+ ):
66
+ super().__init__()
67
+ dim_inner = int(dim * mult)
68
+ self.net = nn.Sequential(
69
+ RMSNorm(dim),
70
+ nn.Linear(dim, dim_inner),
71
+ nn.GELU(),
72
+ nn.Dropout(dropout),
73
+ nn.Linear(dim_inner, dim),
74
+ nn.Dropout(dropout)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+ class Attention(Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ heads=8,
85
+ dim_head=64,
86
+ dropout=0.,
87
+ rotary_embed=None,
88
+ flash=True,
89
+ sage_attention=False,
90
+ ):
91
+ super().__init__()
92
+ self.heads = heads
93
+ self.scale = dim_head ** -0.5
94
+ dim_inner = heads * dim_head
95
+
96
+ self.rotary_embed = rotary_embed
97
+
98
+ if sage_attention:
99
+ self.attend = AttendSage(flash=flash, dropout=dropout)
100
+ else:
101
+ self.attend = Attend(flash=flash, dropout=dropout)
102
+
103
+ self.norm = RMSNorm(dim)
104
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
105
+
106
+ self.to_gates = nn.Linear(dim, heads)
107
+
108
+ self.to_out = nn.Sequential(
109
+ nn.Linear(dim_inner, dim, bias=False),
110
+ nn.Dropout(dropout)
111
+ )
112
+
113
+ def forward(self, x):
114
+ x = self.norm(x)
115
+
116
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
117
+
118
+ if exists(self.rotary_embed):
119
+ q = self.rotary_embed.rotate_queries_or_keys(q)
120
+ k = self.rotary_embed.rotate_queries_or_keys(k)
121
+
122
+ out = self.attend(q, k, v)
123
+
124
+ gates = self.to_gates(x)
125
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
126
+
127
+ out = rearrange(out, 'b h n d -> b n (h d)')
128
+ return self.to_out(out)
129
+
130
+
131
+ class LinearAttention(Module):
132
+ """
133
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
134
+ """
135
+
136
+ @beartype
137
+ def __init__(
138
+ self,
139
+ *,
140
+ dim,
141
+ dim_head=32,
142
+ heads=8,
143
+ scale=8,
144
+ flash=False,
145
+ dropout=0.,
146
+ sage_attention=False,
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ if sage_attention:
160
+ self.attend = AttendSage(
161
+ scale=scale,
162
+ dropout=dropout,
163
+ flash=flash
164
+ )
165
+ else:
166
+ self.attend = Attend(
167
+ scale=scale,
168
+ dropout=dropout,
169
+ flash=flash
170
+ )
171
+
172
+ self.to_out = nn.Sequential(
173
+ Rearrange('b h d n -> b n (h d)'),
174
+ nn.Linear(dim_inner, dim, bias=False)
175
+ )
176
+
177
+ def forward(
178
+ self,
179
+ x
180
+ ):
181
+ x = self.norm(x)
182
+
183
+ q, k, v = self.to_qkv(x)
184
+
185
+ q, k = map(l2norm, (q, k))
186
+ q = q * self.temperature.exp()
187
+
188
+ out = self.attend(q, k, v)
189
+
190
+ return self.to_out(out)
191
+
192
+ @torch.compile()
193
+ class Transformer(Module):
194
+ def __init__(
195
+ self,
196
+ *,
197
+ dim,
198
+ depth,
199
+ dim_head=64,
200
+ heads=8,
201
+ attn_dropout=0.,
202
+ ff_dropout=0.,
203
+ ff_mult=4,
204
+ norm_output=True,
205
+ rotary_embed=None,
206
+ flash_attn=True,
207
+ linear_attn=False,
208
+ sage_attention=False,
209
+ ):
210
+ super().__init__()
211
+ self.layers = ModuleList([])
212
+
213
+ for _ in range(depth):
214
+ if linear_attn:
215
+ attn = LinearAttention(
216
+ dim=dim,
217
+ dim_head=dim_head,
218
+ heads=heads,
219
+ dropout=attn_dropout,
220
+ flash=flash_attn,
221
+ sage_attention=sage_attention
222
+ )
223
+ else:
224
+ attn = Attention(
225
+ dim=dim,
226
+ dim_head=dim_head,
227
+ heads=heads,
228
+ dropout=attn_dropout,
229
+ rotary_embed=rotary_embed,
230
+ flash=flash_attn,
231
+ sage_attention=sage_attention
232
+ )
233
+
234
+ self.layers.append(ModuleList([
235
+ attn,
236
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
237
+ ]))
238
+
239
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
240
+
241
+ def forward(self, x):
242
+
243
+ for attn, ff in self.layers:
244
+ x = attn(x) + x
245
+ x = ff(x) + x
246
+
247
+ return self.norm(x)
248
+
249
+
250
+ # bandsplit module
251
+
252
+
253
+
254
+ class BandSplit(Module):
255
+ @beartype
256
+ def __init__(
257
+ self,
258
+ dim,
259
+ dim_inputs: Tuple[int, ...]
260
+ ):
261
+ super().__init__()
262
+ self.dim_inputs = dim_inputs
263
+ self.to_features = ModuleList([])
264
+
265
+ for dim_in in dim_inputs:
266
+ net = nn.Sequential(
267
+ RMSNorm(dim_in),
268
+ nn.Linear(dim_in, dim)
269
+ )
270
+
271
+ self.to_features.append(net)
272
+
273
+ def forward(self, x):
274
+
275
+ x = x.split(self.dim_inputs, dim=-1)
276
+
277
+ outs = []
278
+ for split_input, to_feature in zip(x, self.to_features):
279
+ split_output = to_feature(split_input)
280
+ outs.append(split_output)
281
+
282
+ x = torch.stack(outs, dim=-2)
283
+
284
+ return x
285
+
286
+ class Conv(nn.Module):
287
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
288
+ super().__init__()
289
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
290
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
291
+ self.act = nn.SiLU() if act else nn.Identity()
292
+
293
+ def forward(self, x):
294
+ return self.act(self.bn(self.conv(x)))
295
+
296
+ def autopad(k, p=None):
297
+ if p is None:
298
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
299
+ return p
300
+
301
+ class DSConv(nn.Module):
302
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
303
+ super().__init__()
304
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
305
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
306
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
307
+ self.act = nn.SiLU() if act else nn.Identity()
308
+
309
+ def forward(self, x):
310
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
311
+
312
+ class DS_Bottleneck(nn.Module):
313
+ def __init__(self, c1, c2, k=3, shortcut=True):
314
+ super().__init__()
315
+ c_ = c1
316
+ self.dsconv1 = DSConv(c1, c_, k=3, s=1)
317
+ self.dsconv2 = DSConv(c_, c2, k=k, s=1)
318
+ self.shortcut = shortcut and c1 == c2
319
+
320
+ def forward(self, x):
321
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
322
+
323
+ class DS_C3k(nn.Module):
324
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
325
+ super().__init__()
326
+ c_ = int(c2 * e)
327
+ self.cv1 = Conv(c1, c_, 1, 1)
328
+ self.cv2 = Conv(c1, c_, 1, 1)
329
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
330
+ self.m = nn.Sequential(*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)])
331
+
332
+ def forward(self, x):
333
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
334
+
335
+ class DS_C3k2(nn.Module):
336
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
337
+ super().__init__()
338
+ c_ = int(c2 * e)
339
+ self.cv1 = Conv(c1, c_, 1, 1)
340
+ self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
341
+ self.cv2 = Conv(c_, c2, 1, 1)
342
+
343
+ def forward(self, x):
344
+ x_ = self.cv1(x)
345
+ x_ = self.m(x_)
346
+ return self.cv2(x_)
347
+
348
+ class AdaptiveHyperedgeGeneration(nn.Module):
349
+ def __init__(self, in_channels, num_hyperedges, num_heads=8):
350
+ super().__init__()
351
+ self.num_hyperedges = num_hyperedges
352
+ self.num_heads = num_heads
353
+ self.head_dim = in_channels // num_heads
354
+
355
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
356
+
357
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
358
+
359
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
360
+
361
+ self.scale = self.head_dim ** -0.5
362
+
363
+ def forward(self, x):
364
+ B, N, C = x.shape
365
+
366
+ f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
367
+ f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
368
+ f_ctx = torch.cat((f_avg, f_max), dim=1)
369
+
370
+ delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
371
+ P = self.global_proto.unsqueeze(0) + delta_P
372
+
373
+ z = self.query_proj(x)
374
+
375
+ z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
376
+
377
+ P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)
378
+
379
+ sim = (z @ P) * self.scale
380
+
381
+ s_bar = sim.mean(dim=1)
382
+
383
+ A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
384
+
385
+ return A
386
+
387
+ class HypergraphConvolution(nn.Module):
388
+ def __init__(self, in_channels, out_channels):
389
+ super().__init__()
390
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
391
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
392
+ self.act = nn.SiLU()
393
+
394
+ def forward(self, x, A):
395
+ f_m = torch.bmm(A, x)
396
+ f_m = self.act(self.W_e(f_m))
397
+
398
+ x_out = torch.bmm(A.transpose(1, 2), f_m)
399
+ x_out = self.act(self.W_v(x_out))
400
+
401
+ return x + x_out
402
+
403
+ class AdaptiveHypergraphComputation(nn.Module):
404
+ def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
405
+ super().__init__()
406
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
407
+ in_channels, num_hyperedges, num_heads
408
+ )
409
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
410
+
411
+ def forward(self, x):
412
+ B, C, H, W = x.shape
413
+ x_flat = x.flatten(2).permute(0, 2, 1)
414
+
415
+ A = self.adaptive_hyperedge_gen(x_flat)
416
+
417
+ x_out_flat = self.hypergraph_conv(x_flat, A)
418
+
419
+ x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
420
+ return x_out
421
+
422
+ class C3AH(nn.Module):
423
+ def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
424
+ super().__init__()
425
+ c_ = int(c1 * e)
426
+ self.cv1 = Conv(c1, c_, 1, 1)
427
+ self.cv2 = Conv(c1, c_, 1, 1)
428
+ self.ahc = AdaptiveHypergraphComputation(
429
+ c_, c_, num_hyperedges, num_heads
430
+ )
431
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
432
+
433
+ def forward(self, x):
434
+ x_lateral = self.cv1(x)
435
+ x_ahc = self.ahc(self.cv2(x))
436
+ return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
437
+
438
+ class HyperACE(nn.Module):
439
+ def __init__(self, in_channels: List[int], out_channels: int,
440
+ num_hyperedges=8, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25):
441
+ super().__init__()
442
+
443
+ c2, c3, c4, c5 = in_channels
444
+ c_mid = c4
445
+
446
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
447
+
448
+ self.c_h = int(c_mid * c_h)
449
+ self.c_l = int(c_mid * c_l)
450
+ self.c_s = c_mid - self.c_h - self.c_l
451
+ assert self.c_s > 0, "Channel split error"
452
+
453
+ self.high_order_branch = nn.ModuleList(
454
+ [C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0) for _ in range(k)]
455
+ )
456
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
457
+
458
+ self.low_order_branch = nn.Sequential(
459
+ *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
460
+ )
461
+
462
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
463
+
464
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
465
+ B2, B3, B4, B5 = x
466
+
467
+ B, _, H4, W4 = B4.shape
468
+
469
+ B2_resized = F.interpolate(B2, size=(H4, W4), mode='bilinear', align_corners=False)
470
+ B3_resized = F.interpolate(B3, size=(H4, W4), mode='bilinear', align_corners=False)
471
+ B5_resized = F.interpolate(B5, size=(H4, W4), mode='bilinear', align_corners=False)
472
+
473
+ x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
474
+
475
+ x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
476
+
477
+ x_h_outs = [m(x_h) for m in self.high_order_branch]
478
+ x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
479
+
480
+ x_l_out = self.low_order_branch(x_l)
481
+
482
+ y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
483
+
484
+ return y
485
+
486
+ class GatedFusion(nn.Module):
487
+ def __init__(self, in_channels):
488
+ super().__init__()
489
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
490
+
491
+ def forward(self, f_in, h):
492
+ if f_in.shape[1] != h.shape[1]:
493
+ raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
494
+ return f_in + self.gamma * h
495
+
496
+
497
+ class Backbone(nn.Module):
498
+ def __init__(self, in_channels=256, base_channels=64, base_depth=3):
499
+ super().__init__()
500
+ c = base_channels
501
+ c2 = base_channels
502
+ c3 = 256
503
+ c4 = 384
504
+ c5 = 512
505
+ c6 = 768
506
+
507
+ self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
508
+
509
+ self.p2 = nn.Sequential(
510
+ DSConv(c2, c3, k=3, s=(2, 1), p=1),
511
+ DS_C3k2(c3, c3, n=base_depth)
512
+ )
513
+
514
+ self.p3 = nn.Sequential(
515
+ DSConv(c3, c4, k=3, s=(2, 1), p=1),
516
+ DS_C3k2(c4, c4, n=base_depth*2)
517
+ )
518
+
519
+ self.p4 = nn.Sequential(
520
+ DSConv(c4, c5, k=3, s=(2, 1), p=1),
521
+ DS_C3k2(c5, c5, n=base_depth*2)
522
+ )
523
+
524
+ self.p5 = nn.Sequential(
525
+ DSConv(c5, c6, k=3, s=(2, 1), p=1),
526
+ DS_C3k2(c6, c6, n=base_depth)
527
+ )
528
+
529
+ self.out_channels = [c3, c4, c5, c6]
530
+
531
+ def forward(self, x):
532
+ x = self.stem(x)
533
+ x2 = self.p2(x)
534
+ x3 = self.p3(x2)
535
+ x4 = self.p4(x3)
536
+ x5 = self.p5(x4)
537
+ return [x2, x3, x4, x5]
538
+
539
+ class Decoder(nn.Module):
540
+ def __init__(self, encoder_channels: List[int], hyperace_out_c: int, decoder_channels: List[int]):
541
+ super().__init__()
542
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
543
+ c_d2, c_d3, c_d4, c_d5 = decoder_channels
544
+
545
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
546
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
547
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
548
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
549
+
550
+ self.fusion_d5 = GatedFusion(c_d5)
551
+ self.fusion_d4 = GatedFusion(c_d4)
552
+ self.fusion_d3 = GatedFusion(c_d3)
553
+ self.fusion_d2 = GatedFusion(c_d2)
554
+
555
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
556
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
557
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
558
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
559
+
560
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
561
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
562
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
563
+
564
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
565
+
566
+ def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
567
+ p2, p3, p4, p5 = enc_feats
568
+
569
+ d5 = self.skip_p5(p5)
570
+ h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear'))
571
+ d5 = self.fusion_d5(d5, h_d5)
572
+
573
+ d5_up = F.interpolate(d5, size=p4.shape[2:], mode='bilinear')
574
+ d4_skip = self.skip_p4(p4)
575
+ d4 = self.up_d5(d5_up) + d4_skip
576
+
577
+ h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear'))
578
+ d4 = self.fusion_d4(d4, h_d4)
579
+
580
+ d4_up = F.interpolate(d4, size=p3.shape[2:], mode='bilinear')
581
+ d3_skip = self.skip_p3(p3)
582
+ d3 = self.up_d4(d4_up) + d3_skip
583
+
584
+ h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear'))
585
+ d3 = self.fusion_d3(d3, h_d3)
586
+
587
+ d3_up = F.interpolate(d3, size=p2.shape[2:], mode='bilinear')
588
+ d2_skip = self.skip_p2(p2)
589
+ d2 = self.up_d3(d3_up) + d2_skip
590
+
591
+ h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear'))
592
+ d2 = self.fusion_d2(d2, h_d2)
593
+
594
+ d2_final = self.final_d2(d2)
595
+
596
+ return d2_final
597
+
598
+ class FreqPixelShuffle(nn.Module):
599
+ def __init__(self, in_channels, out_channels, scale=2):
600
+ super().__init__()
601
+ self.scale = scale
602
+ self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
603
+ self.act = nn.SiLU()
604
+
605
+ def forward(self, x):
606
+ x = self.conv(x)
607
+ B, C_r, H, W = x.shape
608
+ out_c = C_r // self.scale
609
+
610
+ x = x.view(B, out_c, self.scale, H, W)
611
+
612
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
613
+ x = x.view(B, out_c, H, W * self.scale)
614
+
615
+ return x
616
+
617
+ class ProgressiveUpsampleHead(nn.Module):
618
+ def __init__(self, in_channels, out_channels, target_bins=1025):
619
+ super().__init__()
620
+ self.target_bins = target_bins
621
+
622
+ c = in_channels
623
+
624
+ self.block1 = FreqPixelShuffle(c, c, scale=2)
625
+ self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
626
+ self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
627
+ self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
628
+
629
+ self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
630
+
631
+ def forward(self, x):
632
+
633
+ x = self.block1(x)
634
+ x = self.block2(x)
635
+ x = self.block3(x)
636
+ x = self.block4(x)
637
+
638
+ if x.shape[-1] != self.target_bins:
639
+ x = F.interpolate(x, size=(x.shape[2], self.target_bins), mode='bilinear', align_corners=False)
640
+
641
+ x = self.final_conv(x)
642
+ return x
643
+
644
+ class SegmModel(nn.Module):
645
+ def __init__(self, in_bands=62, in_dim=256, out_bins=1025, out_channels=4,
646
+ base_channels=64, base_depth=2,
647
+ num_hyperedges=16, num_heads=8):
648
+ super().__init__()
649
+
650
+ self.backbone = Backbone(in_channels=in_dim, base_channels=base_channels, base_depth=base_depth)
651
+ enc_channels = self.backbone.out_channels
652
+ c2, c3, c4, c5 = enc_channels
653
+
654
+ hyperace_in_channels = enc_channels
655
+ hyperace_out_channels = c4
656
+ self.hyperace = HyperACE(
657
+ hyperace_in_channels, hyperace_out_channels,
658
+ num_hyperedges, num_heads, k=3, l=2
659
+ )
660
+
661
+ decoder_channels = [c2, c3, c4, c5]
662
+ self.decoder = Decoder(
663
+ enc_channels, hyperace_out_channels, decoder_channels
664
+ )
665
+
666
+ self.upsample_head = ProgressiveUpsampleHead(
667
+ in_channels=decoder_channels[0],
668
+ out_channels=out_channels,
669
+ target_bins=out_bins
670
+ )
671
+
672
+ def forward(self, x):
673
+ H, W = x.shape[2:]
674
+
675
+ enc_feats = self.backbone(x)
676
+
677
+ h_ace_feats = self.hyperace(enc_feats)
678
+
679
+ dec_feat = self.decoder(enc_feats, h_ace_feats)
680
+
681
+ feat_time_restored = F.interpolate(dec_feat, size=(H, dec_feat.shape[-1]), mode='bilinear', align_corners=False)
682
+
683
+ out = self.upsample_head(feat_time_restored)
684
+
685
+ return out
686
+ def MLP(
687
+ dim_in,
688
+ dim_out,
689
+ dim_hidden=None,
690
+ depth=1,
691
+ activation=nn.Tanh
692
+ ):
693
+ dim_hidden = default(dim_hidden, dim_in)
694
+
695
+ net = []
696
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
697
+
698
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
699
+ is_last = ind == (len(dims) - 2)
700
+
701
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
702
+
703
+ if is_last:
704
+ continue
705
+
706
+ net.append(activation())
707
+
708
+ return nn.Sequential(*net)
709
+
710
+ class MaskEstimator(Module):
711
+ @beartype
712
+ def __init__(
713
+ self,
714
+ dim,
715
+ dim_inputs: Tuple[int, ...],
716
+ depth,
717
+ mlp_expansion_factor=4
718
+ ):
719
+ super().__init__()
720
+ self.dim_inputs = dim_inputs
721
+ self.to_freqs = ModuleList([])
722
+ dim_hidden = dim * mlp_expansion_factor
723
+
724
+ for dim_in in dim_inputs:
725
+ net = []
726
+
727
+ mlp = nn.Sequential(
728
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
729
+ nn.GLU(dim=-1)
730
+ )
731
+
732
+ self.to_freqs.append(mlp)
733
+
734
+ self.segm = SegmModel(in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs)//4)
735
+
736
+ def forward(self, x):
737
+ y = rearrange(x, 'b t f c -> b c t f')
738
+ y = self.segm(y)
739
+ y = rearrange(y, 'b c t f -> b t (f c)')
740
+
741
+ x = x.unbind(dim=-2)
742
+
743
+ outs = []
744
+
745
+ for band_features, mlp in zip(x, self.to_freqs):
746
+ freq_out = mlp(band_features)
747
+ outs.append(freq_out)
748
+
749
+ return torch.cat(outs, dim=-1) + y
750
+
751
+
752
+ # main class
753
+
754
+ DEFAULT_FREQS_PER_BANDS = (
755
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
756
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
757
+ 2, 2, 2, 2,
758
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
759
+ 12, 12, 12, 12, 12, 12, 12, 12,
760
+ 24, 24, 24, 24, 24, 24, 24, 24,
761
+ 48, 48, 48, 48, 48, 48, 48, 48,
762
+ 128, 129,
763
+ )
764
+
765
+ class BSRoformer(Module):
766
+
767
+ @beartype
768
+ def __init__(
769
+ self,
770
+ dim,
771
+ *,
772
+ depth,
773
+ stereo=False,
774
+ num_stems=1,
775
+ time_transformer_depth=2,
776
+ freq_transformer_depth=2,
777
+ linear_transformer_depth=0,
778
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
779
+ # in the paper, they divide into ~60 bands, test with 1 for starters
780
+ dim_head=64,
781
+ heads=8,
782
+ attn_dropout=0.,
783
+ ff_dropout=0.,
784
+ flash_attn=True,
785
+ dim_freqs_in=1025,
786
+ stft_n_fft=2048,
787
+ stft_hop_length=512,
788
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
789
+ stft_win_length=2048,
790
+ stft_normalized=False,
791
+ stft_window_fn: Optional[Callable] = None,
792
+ mask_estimator_depth=2,
793
+ multi_stft_resolution_loss_weight=1.,
794
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
795
+ multi_stft_hop_size=147,
796
+ multi_stft_normalized=False,
797
+ multi_stft_window_fn: Callable = torch.hann_window,
798
+ mlp_expansion_factor=4,
799
+ use_torch_checkpoint=False,
800
+ skip_connection=False,
801
+ sage_attention=False,
802
+ ):
803
+ super().__init__()
804
+
805
+ self.stereo = stereo
806
+ self.audio_channels = 2 if stereo else 1
807
+ self.num_stems = num_stems
808
+ self.use_torch_checkpoint = use_torch_checkpoint
809
+ self.skip_connection = skip_connection
810
+
811
+ self.layers = ModuleList([])
812
+
813
+ if sage_attention:
814
+ print("Use Sage Attention")
815
+
816
+ transformer_kwargs = dict(
817
+ dim=dim,
818
+ heads=heads,
819
+ dim_head=dim_head,
820
+ attn_dropout=attn_dropout,
821
+ ff_dropout=ff_dropout,
822
+ flash_attn=flash_attn,
823
+ norm_output=False,
824
+ sage_attention=sage_attention,
825
+ )
826
+
827
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
828
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
829
+
830
+ for _ in range(depth):
831
+ tran_modules = []
832
+ tran_modules.append(
833
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
834
+ )
835
+ tran_modules.append(
836
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
837
+ )
838
+ self.layers.append(nn.ModuleList(tran_modules))
839
+
840
+ self.final_norm = RMSNorm(dim)
841
+
842
+ self.stft_kwargs = dict(
843
+ n_fft=stft_n_fft,
844
+ hop_length=stft_hop_length,
845
+ win_length=stft_win_length,
846
+ normalized=stft_normalized
847
+ )
848
+
849
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
850
+
851
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
852
+
853
+ assert len(freqs_per_bands) > 1
854
+ assert sum(
855
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
856
+
857
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
858
+
859
+ self.band_split = BandSplit(
860
+ dim=dim,
861
+ dim_inputs=freqs_per_bands_with_complex
862
+ )
863
+
864
+ self.mask_estimators = nn.ModuleList([])
865
+
866
+ for _ in range(num_stems):
867
+ mask_estimator = MaskEstimator(
868
+ dim=dim,
869
+ dim_inputs=freqs_per_bands_with_complex,
870
+ depth=mask_estimator_depth,
871
+ mlp_expansion_factor=mlp_expansion_factor,
872
+ )
873
+
874
+ self.mask_estimators.append(mask_estimator)
875
+
876
+ # for the multi-resolution stft loss
877
+
878
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
879
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
880
+ self.multi_stft_n_fft = stft_n_fft
881
+ self.multi_stft_window_fn = multi_stft_window_fn
882
+
883
+ self.multi_stft_kwargs = dict(
884
+ hop_length=multi_stft_hop_size,
885
+ normalized=multi_stft_normalized
886
+ )
887
+
888
+ def forward(
889
+ self,
890
+ raw_audio,
891
+ target=None,
892
+ return_loss_breakdown=False
893
+ ):
894
+ """
895
+ einops
896
+
897
+ b - batch
898
+ f - freq
899
+ t - time
900
+ s - audio channel (1 for mono, 2 for stereo)
901
+ n - number of 'stems'
902
+ c - complex (2)
903
+ d - feature dimension
904
+ """
905
+
906
+ device = raw_audio.device
907
+
908
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
909
+ x_is_mps = True if device.type == "mps" else False
910
+
911
+ if raw_audio.ndim == 2:
912
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
913
+
914
+ channels = raw_audio.shape[1]
915
+ assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
916
+
917
+ # to stft
918
+
919
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
920
+
921
+ stft_window = self.stft_window_fn(device=device)
922
+
923
+ # RuntimeError: FFT operations are only supported on MacOS 14+
924
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
925
+ try:
926
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
927
+ except:
928
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs,
929
+ window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(
930
+ device)
931
+ stft_repr = torch.view_as_real(stft_repr)
932
+
933
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
934
+
935
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
936
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
937
+
938
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
939
+
940
+
941
+ x = self.band_split(x)
942
+
943
+ # axial / hierarchical attention
944
+
945
+ for i, transformer_block in enumerate(self.layers):
946
+
947
+
948
+ time_transformer, freq_transformer = transformer_block
949
+
950
+
951
+ x = rearrange(x, 'b t f d -> b f t d')
952
+ x, ps = pack([x], '* t d')
953
+
954
+
955
+ x = time_transformer(x)
956
+
957
+ x, = unpack(x, ps, '* t d')
958
+ x = rearrange(x, 'b f t d -> b t f d')
959
+ x, ps = pack([x], '* f d')
960
+
961
+
962
+ x = freq_transformer(x)
963
+
964
+ x, = unpack(x, ps, '* f d')
965
+
966
+
967
+ x = self.final_norm(x)
968
+
969
+ num_stems = len(self.mask_estimators)
970
+
971
+
972
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
973
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
974
+
975
+ # modulate frequency representation
976
+
977
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
978
+
979
+ stft_repr = torch.view_as_complex(stft_repr)
980
+ mask = torch.view_as_complex(mask)
981
+
982
+ stft_repr = stft_repr * mask
983
+
984
+ # istft
985
+
986
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
987
+
988
+ if num_stems == 1:
989
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
990
+
991
+ # if a target is passed in, calculate loss for learning
992
+
993
+ if not exists(target):
994
+ return recon_audio
995
+
996
+ if self.num_stems > 1:
997
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
998
+
999
+ if target.ndim == 2:
1000
+ target = rearrange(target, '... t -> ... 1 t')
1001
+
1002
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
1003
+
1004
+ loss = F.l1_loss(recon_audio, target)
1005
+
1006
+ multi_stft_resolution_loss = 0.
1007
+
1008
+ for window_size in self.multi_stft_resolutions_window_sizes:
1009
+ res_stft_kwargs = dict(
1010
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
1011
+ win_length=window_size,
1012
+ return_complex=True,
1013
+ window=self.multi_stft_window_fn(window_size, device=device),
1014
+ **self.multi_stft_kwargs,
1015
+ )
1016
+
1017
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
1018
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
1019
+
1020
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
1021
+
1022
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1023
+
1024
+ total_loss = loss + weighted_multi_resolution_loss
1025
+
1026
+ if not return_loss_breakdown:
1027
+ return total_loss
1028
+
1029
+ return total_loss, (loss, multi_stft_resolution_loss)
config.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 960000
3
+ dim_f: 1024
4
+ dim_t: 801 # don't work (use in model)
5
+ hop_length: 441 # don't work (use in model)
6
+ n_fft: 2048
7
+ num_channels: 2
8
+ sample_rate: 44100
9
+ min_mean_abs: 0.0001
10
+
11
+ model:
12
+ dim: 256
13
+ depth: 12
14
+ stereo: true
15
+ num_stems: 1
16
+ time_transformer_depth: 1
17
+ freq_transformer_depth: 1
18
+ linear_transformer_depth: 0
19
+ freqs_per_bands: !!python/tuple
20
+ - 2
21
+ - 2
22
+ - 2
23
+ - 2
24
+ - 2
25
+ - 2
26
+ - 2
27
+ - 2
28
+ - 2
29
+ - 2
30
+ - 2
31
+ - 2
32
+ - 2
33
+ - 2
34
+ - 2
35
+ - 2
36
+ - 2
37
+ - 2
38
+ - 2
39
+ - 2
40
+ - 2
41
+ - 2
42
+ - 2
43
+ - 2
44
+ - 4
45
+ - 4
46
+ - 4
47
+ - 4
48
+ - 4
49
+ - 4
50
+ - 4
51
+ - 4
52
+ - 4
53
+ - 4
54
+ - 4
55
+ - 4
56
+ - 12
57
+ - 12
58
+ - 12
59
+ - 12
60
+ - 12
61
+ - 12
62
+ - 12
63
+ - 12
64
+ - 24
65
+ - 24
66
+ - 24
67
+ - 24
68
+ - 24
69
+ - 24
70
+ - 24
71
+ - 24
72
+ - 48
73
+ - 48
74
+ - 48
75
+ - 48
76
+ - 48
77
+ - 48
78
+ - 48
79
+ - 48
80
+ - 128
81
+ - 129
82
+ dim_head: 64
83
+ heads: 8
84
+ attn_dropout: 0.0
85
+ ff_dropout: 0.0
86
+ flash_attn: true
87
+ dim_freqs_in: 1025
88
+ stft_n_fft: 2048
89
+ stft_hop_length: 512
90
+ stft_win_length: 2048
91
+ stft_normalized: false
92
+ mask_estimator_depth: 2
93
+ multi_stft_resolution_loss_weight: 1.0
94
+ multi_stft_resolutions_window_sizes: !!python/tuple
95
+ - 4096
96
+ - 2048
97
+ - 1024
98
+ - 512
99
+ - 256
100
+ multi_stft_hop_size: 147
101
+ multi_stft_normalized: False
102
+ mlp_expansion_factor: 4
103
+ use_torch_checkpoint: True
104
+ skip_connection: False
105
+
106
+
107
+ training:
108
+ batch_size: 1
109
+ gradient_accumulation_steps: 1
110
+ grad_clip: 0
111
+ instruments: ['vocals', 'instrument']
112
+ lr: 1.0e-5
113
+ patience: 5
114
+ reduce_factor: 0.9
115
+ target_instrument: instrument
116
+ num_epochs: 1000
117
+ num_steps: 1000
118
+ q: 0.95
119
+ coarse_loss_clip: true
120
+ ema_momentum: 0.999
121
+ optimizer: adam
122
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
123
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
124
+
125
+
126
+ inference:
127
+ batch_size: 2
128
+ dim_t: 1876
129
+ num_overlap: 4