adityas129 commited on
Commit
cf5fc2d
·
verified ·
1 Parent(s): 7dcdc7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +436 -146
app.py CHANGED
@@ -205,6 +205,123 @@ _patch_t5x_for_gpu_coords()
205
  jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  @contextmanager
209
  def mrt_overrides(mrt, **kwargs):
210
  """Temporarily set attributes on MRT if they exist; restore after."""
@@ -687,43 +804,60 @@ def generate(
687
  loop_weight: float = Form(1.0),
688
  loudness_mode: str = Form("auto"),
689
  loudness_headroom_db: float = Form(1.0),
690
- guidance_weight: float = Form(5.0),
691
- temperature: float = Form(1.1),
692
- topk: int = Form(40),
693
  target_sample_rate: int | None = Form(None),
694
  intro_bars_to_drop: int = Form(0), # <— NEW
695
  ):
696
- # Read file
697
- data = loop_audio.file.read()
698
- if not data:
699
- return {"error": "Empty file"}
700
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
701
- tmp.write(data)
702
- tmp_path = tmp.name
703
-
704
- # Parse styles + weights
705
- extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
706
- weights = [float(x) for x in style_weights.split(",")] if style_weights else None
707
-
708
- mrt = get_mrt() # warm once, in this worker thread
709
- # Temporarily override MRT inference knobs for this request
710
- with mrt_overrides(mrt,
711
- guidance_weight=guidance_weight,
712
- temperature=temperature,
713
- topk=topk):
714
- wav, loud_stats = generate_loop_continuation_with_mrt(
715
- mrt,
716
- input_wav_path=tmp_path,
717
- bpm=bpm,
718
- extra_styles=extra_styles,
719
- style_weights=weights,
720
- bars=bars,
721
- beats_per_bar=beats_per_bar,
722
- loop_weight=loop_weight,
723
- loudness_mode=loudness_mode,
724
- loudness_headroom_db=loudness_headroom_db,
725
- intro_bars_to_drop=intro_bars_to_drop, # <— pass through
726
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
  # 1) Figure out the desired SR
729
  inp_info = sf.info(tmp_path)
@@ -771,9 +905,9 @@ def generate_style(
771
  beats_per_bar: int = Form(4),
772
  styles: str = Form("warmup"),
773
  style_weights: str = Form(""),
774
- guidance_weight: float = Form(1.1),
775
- temperature: float = Form(1.1),
776
- topk: int = Form(40),
777
  target_sample_rate: int | None = Form(None),
778
  intro_bars_to_drop: int = Form(0),
779
  ):
@@ -781,26 +915,42 @@ def generate_style(
781
  Style-only, bar-aligned generation (no input audio).
782
  Seeds with 10s of silent context; outputs exactly `bars` at the requested BPM.
783
  """
784
- mrt = get_mrt()
785
 
786
- # Override sampling knobs just for this request
787
- with mrt_overrides(mrt,
788
- guidance_weight=guidance_weight,
789
- temperature=temperature,
790
- topk=topk):
791
- wav, _ = generate_style_only_with_mrt(
792
- mrt,
793
- bpm=bpm,
794
- bars=bars,
795
- beats_per_bar=beats_per_bar,
796
- styles=styles,
797
- style_weights=style_weights,
798
- intro_bars_to_drop=intro_bars_to_drop,
799
- )
800
 
801
- # Determine target SR (defaults to model SR = 48k)
802
- cur_sr = int(mrt.sample_rate)
803
- target_sr = int(target_sample_rate or cur_sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
  x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
805
 
806
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
@@ -849,87 +999,102 @@ def jam_start(
849
 
850
  loudness_mode: str = Form("auto"),
851
  loudness_headroom_db: float = Form(1.0),
852
- guidance_weight: float = Form(1.1),
853
- temperature: float = Form(1.1),
854
- topk: int = Form(40),
855
  target_sample_rate: int | None = Form(None),
856
  ):
857
- asset_manager.ensure_assets_loaded(get_mrt())
858
 
859
- # enforce single active jam per GPU
860
- with jam_lock:
861
- for sid, w in list(jam_registry.items()):
862
- if w.is_alive():
863
- raise HTTPException(status_code=429, detail="A jam is already running. Try again later.")
864
 
865
- # read input + prep context/style (reuse your existing code)
866
- data = loop_audio.file.read()
867
- if not data: raise HTTPException(status_code=400, detail="Empty file")
868
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
869
- tmp.write(data); tmp_path = tmp.name
870
 
871
- mrt = get_mrt()
872
- loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
 
 
 
873
 
874
- # build tail context + style vec (tail-biased)
875
- codec_fps = float(mrt.codec.frame_rate)
876
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
877
- loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
 
878
 
879
- # Parse client style fields (preserves your semantics)
880
- text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
881
- try:
882
- tw = [float(x) for x in style_weights.split(",")] if style_weights else []
883
- except ValueError:
884
- tw = []
885
- try:
886
- cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
887
- except ValueError:
888
- cw = []
889
 
890
- # Compute loop-tail embed once (same as before)
891
- loop_tail_embed = mrt.embed_style(loop_tail)
 
 
892
 
893
- # Build final style vector:
894
- # - identical to your previous mix when mean==0 and cw is empty
895
- # - otherwise includes mean and centroid components (weights auto-normalized)
896
- style_vec = build_style_vector(
897
- mrt,
898
- text_styles=text_list,
899
- text_weights=tw,
900
- loop_embed=loop_tail_embed,
901
- loop_weight=float(loop_weight),
902
- mean_weight=float(mean),
903
- centroid_weights=cw,
904
- ).astype(np.float32, copy=False)
905
 
906
- # target SR (default input SR)
907
- inp_info = sf.info(tmp_path)
908
- input_sr = int(inp_info.samplerate)
909
- target_sr = int(target_sample_rate or input_sr)
910
 
911
- params = JamParams(
912
- bpm=bpm,
913
- beats_per_bar=beats_per_bar,
914
- bars_per_chunk=bars_per_chunk,
915
- target_sr=target_sr,
916
- loudness_mode=loudness_mode,
917
- headroom_db=loudness_headroom_db,
918
- style_vec=style_vec,
919
- ref_loop=loop_tail, # For loudness matching
920
- combined_loop=loop, # NEW: Full loop for context setup
921
- guidance_weight=guidance_weight,
922
- temperature=temperature,
923
- topk=topk
924
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
- worker = JamWorker(mrt, params)
927
- sid = str(uuid.uuid4())
928
- with jam_lock:
929
- jam_registry[sid] = worker
930
- worker.start()
 
 
 
 
 
931
 
932
- return {"session_id": sid}
 
 
 
933
 
934
  @app.get("/jam/next")
935
  def jam_next(session_id: str):
@@ -938,13 +1103,17 @@ def jam_next(session_id: str):
938
  This ensures chunks are delivered in order without gaps.
939
  """
940
  with jam_lock:
941
- worker = jam_registry.get(session_id)
942
- if worker is None or not worker.is_alive():
 
 
 
 
943
  raise HTTPException(status_code=404, detail="Session not found")
944
 
945
  # Get the next sequential chunk (this blocks until ready)
946
  chunk = worker.get_next_chunk()
947
-
948
  if chunk is None:
949
  raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
950
 
@@ -963,12 +1132,16 @@ def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
963
  This helps the worker manage its buffer and generation flow.
964
  """
965
  with jam_lock:
966
- worker = jam_registry.get(session_id)
967
- if worker is None or not worker.is_alive():
 
 
 
 
968
  raise HTTPException(status_code=404, detail="Session not found")
969
 
970
  worker.mark_chunk_consumed(chunk_index)
971
-
972
  return {"consumed": chunk_index}
973
 
974
 
@@ -976,16 +1149,22 @@ def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
976
  @app.post("/jam/stop")
977
  def jam_stop(session_id: str = Body(..., embed=True)):
978
  with jam_lock:
979
- worker = jam_registry.get(session_id)
980
- if worker is None:
981
  raise HTTPException(status_code=404, detail="Session not found")
982
 
 
 
 
983
  worker.stop()
984
  worker.join(timeout=5.0)
985
  if worker.is_alive():
986
- # Its daemon=True, so it wont block process exit, but report it
987
  print(f"⚠️ JamWorker {session_id} did not stop within timeout")
988
 
 
 
 
989
  with jam_lock:
990
  jam_registry.pop(session_id, None)
991
  return {"stopped": True}
@@ -994,13 +1173,19 @@ def jam_stop(session_id: str = Body(..., embed=True)):
994
  def jam_stop_all():
995
  """Force stop all active jam sessions (nuclear option for cleanup)"""
996
  stopped_sessions = []
997
-
998
  with jam_lock:
999
- for session_id, worker in list(jam_registry.items()):
 
 
 
1000
  if worker.is_alive():
1001
  worker.stop()
1002
  worker.join(timeout=2.0)
1003
  stopped_sessions.append(session_id)
 
 
 
1004
  jam_registry.pop(session_id, None)
1005
 
1006
  return {"stopped_sessions": stopped_sessions, "count": len(stopped_sessions)}
@@ -1024,13 +1209,19 @@ def jam_update(
1024
  mean: Optional[float] = Form(None),
1025
  centroid_weights: str = Form(""),
1026
  ):
1027
- asset_manager.ensure_assets_loaded(get_mrt())
1028
-
1029
  with jam_lock:
1030
- worker = jam_registry.get(session_id)
1031
- if worker is None or not worker.is_alive():
 
 
 
 
1032
  raise HTTPException(status_code=404, detail="Session not found")
1033
 
 
 
 
 
1034
  # 1) fast knob updates
1035
  if any(v is not None for v in (guidance_weight, temperature, topk)):
1036
  worker.update_knobs(
@@ -1098,8 +1289,12 @@ def jam_update(
1098
  @app.post("/jam/reseed")
1099
  def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
1100
  with jam_lock:
1101
- worker = jam_registry.get(session_id)
1102
- if worker is None or not worker.is_alive():
 
 
 
 
1103
  raise HTTPException(status_code=404, detail="Session not found")
1104
 
1105
  # Option 1: use uploaded new “combined” bounce from the app
@@ -1129,8 +1324,13 @@ def jam_reseed_splice(
1129
  anchor_bars: float = Form(2.0), # how much of the original to re-inject
1130
  combined_audio: UploadFile = File(None), # preferred: Swift supplies the current combined mix
1131
  ):
1132
- worker = jam_registry.get(session_id)
1133
- if worker is None or not worker.is_alive():
 
 
 
 
 
1134
  raise HTTPException(status_code=404, detail="Session not found")
1135
 
1136
  # Build a waveform to reseed from
@@ -1160,11 +1360,12 @@ def jam_reseed_splice(
1160
  @app.get("/jam/status")
1161
  def jam_status(session_id: str):
1162
  with jam_lock:
1163
- worker = jam_registry.get(session_id)
1164
 
1165
- if worker is None:
1166
  raise HTTPException(status_code=404, detail="Session not found")
1167
 
 
1168
  running = worker.is_alive()
1169
 
1170
  # Snapshot safely
@@ -1284,8 +1485,12 @@ async def ws_jam(websocket: WebSocket):
1284
  # attach or create
1285
  if sid:
1286
  with jam_lock:
1287
- worker = jam_registry.get(sid)
1288
- if worker is None or not worker.is_alive():
 
 
 
 
1289
  await send_json({"type":"error","error":"Session not found"})
1290
  continue
1291
  else:
@@ -1645,6 +1850,91 @@ def read_root():
1645
  """
1646
  return Response(content=html_content, media_type="text/html")
1647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1648
  @app.get("/lil_demo_540p.mp4")
1649
  def demo_video():
1650
  return FileResponse(Path(__file__).parent / "lil_demo_540p.mp4", media_type="video/mp4")
 
205
  jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
207
 
208
+ # ============================================================================
209
+ # Global Generation Parameters
210
+ # ============================================================================
211
+
212
+ class GlobalGenParams:
213
+ """Global defaults for temperature, topk, guidance_weight.
214
+ Applied at MRT initialization. Changes require pool restart."""
215
+
216
+ def __init__(self):
217
+ self._lock = threading.RLock()
218
+ self.temperature = 1.1
219
+ self.topk = 40
220
+ self.guidance_weight = 1.1
221
+
222
+ def get(self):
223
+ with self._lock:
224
+ return {
225
+ 'temperature': self.temperature,
226
+ 'topk': self.topk,
227
+ 'guidance_weight': self.guidance_weight
228
+ }
229
+
230
+ def update(self, temperature=None, topk=None, guidance_weight=None):
231
+ """Update requires MRT pool restart to take effect"""
232
+ with self._lock:
233
+ if temperature is not None:
234
+ self.temperature = float(temperature)
235
+ if topk is not None:
236
+ self.topk = int(topk)
237
+ if guidance_weight is not None:
238
+ self.guidance_weight = float(guidance_weight)
239
+ return self.get()
240
+
241
+ _GLOBAL_GEN_PARAMS = GlobalGenParams()
242
+
243
+ # ============================================================================
244
+ # MRT Instance Pool (for parallel requests)
245
+ # ============================================================================
246
+
247
+ _MRT_POOL = []
248
+ _MRT_POOL_LOCK = threading.Lock()
249
+ _MRT_AVAILABLE = []
250
+ _POOL_INITIALIZED = False
251
+ _POOL_INIT_LOCK = threading.Lock()
252
+
253
+ def init_mrt_pool(pool_size=2):
254
+ """Initialize MRT pool with global params"""
255
+ global _MRT_POOL, _MRT_AVAILABLE
256
+
257
+ defaults = _GLOBAL_GEN_PARAMS.get()
258
+
259
+ _MRT_POOL.clear()
260
+ _MRT_AVAILABLE.clear()
261
+
262
+ for i in range(pool_size):
263
+ ckpt_dir = CheckpointManager.resolve_checkpoint_dir()
264
+ mrt = system.MagentaRT(
265
+ tag=os.getenv("MRT_SIZE", "large"),
266
+ guidance_weight=defaults['guidance_weight'],
267
+ device="gpu",
268
+ checkpoint_dir=ckpt_dir,
269
+ lazy=True
270
+ )
271
+ # Set other params
272
+ mrt.temperature = defaults['temperature']
273
+ mrt.topk = defaults['topk']
274
+
275
+ # Load assets if configured
276
+ if asset_manager.mean_embed is None and asset_manager.centroids is None:
277
+ repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
278
+ if repo:
279
+ asset_manager.load_finetune_assets_from_hf(repo, None)
280
+ _sync_assets_globals_from_manager()
281
+
282
+ _MRT_POOL.append(mrt)
283
+ _MRT_AVAILABLE.append(True)
284
+
285
+ def ensure_pool_initialized():
286
+ """Lazy init pool on first request"""
287
+ global _POOL_INITIALIZED
288
+ if not _POOL_INITIALIZED:
289
+ with _POOL_INIT_LOCK:
290
+ if not _POOL_INITIALIZED:
291
+ init_mrt_pool(pool_size=2)
292
+ _POOL_INITIALIZED = True
293
+
294
+ def get_available_mrt():
295
+ """Get an available MRT from pool. Returns (index, mrt) or (None, None)"""
296
+ with _MRT_POOL_LOCK:
297
+ for i, available in enumerate(_MRT_AVAILABLE):
298
+ if available:
299
+ _MRT_AVAILABLE[i] = False
300
+ return (i, _MRT_POOL[i])
301
+ return (None, None)
302
+
303
+ def release_mrt(index: int):
304
+ """Release MRT back to pool"""
305
+ with _MRT_POOL_LOCK:
306
+ if 0 <= index < len(_MRT_AVAILABLE):
307
+ _MRT_AVAILABLE[index] = True
308
+
309
+ def reset_mrt_pool():
310
+ """Recreate pool with current global params (requires stopping all sessions)"""
311
+ global _POOL_INITIALIZED
312
+
313
+ with _POOL_INIT_LOCK:
314
+ with _MRT_POOL_LOCK:
315
+ init_mrt_pool(pool_size=2)
316
+ _POOL_INITIALIZED = True
317
+
318
+ # ============================================================================
319
+ # Legacy single MRT support (for backward compatibility)
320
+ # ============================================================================
321
+
322
+ _MRT = None
323
+ _MRT_LOCK = threading.Lock()
324
+
325
  @contextmanager
326
  def mrt_overrides(mrt, **kwargs):
327
  """Temporarily set attributes on MRT if they exist; restore after."""
 
804
  loop_weight: float = Form(1.0),
805
  loudness_mode: str = Form("auto"),
806
  loudness_headroom_db: float = Form(1.0),
807
+ guidance_weight: Optional[float] = Form(None),
808
+ temperature: Optional[float] = Form(None),
809
+ topk: Optional[int] = Form(None),
810
  target_sample_rate: int | None = Form(None),
811
  intro_bars_to_drop: int = Form(0), # <— NEW
812
  ):
813
+ ensure_pool_initialized()
814
+
815
+ # Get available MRT from pool
816
+ mrt_index, mrt = get_available_mrt()
817
+ if mrt is None:
818
+ raise HTTPException(status_code=503, detail="All slots busy, retry shortly")
819
+
820
+ try:
821
+ # Apply global defaults if not specified
822
+ defaults = _GLOBAL_GEN_PARAMS.get()
823
+ guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight']
824
+ temperature = temperature if temperature is not None else defaults['temperature']
825
+ topk = topk if topk is not None else defaults['topk']
826
+
827
+ # Read file
828
+ data = loop_audio.file.read()
829
+ if not data:
830
+ return {"error": "Empty file"}
831
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
832
+ tmp.write(data)
833
+ tmp_path = tmp.name
834
+
835
+ # Parse styles + weights
836
+ extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
837
+ weights = [float(x) for x in style_weights.split(",")] if style_weights else None
838
+
839
+ # Temporarily override MRT inference knobs for this request
840
+ with mrt_overrides(mrt,
841
+ guidance_weight=guidance_weight,
842
+ temperature=temperature,
843
+ topk=topk):
844
+ wav, loud_stats = generate_loop_continuation_with_mrt(
845
+ mrt,
846
+ input_wav_path=tmp_path,
847
+ bpm=bpm,
848
+ extra_styles=extra_styles,
849
+ style_weights=weights,
850
+ bars=bars,
851
+ beats_per_bar=beats_per_bar,
852
+ loop_weight=loop_weight,
853
+ loudness_mode=loudness_mode,
854
+ loudness_headroom_db=loudness_headroom_db,
855
+ intro_bars_to_drop=intro_bars_to_drop, # <— pass through
856
+ )
857
+
858
+ finally:
859
+ # Always release MRT back to pool
860
+ release_mrt(mrt_index)
861
 
862
  # 1) Figure out the desired SR
863
  inp_info = sf.info(tmp_path)
 
905
  beats_per_bar: int = Form(4),
906
  styles: str = Form("warmup"),
907
  style_weights: str = Form(""),
908
+ guidance_weight: Optional[float] = Form(None),
909
+ temperature: Optional[float] = Form(None),
910
+ topk: Optional[int] = Form(None),
911
  target_sample_rate: int | None = Form(None),
912
  intro_bars_to_drop: int = Form(0),
913
  ):
 
915
  Style-only, bar-aligned generation (no input audio).
916
  Seeds with 10s of silent context; outputs exactly `bars` at the requested BPM.
917
  """
918
+ ensure_pool_initialized()
919
 
920
+ # Get available MRT from pool
921
+ mrt_index, mrt = get_available_mrt()
922
+ if mrt is None:
923
+ raise HTTPException(status_code=503, detail="All slots busy, retry shortly")
 
 
 
 
 
 
 
 
 
 
924
 
925
+ try:
926
+ # Apply global defaults if not specified
927
+ defaults = _GLOBAL_GEN_PARAMS.get()
928
+ guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight']
929
+ temperature = temperature if temperature is not None else defaults['temperature']
930
+ topk = topk if topk is not None else defaults['topk']
931
+
932
+ # Override sampling knobs just for this request
933
+ with mrt_overrides(mrt,
934
+ guidance_weight=guidance_weight,
935
+ temperature=temperature,
936
+ topk=topk):
937
+ wav, _ = generate_style_only_with_mrt(
938
+ mrt,
939
+ bpm=bpm,
940
+ bars=bars,
941
+ beats_per_bar=beats_per_bar,
942
+ styles=styles,
943
+ style_weights=style_weights,
944
+ intro_bars_to_drop=intro_bars_to_drop,
945
+ )
946
+
947
+ # Determine target SR (defaults to model SR = 48k)
948
+ cur_sr = int(mrt.sample_rate)
949
+ target_sr = int(target_sample_rate or cur_sr)
950
+
951
+ finally:
952
+ # Always release MRT back to pool
953
+ release_mrt(mrt_index)
954
  x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
955
 
956
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
 
999
 
1000
  loudness_mode: str = Form("auto"),
1001
  loudness_headroom_db: float = Form(1.0),
1002
+ guidance_weight: Optional[float] = Form(None),
1003
+ temperature: Optional[float] = Form(None),
1004
+ topk: Optional[int] = Form(None),
1005
  target_sample_rate: int | None = Form(None),
1006
  ):
1007
+ ensure_pool_initialized()
1008
 
1009
+ # Get available MRT from pool
1010
+ mrt_index, mrt = get_available_mrt()
1011
+ if mrt is None:
1012
+ raise HTTPException(status_code=429, detail="All slots busy (max 2 concurrent JAM sessions)")
 
1013
 
1014
+ try:
1015
+ asset_manager.ensure_assets_loaded(mrt)
 
 
 
1016
 
1017
+ # Apply global defaults if not specified
1018
+ defaults = _GLOBAL_GEN_PARAMS.get()
1019
+ guidance_weight = guidance_weight if guidance_weight is not None else defaults['guidance_weight']
1020
+ temperature = temperature if temperature is not None else defaults['temperature']
1021
+ topk = topk if topk is not None else defaults['topk']
1022
 
1023
+ # read input + prep context/style (reuse your existing code)
1024
+ data = loop_audio.file.read()
1025
+ if not data: raise HTTPException(status_code=400, detail="Empty file")
1026
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1027
+ tmp.write(data); tmp_path = tmp.name
1028
 
1029
+ loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
 
 
 
 
 
 
 
 
 
1030
 
1031
+ # build tail context + style vec (tail-biased)
1032
+ codec_fps = float(mrt.codec.frame_rate)
1033
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1034
+ loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
1035
 
1036
+ # Parse client style fields (preserves your semantics)
1037
+ text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
1038
+ try:
1039
+ tw = [float(x) for x in style_weights.split(",")] if style_weights else []
1040
+ except ValueError:
1041
+ tw = []
1042
+ try:
1043
+ cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
1044
+ except ValueError:
1045
+ cw = []
 
 
1046
 
1047
+ # Compute loop-tail embed once (same as before)
1048
+ loop_tail_embed = mrt.embed_style(loop_tail)
 
 
1049
 
1050
+ # Build final style vector:
1051
+ # - identical to your previous mix when mean==0 and cw is empty
1052
+ # - otherwise includes mean and centroid components (weights auto-normalized)
1053
+ style_vec = build_style_vector(
1054
+ mrt,
1055
+ text_styles=text_list,
1056
+ text_weights=tw,
1057
+ loop_embed=loop_tail_embed,
1058
+ loop_weight=float(loop_weight),
1059
+ mean_weight=float(mean),
1060
+ centroid_weights=cw,
1061
+ ).astype(np.float32, copy=False)
1062
+
1063
+ # target SR (default input SR)
1064
+ inp_info = sf.info(tmp_path)
1065
+ input_sr = int(inp_info.samplerate)
1066
+ target_sr = int(target_sample_rate or input_sr)
1067
+
1068
+ params = JamParams(
1069
+ bpm=bpm,
1070
+ beats_per_bar=beats_per_bar,
1071
+ bars_per_chunk=bars_per_chunk,
1072
+ target_sr=target_sr,
1073
+ loudness_mode=loudness_mode,
1074
+ headroom_db=loudness_headroom_db,
1075
+ style_vec=style_vec,
1076
+ ref_loop=loop_tail, # For loudness matching
1077
+ combined_loop=loop, # NEW: Full loop for context setup
1078
+ guidance_weight=guidance_weight,
1079
+ temperature=temperature,
1080
+ topk=topk
1081
+ )
1082
 
1083
+ worker = JamWorker(mrt, params)
1084
+ sid = str(uuid.uuid4())
1085
+ with jam_lock:
1086
+ jam_registry[sid] = {
1087
+ 'worker': worker,
1088
+ 'mrt_index': mrt_index
1089
+ }
1090
+ worker.start()
1091
+
1092
+ return {"session_id": sid, "slot": mrt_index}
1093
 
1094
+ except Exception as e:
1095
+ # Release MRT back to pool on failure
1096
+ release_mrt(mrt_index)
1097
+ raise
1098
 
1099
  @app.get("/jam/next")
1100
  def jam_next(session_id: str):
 
1103
  This ensures chunks are delivered in order without gaps.
1104
  """
1105
  with jam_lock:
1106
+ session_info = jam_registry.get(session_id)
1107
+ if session_info is None:
1108
+ raise HTTPException(status_code=404, detail="Session not found")
1109
+
1110
+ worker = session_info['worker']
1111
+ if not worker.is_alive():
1112
  raise HTTPException(status_code=404, detail="Session not found")
1113
 
1114
  # Get the next sequential chunk (this blocks until ready)
1115
  chunk = worker.get_next_chunk()
1116
+
1117
  if chunk is None:
1118
  raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
1119
 
 
1132
  This helps the worker manage its buffer and generation flow.
1133
  """
1134
  with jam_lock:
1135
+ session_info = jam_registry.get(session_id)
1136
+ if session_info is None:
1137
+ raise HTTPException(status_code=404, detail="Session not found")
1138
+
1139
+ worker = session_info['worker']
1140
+ if not worker.is_alive():
1141
  raise HTTPException(status_code=404, detail="Session not found")
1142
 
1143
  worker.mark_chunk_consumed(chunk_index)
1144
+
1145
  return {"consumed": chunk_index}
1146
 
1147
 
 
1149
  @app.post("/jam/stop")
1150
  def jam_stop(session_id: str = Body(..., embed=True)):
1151
  with jam_lock:
1152
+ session_info = jam_registry.get(session_id)
1153
+ if session_info is None:
1154
  raise HTTPException(status_code=404, detail="Session not found")
1155
 
1156
+ worker = session_info['worker']
1157
+ mrt_index = session_info['mrt_index']
1158
+
1159
  worker.stop()
1160
  worker.join(timeout=5.0)
1161
  if worker.is_alive():
1162
+ # It's daemon=True, so it won't block process exit, but report it
1163
  print(f"⚠️ JamWorker {session_id} did not stop within timeout")
1164
 
1165
+ # Release MRT back to pool
1166
+ release_mrt(mrt_index)
1167
+
1168
  with jam_lock:
1169
  jam_registry.pop(session_id, None)
1170
  return {"stopped": True}
 
1173
  def jam_stop_all():
1174
  """Force stop all active jam sessions (nuclear option for cleanup)"""
1175
  stopped_sessions = []
1176
+
1177
  with jam_lock:
1178
+ for session_id, session_info in list(jam_registry.items()):
1179
+ worker = session_info['worker']
1180
+ mrt_index = session_info['mrt_index']
1181
+
1182
  if worker.is_alive():
1183
  worker.stop()
1184
  worker.join(timeout=2.0)
1185
  stopped_sessions.append(session_id)
1186
+
1187
+ # Release MRT back to pool
1188
+ release_mrt(mrt_index)
1189
  jam_registry.pop(session_id, None)
1190
 
1191
  return {"stopped_sessions": stopped_sessions, "count": len(stopped_sessions)}
 
1209
  mean: Optional[float] = Form(None),
1210
  centroid_weights: str = Form(""),
1211
  ):
 
 
1212
  with jam_lock:
1213
+ session_info = jam_registry.get(session_id)
1214
+ if session_info is None:
1215
+ raise HTTPException(status_code=404, detail="Session not found")
1216
+
1217
+ worker = session_info['worker']
1218
+ if not worker.is_alive():
1219
  raise HTTPException(status_code=404, detail="Session not found")
1220
 
1221
+ # Get MRT from the worker's assigned instance
1222
+ mrt = _MRT_POOL[session_info['mrt_index']]
1223
+ asset_manager.ensure_assets_loaded(mrt)
1224
+
1225
  # 1) fast knob updates
1226
  if any(v is not None for v in (guidance_weight, temperature, topk)):
1227
  worker.update_knobs(
 
1289
  @app.post("/jam/reseed")
1290
  def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
1291
  with jam_lock:
1292
+ session_info = jam_registry.get(session_id)
1293
+ if session_info is None:
1294
+ raise HTTPException(status_code=404, detail="Session not found")
1295
+
1296
+ worker = session_info['worker']
1297
+ if not worker.is_alive():
1298
  raise HTTPException(status_code=404, detail="Session not found")
1299
 
1300
  # Option 1: use uploaded new “combined” bounce from the app
 
1324
  anchor_bars: float = Form(2.0), # how much of the original to re-inject
1325
  combined_audio: UploadFile = File(None), # preferred: Swift supplies the current combined mix
1326
  ):
1327
+ with jam_lock:
1328
+ session_info = jam_registry.get(session_id)
1329
+ if session_info is None:
1330
+ raise HTTPException(status_code=404, detail="Session not found")
1331
+
1332
+ worker = session_info['worker']
1333
+ if not worker.is_alive():
1334
  raise HTTPException(status_code=404, detail="Session not found")
1335
 
1336
  # Build a waveform to reseed from
 
1360
  @app.get("/jam/status")
1361
  def jam_status(session_id: str):
1362
  with jam_lock:
1363
+ session_info = jam_registry.get(session_id)
1364
 
1365
+ if session_info is None:
1366
  raise HTTPException(status_code=404, detail="Session not found")
1367
 
1368
+ worker = session_info['worker']
1369
  running = worker.is_alive()
1370
 
1371
  # Snapshot safely
 
1485
  # attach or create
1486
  if sid:
1487
  with jam_lock:
1488
+ session_info = jam_registry.get(sid)
1489
+ if session_info is None:
1490
+ await send_json({"type":"error","error":"Session not found"})
1491
+ continue
1492
+ worker = session_info['worker']
1493
+ if not worker.is_alive():
1494
  await send_json({"type":"error","error":"Session not found"})
1495
  continue
1496
  else:
 
1850
  """
1851
  return Response(content=html_content, media_type="text/html")
1852
 
1853
+ # ============================================================================
1854
+ # Global Generation Configuration Endpoints
1855
+ # ============================================================================
1856
+
1857
+ @app.get("/config/generation")
1858
+ async def get_generation_config():
1859
+ """
1860
+ Get current global defaults for temperature, topk, and guidance_weight.
1861
+ These defaults are applied at MRT initialization and affect all new requests.
1862
+ """
1863
+ return _GLOBAL_GEN_PARAMS.get()
1864
+
1865
+ @app.put("/config/generation")
1866
+ async def update_generation_config(
1867
+ temperature: Optional[float] = None,
1868
+ topk: Optional[int] = None,
1869
+ guidance_weight: Optional[float] = None
1870
+ ):
1871
+ """
1872
+ Update global defaults for temperature, topk, and guidance_weight.
1873
+
1874
+ NOTE: Changes require MRT pool restart to take effect.
1875
+ Call POST /config/generation/apply after updating to apply changes.
1876
+
1877
+ Per-request overrides still work - explicit parameters in requests
1878
+ will override these global defaults.
1879
+ """
1880
+ return {
1881
+ "updated": _GLOBAL_GEN_PARAMS.update(
1882
+ temperature=temperature,
1883
+ topk=topk,
1884
+ guidance_weight=guidance_weight
1885
+ ),
1886
+ "note": "Changes require pool restart. Call POST /config/generation/apply to apply."
1887
+ }
1888
+
1889
+ @app.post("/config/generation/apply")
1890
+ async def apply_generation_config():
1891
+ """
1892
+ Restart MRT pool with new global parameters.
1893
+
1894
+ This will:
1895
+ 1. Check if any JAM sessions are active
1896
+ 2. If active sessions exist, return 409 error
1897
+ 3. If no active sessions, recreate MRT pool with new parameters
1898
+
1899
+ All future requests will use the new global defaults.
1900
+ """
1901
+ # Check for active sessions
1902
+ with jam_lock:
1903
+ active_sessions = []
1904
+ for sid, session_info in jam_registry.items():
1905
+ if session_info['worker'].is_alive():
1906
+ active_sessions.append(sid)
1907
+
1908
+ if active_sessions:
1909
+ raise HTTPException(
1910
+ status_code=409,
1911
+ detail=f"Cannot restart: {len(active_sessions)} active JAM session(s). Stop them first via /jam/stop"
1912
+ )
1913
+
1914
+ # Restart pool with new parameters
1915
+ reset_mrt_pool()
1916
+
1917
+ return {
1918
+ "status": "applied",
1919
+ "params": _GLOBAL_GEN_PARAMS.get(),
1920
+ "message": "MRT pool restarted with new parameters"
1921
+ }
1922
+
1923
+ @app.get("/config/generation/pool_status")
1924
+ async def get_pool_status():
1925
+ """Get current MRT pool status and availability"""
1926
+ with _MRT_POOL_LOCK:
1927
+ return {
1928
+ "pool_size": len(_MRT_POOL),
1929
+ "available": _MRT_AVAILABLE.copy(),
1930
+ "initialized": _POOL_INITIALIZED,
1931
+ "params": _GLOBAL_GEN_PARAMS.get()
1932
+ }
1933
+
1934
+ # ============================================================================
1935
+ # Static Files
1936
+ # ============================================================================
1937
+
1938
  @app.get("/lil_demo_540p.mp4")
1939
  def demo_video():
1940
  return FileResponse(Path(__file__).parent / "lil_demo_540p.mp4", media_type="video/mp4")