| | import glob |
| | import json |
| | import os |
| | import cv2 |
| | import pickle |
| | import random |
| | import re |
| | import subprocess |
| | from functools import partial |
| |
|
| | import librosa.core |
| | import numpy as np |
| | import torch |
| | import torch.distributions |
| | import torch.distributed as dist |
| | import torch.optim |
| | import torch.utils.data |
| |
|
| | from utils.commons.indexed_datasets import IndexedDataset |
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | import torch.nn.functional as F |
| | import pandas as pd |
| | from tqdm import tqdm |
| | import csv |
| | from utils.commons.hparams import hparams, set_hparams |
| | from utils.commons.meters import Timer |
| | from data_util.face3d_helper import Face3DHelper |
| | from utils.audio import librosa_wav2mfcc |
| | from utils.commons.dataset_utils import collate_xd |
| | from utils.commons.tensor_utils import convert_to_tensor |
| | from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image |
| | from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic |
| | from utils.commons.image_utils import load_image_as_uint8_tensor |
| | from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler |
| |
|
| |
|
| | def sample_idx(img_dir, num_frames): |
| | cnt = 0 |
| | while True: |
| | cnt += 1 |
| | if cnt > 1000: |
| | print(f"recycle for more than 1000 times, check this {img_dir}") |
| | idx = random.randint(0, num_frames-1) |
| | ret1 = find_img_name(img_dir, idx) |
| | if ret1 == 'None': |
| | continue |
| | ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx) |
| | if ret2 == 'None': |
| | continue |
| | ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx) |
| | if ret3 == 'None': |
| | continue |
| | ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx) |
| | if ret4 == 'None': |
| | continue |
| | return idx |
| | |
| |
|
| | def find_img_name(img_dir, idx): |
| | gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg") |
| | if not os.path.exists(gt_img_fname): |
| | gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg") |
| | if not os.path.exists(gt_img_fname): |
| | gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg") |
| | if not os.path.exists(gt_img_fname): |
| | gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png") |
| | if not os.path.exists(gt_img_fname): |
| | gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png") |
| | if not os.path.exists(gt_img_fname): |
| | gt_img_fname = os.path.join(img_dir, str(idx) + ".png") |
| | if os.path.exists(gt_img_fname): |
| | return gt_img_fname |
| | else: |
| | return 'None' |
| | |
| | |
| | def get_win_from_arr(arr, index, win_size): |
| | left = index - win_size//2 |
| | right = index + (win_size - win_size//2) |
| | pad_left = 0 |
| | pad_right = 0 |
| | if left < 0: |
| | pad_left = -left |
| | left = 0 |
| | if right > arr.shape[0]: |
| | pad_right = right - arr.shape[0] |
| | right = arr.shape[0] |
| | win = arr[left:right] |
| | if pad_left > 0: |
| | if isinstance(arr, np.ndarray): |
| | win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0) |
| | else: |
| | win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0) |
| | if pad_right > 0: |
| | if isinstance(arr, np.ndarray): |
| | win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) |
| | else: |
| | win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) |
| | return win |
| |
|
| |
|
| | class Img2Plane_Dataset(Dataset): |
| | def __init__(self, prefix='train', data_dir=None): |
| | self.db_key = prefix |
| | self.ds = None |
| | self.sizes = None |
| | self.x_maxframes = 200 |
| | self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
| | self.x_multiply = 8 |
| | self.hparams = hparams |
| | self.pose_sampler = UnifiedCameraPoseSampler() |
| | self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
| |
|
| | def __len__(self): |
| | ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
| | return len(ds) |
| |
|
| | def _get_item(self, index): |
| | """ |
| | This func is necessary to open files in multi-threads! |
| | """ |
| | if self.ds is None: |
| | self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
| | return self.ds[index] |
| | |
| | def __getitem__(self, idx): |
| | raw_item = self._get_item(idx) |
| | if raw_item is None: |
| | print("loading from binary data failed!") |
| | return None |
| | item = { |
| | 'idx': idx, |
| | 'item_name': raw_item['img_dir'], |
| | } |
| | img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
| | num_frames = len(raw_item['exp']) |
| |
|
| | hparams = self.hparams |
| | camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
| | c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
| | raw_item['c2w'] = c2w |
| | raw_item['intrinsics'] = intrinsics |
| |
|
| |
|
| | max_pitch = 10 / 180 * 3.1415926 |
| | min_pitch = -max_pitch |
| | pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
| | max_yaw = 16 / 180 * 3.1415926 |
| | min_yaw = - max_yaw |
| | yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
| | distance = random.random() * (3.2-2.7) + 2.7 |
| | ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
| |
|
| | if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
| | max_pitch = 26 / 180 * 3.1415926 |
| | min_pitch = -max_pitch |
| | pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
| | max_yaw = 38 / 180 * 3.1415926 |
| | min_yaw = - max_yaw |
| | yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
| | distance = random.random() * (4.0-2.7) + 2.7 |
| | real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
| | else: |
| | real_idx = sample_idx(img_dir, num_frames) |
| | real_c2w = raw_item['c2w'][real_idx] |
| | real_intrinsics = raw_item['intrinsics'][real_idx] |
| | real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
| | real_camera = convert_to_tensor(real_camera) |
| |
|
| | if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : |
| | max_pitch = 26 / 180 * 3.1415926 |
| | min_pitch = -max_pitch |
| | pitch = random.random() * (max_pitch - min_pitch) + min_pitch |
| | max_yaw = 38 / 180 * 3.1415926 |
| | min_yaw = - max_yaw |
| | yaw = random.random() * (max_yaw - min_yaw) + min_yaw |
| | distance = random.random() * (4.0-2.7) + 2.7 |
| | fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] |
| | else: |
| | fake_idx = sample_idx(img_dir, num_frames) |
| | fake_c2w = raw_item['c2w'][fake_idx] |
| | fake_intrinsics = raw_item['intrinsics'][fake_idx] |
| | fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0) |
| | fake_camera = convert_to_tensor(fake_camera) |
| |
|
| | item.update({ |
| | 'ws_camera': ws_camera, |
| | 'real_camera': real_camera, |
| | 'fake_camera': fake_camera, |
| | |
| | }) |
| |
|
| | return item |
| | |
| | def get_dataloader(self, batch_size=1, num_workers=0): |
| | loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
| | return loader |
| |
|
| | def collater(self, samples): |
| | hparams = self.hparams |
| | if len(samples) == 0: |
| | return {} |
| | batch = {} |
| |
|
| | batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) |
| | batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
| | batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
| | return batch |
| |
|
| |
|
| |
|
| | class Motion2Video_Dataset(Dataset): |
| | def __init__(self, prefix='train', data_dir=None): |
| | self.db_key = prefix |
| | self.ds = None |
| | self.sizes = None |
| | self.x_maxframes = 200 |
| | self.face3d_helper = Face3DHelper('deep_3drecon/BFM') |
| | self.x_multiply = 8 |
| | self.hparams = hparams |
| | self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
| |
|
| | def __len__(self): |
| | ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
| | return len(ds) |
| |
|
| | def _get_item(self, index): |
| | """ |
| | This func is necessary to open files in multi-threads! |
| | """ |
| | if self.ds is None: |
| | self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
| | return self.ds[index] |
| | |
| | def __getitem__(self, idx): |
| | raw_item = self._get_item(idx) |
| | if raw_item is None: |
| | print("loading from binary data failed!") |
| | return None |
| | item = { |
| | 'idx': idx, |
| | 'item_name': raw_item['img_dir'], |
| | } |
| | |
| | camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) |
| | c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] |
| | raw_item['c2w'] = c2w |
| | raw_item['intrinsics'] = intrinsics |
| |
|
| | img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') |
| | num_frames = len(raw_item['exp']) |
| |
|
| | |
| | real_idx = sample_idx(img_dir, num_frames) |
| | real_c2w = raw_item['c2w'][real_idx] |
| | |
| | real_intrinsics = raw_item['intrinsics'][real_idx] |
| | real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) |
| | real_camera = convert_to_tensor(real_camera) |
| | item['real_camera'] = real_camera |
| |
|
| | gt_img_fname = find_img_name(img_dir, real_idx) |
| | gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
| | item['real_gt_img'] = gt_img.float() / 127.5 - 1 |
| | |
| | for key in ['head', 'com', 'inpaint_torso']: |
| | key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
| | key_img_fname = find_img_name(key_img_dir, real_idx) |
| | key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
| | item[f'real_{key}_img'] = key_img.float() / 127.5 - 1 |
| | bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg' |
| | bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] |
| | item[f'bg_img'] = bg_img.float() / 127.5 - 1 |
| |
|
| | seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
| | seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
| | segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
| | item[f'real_segmap'] = segmap |
| | item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
| | item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0) |
| | item.update({ |
| | |
| | 'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
| | |
| | 'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]), |
| | 'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]), |
| | 'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]), |
| | }) |
| |
|
| | pertube_idx_candidates = [idx for idx in [real_idx-1, real_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
| | |
| | pertube_idx = random.choice(pertube_idx_candidates) |
| | item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
| | item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1'] |
| |
|
| | |
| | fake_idx = sample_idx(img_dir, num_frames) |
| | min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
| | while abs(fake_idx - real_idx) < min_offset: |
| | fake_idx = sample_idx(img_dir, num_frames) |
| | min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) |
| | fake_c2w = raw_item['c2w'][fake_idx] |
| |
|
| | fake_intrinsics = raw_item['intrinsics'][fake_idx] |
| | fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0) |
| | fake_camera = convert_to_tensor(fake_camera) |
| | item['fake_camera'] = fake_camera |
| | |
| | gt_img_fname = find_img_name(img_dir, fake_idx) |
| | gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] |
| | item['fake_gt_img'] = gt_img.float() / 127.5 - 1 |
| | seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") |
| | seg_img = cv2.imread(seg_img_name)[:,:, ::-1] |
| | segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) |
| | item[f'fake_segmap'] = segmap |
| | item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0) |
| | item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0) |
| | |
| | for key in ['head', 'com', 'inpaint_torso']: |
| | key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") |
| | key_img_fname = find_img_name(key_img_dir, fake_idx) |
| | key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] |
| | item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1 |
| |
|
| | item.update({ |
| | |
| | f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]), |
| | |
| | f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]), |
| | f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]), |
| | f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]), |
| | }) |
| |
|
| | |
| | pertube_idx_candidates = [idx for idx in [fake_idx-1, fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] |
| | pertube_idx = random.choice(pertube_idx_candidates) |
| | item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) |
| | item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1'] |
| |
|
| | return item |
| |
|
| | def get_dataloader(self, batch_size=1, num_workers=0): |
| | loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
| | return loader |
| |
|
| | def collater(self, samples): |
| | hparams = self.hparams |
| | if len(samples) == 0: |
| | return {} |
| | batch = {} |
| |
|
| | batch['th1kh_item_names'] = [s['item_name'] for s in samples] |
| | batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) |
| | |
| | batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) |
| | batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) |
| | batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) |
| | |
| | for key in ['head', 'com', 'inpaint_torso']: |
| | batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) |
| | batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) |
| | |
| | batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) |
| | batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) |
| | batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) |
| | batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) |
| | batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) |
| |
|
| | batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) |
| | |
| | for key in ['head', 'com', 'inpaint_torso']: |
| | batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) |
| |
|
| | batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) |
| | batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) |
| | batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) |
| | batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) |
| | batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) |
| | batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) |
| | batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) |
| |
|
| | batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) |
| | batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) |
| | batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) |
| | batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) |
| |
|
| | return batch |
| |
|
| | if __name__ == '__main__': |
| | os.environ["OMP_NUM_THREADS"] = "1" |
| |
|
| | ds = Img2Plane_Dataset("train", 'data/binary/th1kh') |
| | |
| | dl = ds.get_dataloader() |
| | for b in tqdm(dl): |
| | pass |
| | |