SkyReels_B / skyreelsinfer /skyreels_video_infer.py
1inkusFace's picture
Update skyreelsinfer/skyreels_video_infer.py
fae740e verified
raw
history blame
3.59 kB
import logging
import os # Keep os here
import time
from datetime import timedelta
from typing import Any
from typing import Dict
# DELAY ALL THESE IMPORTS:
# import torch
# from diffusers import HunyuanVideoTransformer3DModel
# from diffusers import DiffusionPipeline
# from PIL import Image
# from transformers import LlamaModel
# from . import TaskType
# from .offload import Offload
# from .offload import OffloadConfig
# from .pipelines import SkyreelsVideoPipeline
logger = logging.getLogger("SkyReelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SkyReelsVideoInfer:
def __init__(
self,
task_type, # No TaskType.
model_id: str,
quant_model: bool = True,
is_offload: bool = True,
offload_config = None, # No OffloadConfig
use_multiprocessing: bool = False,
):
self.task_type = task_type
self.model_id = model_id
self.quant_model = quant_model
self.is_offload = is_offload
self.offload_config = offload_config
self._initialize_pipeline()
def _load_model(
self,
model_id: str,
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
quant_model: bool = True,
device: str = "cuda",
):
# DELAYED IMPORTS:
import torch
from diffusers import HunyuanVideoTransformer3DModel
from diffusers import DiffusionPipeline
from PIL import Image
from transformers import LlamaModel
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from .pipelines import SkyreelsVideoPipeline # Local import
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to(device)
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
).to(device)
if quant_model:
quantize_(text_encoder, float8_weight_only(), device=device)
quantize_(transformer, float8_weight_only(), device=device)
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
).to(device)
pipe.vae.enable_tiling()
return pipe
def _initialize_pipeline(self):
#More Delayed Imports
from .offload import Offload
self.pipe = self._load_model( #No : SkyreelsVideoPipeline
model_id=self.model_id, quant_model=self.quant_model, device="cuda"
)
if self.is_offload and self.offload_config:
Offload.offload(
pipeline=self.pipe,
config=self.offload_config,
)
def inference(self, kwargs):
#DELAYED IMPORTS
from . import TaskType
if self.task_type == TaskType.I2V:
image = kwargs.pop("image")
output = self.pipe(image=image, **kwargs)
else:
output = self.pipe(**kwargs)
return output.frames