File size: 2,772 Bytes
cca49ed
227bc73
33134ab
27b9ec6
a5c228f
27b9ec6
227bc73
df19679
5af50c2
33acb21
df19679
 
8acb7dc
bff2418
 
a4a2927
20c3c08
30904d8
 
 
33134ab
 
16c2e7a
c9be0fc
33134ab
 
 
 
2e5533e
30904d8
7627325
cca49ed
33134ab
 
 
 
 
ecea5f9
 
33134ab
 
 
a5c228f
33134ab
 
ecea5f9
33134ab
ecea5f9
 
33134ab
 
bff2418
33134ab
 
47d7323
33134ab
 
 
 
a4a2927
33134ab
 
 
 
 
 
 
 
 
 
 
 
 
 
eedc8c2
859df7f
2cf3b99
a4a2927
 
5f0f430
 
2cf3b99
33134ab
350e89b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import spaces
import gradio as gr
import argparse
import sys
import time
import os
import random
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image

#predictor = None
#task_type = None

#@spaces.GPU(duration=120)
def init_predictor():
    global predictor
    predictor = SkyReelsVideoSingleGpuInfer(
        task_type= TaskType.I2V,
        model_id="Skywork/SkyReels-V1-Hunyuan-I2V",
        quant_model=False,
        is_offload=False,
        offload_config=OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
        )
    )
    
@spaces.GPU(duration=80)
def generate_video(prompt, seed, image=None):
    print(f"image:{type(image)}")
    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))
    kwargs = {
        "prompt": prompt,
        "height": 512,
        "width": 512,
        "num_frames": 97,
        "num_inference_steps": 30,
        "seed": seed,
        "guidance_scale": 6.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
        "cfg_for": False,
    }
    assert image is not None, "please input image"
    kwargs["image"] = load_image(image=image)
    #global predictor
    output = predictor.inference(kwargs)
    save_dir = f"./result/{task_type}"
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
    print(f"generate video, local path: {video_out_file}")
    export_to_video(output, video_out_file, fps=24)
    return video_out_file, kwargs

def create_gradio_interface():
        with gr.Blocks() as demo:
            with gr.Row():
                image = gr.Image(label="Upload Image", type="filepath")
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed, image],
                outputs=[output_video, output_params],
            )
        return demo
    
#init_predictor()

if __name__ == "__main__":
    #import multiprocessing
    #multiprocessing.freeze_support()
    init_predictor()
    demo = create_gradio_interface()
    demo.launch()