Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.fft as fft | |
| from torch import nn | |
| from torch.nn import functional | |
| from math import sqrt | |
| from einops import rearrange | |
| import math | |
| import numbers | |
| from typing import List | |
| # adapted from https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 | |
| # and https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/19 | |
| def gaussian_smoothing_kernel(shape, kernel_size, sigma, dim=2): | |
| """ | |
| Apply gaussian smoothing on a | |
| 1d, 2d or 3d tensor. Filtering is performed seperately for each channel | |
| in the input using a depthwise convolution. | |
| Arguments: | |
| channels (int, sequence): Number of channels of the input tensors. Output will | |
| have this number of channels as well. | |
| kernel_size (int, sequence): Size of the gaussian kernel. | |
| sigma (float, sequence): Standard deviation of the gaussian kernel. | |
| dim (int, optional): The number of dimensions of the data. | |
| Default value is 2 (spatial). | |
| """ | |
| if isinstance(kernel_size, numbers.Number): | |
| kernel_size = [kernel_size] * dim | |
| if isinstance(sigma, numbers.Number): | |
| sigma = [sigma] * dim | |
| # The gaussian kernel is the product of the | |
| # gaussian function of each dimension. | |
| kernel = 1 | |
| meshgrids = torch.meshgrid( | |
| [ | |
| torch.arange(size, dtype=torch.float32) | |
| for size in kernel_size | |
| ] | |
| ) | |
| for size, std, mgrid in zip(kernel_size, sigma, meshgrids): | |
| mean = (size - 1) / 2 | |
| kernel *= torch.exp(-((mgrid - mean) / std) ** 2 / 2) | |
| # kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ | |
| # torch.exp(-((mgrid - mean) / std) ** 2 / 2) | |
| # Make sure sum of values in gaussian kernel equals 1. | |
| kernel = kernel / torch.sum(kernel) | |
| pad_length = (math.floor( | |
| (shape[-1]-kernel_size[-1])/2), math.floor((shape[-1]-kernel_size[-1])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-3]-kernel_size[-3])/2), math.floor((shape[-3]-kernel_size[-3])/2)) | |
| kernel = functional.pad(kernel, pad_length) | |
| assert kernel.shape == shape[-3:] | |
| return kernel | |
| ''' | |
| # Reshape to depthwise convolutional weight | |
| kernel = kernel.view(1, 1, *kernel.size()) | |
| kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) | |
| self.register_buffer('weight', kernel) | |
| self.groups = channels | |
| if dim == 1: | |
| self.conv = functional.conv1d | |
| elif dim == 2: | |
| self.conv = functional.conv2d | |
| elif dim == 3: | |
| self.conv = functional.conv3d | |
| else: | |
| raise RuntimeError( | |
| 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( | |
| dim) | |
| ) | |
| ''' | |
| class NoiseGenerator(): | |
| def __init__(self, alpha: float = 0.0, shared_noise_across_chunks: bool = False, mode="vanilla", forward_steps: int = 850, radius: List[float] = None) -> None: | |
| self.mode = mode | |
| self.alpha = alpha | |
| self.shared_noise_across_chunks = shared_noise_across_chunks | |
| self.forward_steps = forward_steps | |
| self.radius = radius | |
| def set_seed(self, seed: int): | |
| self.seed = seed | |
| def reset_seed(self, seed: int): | |
| pass | |
| def reset_noise_generator_state(self): | |
| if hasattr(self, "e_shared"): | |
| del self.e_shared | |
| def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None, content=None): | |
| assert (z_0 is not None) != ( | |
| shape is not None), f"either z_0 must be None, or shape must be None. Both provided." | |
| kwargs = {} | |
| noise = torch.randn(shape, **kwargs) | |
| if z_0 is None: | |
| if device is not None: | |
| kwargs["device"] = device | |
| if dtype is not None: | |
| kwargs["dtype"] = dtype | |
| else: | |
| kwargs["device"] = z_0.device | |
| kwargs["dtype"] = z_0.dtype | |
| shape = z_0.shape | |
| if generator is not None: | |
| kwargs["generator"] = generator | |
| B, F, C, W, H = shape | |
| if F == 4 and C > 4: | |
| frame_idx = 2 | |
| F, C = C, F | |
| else: | |
| frame_idx = 1 | |
| if "mixed_noise" in self.mode: | |
| shape_per_frame = [dim for dim in shape] | |
| shape_per_frame[frame_idx] = 1 | |
| zero_mean = torch.zeros( | |
| shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) | |
| std = torch.ones( | |
| shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) | |
| alpha = self.alpha | |
| std_coeff_shared = (alpha**2) / (1 + alpha**2) | |
| if self.shared_noise_across_chunks and hasattr(self, "e_shared"): | |
| e_shared = self.e_shared | |
| else: | |
| e_shared = torch.normal(mean=zero_mean, std=sqrt( | |
| std_coeff_shared)*std, generator=kwargs["generator"] if "generator" in kwargs else None) | |
| if self.shared_noise_across_chunks: | |
| self.e_shared = e_shared | |
| e_inds = [] | |
| for frame in range(shape[frame_idx]): | |
| std_coeff_ind = 1 / (1 + alpha**2) | |
| e_ind = torch.normal( | |
| mean=zero_mean, std=sqrt(std_coeff_ind)*std, generator=kwargs["generator"] if "generator" in kwargs else None) | |
| e_inds.append(e_ind) | |
| noise = torch.cat( | |
| [e_shared + e_ind for e_ind in e_inds], dim=frame_idx) | |
| if "consistI2V" in self.mode and content is not None: | |
| # if self.mode == "mixed_noise_consistI2V", we will use 'noise' from 'mixed_noise'. Otherwise, it is randn noise. | |
| if frame_idx == 1: | |
| assert content.shape[0] == noise.shape[0] and content.shape[2:] == noise.shape[2:] | |
| content = torch.concat([content, content[:, -1:].repeat( | |
| 1, noise.shape[1]-content.shape[1], 1, 1, 1)], dim=1) | |
| noise = rearrange(noise, "B F C W H -> (B C) F W H") | |
| content = rearrange(content, "B F C W H -> (B C) F W H") | |
| else: | |
| assert content.shape[:2] == noise.shape[: | |
| 2] and content.shape[3:] == noise.shape[3:] | |
| content = torch.concat( | |
| [content, content[:, :, -1:].repeat(1, 1, noise.shape[2]-content.shape[2], 1, 1)], dim=2) | |
| noise = rearrange(noise, "B C F W H -> (B C) F W H") | |
| content = rearrange(content, "B C F W H -> (B C) F W H") | |
| # TODO implement DDPM_forward using diffusers framework | |
| ''' | |
| content_noisy = ddpm_forward( | |
| content, noise, self.forward_steps) | |
| ''' | |
| # A 2D low pass filter was given in the blog: | |
| # see https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/ | |
| # alternative | |
| # do we have to specify more (s,dim,norm?) | |
| noise_fft = fft.fftn(noise) | |
| content_noisy_fft = fft.fftn(content_noisy) | |
| # shift low frequency parts to center | |
| noise_fft_shifted = fft.fftshift(noise_fft) | |
| content_noisy_fft_shifted = fft.fftshift(content_noisy_fft) | |
| # create gaussian low pass filter 'gaussian_low_pass_filter' (specify std!) | |
| # mask out high frequencies using 'cutoff_frequence', something like gaussian_low_pass_filter[freq > cut_off_frequency] = 0.0 | |
| # TODO define 'gaussian_low_pass_filter', apply frequency cutoff filter using self.cutoff_frequency. We need to apply fft.fftshift too probably. | |
| # TODO what exactly is the "normalized space-time stop frequency" used for the cutoff? | |
| gaussian_3d = gaussian_smoothing_kernel(noise_fft.shape, kernel_size=( | |
| noise_fft.shape[-3], noise_fft.shape[-2], noise_fft.shape[-1]), sigma=1, dim=3).to(noise.device) | |
| # define cutoff frequency around the kernel center | |
| # TODO define center and cut off radius, e.g. somethink like gaussian_3d[...,:c_x-r_x,:c_y-r_y:,:c_z-r_z] = 0.0 and gaussian_3d[...,c_x+r_x:,c_y+r_y:,c_z+r_z:] = 0.0 | |
| # as we have 16 x 32 x 32, center should be (7.5,15.5,15.5) | |
| radius = self.radius | |
| # TODO we need to use rounding (ceil?) | |
| gaussian_3d[:center[0]-radius[0], :center[1] - | |
| radius[1], :center[2]-radius[2]] = 0.0 | |
| gaussian_3d[center[0]+radius[0]:, | |
| center[1]+radius[1]:, center[2]+radius[2]:] = 0.0 | |
| noise_fft_shifted_hp = noise_fft_shifted * (1 - gaussian_3d) | |
| content_noisy_fft_shifted_lp = content_noisy_fft_shifted * gaussian_3d | |
| noise = fft.ifftn(fft.ifftshift( | |
| noise_fft_shifted_hp+content_noisy_fft_shifted_lp)) | |
| if frame_idx == 1: | |
| noise = rearrange( | |
| noise, "(B C) F W H -> B F C W H", B=B) | |
| else: | |
| noise = rearrange( | |
| noise, "(B C) F W H -> B C F W H", B=B) | |
| assert noise.shape == shape | |
| return noise | |