| | from typing import Dict, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| | from loss import FourierLoss |
| | from normalizer import Normalizer |
| | from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder |
| | from mae_utils import flatten_images |
| | from vit import ( |
| | generate_2d_sincos_pos_embeddings, |
| | sincos_positional_encoding_vit, |
| | vit_small_patch16_256, |
| | ) |
| |
|
| | TensorDict = Dict[str, torch.Tensor] |
| |
|
| |
|
| | class MAEConfig(PretrainedConfig): |
| | model_type = "MAE" |
| |
|
| | def __init__( |
| | self, |
| | mask_ratio=0.75, |
| | encoder=None, |
| | decoder=None, |
| | loss=None, |
| | optimizer=None, |
| | input_norm=None, |
| | fourier_loss=None, |
| | fourier_loss_weight=0.0, |
| | lr_scheduler=None, |
| | use_MAE_weight_init=False, |
| | crop_size=-1, |
| | mask_fourier_loss=True, |
| | return_channelwise_embeddings=False, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.mask_ratio = mask_ratio |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.loss = loss |
| | self.optimizer = optimizer |
| | self.input_norm = input_norm |
| | self.fourier_loss = fourier_loss |
| | self.fourier_loss_weight = fourier_loss_weight |
| | self.lr_scheduler = lr_scheduler |
| | self.use_MAE_weight_init = use_MAE_weight_init |
| | self.crop_size = crop_size |
| | self.mask_fourier_loss = mask_fourier_loss |
| | self.return_channelwise_embeddings = return_channelwise_embeddings |
| |
|
| |
|
| | class MAEModel(PreTrainedModel): |
| | config_class = MAEConfig |
| |
|
| | |
| | TOTAL_LOSS = "loss" |
| | RECON_LOSS = "reconstruction_loss" |
| | FOURIER_LOSS = "fourier_loss" |
| |
|
| | def __init__(self, config: MAEConfig): |
| | super().__init__(config) |
| |
|
| | self.mask_ratio = config.mask_ratio |
| |
|
| | |
| | self.encoder = MAEEncoder( |
| | vit_backbone=sincos_positional_encoding_vit( |
| | vit_backbone=vit_small_patch16_256(global_pool="avg") |
| | ), |
| | max_in_chans=11, |
| | channel_agnostic=True, |
| | ) |
| | self.decoder = CAMAEDecoder( |
| | depth=8, |
| | embed_dim=512, |
| | mlp_ratio=4, |
| | norm_layer=nn.LayerNorm, |
| | num_heads=16, |
| | num_modalities=6, |
| | qkv_bias=True, |
| | tokens_per_modality=256, |
| | ) |
| | self.input_norm = torch.nn.Sequential( |
| | Normalizer(), |
| | nn.InstanceNorm2d(None, affine=False, track_running_stats=False), |
| | ) |
| |
|
| | self.fourier_loss_weight = config.fourier_loss_weight |
| | self.mask_fourier_loss = config.mask_fourier_loss |
| | self.return_channelwise_embeddings = config.return_channelwise_embeddings |
| | self.tokens_per_channel = 256 |
| |
|
| | |
| | self.loss = torch.nn.MSELoss(reduction="none") |
| |
|
| | self.fourier_loss = FourierLoss(num_multimodal_modalities=6) |
| | if self.fourier_loss_weight > 0 and self.fourier_loss is None: |
| | raise ValueError( |
| | "FourierLoss weight is activated but no fourier_loss was defined in constructor" |
| | ) |
| | elif self.fourier_loss_weight >= 1: |
| | raise ValueError( |
| | "FourierLoss weight is too large to do mixing factor, weight should be < 1" |
| | ) |
| |
|
| | self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0]) |
| |
|
| | |
| | self.encoder_decoder_proj = nn.Linear( |
| | self.encoder.embed_dim, self.decoder.embed_dim, bias=True |
| | ) |
| |
|
| | self.decoder_pred = nn.Linear( |
| | self.decoder.embed_dim, |
| | self.patch_size**2 |
| | * (1 if self.encoder.channel_agnostic else self.in_chans), |
| | bias=True, |
| | ) |
| |
|
| | |
| | self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( |
| | self.decoder.embed_dim, |
| | length=self.encoder.vit_backbone.patch_embed.grid_size[0], |
| | use_class_token=self.encoder.vit_backbone.cls_token is not None, |
| | num_modality=( |
| | self.decoder.num_modalities if self.encoder.channel_agnostic else 1 |
| | ), |
| | ) |
| |
|
| | if config.use_MAE_weight_init: |
| | w = self.encoder.vit_backbone.patch_embed.proj.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02) |
| | torch.nn.init.normal_(self.decoder.mask_token, std=0.02) |
| |
|
| | self.apply(self._MAE_init_weights) |
| |
|
| | def setup(self, stage: str) -> None: |
| | super().setup(stage) |
| |
|
| | def _MAE_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | @staticmethod |
| | def decode_to_reconstruction( |
| | encoder_latent: torch.Tensor, |
| | ind_restore: torch.Tensor, |
| | proj: torch.nn.Module, |
| | decoder: MAEDecoder | CAMAEDecoder, |
| | pred: torch.nn.Module, |
| | ) -> torch.Tensor: |
| | """Feed forward the encoder latent through the decoders necessary projections and transformations.""" |
| | decoder_latent_projection = proj( |
| | encoder_latent |
| | ) |
| | decoder_tokens = decoder.forward_masked( |
| | decoder_latent_projection, ind_restore |
| | ) |
| | predicted_reconstruction = pred( |
| | decoder_tokens |
| | ) |
| | return predicted_reconstruction[:, 1:, :] |
| |
|
| | def forward( |
| | self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | imgs = self.input_norm(imgs) |
| | latent, mask, ind_restore = self.encoder.forward_masked( |
| | imgs, self.mask_ratio, constant_noise |
| | ) |
| | reconstruction = self.decode_to_reconstruction( |
| | latent, |
| | ind_restore, |
| | self.encoder_decoder_proj, |
| | self.decoder, |
| | self.decoder_pred, |
| | ) |
| | return latent, reconstruction, mask |
| |
|
| | def compute_MAE_loss( |
| | self, |
| | reconstruction: torch.Tensor, |
| | img: torch.Tensor, |
| | mask: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| | """Computes final loss and returns specific values of component losses for metric reporting.""" |
| | loss_dict = {} |
| | img = self.input_norm(img) |
| | target_flattened = flatten_images( |
| | img, |
| | patch_size=self.patch_size, |
| | channel_agnostic=self.encoder.channel_agnostic, |
| | ) |
| |
|
| | loss: torch.Tensor = self.loss( |
| | reconstruction, target_flattened |
| | ) |
| | loss = loss.mean( |
| | dim=-1 |
| | ) |
| | loss = (loss * mask).sum() / mask.sum() |
| | loss_dict[self.RECON_LOSS] = loss.item() |
| |
|
| | |
| | if self.fourier_loss_weight > 0: |
| | floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened) |
| | if not self.mask_fourier_loss: |
| | floss = floss.mean() |
| | else: |
| | floss = floss.mean(dim=-1) |
| | floss = (floss * mask).sum() / mask.sum() |
| |
|
| | loss_dict[self.FOURIER_LOSS] = floss.item() |
| |
|
| | |
| | if self.fourier_loss_weight > 0: |
| | loss = (1 - self.fourier_loss_weight) * loss + ( |
| | self.fourier_loss_weight * floss |
| | ) |
| | return loss, loss_dict |
| |
|
| | def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
| | img = batch["pixels"] |
| | latent, reconstruction, mask = self(img.clone()) |
| | full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask) |
| | return { |
| | "loss": full_loss, |
| | **loss_dict, |
| | } |
| |
|
| | def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: |
| | return self.training_step(batch, batch_idx) |
| |
|
| | def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None: |
| | self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr()) |
| | for key, value in outputs.items(): |
| | if key.endswith("loss"): |
| | self.metrics[key].update(value) |
| |
|
| | def on_validation_batch_end( |
| | self, |
| | outputs: TensorDict, |
| | batch: TensorDict, |
| | batch_idx: int, |
| | dataloader_idx: int = 0, |
| | ) -> None: |
| | super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) |
| |
|
| | def predict(self, imgs: torch.Tensor) -> torch.Tensor: |
| | imgs = self.input_norm(imgs) |
| | X = self.encoder.vit_backbone.forward_features( |
| | imgs |
| | ) |
| | if self.return_channelwise_embeddings: |
| | N, _, d = X.shape |
| | num_channels = imgs.shape[1] |
| | X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d) |
| | pooled_segments = X_reshaped.mean( |
| | dim=2 |
| | ) |
| | latent = pooled_segments.view(N, num_channels * d).contiguous() |
| | else: |
| | latent = X[:, 1:, :].mean(dim=1) |
| | return latent |
| |
|
| | def save_pretrained(self, save_directory: str, **kwargs): |
| | filename = kwargs.pop("filename", "model.safetensors") |
| | modelpath = f"{save_directory}/{filename}" |
| | self.config.save_pretrained(save_directory) |
| | torch.save({"state_dict": self.state_dict()}, modelpath) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| | filename = kwargs.pop("filename", "model.safetensors") |
| |
|
| | modelpath = f"{pretrained_model_name_or_path}/{filename}" |
| | config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| | state_dict = torch.load(modelpath, map_location="cpu") |
| | model = cls(config) |
| | model.load_state_dict(state_dict["state_dict"]) |
| | return model |
| |
|