import logging import os # Keep os here import time from datetime import timedelta from typing import Any from typing import Dict # DELAY ALL THESE IMPORTS: # import torch # from diffusers import HunyuanVideoTransformer3DModel # from diffusers import DiffusionPipeline # from PIL import Image # from transformers import LlamaModel # from . import TaskType # from .offload import Offload # from .offload import OffloadConfig # from .pipelines import SkyreelsVideoPipeline logger = logging.getLogger("SkyReelsVideoInfer") logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s" ) console_handler.setFormatter(formatter) logger.addHandler(console_handler) class SkyReelsVideoInfer: def __init__( self, task_type, # No TaskType. model_id: str, quant_model: bool = True, is_offload: bool = True, offload_config = None, # No OffloadConfig use_multiprocessing: bool = False, ): self.task_type = task_type self.model_id = model_id self.quant_model = quant_model self.is_offload = is_offload self.offload_config = offload_config self._initialize_pipeline() def _load_model( self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True, device: str = "cuda", ): # DELAYED IMPORTS: import torch from diffusers import HunyuanVideoTransformer3DModel from diffusers import DiffusionPipeline from PIL import Image from transformers import LlamaModel from torchao.quantization import float8_weight_only from torchao.quantization import quantize_ from .pipelines import SkyreelsVideoPipeline # Local import logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}") text_encoder = LlamaModel.from_pretrained( base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, ).to(device) transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, torch_dtype=torch.bfloat16, ).to(device) if quant_model: quantize_(text_encoder, float8_weight_only(), device=device) quantize_(transformer, float8_weight_only(), device=device) pipe = SkyreelsVideoPipeline.from_pretrained( base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16, ).to(device) pipe.vae.enable_tiling() return pipe def _initialize_pipeline(self): #More Delayed Imports from .offload import Offload self.pipe = self._load_model( #No : SkyreelsVideoPipeline model_id=self.model_id, quant_model=self.quant_model, device="cuda" ) if self.is_offload and self.offload_config: Offload.offload( pipeline=self.pipe, config=self.offload_config, ) def inference(self, kwargs): #DELAYED IMPORTS from . import TaskType if self.task_type == TaskType.I2V: image = kwargs.pop("image") output = self.pipe(image=image, **kwargs) else: output = self.pipe(**kwargs) return output.frames