Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import open_clip | |
| import torch | |
| import taming.models.vqgan | |
| import ml_collections | |
| import einops | |
| import random | |
| import pathlib | |
| import subprocess | |
| import shlex | |
| import wget | |
| # Model | |
| from libs.muse import MUSE | |
| import utils | |
| import numpy as np | |
| from PIL import Image | |
| print("cuda available:",torch.cuda.is_available()) | |
| print("cuda device count:",torch.cuda.device_count()) | |
| print("cuda device name:",torch.cuda.get_device_name(0)) | |
| print(os.system("nvidia-smi")) | |
| print(os.system("nvcc --version")) | |
| empty_context = np.load("assets/contexts/empty_context.npy") | |
| print("downloading cc3m-285000.ckpt") | |
| os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True) | |
| os.system("ls") | |
| wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth") | |
| wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth") | |
| wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth") | |
| wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth") | |
| wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth") | |
| wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt") | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def d(**kwargs): | |
| """Helper of creating a config dict.""" | |
| return ml_collections.ConfigDict(initial_dictionary=kwargs) | |
| def get_config(): | |
| config = ml_collections.ConfigDict() | |
| config.seed = 1234 | |
| config.z_shape = (8, 16, 16) | |
| config.autoencoder = d( | |
| config_file='vq-f16-jax.yaml', | |
| ) | |
| config.resume_root="assets/ckpts/cc3m-285000.ckpt" | |
| config.adapter_path=None | |
| config.optimizer = d( | |
| name='adamw', | |
| lr=0.0002, | |
| weight_decay=0.03, | |
| betas=(0.99, 0.99), | |
| ) | |
| config.lr_scheduler = d( | |
| name='customized', | |
| warmup_steps=5000 | |
| ) | |
| config.nnet = d( | |
| name='uvit_t2i_vq', | |
| img_size=16, | |
| codebook_size=1024, | |
| in_chans=4, | |
| embed_dim=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| clip_dim=1280, | |
| num_clip_token=77, | |
| use_checkpoint=True, | |
| skip=True, | |
| d_prj=32, | |
| is_shared=False | |
| ) | |
| config.muse = d( | |
| ignore_ind=-1, | |
| smoothing=0.1, | |
| gen_temp=4.5 | |
| ) | |
| config.sample = d( | |
| sample_steps=36, | |
| n_samples=50, | |
| mini_batch_size=8, | |
| cfg=True, | |
| linear_inc_scale=True, | |
| scale=10., | |
| path='', | |
| lambdaA=2.0, # Stage I: 2.0; Stage II: TODO | |
| lambdaB=5.0, # Stage I: 5.0; Stage II: TODO | |
| ) | |
| return config | |
| def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None): | |
| _cond = nnet_ema(x, context=context) | |
| _cond_w_adapter = nnet_ema(x,context=context,use_adapter=True) | |
| _empty_context = torch.tensor(empty_context, device=device) | |
| _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) | |
| _uncond = nnet_ema(x, context=_empty_context) | |
| res = _cond + scale * (_cond - _uncond) | |
| if lambdaA is not None: | |
| res = _cond_w_adapter + lambdaA*(_cond_w_adapter - _cond) + lambdaB*(_cond - _uncond) | |
| return res | |
| def unprocess(x): | |
| x.clamp_(0., 1.) | |
| return x | |
| config = get_config() | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # Load open_clip and vq model | |
| prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k') | |
| prompt_model = prompt_model.to(device) | |
| prompt_model.eval() | |
| tokenizer = open_clip.get_tokenizer('ViT-bigG-14') | |
| vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml') | |
| vq_model.eval() | |
| vq_model.requires_grad_(False) | |
| vq_model.to(device) | |
| ## config | |
| muse = MUSE(codebook_size=vq_model.n_embed, device=device, **config.muse) | |
| train_state = utils.initialize_train_state(config, device) | |
| train_state.resume(ckpt_root=config.resume_root) | |
| nnet_ema = train_state.nnet_ema | |
| nnet_ema.eval() | |
| nnet_ema.requires_grad_(False) | |
| nnet_ema.to(device) | |
| style_ref = { | |
| "None":None, | |
| "0102":"style_adapter/0102.pth", | |
| "0103":"style_adapter/0103.pth", | |
| "0106":"style_adapter/0106.pth", | |
| "0108":"style_adapter/0108.pth", | |
| "0301":"style_adapter/0301.pth", | |
| "0305":"style_adapter/0305.pth", | |
| } | |
| style_postfix ={ | |
| "None":"", | |
| "0102":" in watercolor painting style", | |
| "0103":" in watercolor painting style", | |
| "0106":" in line drawing style", | |
| "0108":" in oil painting style", | |
| "0301":" in 3d rendering style", | |
| "0305":" in kid crayon drawing style", | |
| } | |
| def decode(_batch): | |
| return vq_model.decode_code(_batch) | |
| def process(prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image=None): | |
| config.sample.lambdaA = lambdaA | |
| config.sample.lambdaB = lambdaB | |
| config.sample.sample_steps = sample_steps | |
| print(style) | |
| adapter_path = style_ref[style] | |
| adapter_postfix = style_postfix[style] | |
| print(f"load adapter path: {adapter_path}") | |
| if adapter_path is not None: | |
| nnet_ema.adapter.load_state_dict(torch.load(adapter_path)) | |
| else: | |
| config.sample.lambdaA=None | |
| config.sample.lambdaB=None | |
| print("load adapter Done!") | |
| # Encode prompt | |
| prompt = prompt+adapter_postfix | |
| text_tokens = tokenizer(prompt).to(device) | |
| text_embedding = prompt_model.encode_text(text_tokens) | |
| text_embedding = text_embedding.repeat(num_samples, 1, 1) # B 77 1280 | |
| print(text_embedding.shape) | |
| print(f"lambdaA: {lambdaA}, lambdaB: {lambdaB}, sample_steps: {sample_steps}") | |
| if seed==-1: | |
| seed = random.randint(0,65535) | |
| config.seed = seed | |
| print(f"seed: {seed}") | |
| set_seed(config.seed) | |
| res = muse.generate(config,num_samples,cfg_nnet,decode,is_eval=True,context=text_embedding) | |
| print(res.shape) | |
| res = (res*255+0.5).clamp_(0,255).permute(0,2,3,1).to('cpu',torch.uint8).numpy() | |
| im = [res[i] for i in range(num_samples)] | |
| return im | |
| block = gr.Blocks() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("## StyleDrop based on Muse (Inference Only) ") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button(label="Run") | |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234) | |
| style = gr.Radio(choices=["0102","0103","0106","0108","0305","None"],type="value",value="None",label="Style") | |
| with gr.Accordion("Advanced options",open=False): | |
| lambdaA = gr.Slider(label="lambdaA", minimum=0.0, maximum=5.0, value=2.0, step=0.01) | |
| lambdaB = gr.Slider(label="lambdaB", minimum=0.0, maximum=10.0, value=5.0, step=0.01) | |
| sample_steps = gr.Slider(label="Sample steps", minimum=1, maximum=50, value=36, step=1) | |
| image=gr.Image(value=None) | |
| with gr.Column(): | |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(columns=2, height='auto') | |
| with gr.Row(): | |
| examples = [ | |
| [ | |
| "A banana on the table", | |
| 1,2.0,5.0,"0103",1234,36, | |
| "data/image_01_03.jpg", | |
| ], | |
| [ | |
| "A cow", | |
| 1,2.0,5.0,"0102",1234,36, | |
| "data/image_01_02.jpg", | |
| ], | |
| [ | |
| "A portrait of tabby cat", | |
| 1,2.0,5.0,"0106",1234,36, | |
| "data/image_01_06.jpg", | |
| ], | |
| [ | |
| "A church in the field", | |
| 1,2.0,5.0,"0108",1234,36, | |
| "data/image_01_08.jpg", | |
| ], | |
| [ | |
| "A Christmas tree", | |
| 1,2.0,5.0,"0305",1234,36, | |
| "data/image_03_05.jpg", | |
| ] | |
| ] | |
| gr.Examples(examples=examples, | |
| fn=process, | |
| inputs=[ | |
| prompt, | |
| num_samples,lambdaA,lambdaB,style,seed,sample_steps,image, | |
| ], | |
| outputs=result_gallery, | |
| cache_examples=os.getenv('SYSTEM') == 'spaces' | |
| ) | |
| ips = [prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image] | |
| run_button.click( | |
| fn=process, | |
| inputs=ips, | |
| outputs=[result_gallery] | |
| ) | |
| block.queue().launch(share=False) | |