1inkusFace commited on
Commit
8f1996d
·
verified ·
1 Parent(s): 03ff4c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -40
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import spaces
3
  import gradio as gr
4
  import argparse
@@ -9,17 +8,17 @@ import subprocess
9
  from PIL import Image
10
  import numpy as np
11
 
12
- subprocess.run(['sh', './sky.sh'])
13
- sys.path.append("./SkyReels-V1")
14
 
15
- from skyreelsinfer import TaskType
16
- from skyreelsinfer.offload import OffloadConfig
17
- from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
18
  from diffusers.utils import export_to_video
19
 
20
  import torch
21
  import logging
22
- from collections import OrderedDict # Import OrderedDict here
23
 
24
  torch.backends.cuda.matmul.allow_tf32 = False
25
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
@@ -31,32 +30,46 @@ torch.set_float32_matmul_precision("highest")
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
  logger = logging.getLogger(__name__)
 
 
34
  # --- Dummy Classes (Keep for standalone execution) ---
35
  class OffloadConfig:
36
- def __init__(self, high_cpu_memory=False, parameters_level=False, compiler_transformer=False, compiler_cache=""):
 
 
 
 
 
 
37
  self.high_cpu_memory = high_cpu_memory
38
  self.parameters_level = parameters_level
39
  self.compiler_transformer = compiler_transformer
40
  self.compiler_cache = compiler_cache
41
 
42
- class TaskType: #Keep here for infer
 
43
  T2V = 0
44
  I2V = 1
45
 
 
46
  class LlamaModel:
47
  @staticmethod
48
  def from_pretrained(*args, **kwargs):
49
  return LlamaModel()
 
50
  def to(self, device):
51
  return self
52
 
 
53
  class HunyuanVideoTransformer3DModel:
54
  @staticmethod
55
  def from_pretrained(*args, **kwargs):
56
  return HunyuanVideoTransformer3DModel()
 
57
  def to(self, device):
58
  return self
59
 
 
60
  class SkyreelsVideoPipeline:
61
  @staticmethod
62
  def from_pretrained(*args, **kwargs):
@@ -76,36 +89,45 @@ class SkyreelsVideoPipeline:
76
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
77
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
78
 
79
- # Create video by repeating the image and adding noise
80
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
81
- frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise.
 
82
 
83
  else: # T2V
84
- frames = torch.randn(1, 3, num_frames, height, width) # Use correct dims
85
 
86
- return type('obj', (object,), {'frames' : frames})() # No longer a list!
87
 
88
  def __init__(self):
89
- super().__init__()
90
- self._modules = OrderedDict()
91
- self.vae = self.VAE()
92
- self._modules["vae"] = self.vae
93
 
94
  def named_children(self):
95
- return self._modules.items()
 
96
  class VAE:
97
  def enable_tiling(self):
98
  pass
99
 
 
100
  def quantize_(*args, **kwargs):
101
  return
102
 
 
103
  def float8_weight_only():
104
  return
105
 
 
106
  # --- End Dummy Classes ---
 
 
107
  class SkyReelsVideoSingleGpuInfer:
108
- def _load_model(self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True):
 
 
109
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
110
  text_encoder = LlamaModel.from_pretrained(
111
  base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
@@ -160,7 +182,7 @@ class SkyReelsVideoSingleGpuInfer:
160
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
161
 
162
  if self.is_offload:
163
- pass
164
  else:
165
  self.pipe.to(self.gpu_device)
166
 
@@ -176,10 +198,10 @@ class SkyReelsVideoSingleGpuInfer:
176
  self.is_initialized = True
177
 
178
  def warm_up(self):
179
- if not self.is_initialized:
180
- raise RuntimeError("Model must be initialized before warm-up.")
181
 
182
- init_kwargs = {
183
  "prompt": "A woman is dancing in a room",
184
  "height": 544,
185
  "width": 960,
@@ -190,26 +212,38 @@ class SkyReelsVideoSingleGpuInfer:
190
  "generator": torch.Generator(self.gpu_device).manual_seed(42),
191
  "embedded_guidance_scale": 1.0,
192
  }
193
- if self.task_type == TaskType.I2V:
194
- init_kwargs["image"] = Image.new("RGB",(544,960), color="black")
195
- self.pipe(**init_kwargs)
196
- logger.info("Warm-up complete.")
197
 
198
  def infer(self, **kwargs):
199
  """Handles inference requests."""
200
  if not self.is_initialized:
201
- self.initialize()
202
  if "seed" in kwargs:
203
  kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
204
  del kwargs["seed"]
205
  assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
206
- result = self.pipe(**kwargs).frames # Return the tensor directly
207
  return result
208
 
 
209
  _predictor = None
210
 
 
211
  @spaces.GPU(duration=90)
212
- def generate_video(prompt, seed, image=None):
 
 
 
 
 
 
 
 
 
 
213
  global _predictor
214
 
215
  if seed == -1:
@@ -232,9 +266,6 @@ def generate_video(prompt, seed, image=None):
232
  else:
233
  task_type = TaskType.I2V
234
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
235
- seed = 43
236
- #generator = torch.Generator(device="cuda").manual_seed(seed)
237
-
238
  kwargs = {
239
  "prompt": prompt,
240
  "image": Image.open(image),
@@ -243,11 +274,10 @@ def generate_video(prompt, seed, image=None):
243
  "num_frames": 97,
244
  "num_inference_steps": 30,
245
  "seed": seed,
246
- #"generator": generator,
247
  "guidance_scale": 6.0,
248
  "embedded_guidance_scale": 1.0,
249
  "negative_prompt": "Aerial view, low quality, bad hands",
250
- "cfg_for": False,
251
  }
252
 
253
  if _predictor is None:
@@ -264,12 +294,13 @@ def generate_video(prompt, seed, image=None):
264
  )
265
  _predictor.initialize()
266
  logger.info("Predictor initialized")
267
- out_samples = []
268
  with torch.no_grad():
269
- output = _predictor.infer(**kwargs)
270
- #out_samples.extend(output.frames[0])
271
- output = (output.cpu().numpy() * 255).astype(np.uint8)
272
- output = output.transpose(0, 2, 3, 4, 1)
 
273
 
274
  save_dir = f"./result"
275
  os.makedirs(save_dir, exist_ok=True)
@@ -278,6 +309,7 @@ def generate_video(prompt, seed, image=None):
278
  export_to_video(output, video_out_file, fps=24)
279
  return video_out_file, kwargs
280
 
 
281
  def create_gradio_interface():
282
  with gr.Blocks() as demo:
283
  with gr.Row():
 
 
1
  import spaces
2
  import gradio as gr
3
  import argparse
 
8
  from PIL import Image
9
  import numpy as np
10
 
11
+ # subprocess.run(['sh', './sky.sh']) # Removed as it's likely environment-specific
12
+ # sys.path.append("./SkyReels-V1") # Removed as it's likely environment-specific
13
 
14
+ # from skyreelsinfer import TaskType # Dummy classes cover this
15
+ # from skyreelsinfer.offload import OffloadConfig # Dummy classes cover this
16
+ # from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Dummy classes cover this
17
  from diffusers.utils import export_to_video
18
 
19
  import torch
20
  import logging
21
+ from collections import OrderedDict
22
 
23
  torch.backends.cuda.matmul.allow_tf32 = False
24
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
 
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
 
32
  logger = logging.getLogger(__name__)
33
+
34
+
35
  # --- Dummy Classes (Keep for standalone execution) ---
36
  class OffloadConfig:
37
+ def __init__(
38
+ self,
39
+ high_cpu_memory: bool = False,
40
+ parameters_level: bool = False,
41
+ compiler_transformer: bool = False,
42
+ compiler_cache: str = "",
43
+ ):
44
  self.high_cpu_memory = high_cpu_memory
45
  self.parameters_level = parameters_level
46
  self.compiler_transformer = compiler_transformer
47
  self.compiler_cache = compiler_cache
48
 
49
+
50
+ class TaskType: # Keep here for infer
51
  T2V = 0
52
  I2V = 1
53
 
54
+
55
  class LlamaModel:
56
  @staticmethod
57
  def from_pretrained(*args, **kwargs):
58
  return LlamaModel()
59
+
60
  def to(self, device):
61
  return self
62
 
63
+
64
  class HunyuanVideoTransformer3DModel:
65
  @staticmethod
66
  def from_pretrained(*args, **kwargs):
67
  return HunyuanVideoTransformer3DModel()
68
+
69
  def to(self, device):
70
  return self
71
 
72
+
73
  class SkyreelsVideoPipeline:
74
  @staticmethod
75
  def from_pretrained(*args, **kwargs):
 
89
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
90
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
91
 
92
+ # Create video by repeating the image
93
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
94
+ frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
95
+ frames = frames.permute(0, 2, 1, 3, 4) #Change to 1,T,C,H,W
96
 
97
  else: # T2V
98
+ frames = torch.randn(1, num_frames, 3, height, width) # Use correct dims: (1, T, C, H, W)
99
 
100
+ return type("obj", (object,), {"frames": frames})() # No longer a list!
101
 
102
  def __init__(self):
103
+ super().__init__()
104
+ self._modules = OrderedDict()
105
+ self.vae = self.VAE()
106
+ self._modules["vae"] = self.vae
107
 
108
  def named_children(self):
109
+ return self._modules.items()
110
+
111
  class VAE:
112
  def enable_tiling(self):
113
  pass
114
 
115
+
116
  def quantize_(*args, **kwargs):
117
  return
118
 
119
+
120
  def float8_weight_only():
121
  return
122
 
123
+
124
  # --- End Dummy Classes ---
125
+
126
+
127
  class SkyReelsVideoSingleGpuInfer:
128
+ def _load_model(
129
+ self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
130
+ ):
131
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
132
  text_encoder = LlamaModel.from_pretrained(
133
  base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
 
182
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
183
 
184
  if self.is_offload:
185
+ pass # Offloading logic (if any) would go here
186
  else:
187
  self.pipe.to(self.gpu_device)
188
 
 
198
  self.is_initialized = True
199
 
200
  def warm_up(self):
201
+ if not self.is_initialized:
202
+ raise RuntimeError("Model must be initialized before warm-up.")
203
 
204
+ init_kwargs = {
205
  "prompt": "A woman is dancing in a room",
206
  "height": 544,
207
  "width": 960,
 
212
  "generator": torch.Generator(self.gpu_device).manual_seed(42),
213
  "embedded_guidance_scale": 1.0,
214
  }
215
+ if self.task_type == TaskType.I2V:
216
+ init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
217
+ self.pipe(**init_kwargs)
218
+ logger.info("Warm-up complete.")
219
 
220
  def infer(self, **kwargs):
221
  """Handles inference requests."""
222
  if not self.is_initialized:
223
+ self.initialize()
224
  if "seed" in kwargs:
225
  kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
226
  del kwargs["seed"]
227
  assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
228
+ result = self.pipe(**kwargs).frames # Return the tensor directly
229
  return result
230
 
231
+
232
  _predictor = None
233
 
234
+
235
  @spaces.GPU(duration=90)
236
+ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
237
+ """Generates a video based on the given prompt and seed.
238
+
239
+ Args:
240
+ prompt: The text prompt to guide video generation.
241
+ seed: The random seed for reproducibility.
242
+ image: Optional path to an image for Image-to-Video.
243
+
244
+ Returns:
245
+ A tuple containing the path to the generated video and the parameters used.
246
+ """
247
  global _predictor
248
 
249
  if seed == -1:
 
266
  else:
267
  task_type = TaskType.I2V
268
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
 
 
 
269
  kwargs = {
270
  "prompt": prompt,
271
  "image": Image.open(image),
 
274
  "num_frames": 97,
275
  "num_inference_steps": 30,
276
  "seed": seed,
 
277
  "guidance_scale": 6.0,
278
  "embedded_guidance_scale": 1.0,
279
  "negative_prompt": "Aerial view, low quality, bad hands",
280
+ "cfg_for": False, #Keep if present in the original
281
  }
282
 
283
  if _predictor is None:
 
294
  )
295
  _predictor.initialize()
296
  logger.info("Predictor initialized")
297
+
298
  with torch.no_grad():
299
+ output = _predictor.infer(**kwargs) #Removed [0]
300
+
301
+ output = (output.numpy() * 255).astype(np.uint8)
302
+ output = output.transpose(0, 2, 3, 4, 1) # Keep this
303
+ output = output[0] # Remove batch dimension, now (T, H, W, C)
304
 
305
  save_dir = f"./result"
306
  os.makedirs(save_dir, exist_ok=True)
 
309
  export_to_video(output, video_out_file, fps=24)
310
  return video_out_file, kwargs
311
 
312
+
313
  def create_gradio_interface():
314
  with gr.Blocks() as demo:
315
  with gr.Row():