Spaces:
Paused
Paused
| import logging | |
| import os # Keep os here | |
| import time | |
| from datetime import timedelta | |
| from typing import Any | |
| from typing import Dict | |
| # DELAY ALL THESE IMPORTS: | |
| # import torch | |
| # from diffusers import HunyuanVideoTransformer3DModel | |
| # from diffusers import DiffusionPipeline | |
| # from PIL import Image | |
| # from transformers import LlamaModel | |
| # from . import TaskType | |
| # from .offload import Offload | |
| # from .offload import OffloadConfig | |
| # from .pipelines import SkyreelsVideoPipeline | |
| logger = logging.getLogger("SkyReelsVideoInfer") | |
| logger.setLevel(logging.DEBUG) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.DEBUG) | |
| formatter = logging.Formatter( | |
| f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s" | |
| ) | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| class SkyReelsVideoInfer: | |
| def __init__( | |
| self, | |
| task_type, # No TaskType. | |
| model_id: str, | |
| quant_model: bool = True, | |
| is_offload: bool = True, | |
| offload_config = None, # No OffloadConfig | |
| use_multiprocessing: bool = False, | |
| ): | |
| self.task_type = task_type | |
| self.model_id = model_id | |
| self.quant_model = quant_model | |
| self.is_offload = is_offload | |
| self.offload_config = offload_config | |
| self._initialize_pipeline() | |
| def _load_model( | |
| self, | |
| model_id: str, | |
| base_model_id: str = "hunyuanvideo-community/HunyuanVideo", | |
| quant_model: bool = True, | |
| device: str = "cuda", | |
| ): | |
| # DELAYED IMPORTS: | |
| import torch | |
| from diffusers import HunyuanVideoTransformer3DModel | |
| from diffusers import DiffusionPipeline | |
| from PIL import Image | |
| from transformers import LlamaModel | |
| from torchao.quantization import float8_weight_only | |
| from torchao.quantization import quantize_ | |
| from .pipelines import SkyreelsVideoPipeline # Local import | |
| logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}") | |
| text_encoder = LlamaModel.from_pretrained( | |
| base_model_id, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| if quant_model: | |
| quantize_(text_encoder, float8_weight_only(), device=device) | |
| quantize_(transformer, float8_weight_only(), device=device) | |
| pipe = SkyreelsVideoPipeline.from_pretrained( | |
| base_model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| pipe.vae.enable_tiling() | |
| return pipe | |
| def _initialize_pipeline(self): | |
| #More Delayed Imports | |
| from .offload import Offload | |
| self.pipe = self._load_model( #No : SkyreelsVideoPipeline | |
| model_id=self.model_id, quant_model=self.quant_model, device="cuda" | |
| ) | |
| if self.is_offload and self.offload_config: | |
| Offload.offload( | |
| pipeline=self.pipe, | |
| config=self.offload_config, | |
| ) | |
| def inference(self, kwargs): | |
| #DELAYED IMPORTS | |
| from . import TaskType | |
| if self.task_type == TaskType.I2V: | |
| image = kwargs.pop("image") | |
| output = self.pipe(image=image, **kwargs) | |
| else: | |
| output = self.pipe(**kwargs) | |
| return output.frames |