pbotsaris commited on
Commit
e3598b0
·
1 Parent(s): aa06a1d

removed autocast as it creates float precision issues. removed debug

Browse files
Files changed (1) hide show
  1. handler.py +3 -78
handler.py CHANGED
@@ -4,73 +4,9 @@ from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
  import io
6
  import base64
7
- import wave
8
- import array
9
- import math
10
-
11
- def generate_sine_wave(freq, duration, sample_rate, amplitude):
12
- n_samples = int(sample_rate * duration)
13
- samples = []
14
-
15
- for x in range(n_samples):
16
- value = amplitude * math.sin(2 * math.pi * freq * x / sample_rate)
17
- samples.append(int(value)) # rounding to the nearest integer
18
-
19
- return array.array("h", samples) # array of short integers
20
-
21
-
22
- def sine_to_base64():
23
- frequency = 440.0 # Frequency in Hz
24
- duration = 1.0 # seconds
25
- volume = 0.5 # 0.0 to 1.0
26
- sample_rate = 44100
27
- amplitude = int(volume * 32767) # 16-bit audio
28
-
29
- sine_wave = generate_sine_wave(frequency, duration, sample_rate, amplitude)
30
-
31
- wav_buffer = io.BytesIO()
32
- with wave.open(wav_buffer, "w") as wav_file:
33
- n_channels = 1
34
- sampwidth = 2
35
- n_frames = len(sine_wave)
36
- comptype = "NONE"
37
- compname = "not compressed"
38
- wav_file.setparams((n_channels, sampwidth, int(sample_rate), n_frames, comptype, compname))
39
- wav_file.writeframes(sine_wave.tobytes())
40
-
41
- base64_string = base64.b64encode(wav_buffer.getvalue()).decode('utf-8')
42
- return base64_string
43
-
44
-
45
- def create_params(params, fr):
46
- # default
47
- out = { "do_sample": True,
48
- "guidance_scale": 3,
49
- "max_new_tokens": 256
50
- }
51
-
52
- has_tokens = False
53
-
54
- if params is None:
55
- return out
56
-
57
- if 'duration' in params:
58
- out['max_new_tokens'] = params['duration'] * fr
59
- has_tokens = True
60
-
61
- for k, p in params.items():
62
- if k in out:
63
- if has_tokens and k == 'max_new_tokens':
64
- continue
65
-
66
- out[k] = p
67
-
68
- return out
69
-
70
 
71
  class EndpointHandler:
72
  def __init__(self, path="pbotsaris/musicgen-small"):
73
- # load model and processor
74
  self.processor = AutoProcessor.from_pretrained(path)
75
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
76
  self.model.to('cuda:0') #type: ignore
@@ -87,15 +23,14 @@ class EndpointHandler:
87
  params = data.pop("parameters", None)
88
 
89
  inputs = self.processor(
90
- text=["80s pop track with a bassy synth"],
91
  padding=True,
92
  return_tensors="pt"
93
  )
94
 
95
  params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore
96
 
97
- with torch.cuda.amp.autocast(): #type: ignore
98
- outputs = self.model.generate(**inputs.to('cuda:0'), do_sample=True, guidance_scale=3, max_new_tokens=256) #type: ignore
99
 
100
  pred = outputs[0, 0].cpu().numpy()
101
  sr = self.model.config.audio_encoder.sampling_rate #type: ignore
@@ -104,18 +39,8 @@ class EndpointHandler:
104
  wavfile.write(wav_buffer, rate=sr, data=pred)
105
  wav_data = wav_buffer.getvalue()
106
 
107
-
108
- w_len = len(wav_data)
109
- p_len = len(pred)
110
-
111
- shape = ""
112
-
113
- for v in outputs.shape: #type: ignore
114
- shape += ":" + str(v)
115
-
116
  base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
117
- return [{"audio": base64_encoded_wav, "wav_len": w_len, "pred_len": p_len, "shape": shape, "sr": sr, 'dtype': str(pred.dtype)}]
118
-
119
 
120
  if __name__ == "__main__":
121
  handler = EndpointHandler()
 
4
  import torch
5
  import io
6
  import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path="pbotsaris/musicgen-small"):
 
10
  self.processor = AutoProcessor.from_pretrained(path)
11
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
12
  self.model.to('cuda:0') #type: ignore
 
23
  params = data.pop("parameters", None)
24
 
25
  inputs = self.processor(
26
+ text=[inputs],
27
  padding=True,
28
  return_tensors="pt"
29
  )
30
 
31
  params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore
32
 
33
+ outputs = self.model.generate(**inputs.to('cuda:0'), **params) #type: ignore
 
34
 
35
  pred = outputs[0, 0].cpu().numpy()
36
  sr = self.model.config.audio_encoder.sampling_rate #type: ignore
 
39
  wavfile.write(wav_buffer, rate=sr, data=pred)
40
  wav_data = wav_buffer.getvalue()
41
 
 
 
 
 
 
 
 
 
 
42
  base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
43
+ return [{"audio": base64_encoded_wav, "sr": sr}]
 
44
 
45
  if __name__ == "__main__":
46
  handler = EndpointHandler()