Spaces:
Build error
Build error
| import torch | |
| from copy import deepcopy | |
| from einops import repeat | |
| import math | |
| class FrameConditioning(): | |
| def __init__(self, | |
| add_frame_to_input: bool = False, | |
| add_frame_to_layers: bool = False, | |
| fill_zero: bool = False, | |
| randomize_mask: bool = False, | |
| concatenate_mask: bool = False, | |
| injection_probability: float = 0.9, | |
| ) -> None: | |
| self.use = None | |
| self.add_frame_to_input = add_frame_to_input | |
| self.add_frame_to_layers = add_frame_to_layers | |
| self.fill_zero = fill_zero | |
| self.randomize_mask = randomize_mask | |
| self.concatenate_mask = concatenate_mask | |
| self.injection_probability = injection_probability | |
| self.add_frame_to_input or self.add_frame_to_layers | |
| assert not add_frame_to_layers or not add_frame_to_input | |
| def set_random_mask(self, random_mask: bool): | |
| frame_conditioning = deepcopy(self) | |
| frame_conditioning.randomize_mask = random_mask | |
| return frame_conditioning | |
| def use(self): | |
| return self.add_frame_to_input or self.add_frame_to_layers | |
| def use(self, value): | |
| if value is not None: | |
| raise NotImplementedError("Direct access not allowed") | |
| def attach_video_frames(self, pl_module, z_0: torch.Tensor = None, batch: torch.Tensor = None, random_mask: bool = False): | |
| assert self.fill_zero, "Not filling with zero not implemented yet" | |
| n_frames_inference = self.inference_params.video_length | |
| with torch.no_grad(): | |
| if z_0 is None: | |
| assert batch is not None | |
| z_0 = pl_module.encode_frame(batch) | |
| assert n_frames_inference == z_0.shape[1], "For frame injection, the number of frames sampled by the dataloader must match the number of frames used for video generation" | |
| shape = list(z_0.shape) | |
| shape[1] = pl_module.inference_params.video_length | |
| M = torch.zeros(shape, dtype=z_0.dtype, | |
| device=pl_module.device) # [B F C W H] | |
| bsz = z_0.shape[0] | |
| if random_mask: | |
| p_inject_frame = self.injection_probability | |
| use_masks = torch.bernoulli( | |
| torch.tensor(p_inject_frame).repeat(bsz)).long() | |
| keep_frame_idx = torch.randint( | |
| 0, n_frames_inference, (bsz,), device=pl_module.device).long() | |
| else: | |
| use_masks = torch.ones((bsz,), device=pl_module.device).long() | |
| # keep only first frame | |
| keep_frame_idx = 0 * use_masks | |
| frame_idx = [] | |
| for batch_idx, (keep_frame, use_mask) in enumerate(zip(keep_frame_idx, use_masks)): | |
| M[batch_idx, keep_frame] = use_mask | |
| frame_idx.append(keep_frame if use_mask == 1 else -1) | |
| x0 = z_0*M | |
| if self.concatenate_mask: | |
| # flatten mask | |
| M = M[:, :, 0, None] | |
| x0 = torch.cat([x0, M], dim=2) | |
| if getattr(pl_module.opt_params.noise_decomposition, "use", False) and random_mask: | |
| assert x0.shape[0] == 1, "randomizing frame injection with noise decomposition not implemented for batch size >1" | |
| return x0, frame_idx | |
| class NoiseDecomposition(): | |
| def __init__(self, | |
| use: bool = False, | |
| random_frame: bool = False, | |
| lambda_f: float = 0.5, | |
| use_base_model: bool = True, | |
| ): | |
| self.use = use | |
| self.random_frame = random_frame | |
| self.lambda_f = lambda_f | |
| self.use_base_model = use_base_model | |
| def get_loss(self, x0, unet_base, unet, noise_scheduler, frame_idx, z_t_base, timesteps, encoder_hidden_states, base_noise, z_t_residual, composed_noise): | |
| if x0 is not None: | |
| # x0.shape = [B,F,C,W,H], if extrapolation_params.fill_zero=true, only one frame per batch non-zero | |
| assert not self.random_frame | |
| # TODO add x0 injection | |
| x0_base = [] | |
| for batch_idx, frame in enumerate(frame_idx): | |
| x0_base.append(x0[batch_idx, frame, None, None]) | |
| x0_base = torch.cat(x0_base, dim=0) | |
| x0_residual = repeat( | |
| x0[:, 0], "B C W H -> B F C W H", F=x0.shape[1]-1) | |
| else: | |
| x0_residual = None | |
| if self.use_base_model: | |
| base_pred = unet_base(z_t_base, timesteps, | |
| encoder_hidden_states, x0=x0_base).sample | |
| else: | |
| base_pred = base_noise | |
| timesteps_alphas = [ | |
| noise_scheduler.alphas_cumprod[t.cpu()] for t in timesteps] | |
| timesteps_alphas = torch.stack( | |
| timesteps_alphas).to(base_pred.device) | |
| timesteps_alphas = repeat(timesteps_alphas, "B -> B F C W H", | |
| F=base_pred.shape[1], C=base_pred.shape[2], W=base_pred.shape[3], H=base_pred.shape[4]) | |
| base_correction = math.sqrt( | |
| lambda_f) * torch.sqrt(1-timesteps_alphas) * base_pred | |
| z_t_residual_dash = z_t_residual - base_correction | |
| residual_pred = unet( | |
| z_t_residual_dash, timesteps, encoder_hidden_states, x0=x0_residual).sample | |
| composed_pred = math.sqrt( | |
| lambda_f)*base_pred.detach() + math.sqrt(1-lambda_f) * residual_pred | |
| loss_residual = torch.nn.functional.mse_loss( | |
| composed_noise.float(), composed_pred.float(), reduction=reduction) | |
| if self.use_base_model: | |
| loss_base = torch.nn.functional.mse_loss( | |
| base_noise.float(), base_pred.float(), reduction=reduction) | |
| loss = loss_residual+loss_base | |
| else: | |
| loss = loss_residual | |
| return loss | |
| def add_noise(self, z_base, base_noise, z_residual, composed_noise, noise_scheduler, timesteps): | |
| z_t_base = noise_scheduler.add_noise( | |
| z_base, base_noise, timesteps) | |
| z_t_residual = noise_scheduler.add_noise( | |
| z_residual, composed_noise, timesteps) | |
| return z_t_base, z_t_residual | |
| def split_latent_into_base_residual(self, z_0, pl_module, noise_generator): | |
| if self.random_frame: | |
| raise NotImplementedError("Must be synced with x0 mask!") | |
| fr_select = torch.randint( | |
| 0, z_0.shape[1], (bsz,), device=pl_module.device).long() | |
| z_base = z_0[:, fr_Select, None] | |
| fr_residual = [fr for fr in range( | |
| z_0.shape[1]) if fr != fr_select] | |
| z_residual = z_0[:, fr_residual, None] | |
| else: | |
| if not pl_module.unet_params.frame_conditioning.randomize_mask: | |
| z_base = z_0[:, 0, None] | |
| z_residual = z_0[:, 1:] | |
| else: | |
| z_base = [] | |
| for batch_idx, frame_at_batch in enumerate(frame_idx): | |
| z_base.append( | |
| z_0[batch_idx, frame_at_batch, None, None]) | |
| z_base = torch.cat(z_base, dim=0) | |
| # z_residual = z_0[[:, 1:] | |
| z_residual = [] | |
| for batch_idx, frame_idx_batch in enumerate(frame_idx): | |
| z_residual_batch = [] | |
| for frame in range(z_0.shape[1]): | |
| if frame_idx_batch != frame: | |
| z_residual_batch.append( | |
| z_0[batch_idx, frame, None, None]) | |
| z_residual_batch = torch.cat( | |
| z_residual_batch, dim=1) | |
| z_residual.append(z_residual_batch) | |
| z_residual = torch.cat(z_residual, dim=0) | |
| base_noise = noise_generator.sample_noise(z_base) # b_t | |
| residual_noise = noise_generator.sample_noise(z_residual) # r^f_t | |
| lambda_f = self.lambda_f | |
| composed_noise = math.sqrt( | |
| lambda_f) * base_noise + math.sqrt(1-lambda_f) * residual_noise # dimension issue? | |
| return z_base, base_noise, z_residual, composed_noise | |
| class NoiseGenerator(): | |
| def __init__(self, mode="vanilla") -> None: | |
| self.mode = mode | |
| def set_seed(self, seed: int): | |
| self.seed = seed | |
| def reset_seed(self, seed: int): | |
| pass | |
| def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None): | |
| assert (z_0 is not None) != ( | |
| shape is not None), f"either z_0 must be None, or shape must be None. Both provided." | |
| kwargs = {} | |
| if z_0 is None: | |
| if device is not None: | |
| kwargs["device"] = device | |
| if dtype is not None: | |
| kwargs["dtype"] = dtype | |
| else: | |
| kwargs["device"] = z_0.device | |
| kwargs["dtype"] = z_0.dtype | |
| shape = z_0.shape | |
| if generator is not None: | |
| kwargs["generator"] = generator | |
| B, F, C, W, H = shape | |
| if self.mode == "vanilla": | |
| noise = torch.randn( | |
| shape, **kwargs) | |
| elif self.mode == "free_noise": | |
| noise = torch.randn(shape, **kwargs) | |
| if noise.shape[1] > 4: | |
| # HARD CODED | |
| noise = noise[:, :8] | |
| noise = torch.cat( | |
| [noise, noise[:, torch.randperm(noise.shape[1])]], dim=1) | |
| elif noise.shape[2] > 4: | |
| noise = noise[:, :, :8] | |
| noise = torch.cat( | |
| [noise, noise[:, :, torch.randperm(noise.shape[2])]], dim=2) | |
| else: | |
| raise NotImplementedError( | |
| f"Shape of noise vector not as expected {noise.shape}") | |
| elif self.mode == "equal": | |
| shape = list(shape) | |
| shape[1] = 1 | |
| noise_init = torch.randn( | |
| shape, **kwargs) | |
| shape[1] = F | |
| noise = torch.zeros( | |
| shape, device=noise_init.device, dtype=noise_init.dtype) | |
| for fr in range(F): | |
| noise[:, fr] = noise_init[:, 0] | |
| elif self.mode == "fusion": | |
| shape = list(shape) | |
| shape[1] = 1 | |
| noise_init = torch.randn( | |
| shape, **kwargs) | |
| noises = [] | |
| noises.append(noise_init) | |
| for fr in range(F-1): | |
| shift = 2*(fr+1) | |
| local_copy = noise_init | |
| shifted_noise = torch.cat( | |
| [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) | |
| noises.append(math.sqrt(0.2)*shifted_noise + | |
| math.sqrt(1-0.2)*torch.rand(shape, **kwargs)) | |
| noise = torch.cat(noises, dim=1) | |
| elif self.mode == "motion_dynamics" or self.mode == "equal_noise_per_sequence": | |
| shape = list(shape) | |
| normal_frames = 1 | |
| shape[1] = normal_frames | |
| init_noise = torch.randn( | |
| shape, **kwargs) | |
| noises = [] | |
| noises.append(init_noise) | |
| init_noise = init_noise[:, -1, None] | |
| print(f"UPDATE with noise = {init_noise.shape}") | |
| if self.mode == "motion_dynamics": | |
| for fr in range(F-normal_frames): | |
| shift = 2*(fr+1) | |
| print(fr, shift) | |
| local_copy = init_noise | |
| shifted_noise = torch.cat( | |
| [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) | |
| noises.append(shifted_noise) | |
| elif self.mode == "equal_noise_per_sequence": | |
| for fr in range(F-1): | |
| noises.append(init_noise) | |
| else: | |
| raise NotImplementedError() | |
| # noises[0] = noises[0] * 0 | |
| noise = torch.cat(noises, dim=1) | |
| print(noise.shape) | |
| return noise | |