SaoYear commited on
Commit
be3e940
·
1 Parent(s): 51e7c22

customize mamba modules, remove CUDA dependency

Browse files
app.py CHANGED
@@ -5,15 +5,12 @@ import shlex
5
  subprocess.check_call(["apt-get", "update"])
6
  subprocess.check_call([sys.executable,"-m","pip","install",
7
  "torch==2.2.0",
8
- "torchvision==0.17.0",
9
- "torchaudio==2.2.0",
10
- "--index-url",
11
- "https://download.pytorch.org/whl/cu121"])
12
  def install_mamba():
13
  subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
14
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu122torch2.2cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"))
15
 
16
- install_mamba()
17
 
18
  import torch
19
  import spaces
@@ -28,7 +25,7 @@ from model.cleanmel import CleanMel
28
  from model.vocos.pretrained import Vocos
29
  from model.stft import InputSTFT, TargetMel
30
 
31
- DEVICE = torch.device("cuda:5")
32
 
33
  def read_audio(file_path):
34
  audio, sample_rate = sf.read(file_path)
@@ -73,15 +70,15 @@ def mel_transform(audio, X_norm):
73
  return transform(audio, X_norm)
74
 
75
  def load_cleanmel(model_name):
76
- model_config = f"../configs/cleanmel_offline.yaml"
77
  model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
78
  cleanmel = CleanMel(**model_config)
79
- cleanmel.load_state_dict(torch.load(f"../ckpts/CleanMel/{model_name}.ckpt"))
80
  return cleanmel.eval()
81
 
82
  def load_vocos():
83
- vocos = Vocos.from_hparams(config_path="../configs/vocos_offline.yaml")
84
- vocos = Vocos.from_pretrained(None, model_path=f"../ckpts/Vocos/vocos_offline.pt", model=vocos)
85
  return vocos.eval()
86
 
87
  def get_mrm_pred(Y_hat, x, X_norm):
@@ -182,4 +179,4 @@ with gr.Blocks(title="CleanMel Demo") as demo:
182
  outputs=[output_audio, output_mel, output_np]
183
  )
184
 
185
- demo.launch(debug=False, share=True)
 
5
  subprocess.check_call(["apt-get", "update"])
6
  subprocess.check_call([sys.executable,"-m","pip","install",
7
  "torch==2.2.0",
8
+ "torchaudio==2.2.0"])
 
 
 
9
  def install_mamba():
10
  subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
11
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu122torch2.2cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"))
12
 
13
+ # install_mamba()
14
 
15
  import torch
16
  import spaces
 
25
  from model.vocos.pretrained import Vocos
26
  from model.stft import InputSTFT, TargetMel
27
 
28
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
29
 
30
  def read_audio(file_path):
31
  audio, sample_rate = sf.read(file_path)
 
70
  return transform(audio, X_norm)
71
 
72
  def load_cleanmel(model_name):
73
+ model_config = f"./configs/cleanmel_offline.yaml"
74
  model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
75
  cleanmel = CleanMel(**model_config)
76
+ cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt"))
77
  return cleanmel.eval()
78
 
79
  def load_vocos():
80
+ vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml")
81
+ vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/vocos_offline.pt", model=vocos)
82
  return vocos.eval()
83
 
84
  def get_mrm_pred(Y_hat, x, X_norm):
 
179
  outputs=[output_audio, output_mel, output_np]
180
  )
181
 
182
+ demo.launch(debug=False)
model/cleanmel.py CHANGED
@@ -11,8 +11,8 @@ from torch.nn import Parameter, init
11
  from torch.nn.common_types import _size_1_t
12
 
13
 
14
- from mamba_ssm import Mamba
15
- from mamba_ssm.utils.generation import InferenceParams
16
 
17
  class LinearGroup(nn.Module):
18
 
 
11
  from torch.nn.common_types import _size_1_t
12
 
13
 
14
+ from model.mamba.mamba import Mamba
15
+ from model.mamba.utils.generation import InferenceParams
16
 
17
  class LinearGroup(nn.Module):
18
 
model/mamba/__init__.py ADDED
File without changes
model/mamba/mamba.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from model.mamba.selective_scan_inferface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
63
+
64
+ self.conv1d = nn.Conv1d(
65
+ in_channels=self.d_inner,
66
+ out_channels=self.d_inner,
67
+ bias=conv_bias,
68
+ kernel_size=d_conv,
69
+ groups=self.d_inner,
70
+ padding=d_conv - 1,
71
+ **factory_kwargs,
72
+ )
73
+
74
+ self.activation = "silu"
75
+ self.act = nn.SiLU()
76
+
77
+ self.x_proj = nn.Linear(
78
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
79
+ )
80
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
81
+
82
+ # Initialize special dt projection to preserve variance at initialization
83
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
84
+ if dt_init == "constant":
85
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
86
+ elif dt_init == "random":
87
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
92
+ dt = torch.exp(
93
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
94
+ + math.log(dt_min)
95
+ ).clamp(min=dt_init_floor)
96
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
97
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
98
+ with torch.no_grad():
99
+ self.dt_proj.bias.copy_(inv_dt)
100
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
101
+ self.dt_proj.bias._no_reinit = True
102
+
103
+ # S4D real initialization
104
+ A = repeat(
105
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
106
+ "n -> d n",
107
+ d=self.d_inner,
108
+ ).contiguous()
109
+ A_log = torch.log(A) # Keep A_log in fp32
110
+ self.A_log = nn.Parameter(A_log)
111
+ self.A_log._no_weight_decay = True
112
+
113
+ # D "skip" parameter
114
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
115
+ self.D._no_weight_decay = True
116
+
117
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118
+
119
+ def forward(self, hidden_states, inference_params=None):
120
+ """
121
+ hidden_states: (B, L, D)
122
+ Returns: same shape as hidden_states
123
+ """
124
+ batch, seqlen, dim = hidden_states.shape
125
+
126
+ conv_state, ssm_state = None, None
127
+ if inference_params is not None:
128
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
129
+ if inference_params.seqlen_offset > 0:
130
+ # The states are updated inplace
131
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
132
+ return out
133
+
134
+ # We do matmul and transpose BLH -> HBL at the same time
135
+ xz = rearrange(
136
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137
+ "d (b l) -> b d l",
138
+ l=seqlen,
139
+ )
140
+ if self.in_proj.bias is not None:
141
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
142
+
143
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
145
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146
+ out = mamba_inner_fn(
147
+ xz,
148
+ self.conv1d.weight,
149
+ self.conv1d.bias,
150
+ self.x_proj.weight,
151
+ self.dt_proj.weight,
152
+ self.out_proj.weight,
153
+ self.out_proj.bias,
154
+ A,
155
+ None, # input-dependent B
156
+ None, # input-dependent C
157
+ self.D.float(),
158
+ delta_bias=self.dt_proj.bias.float(),
159
+ delta_softplus=True,
160
+ )
161
+ else:
162
+ x, z = xz.chunk(2, dim=1)
163
+ # Compute short convolution
164
+ if conv_state is not None:
165
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
166
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
167
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
168
+ if causal_conv1d_fn is None:
169
+ x = self.act(self.conv1d(x)[..., :seqlen])
170
+ else:
171
+ assert self.activation in ["silu", "swish"]
172
+ x = causal_conv1d_fn(
173
+ x=x,
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ )
178
+
179
+ # We're careful here about the layout, to avoid extra transposes.
180
+ # We want dt to have d as the slowest moving dimension
181
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
182
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
183
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
184
+ dt = self.dt_proj.weight @ dt.t()
185
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
188
+ assert self.activation in ["silu", "swish"]
189
+ y = selective_scan_fn(
190
+ x,
191
+ dt,
192
+ A,
193
+ B,
194
+ C,
195
+ self.D.float(),
196
+ z=z,
197
+ delta_bias=self.dt_proj.bias.float(),
198
+ delta_softplus=True,
199
+ return_last_state=ssm_state is not None,
200
+ )
201
+ if ssm_state is not None:
202
+ y, last_state = y
203
+ ssm_state.copy_(last_state)
204
+ y = rearrange(y, "b d l -> b l d")
205
+ out = self.out_proj(y)
206
+ return out
207
+
208
+ def step(self, hidden_states, conv_state, ssm_state):
209
+ dtype = hidden_states.dtype
210
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
211
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
212
+ x, z = xz.chunk(2, dim=-1) # (B D)
213
+
214
+ # Conv step
215
+ if causal_conv1d_update is None:
216
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
217
+ conv_state[:, :, -1] = x
218
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
219
+ if self.conv1d.bias is not None:
220
+ x = x + self.conv1d.bias
221
+ x = self.act(x).to(dtype=dtype)
222
+ else:
223
+ x = causal_conv1d_update(
224
+ x,
225
+ conv_state,
226
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
227
+ self.conv1d.bias,
228
+ self.activation,
229
+ )
230
+
231
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
232
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
233
+ # Don't add dt_bias here
234
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
235
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
236
+
237
+ # SSM step
238
+ if selective_state_update is None:
239
+ # Discretize A and B
240
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
241
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
242
+ dB = torch.einsum("bd,bn->bdn", dt, B)
243
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
244
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
245
+ y = y + self.D.to(dtype) * x
246
+ y = y * self.act(z) # (B D)
247
+ else:
248
+ y = selective_state_update(
249
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
250
+ )
251
+
252
+ out = self.out_proj(y)
253
+ return out.unsqueeze(1), conv_state, ssm_state
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
256
+ device = self.out_proj.weight.device
257
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
258
+ conv_state = torch.zeros(
259
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
260
+ )
261
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
262
+ # ssm_dtype = torch.float32
263
+ ssm_state = torch.zeros(
264
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
265
+ )
266
+ return conv_state, ssm_state
267
+
268
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
269
+ assert self.layer_idx is not None
270
+ if self.layer_idx not in inference_params.key_value_memory_dict:
271
+ batch_shape = (batch_size,)
272
+ conv_state = torch.zeros(
273
+ batch_size,
274
+ self.d_model * self.expand,
275
+ self.d_conv,
276
+ device=self.conv1d.weight.device,
277
+ dtype=self.conv1d.weight.dtype,
278
+ )
279
+ ssm_state = torch.zeros(
280
+ batch_size,
281
+ self.d_model * self.expand,
282
+ self.d_state,
283
+ device=self.dt_proj.weight.device,
284
+ dtype=self.dt_proj.weight.dtype,
285
+ # dtype=torch.float32,
286
+ )
287
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
288
+ else:
289
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
290
+ # TODO: What if batch size changes between generation, and we reuse the same states?
291
+ if initialize_states:
292
+ conv_state.zero_()
293
+ ssm_state.zero_()
294
+ return conv_state, ssm_state
295
+
296
+
297
+ class Block(nn.Module):
298
+ def __init__(
299
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
300
+ ):
301
+ """
302
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
303
+
304
+ This Block has a slightly different structure compared to a regular
305
+ prenorm Transformer block.
306
+ The standard block is: LN -> MHA/MLP -> Add.
307
+ [Ref: https://arxiv.org/abs/2002.04745]
308
+ Here we have: Add -> LN -> Mixer, returning both
309
+ the hidden_states (output of the mixer) and the residual.
310
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
311
+ The residual needs to be provided (except for the very first block).
312
+ """
313
+ super().__init__()
314
+ self.residual_in_fp32 = residual_in_fp32
315
+ self.fused_add_norm = fused_add_norm
316
+ self.mixer = mixer_cls(dim)
317
+ self.norm = norm_cls(dim)
318
+ if self.fused_add_norm:
319
+ assert RMSNorm is not None, "RMSNorm import fails"
320
+ assert isinstance(
321
+ self.norm, (nn.LayerNorm, RMSNorm)
322
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
323
+
324
+ def forward(
325
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
326
+ ):
327
+ r"""Pass the input through the encoder layer.
328
+
329
+ Args:
330
+ hidden_states: the sequence to the encoder layer (required).
331
+ residual: hidden_states = Mixer(LN(residual))
332
+ """
333
+ if not self.fused_add_norm:
334
+ residual = (hidden_states + residual) if residual is not None else hidden_states
335
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
336
+ if self.residual_in_fp32:
337
+ residual = residual.to(torch.float32)
338
+ else:
339
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
340
+ hidden_states, residual = fused_add_norm_fn(
341
+ hidden_states,
342
+ self.norm.weight,
343
+ self.norm.bias,
344
+ residual=residual,
345
+ prenorm=True,
346
+ residual_in_fp32=self.residual_in_fp32,
347
+ eps=self.norm.eps,
348
+ )
349
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
350
+ return hidden_states, residual
351
+
352
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
353
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
model/mamba/selective_scan_inferface.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ # try:
17
+ # import selective_scan_cuda
18
+ # except ImportError:
19
+ selective_scan_cuda = None
20
+
21
+
22
+ class SelectiveScanFn(torch.autograd.Function):
23
+
24
+ @staticmethod
25
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
26
+ return_last_state=False):
27
+ if u.stride(-1) != 1:
28
+ u = u.contiguous()
29
+ if delta.stride(-1) != 1:
30
+ delta = delta.contiguous()
31
+ if D is not None:
32
+ D = D.contiguous()
33
+ if B.stride(-1) != 1:
34
+ B = B.contiguous()
35
+ if C.stride(-1) != 1:
36
+ C = C.contiguous()
37
+ if z is not None and z.stride(-1) != 1:
38
+ z = z.contiguous()
39
+ if B.dim() == 3:
40
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
41
+ ctx.squeeze_B = True
42
+ if C.dim() == 3:
43
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
44
+ ctx.squeeze_C = True
45
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
46
+ ctx.delta_softplus = delta_softplus
47
+ ctx.has_z = z is not None
48
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
49
+ if not ctx.has_z:
50
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
51
+ return out if not return_last_state else (out, last_state)
52
+ else:
53
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
54
+ out_z = rest[0]
55
+ return out_z if not return_last_state else (out_z, last_state)
56
+
57
+ @staticmethod
58
+ def backward(ctx, dout, *args):
59
+ if not ctx.has_z:
60
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
61
+ z = None
62
+ out = None
63
+ else:
64
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
65
+ if dout.stride(-1) != 1:
66
+ dout = dout.contiguous()
67
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
68
+ # backward of selective_scan_cuda with the backward of chunk).
69
+ # Here we just pass in None and dz will be allocated in the C++ code.
70
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
71
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
72
+ False # option to recompute out_z, not used here
73
+ )
74
+ dz = rest[0] if ctx.has_z else None
75
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
76
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
77
+ return (du, ddelta, dA, dB, dC,
78
+ dD if D is not None else None,
79
+ dz,
80
+ ddelta_bias if delta_bias is not None else None,
81
+ None,
82
+ None)
83
+
84
+
85
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
86
+ return_last_state=False):
87
+ """if return_last_state is True, returns (out, last_state)
88
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
89
+ not considered in the backward pass.
90
+ """
91
+ if selective_scan_cuda is None:
92
+ return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
93
+ else:
94
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
95
+
96
+
97
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
98
+ return_last_state=False):
99
+ """
100
+ u: r(B D L)
101
+ delta: r(B D L)
102
+ A: c(D N) or r(D N)
103
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
104
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
105
+ D: r(D)
106
+ z: r(B D L)
107
+ delta_bias: r(D), fp32
108
+
109
+ out: r(B D L)
110
+ last_state (optional): r(B D dstate) or c(B D dstate)
111
+ """
112
+ dtype_in = u.dtype
113
+ u = u.float()
114
+ delta = delta.float()
115
+ if delta_bias is not None:
116
+ delta = delta + delta_bias[..., None].float()
117
+ if delta_softplus:
118
+ delta = F.softplus(delta)
119
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
120
+ is_variable_B = B.dim() >= 3
121
+ is_variable_C = C.dim() >= 3
122
+ if A.is_complex():
123
+ if is_variable_B:
124
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
125
+ if is_variable_C:
126
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
127
+ else:
128
+ B = B.float()
129
+ C = C.float()
130
+ x = A.new_zeros((batch, dim, dstate))
131
+ ys = []
132
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
133
+ if not is_variable_B:
134
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
135
+ else:
136
+ if B.dim() == 3:
137
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
138
+ else:
139
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
140
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
141
+ if is_variable_C and C.dim() == 4:
142
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
143
+ last_state = None
144
+ for i in range(u.shape[2]):
145
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
146
+ if not is_variable_C:
147
+ y = torch.einsum('bdn,dn->bd', x, C)
148
+ else:
149
+ if C.dim() == 3:
150
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
151
+ else:
152
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
153
+ if i == u.shape[2] - 1:
154
+ last_state = x
155
+ if y.is_complex():
156
+ y = y.real * 2
157
+ ys.append(y)
158
+ y = torch.stack(ys, dim=2) # (batch dim L)
159
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
160
+ if z is not None:
161
+ out = out * F.silu(z)
162
+ out = out.to(dtype=dtype_in)
163
+ return out if not return_last_state else (out, last_state)
164
+
165
+
166
+ class MambaInnerFn(torch.autograd.Function):
167
+
168
+ @staticmethod
169
+ @custom_fwd
170
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
171
+ out_proj_weight, out_proj_bias,
172
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
173
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
174
+ """
175
+ xz: (batch, dim, seqlen)
176
+ """
177
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
178
+ assert checkpoint_lvl in [0, 1]
179
+ L = xz.shape[-1]
180
+ delta_rank = delta_proj_weight.shape[1]
181
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
182
+ if torch.is_autocast_enabled():
183
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
184
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
185
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
186
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
187
+ if out_proj_bias is not None else None)
188
+ if xz.stride(-1) != 1:
189
+ xz = xz.contiguous()
190
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
191
+ x, z = xz.chunk(2, dim=1)
192
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
193
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
194
+ x, conv1d_weight, conv1d_bias, None, None, None, True
195
+ )
196
+ # We're being very careful here about the layout, to avoid extra transposes.
197
+ # We want delta to have d as the slowest moving dimension
198
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
199
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
200
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
201
+ ctx.is_variable_B = B is None
202
+ ctx.is_variable_C = C is None
203
+ ctx.B_proj_bias_is_None = B_proj_bias is None
204
+ ctx.C_proj_bias_is_None = C_proj_bias is None
205
+ if B is None: # variable B
206
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
207
+ if B_proj_bias is not None:
208
+ B = B + B_proj_bias.to(dtype=B.dtype)
209
+ if not A.is_complex():
210
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
211
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
212
+ else:
213
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
214
+ else:
215
+ if B.stride(-1) != 1:
216
+ B = B.contiguous()
217
+ if C is None: # variable C
218
+ C = x_dbl[:, -d_state:] # (bl dstate)
219
+ if C_proj_bias is not None:
220
+ C = C + C_proj_bias.to(dtype=C.dtype)
221
+ if not A.is_complex():
222
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
223
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
224
+ else:
225
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
226
+ else:
227
+ if C.stride(-1) != 1:
228
+ C = C.contiguous()
229
+ if D is not None:
230
+ D = D.contiguous()
231
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
232
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
233
+ )
234
+ ctx.delta_softplus = delta_softplus
235
+ ctx.out_proj_bias_is_None = out_proj_bias is None
236
+ ctx.checkpoint_lvl = checkpoint_lvl
237
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
238
+ conv1d_out, delta = None, None
239
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
240
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
241
+ A, B, C, D, delta_bias, scan_intermediates, out)
242
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
243
+
244
+ @staticmethod
245
+ @custom_bwd
246
+ def backward(ctx, dout):
247
+ # dout: (batch, seqlen, dim)
248
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
249
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
250
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
251
+ L = xz.shape[-1]
252
+ delta_rank = delta_proj_weight.shape[1]
253
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
254
+ x, z = xz.chunk(2, dim=1)
255
+ if dout.stride(-1) != 1:
256
+ dout = dout.contiguous()
257
+ if ctx.checkpoint_lvl == 1:
258
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
259
+ x, conv1d_weight, conv1d_bias, None, None, None, True
260
+ )
261
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
262
+ "d (b l) -> b d l", l = L)
263
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
264
+ # backward of selective_scan_cuda with the backward of chunk).
265
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
266
+ dx, dz = dxz.chunk(2, dim=1)
267
+ dout = rearrange(dout, "b l e -> e (b l)")
268
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
269
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
270
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
271
+ ctx.delta_softplus,
272
+ True # option to recompute out_z
273
+ )
274
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
275
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
276
+ dD = dD if D is not None else None
277
+ dx_dbl = torch.empty_like(x_dbl)
278
+ dB_proj_bias = None
279
+ if ctx.is_variable_B:
280
+ if not A.is_complex():
281
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
282
+ else:
283
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
284
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
285
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
286
+ dB = None
287
+ dC_proj_bias = None
288
+ if ctx.is_variable_C:
289
+ if not A.is_complex():
290
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
291
+ else:
292
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
293
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
294
+ dx_dbl[:, -d_state:] = dC # (bl d)
295
+ dC = None
296
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
297
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
298
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
299
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
300
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
301
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
302
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
303
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
304
+ # backward of conv1d with the backward of chunk).
305
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
306
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
307
+ )
308
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
309
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
310
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
311
+ dout_proj_weight, dout_proj_bias,
312
+ dA, dB, dC, dD,
313
+ ddelta_bias if delta_bias is not None else None,
314
+ dB_proj_bias, dC_proj_bias, None)
315
+
316
+
317
+ def mamba_inner_fn(
318
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
319
+ out_proj_weight, out_proj_bias,
320
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
321
+ C_proj_bias=None, delta_softplus=True
322
+ ):
323
+ if causal_conv1d_cuda is None:
324
+ return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
325
+ out_proj_weight, out_proj_bias,
326
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
327
+ else:
328
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
329
+ out_proj_weight, out_proj_bias,
330
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
331
+
332
+
333
+
334
+ def mamba_inner_ref(
335
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
336
+ out_proj_weight, out_proj_bias,
337
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
338
+ C_proj_bias=None, delta_softplus=True
339
+ ):
340
+ assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
341
+ L = xz.shape[-1]
342
+ delta_rank = delta_proj_weight.shape[1]
343
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
344
+ x, z = xz.chunk(2, dim=1)
345
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
346
+ # We're being very careful here about the layout, to avoid extra transposes.
347
+ # We want delta to have d as the slowest moving dimension
348
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
349
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
350
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
351
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
352
+ if B is None: # variable B
353
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
354
+ if B_proj_bias is not None:
355
+ B = B + B_proj_bias.to(dtype=B.dtype)
356
+ if not A.is_complex():
357
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
358
+ else:
359
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
360
+ if C is None: # variable B
361
+ C = x_dbl[:, -d_state:] # (bl d)
362
+ if C_proj_bias is not None:
363
+ C = C + C_proj_bias.to(dtype=C.dtype)
364
+ if not A.is_complex():
365
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
366
+ else:
367
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
368
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
369
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
model/mamba/utils/generation.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_p)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ streamer: Optional[TextStreamer] = None
136
+ ):
137
+ """Decoding, either greedy or with top-k or top-p sampling.
138
+ If top-k = 0, don't limit the number of candidates (pure sampling).
139
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
140
+ then top-p.
141
+ We assume that all sequences in the same batch have the same length.
142
+
143
+ Arguments:
144
+ input_ids: (batch, seq_len)
145
+ max_length: int
146
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
147
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
148
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
149
+ sequences: (batch, max_length)
150
+ scores: tuples of (batch, vocab_size)
151
+ """
152
+ if streamer is not None:
153
+ streamer.put(input_ids.cpu())
154
+
155
+ batch_size, seqlen_og = input_ids.shape
156
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
157
+ if cg:
158
+ if not hasattr(model, "_decoding_cache"):
159
+ model._decoding_cache = None
160
+ model._decoding_cache = update_graph_cache(
161
+ model,
162
+ model._decoding_cache,
163
+ batch_size,
164
+ seqlen_og,
165
+ max_length,
166
+ )
167
+ inference_params = model._decoding_cache.inference_params
168
+ inference_params.reset(max_length, batch_size)
169
+ else:
170
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
171
+
172
+ def get_logits(input_ids, inference_params):
173
+ decoding = inference_params.seqlen_offset > 0
174
+ if decoding:
175
+ position_ids = torch.full(
176
+ (batch_size, 1),
177
+ inference_params.seqlen_offset,
178
+ dtype=torch.long,
179
+ device=input_ids.device,
180
+ )
181
+ else:
182
+ position_ids = None
183
+ if not cg or not decoding:
184
+ logits = model(
185
+ input_ids,
186
+ position_ids=position_ids,
187
+ inference_params=inference_params,
188
+ num_last_tokens=1,
189
+ ).logits.squeeze(dim=1)
190
+ else:
191
+ logits = model._decoding_cache.run(
192
+ input_ids, position_ids, inference_params.seqlen_offset
193
+ ).squeeze(dim=1)
194
+ return logits[..., :vocab_size] if vocab_size is not None else logits
195
+
196
+ def sample_tokens(logits, inference_params):
197
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
198
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
199
+ else:
200
+ token = teacher_outputs[:, inference_params.seqlen_offset]
201
+ # return rearrange(token, "b -> b 1")
202
+ return token.unsqueeze(1)
203
+
204
+ def should_stop(current_token, inference_params):
205
+ if inference_params.seqlen_offset == 0:
206
+ return False
207
+ if eos_token_id is not None and (current_token == eos_token_id).all():
208
+ return True
209
+ if inference_params.seqlen_offset >= max_length - 1:
210
+ return True
211
+ return False
212
+
213
+ start = torch.cuda.Event(enable_timing=enable_timing)
214
+ end = torch.cuda.Event(enable_timing=enable_timing)
215
+
216
+ if enable_timing:
217
+ start.record()
218
+ scores, sequences = [], [input_ids]
219
+ sequences_cat = input_ids
220
+ while not should_stop(sequences[-1], inference_params):
221
+ scores.append(get_logits(sequences[-1], inference_params))
222
+ inference_params.seqlen_offset += sequences[-1].shape[1]
223
+ if repetition_penalty == 1.0:
224
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
225
+ else:
226
+ logits = modify_logit_for_repetition_penalty(
227
+ scores[-1].clone(), sequences_cat, repetition_penalty
228
+ )
229
+ sampled_tokens = sample_tokens(logits, inference_params)
230
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
231
+ sequences.append(sampled_tokens)
232
+ if streamer is not None:
233
+ streamer.put(sampled_tokens.cpu())
234
+ if streamer is not None:
235
+ streamer.end()
236
+ if enable_timing:
237
+ end.record()
238
+ torch.cuda.synchronize()
239
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
240
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
241
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
242
+
243
+
244
+ class GenerationMixin:
245
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
246
+ raise NotImplementedError
247
+
248
+ def generate(
249
+ self,
250
+ input_ids,
251
+ max_length,
252
+ top_k=1,
253
+ top_p=0.0,
254
+ min_p=0.0,
255
+ temperature=1.0,
256
+ return_dict_in_generate=False,
257
+ output_scores=False,
258
+ **kwargs,
259
+ ):
260
+ output = decode(
261
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
262
+ )
263
+ if not output_scores:
264
+ output.scores = None
265
+ return output if return_dict_in_generate else output.sequences
266
+
267
+
268
+ @dataclass
269
+ class DecodingCGCache:
270
+ max_batch_size: int = 0
271
+ max_seqlen: int = 0
272
+ device = None
273
+ dtype = None
274
+ callables: dict = field(default_factory=dict)
275
+ mempool = None
276
+ inference_params: Optional[InferenceParams] = None
277
+ run: Optional[Callable] = None
278
+
279
+
280
+ @torch.inference_mode()
281
+ def update_graph_cache(
282
+ model,
283
+ cache,
284
+ batch_size,
285
+ seqlen_og,
286
+ max_seqlen,
287
+ decoding_seqlens=(1,),
288
+ dtype=None,
289
+ n_warmups=2,
290
+ ):
291
+ if cache is None:
292
+ cache = DecodingCGCache()
293
+ param_example = next(iter(model.parameters()))
294
+ device = param_example.device
295
+ if dtype is None:
296
+ dtype = param_example.dtype
297
+ if (
298
+ (device, dtype) != (cache.device, cache.dtype)
299
+ or batch_size > cache.max_batch_size
300
+ or max_seqlen > cache.max_seqlen
301
+ ): # Invalidate the cache
302
+ cache.callables = {}
303
+ cache.mempool = None
304
+ cache.inference_params = None
305
+ gc.collect()
306
+ cache.device, cache.dtype = device, dtype
307
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
308
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
309
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
310
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
311
+ cache.inference_params = InferenceParams(
312
+ max_seqlen=max_seqlen,
313
+ max_batch_size=batch_size,
314
+ seqlen_offset=seqlen_og,
315
+ key_value_memory_dict=inf_cache,
316
+ lengths_per_sample=lengths_per_sample,
317
+ )
318
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
319
+ for decoding_seqlen in decoding_seqlens:
320
+ if (batch_size, decoding_seqlen) not in cache.callables:
321
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
322
+ model,
323
+ cache.inference_params,
324
+ batch_size,
325
+ max_seqlen,
326
+ decoding_seqlen=decoding_seqlen,
327
+ mempool=cache.mempool,
328
+ n_warmups=n_warmups,
329
+ )
330
+
331
+ def dispatch(input_ids, position_ids, seqlen):
332
+ batch_size, decoding_seqlen = input_ids.shape[:2]
333
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
334
+
335
+ cache.run = dispatch
336
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
337
+ return cache
338
+
339
+
340
+ def capture_graph(
341
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
342
+ ):
343
+ device = next(iter(model.parameters())).device
344
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
345
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
346
+ seqlen_offset_og = inference_params.seqlen_offset
347
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
348
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
349
+
350
+ # Warmup before capture
351
+ s = torch.cuda.Stream()
352
+ s.wait_stream(torch.cuda.current_stream())
353
+ with torch.cuda.stream(s):
354
+ for _ in range(n_warmups):
355
+ logits = model(
356
+ input_ids,
357
+ position_ids=position_ids,
358
+ inference_params=inference_params,
359
+ num_last_tokens=decoding_seqlen,
360
+ ).logits
361
+ s.synchronize()
362
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
363
+ # which requires that graph launch and non-captured launch to not overlap (I think,
364
+ # that's how I interpret the documentation). I'm not sure if this is required.
365
+ if torch.distributed.is_initialized():
366
+ torch.distributed.barrier()
367
+ torch.cuda.current_stream().wait_stream(s)
368
+ # Captures the graph
369
+ # To allow capture, automatically sets a side stream as the current stream in the context
370
+ graph = torch.cuda.CUDAGraph()
371
+ with torch.cuda.graph(graph, pool=mempool):
372
+ logits = model(
373
+ input_ids,
374
+ position_ids=position_ids,
375
+ inference_params=inference_params,
376
+ num_last_tokens=decoding_seqlen,
377
+ ).logits
378
+
379
+ def run(new_input_ids, new_position_ids, seqlen):
380
+ inference_params.lengths_per_sample[:] = seqlen
381
+ input_ids.copy_(new_input_ids)
382
+ position_ids.copy_(new_position_ids)
383
+ graph.replay()
384
+ return logits.clone()
385
+
386
+ inference_params.seqlen_offset = seqlen_offset_og
387
+ return run
requirements.txt CHANGED
@@ -5,4 +5,4 @@ PyYAML==6.0.2
5
  scipy==1.15.3
6
  soundfile==0.12.1
7
  spaces==0.37.0
8
- transformers==4.40.1
 
5
  scipy==1.15.3
6
  soundfile==0.12.1
7
  spaces==0.37.0
8
+ transformers