pbotsaris commited on
Commit
3ce0400
·
1 Parent(s): da50f54

added debug returns to check models output

Browse files
Files changed (1) hide show
  1. handler.py +18 -9
handler.py CHANGED
@@ -73,9 +73,9 @@ class EndpointHandler:
73
  # load model and processor
74
  self.processor = AutoProcessor.from_pretrained(path)
75
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
76
- self.model.to('cuda')
77
 
78
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
79
  """
80
  Args:
81
  data (:dict:):
@@ -87,25 +87,34 @@ class EndpointHandler:
87
  params = data.pop("parameters", None)
88
 
89
  inputs = self.processor(
90
- text=[inputs],
91
  padding=True,
92
  return_tensors="pt"
93
- ).to('cuda')
94
 
95
- params = create_params(params, self.model.config.audio_encoder.frame_rate)
96
 
97
- with torch.cuda.amp.autocast():
98
- outputs = self.model.generate(**inputs.to('cuda'), do_sample=True, guidance_scale=3, max_new_tokens=256)
99
 
100
  pred = outputs[0, 0].cpu().numpy()
101
- sr = self.model.config.audio_encoder.sampling_rate
102
 
103
  wav_buffer = io.BytesIO()
104
  wavfile.write(wav_buffer, rate=sr, data=pred)
105
  wav_data = wav_buffer.getvalue()
106
 
 
 
 
 
 
 
 
 
 
107
  base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
108
- return [{"audio": base64_encoded_wav}]
109
 
110
 
111
  if __name__ == "__main__":
 
73
  # load model and processor
74
  self.processor = AutoProcessor.from_pretrained(path)
75
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
76
+ self.model.to('cuda:0') #type: ignore
77
 
78
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
79
  """
80
  Args:
81
  data (:dict:):
 
87
  params = data.pop("parameters", None)
88
 
89
  inputs = self.processor(
90
+ text=["80s pop track with a bassy synth"],
91
  padding=True,
92
  return_tensors="pt"
93
+ )
94
 
95
+ params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore
96
 
97
+ with torch.cuda.amp.autocast(): #type: ignore
98
+ outputs = self.model.generate(**inputs.to('cuda:0'), do_sample=True, guidance_scale=3, max_new_tokens=256) #type: ignore
99
 
100
  pred = outputs[0, 0].cpu().numpy()
101
+ sr = self.model.config.audio_encoder.sampling_rate #type: ignore
102
 
103
  wav_buffer = io.BytesIO()
104
  wavfile.write(wav_buffer, rate=sr, data=pred)
105
  wav_data = wav_buffer.getvalue()
106
 
107
+
108
+ w_len = len(wav_data)
109
+ p_len = len(pred)
110
+
111
+ shape = ""
112
+
113
+ for v in outputs.shape: #type: ignore
114
+ shape += ":" + str(v)
115
+
116
  base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
117
+ return [{"audio": base64_encoded_wav, "wav_len": w_len, "pred_len": p_len, "shape": shape}]
118
 
119
 
120
  if __name__ == "__main__":