Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import numpy as np | |
| from numpy import linalg | |
| import PIL.Image as Image | |
| import torch | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| from argparse import Namespace | |
| import easydict | |
| import legacy | |
| import dnnlib | |
| from opensimplex import OpenSimplex | |
| from configs import data_configs | |
| from models.psp import pSp | |
| def build_stylegan2( | |
| increment = 0.01, | |
| network_pkl = 'pretrained/ohayou_face2.pkl', | |
| process = 'image', #['image', 'interpolation','truncation','interpolation-truncation'] | |
| random_seed = 0, | |
| diameter = 100.0, | |
| scale_type = 'pad', #['pad', 'padside', 'symm','symmside'] | |
| size = [512, 512], | |
| seeds = [0], | |
| space = 'z', #['z', 'w'] | |
| fps = 24, | |
| frames = 240, | |
| noise_mode = 'none', #['const', 'random', 'none'] | |
| outdir = 'path', | |
| projected_w = 'path', | |
| easing = 'linear', | |
| device = 'cpu' | |
| ): | |
| G_kwargs = dnnlib.EasyDict() | |
| G_kwargs.size = size | |
| G_kwargs.scale_type = scale_type | |
| device = torch.device(device) | |
| with dnnlib.util.open_url(network_pkl) as f: | |
| # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
| G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore | |
| return G.synthesis | |
| def build_psp(): | |
| test_opts = easydict.EasyDict({ | |
| # arguments for inference script | |
| 'checkpoint_path' : 'pretrained/ohayou_face.pt', | |
| 'couple_outputs' : False, | |
| 'resize_outputs' : False, | |
| 'test_batch_size' : 1, | |
| 'test_workers' : 1, | |
| # arguments for style-mixing script | |
| 'n_images' : None, | |
| 'n_outputs_to_generate' : 5, | |
| 'mix_alpha' : None, | |
| 'latent_mask' : None, | |
| # arguments for super-resolution | |
| 'resize_factors' : None, | |
| }) | |
| # update test options with options used during training | |
| ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') | |
| opts = ckpt['opts'] | |
| opts.update(vars(test_opts)) | |
| if 'learn_in_w' not in opts: | |
| opts['learn_in_w'] = False | |
| opts = Namespace(**opts) | |
| opts.device = 'cpu' | |
| net = pSp(opts) | |
| net.eval() | |
| return net | |
| def img_preprocess(img, transform): | |
| if (img.mode == 'RGBA') or (img.mode == 'P'): | |
| img.load() | |
| background = Image.new("RGB", img.size, (255, 255, 255)) | |
| background.paste(img, mask=img.split()[3]) # 3 is the alpha channel | |
| img = background | |
| assert img.mode == 'RGB' | |
| img = transform(img) | |
| img = img.unsqueeze(dim=0) | |
| return img |