1inkusFace commited on
Commit
3262ed7
·
verified ·
1 Parent(s): 40b1047

Update skyreelsinfer/skyreels_video_infer.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +208 -53
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -1,34 +1,26 @@
1
  import logging
2
- import os # Keep os here
 
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
 
7
  import torch
 
 
8
  from diffusers import HunyuanVideoTransformer3DModel
9
- from diffusers import DiffusionPipeline
10
  from PIL import Image
11
- from transformers import LlamaModel
12
  from torchao.quantization import float8_weight_only
13
  from torchao.quantization import quantize_
14
- from .pipelines import SkyreelsVideoPipeline # Local import
 
 
15
  from .offload import Offload
16
  from .offload import OffloadConfig
17
- from . import TaskType
18
-
19
- # DELAY ALL THESE IMPORTS:
20
- # import torch
21
- # from diffusers import HunyuanVideoTransformer3DModel
22
- # from diffusers import DiffusionPipeline
23
- # from PIL import Image
24
- # from transformers import LlamaModel
25
-
26
- # from . import TaskType
27
- # from .offload import Offload
28
- # from .offload import OffloadConfig
29
- # from .pipelines import SkyreelsVideoPipeline
30
 
31
- logger = logging.getLogger("SkyReelsVideoInfer")
32
  logger.setLevel(logging.DEBUG)
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.DEBUG)
@@ -38,66 +30,229 @@ formatter = logging.Formatter(
38
  console_handler.setFormatter(formatter)
39
  logger.addHandler(console_handler)
40
 
41
- class SkyReelsVideoInfer:
42
- def __init__(
43
- self,
44
- task_type, # No TaskType.
45
- model_id: str,
46
- quant_model: bool = True,
47
- is_offload: bool = True,
48
- offload_config: OffloadConfig = OffloadConfig(),
49
- use_multiprocessing: bool = False,
50
- ):
51
- self.task_type = task_type
52
- self.model_id = model_id
53
- self.quant_model = quant_model
54
- self.is_offload = is_offload
55
- self.offload_config = offload_config
56
- self._initialize_pipeline()
57
 
 
58
  def _load_model(
59
  self,
60
  model_id: str,
61
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
62
  quant_model: bool = True,
63
- device: str = "cuda",
64
- ):
65
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
66
  text_encoder = LlamaModel.from_pretrained(
67
  base_model_id,
68
  subfolder="text_encoder",
69
  torch_dtype=torch.bfloat16,
70
- ).to(device)
71
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
72
  model_id,
 
73
  torch_dtype=torch.bfloat16,
74
- ).to(device)
 
75
  if quant_model:
76
- quantize_(text_encoder, float8_weight_only(), device=device)
77
- quantize_(transformer, float8_weight_only(), device=device)
 
 
 
 
78
  pipe = SkyreelsVideoPipeline.from_pretrained(
79
  base_model_id,
80
  transformer=transformer,
81
  text_encoder=text_encoder,
82
  torch_dtype=torch.bfloat16,
83
- ).to(device)
84
  pipe.vae.enable_tiling()
 
85
  return pipe
86
 
87
- def _initialize_pipeline(self):
88
- self.pipe = self._load_model( #No : SkyreelsVideoPipeline
89
- model_id=self.model_id, quant_model=self.quant_model, device="cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
- if self.is_offload and self.offload_config:
 
 
 
92
  Offload.offload(
93
  pipeline=self.pipe,
94
- config=self.offload_config,
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
 
96
 
97
- def inference(self, kwargs):
 
 
 
 
 
 
 
 
 
 
 
98
  if self.task_type == TaskType.I2V:
99
- image = kwargs.pop("image")
100
- output = self.pipe(image=image, **kwargs)
101
- else:
102
- output = self.pipe(**kwargs)
103
- return output.frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import os
3
+ import threading
4
  import time
5
  from datetime import timedelta
6
  from typing import Any
7
  from typing import Dict
8
+
9
  import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
  from diffusers import HunyuanVideoTransformer3DModel
 
13
  from PIL import Image
 
14
  from torchao.quantization import float8_weight_only
15
  from torchao.quantization import quantize_
16
+ from transformers import LlamaModel
17
+
18
+ from . import TaskType
19
  from .offload import Offload
20
  from .offload import OffloadConfig
21
+ from .pipelines import SkyreelsVideoPipeline
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ logger = logging.getLogger("SkyreelsVideoInfer")
24
  logger.setLevel(logging.DEBUG)
25
  console_handler = logging.StreamHandler()
26
  console_handler.setLevel(logging.DEBUG)
 
30
  console_handler.setFormatter(formatter)
31
  logger.addHandler(console_handler)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ class SkyReelsVideoSingleGpuInfer:
35
  def _load_model(
36
  self,
37
  model_id: str,
38
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
39
  quant_model: bool = True,
40
+ gpu_device: str = "cuda:0",
41
+ ) -> SkyreelsVideoPipeline:
42
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
43
  text_encoder = LlamaModel.from_pretrained(
44
  base_model_id,
45
  subfolder="text_encoder",
46
  torch_dtype=torch.bfloat16,
47
+ ).to("cpu")
48
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
49
  model_id,
50
+ # subfolder="transformer",
51
  torch_dtype=torch.bfloat16,
52
+ device="cpu",
53
+ ).to("cpu")
54
  if quant_model:
55
+ quantize_(text_encoder, float8_weight_only(), device=gpu_device)
56
+ text_encoder.to("cpu")
57
+ torch.cuda.empty_cache()
58
+ quantize_(transformer, float8_weight_only(), device=gpu_device)
59
+ transformer.to("cpu")
60
+ torch.cuda.empty_cache()
61
  pipe = SkyreelsVideoPipeline.from_pretrained(
62
  base_model_id,
63
  transformer=transformer,
64
  text_encoder=text_encoder,
65
  torch_dtype=torch.bfloat16,
66
+ ).to("cpu")
67
  pipe.vae.enable_tiling()
68
+ torch.cuda.empty_cache()
69
  return pipe
70
 
71
+ def __init__(
72
+ self,
73
+ task_type: TaskType,
74
+ model_id: str,
75
+ quant_model: bool = True,
76
+ local_rank: int = 0,
77
+ world_size: int = 1,
78
+ is_offload: bool = True,
79
+ offload_config: OffloadConfig = OffloadConfig(),
80
+ enable_cfg_parallel: bool = True,
81
+ ):
82
+ self.task_type = task_type
83
+ self.gpu_rank = local_rank
84
+ dist.init_process_group(
85
+ backend="nccl",
86
+ init_method="tcp://127.0.0.1:23456",
87
+ timeout=timedelta(seconds=600),
88
+ world_size=world_size,
89
+ rank=local_rank,
90
+ )
91
+ os.environ["LOCAL_RANK"] = str(local_rank)
92
+ logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
93
+ torch.cuda.set_device(dist.get_rank())
94
+ torch.backends.cuda.enable_cudnn_sdp(False)
95
+ gpu_device = f"cuda:{dist.get_rank()}"
96
+
97
+ self.pipe: SkyreelsVideoPipeline = self._load_model(
98
+ model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
99
+ )
100
+
101
+ from para_attn.context_parallel import init_context_parallel_mesh
102
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
103
+ from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
104
+
105
+ max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
106
+ max_ulysses_dim_size = int(world_size / max_batch_dim_size)
107
+ logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
108
+
109
+ mesh = init_context_parallel_mesh(
110
+ self.pipe.device.type,
111
+ max_ring_dim_size=1,
112
+ max_batch_dim_size=max_batch_dim_size,
113
  )
114
+ parallelize_pipe(self.pipe, mesh=mesh)
115
+ parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
116
+
117
+ if is_offload:
118
  Offload.offload(
119
  pipeline=self.pipe,
120
+ config=offload_config,
121
+ )
122
+ else:
123
+ self.pipe.to(gpu_device)
124
+
125
+ if offload_config.compiler_transformer:
126
+ torch._dynamo.config.suppress_errors = True
127
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
128
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
129
+ self.pipe.transformer = torch.compile(
130
+ self.pipe.transformer,
131
+ mode="max-autotune-no-cudagraphs",
132
+ dynamic=True,
133
  )
134
+ self.warm_up()
135
 
136
+ def warm_up(self):
137
+ init_kwargs = {
138
+ "prompt": "A woman is dancing in a room",
139
+ "height": 544,
140
+ "width": 960,
141
+ "guidance_scale": 6,
142
+ "num_inference_steps": 1,
143
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
144
+ "num_frames": 97,
145
+ "generator": torch.Generator("cuda").manual_seed(42),
146
+ "embedded_guidance_scale": 1.0,
147
+ }
148
  if self.task_type == TaskType.I2V:
149
+ init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
150
+ self.pipe(**init_kwargs)
151
+
152
+ def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
153
+ response_queue.put(f"rank:{self.gpu_rank} ready")
154
+ logger.info(f"rank:{self.gpu_rank} finish init pipe")
155
+ while True:
156
+ logger.info(f"rank:{self.gpu_rank} waiting for request")
157
+ kwargs = request_queue.get()
158
+ logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
159
+ if "seed" in kwargs:
160
+ kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
161
+ del kwargs["seed"]
162
+ start_time = time.time()
163
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
164
+ out = self.pipe(**kwargs).frames[0]
165
+ logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
166
+ if dist.get_rank() == 0:
167
+ response_queue.put(out)
168
+
169
+
170
+ def single_gpu_run(
171
+ rank,
172
+ task_type: TaskType,
173
+ model_id: str,
174
+ request_queue: mp.Queue,
175
+ response_queue: mp.Queue,
176
+ quant_model: bool = True,
177
+ world_size: int = 1,
178
+ is_offload: bool = True,
179
+ offload_config: OffloadConfig = OffloadConfig(),
180
+ enable_cfg_parallel: bool = True,
181
+ ):
182
+ pipe = SkyReelsVideoSingleGpuInfer(
183
+ task_type=task_type,
184
+ model_id=model_id,
185
+ quant_model=quant_model,
186
+ local_rank=rank,
187
+ world_size=world_size,
188
+ is_offload=is_offload,
189
+ offload_config=offload_config,
190
+ enable_cfg_parallel=enable_cfg_parallel,
191
+ )
192
+ pipe.damon_inference(request_queue, response_queue)
193
+
194
+
195
+ class SkyReelsVideoInfer:
196
+ def __init__(
197
+ self,
198
+ task_type: TaskType,
199
+ model_id: str,
200
+ quant_model: bool = True,
201
+ world_size: int = 1,
202
+ is_offload: bool = True,
203
+ offload_config: OffloadConfig = OffloadConfig(),
204
+ enable_cfg_parallel: bool = True,
205
+ ):
206
+ self.world_size = world_size
207
+ smp = mp.get_context("spawn")
208
+ self.REQ_QUEUES: mp.Queue = smp.Queue()
209
+ self.RESP_QUEUE: mp.Queue = smp.Queue()
210
+ assert self.world_size > 0, "gpu_num must be greater than 0"
211
+ spawn_thread = threading.Thread(
212
+ target=self.lauch_single_gpu_infer,
213
+ args=(task_type, model_id, quant_model, world_size, is_offload, offload_config, enable_cfg_parallel),
214
+ daemon=True,
215
+ )
216
+ spawn_thread.start()
217
+ logger.info(f"Started multi-GPU thread with GPU_NUM: {world_size}")
218
+ print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
219
+ # Block and wait for the prediction process to start
220
+ for _ in range(world_size):
221
+ msg = self.RESP_QUEUE.get()
222
+ logger.info(f"launch_multi_gpu get init msg: {msg}")
223
+ print(f"launch_multi_gpu get init msg: {msg}")
224
+
225
+ def lauch_single_gpu_infer(
226
+ self,
227
+ task_type: TaskType,
228
+ model_id: str,
229
+ quant_model: bool = True,
230
+ world_size: int = 1,
231
+ is_offload: bool = True,
232
+ offload_config: OffloadConfig = OffloadConfig(),
233
+ enable_cfg_parallel: bool = True,
234
+ ):
235
+ mp.spawn(
236
+ single_gpu_run,
237
+ nprocs=world_size,
238
+ join=True,
239
+ daemon=True,
240
+ args=(
241
+ task_type,
242
+ model_id,
243
+ self.REQ_QUEUES,
244
+ self.RESP_QUEUE,
245
+ quant_model,
246
+ world_size,
247
+ is_offload,
248
+ offload_config,
249
+ enable_cfg_parallel,
250
+ ),
251
+ )
252
+ logger.info(f"finish lanch multi gpu infer, world_size:{world_size}")
253
+
254
+ def inference(self, kwargs: Dict[str, Any]):
255
+ # put request to singlegpuinfer
256
+ for _ in range(self.world_size):
257
+ self.REQ_QUEUES.put(kwargs)
258
+ return self.RESP_QUEUE.get()