| from typing import Dict, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from diffusers import AutoencoderKL |
| from diffusers.configuration_utils import register_to_config |
| from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput |
| from diffusers.utils.accelerate_utils import apply_forward_hook |
|
|
|
|
| class AutoencoderKLNextStep(AutoencoderKL): |
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 3, |
| out_channels: int = 3, |
| down_block_types: Tuple[str] = ("DownEncoderBlock2D",), |
| up_block_types: Tuple[str] = ("UpDecoderBlock2D",), |
| block_out_channels: Tuple[int] = (64,), |
| layers_per_block: int = 1, |
| act_fn: str = "silu", |
| latent_channels: int = 4, |
| norm_num_groups: int = 32, |
| sample_size: int = 32, |
| scaling_factor: float = 0.18215, |
| shift_factor: Optional[float] = None, |
| latents_mean: Optional[Tuple[float]] = None, |
| latents_std: Optional[Tuple[float]] = None, |
| force_upcast: bool = True, |
| use_quant_conv: bool = True, |
| use_post_quant_conv: bool = True, |
| mid_block_add_attention: bool = True, |
| deterministic: bool = False, |
| normalize_latents: bool = False, |
| patch_size: Optional[int] = None, |
| ): |
| super().__init__( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| down_block_types=down_block_types, |
| up_block_types=up_block_types, |
| block_out_channels=block_out_channels, |
| layers_per_block=layers_per_block, |
| act_fn=act_fn, |
| latent_channels=latent_channels, |
| norm_num_groups=norm_num_groups, |
| sample_size=sample_size, |
| scaling_factor=scaling_factor, |
| shift_factor=shift_factor, |
| latents_mean=latents_mean, |
| latents_std=latents_std, |
| force_upcast=force_upcast, |
| use_quant_conv=use_quant_conv, |
| use_post_quant_conv=use_post_quant_conv, |
| mid_block_add_attention=mid_block_add_attention, |
| ) |
| self.deterministic = deterministic |
| self.normalize_latents = normalize_latents |
| self.patch_size = patch_size |
|
|
| def patchify(self, x: torch.Tensor) -> torch.Tensor: |
| b, c, h, w = x.shape |
| p = self.patch_size |
| h_, w_ = h // p, w // p |
|
|
| x = x.reshape(b, c, h_, p, w_, p) |
| x = torch.einsum("bchpwq->bcpqhw", x) |
| x = x.reshape(b, c * p ** 2, h_, w_) |
| return x |
|
|
| def unpatchify(self, x: torch.Tensor) -> torch.Tensor: |
| b, _, h_, w_ = x.shape |
| p = self.patch_size |
| c = x.shape[1] // (p ** 2) |
|
|
| x = x.reshape(b, c, p, p, h_, w_) |
| x = torch.einsum("bcpqhw->bchpwq", x) |
| x = x.reshape(b, c, h_ * p, w_ * p) |
| return x |
|
|
| @apply_forward_hook |
| def encode( |
| self, x: torch.Tensor, return_dict: bool = True |
| ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
| if self.use_slicing and x.shape[0] > 1: |
| encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
| h = torch.cat(encoded_slices) |
| else: |
| h = self._encode(x) |
|
|
| mean, logvar = torch.chunk(h, 2, dim=1) |
| if self.patch_size is not None: |
| mean = self.patchify(mean) |
| if self.normalize_latents: |
| mean = mean.permute(0, 2, 3, 1) |
| mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6) |
| mean = mean.permute(0, 3, 1, 2) |
| if self.patch_size is not None: |
| mean = self.unpatchify(mean) |
| h = torch.cat([mean, logvar], dim=1).contiguous() |
| posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic) |
|
|
| if not return_dict: |
| return (posterior,) |
|
|
| return AutoencoderKLOutput(latent_dist=posterior) |
|
|
| def forward( |
| self, |
| sample: torch.Tensor, |
| sample_posterior: bool = False, |
| return_dict: bool = True, |
| generator: Optional[torch.Generator] = None, |
| noise_strength: float = 0.0, |
| ) -> Union[DecoderOutput, torch.Tensor]: |
| x = sample |
| posterior = self.encode(x).latent_dist |
| if sample_posterior: |
| z = posterior.sample(generator=generator) |
| else: |
| z = posterior.mode() |
| if noise_strength > 0.0: |
| p = torch.distributions.Uniform(0, noise_strength) |
| z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor( |
| z.shape, device=z.device, dtype=z.dtype |
| ) |
| dec = self.decode(z).sample |
|
|
| if not return_dict: |
| return (dec,) |
|
|
| return DecoderOutput(sample=dec) |
|
|