Epoch 29
Browse files- samples/.ipynb_checkpoints/0009-checkpoint.png +0 -0
- samples/.ipynb_checkpoints/0019-checkpoint.png +0 -0
- samples/.ipynb_checkpoints/0029-checkpoint.png +0 -0
- samples/.ipynb_checkpoints/0039-checkpoint.png +0 -0
- samples/.ipynb_checkpoints/0049-checkpoint.png +0 -0
- samples/.ipynb_checkpoints/overview-checkpoint.ipynb +562 -0
- samples/0009.png +0 -0
- samples/0019.png +0 -0
- samples/0029.png +0 -0
- samples/0039.png +0 -0
- samples/0049.png +0 -0
samples/.ipynb_checkpoints/0009-checkpoint.png
ADDED
|
samples/.ipynb_checkpoints/0019-checkpoint.png
ADDED
|
samples/.ipynb_checkpoints/0029-checkpoint.png
ADDED
|
samples/.ipynb_checkpoints/0039-checkpoint.png
ADDED
|
samples/.ipynb_checkpoints/0049-checkpoint.png
ADDED
|
samples/.ipynb_checkpoints/overview-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 19,
|
| 6 |
+
"id": "9074d4b4",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 12 |
+
"import scipy.io\n",
|
| 13 |
+
"from torchvision import transforms, utils\n",
|
| 14 |
+
"import os\n",
|
| 15 |
+
"from PIL import Image"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 20,
|
| 21 |
+
"id": "c85ff158",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"class GravityDataset(Dataset):\n",
|
| 26 |
+
" def __init__(self, folder, image_size, exts=['mat']):\n",
|
| 27 |
+
" super().__init__()\n",
|
| 28 |
+
" self.folder = folder\n",
|
| 29 |
+
" self.image_size = image_size\n",
|
| 30 |
+
" self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.mat')]\n",
|
| 31 |
+
"\n",
|
| 32 |
+
" # Define transformations that are independent of scaling\n",
|
| 33 |
+
" self.transform = transforms.Compose([\n",
|
| 34 |
+
" transforms.Resize((int(image_size * 1.12), int(image_size * 1.12))), # Resize slightly larger\n",
|
| 35 |
+
" transforms.RandomCrop(image_size), # Then crop to the target size\n",
|
| 36 |
+
" transforms.RandomHorizontalFlip(), # Random horizontal flip\n",
|
| 37 |
+
" transforms.ToTensor() # Convert to tensor\n",
|
| 38 |
+
" ])\n",
|
| 39 |
+
"\n",
|
| 40 |
+
" def scale_to_minus1_1(self, tensor):\n",
|
| 41 |
+
" \"\"\"Dynamically scale the tensor to the range [-1, 1] based on its own min and max values.\"\"\"\n",
|
| 42 |
+
" min_val = tensor.min()\n",
|
| 43 |
+
" max_val = tensor.max()\n",
|
| 44 |
+
" # Avoid division by zero if min and max are the same\n",
|
| 45 |
+
" if max_val > min_val:\n",
|
| 46 |
+
" return 2 * ((tensor - min_val) / (max_val - min_val)) - 1\n",
|
| 47 |
+
" else:\n",
|
| 48 |
+
" return tensor # If min and max are the same, return tensor as is.\n",
|
| 49 |
+
" # tensor = tensor\n",
|
| 50 |
+
" return tensor # If min and max are the same, return tensor as is.\n",
|
| 51 |
+
"\n",
|
| 52 |
+
" def __len__(self):\n",
|
| 53 |
+
" return len(self.paths)\n",
|
| 54 |
+
"\n",
|
| 55 |
+
" def __getitem__(self, index):\n",
|
| 56 |
+
" file_path = self.paths[index]\n",
|
| 57 |
+
" # Load the .mat file\n",
|
| 58 |
+
" data_loc = scipy.io.loadmat(file_path)\n",
|
| 59 |
+
" data_val = data_loc['d']\n",
|
| 60 |
+
"\n",
|
| 61 |
+
" data_val = data_val.reshape(32, 32)\n",
|
| 62 |
+
" # Convert numpy array as a PIL image\n",
|
| 63 |
+
" img = Image.fromarray(data_val)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" # Apply transformations\n",
|
| 66 |
+
" if self.transform:\n",
|
| 67 |
+
" img = self.transform(img)\n",
|
| 68 |
+
" \n",
|
| 69 |
+
" # Scale the image to [-1, 1] based on its own min and max values\n",
|
| 70 |
+
" img = self.scale_to_minus1_1(img)\n",
|
| 71 |
+
" return img"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 21,
|
| 77 |
+
"id": "c16dc0d1",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [],
|
| 80 |
+
"source": [
|
| 81 |
+
"data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
|
| 82 |
+
"dataset = GravityDataset(data, image_size=32)\n",
|
| 83 |
+
"dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": 22,
|
| 89 |
+
"id": "a8bd15c0",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"outputs": [
|
| 92 |
+
{
|
| 93 |
+
"name": "stderr",
|
| 94 |
+
"output_type": "stream",
|
| 95 |
+
"text": [
|
| 96 |
+
"Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 28.56it/s]\n"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"data": {
|
| 101 |
+
"text/plain": [
|
| 102 |
+
"4"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
"execution_count": 22,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"output_type": "execute_result"
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
"source": [
|
| 111 |
+
"from diffusers import StableDiffusionPipeline\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"pipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n",
|
| 114 |
+
"pipeline.unet.config[\"in_channels\"]\n",
|
| 115 |
+
"4"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "markdown",
|
| 120 |
+
"id": "72b20338",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"source": [
|
| 123 |
+
"**Training Configuration**"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": 23,
|
| 129 |
+
"id": "e01c3fae",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"outputs": [],
|
| 132 |
+
"source": [
|
| 133 |
+
"from dataclasses import dataclass\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"@dataclass\n",
|
| 136 |
+
"class TrainingConfig:\n",
|
| 137 |
+
" image_size = 32 # the generated image resolution\n",
|
| 138 |
+
" train_batch_size = 32\n",
|
| 139 |
+
" eval_batch_size = 16 # how many images to sample during evaluation\n",
|
| 140 |
+
" num_epochs = 50\n",
|
| 141 |
+
" gradient_accumulation_steps = 1\n",
|
| 142 |
+
" learning_rate = 1e-4\n",
|
| 143 |
+
" lr_warmup_steps = 500\n",
|
| 144 |
+
" save_image_epochs = 10\n",
|
| 145 |
+
" save_model_epochs = 30\n",
|
| 146 |
+
" mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n",
|
| 147 |
+
" output_dir = \"ddpm-butterflies-128\" # the model name locally and on the HF Hub\n",
|
| 148 |
+
"\n",
|
| 149 |
+
" push_to_hub = True # whether to upload the saved model to the HF Hub\n",
|
| 150 |
+
" hub_model_id = \"jainadarsh/trial\" # the name of the repository to create on the HF Hub\n",
|
| 151 |
+
" hub_private_repo = None\n",
|
| 152 |
+
" overwrite_output_dir = True # overwrite the old model when re-running the notebook\n",
|
| 153 |
+
" seed = 0\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"config = TrainingConfig()"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "markdown",
|
| 161 |
+
"id": "9f52a27e",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"source": [
|
| 164 |
+
"**Load Dataset**"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "code",
|
| 169 |
+
"execution_count": 74,
|
| 170 |
+
"id": "b4a98567",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"outputs": [],
|
| 173 |
+
"source": [
|
| 174 |
+
"# from datasets import load_dataset\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n",
|
| 177 |
+
"# dataset = load_dataset(config.dataset_name, split=\"train\")\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
|
| 180 |
+
"dataset = GravityDataset(data, image_size=32)\n",
|
| 181 |
+
"train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": 75,
|
| 187 |
+
"id": "e3aed193",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [
|
| 190 |
+
{
|
| 191 |
+
"data": {
|
| 192 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAABOwAAAEhCAYAAADMCz9IAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAISNJREFUeJzt2smSI7e1MGAkZ7JY1ZNaLdmK8MZb77T2+/hF/RDeOMJ2eAi1ZXWruyYO+S/+xb03bONASlQSZH3f9qAOkCCABE+x6/u+TwAAAABAEyanHgAAAAAA8D8U7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49gBb1fX/qIaSU6oyjhRwlfx+1aSXH8Xh88j7G0nXdoHiJ7XY7OEdr3r17l43/61//CnPsdrtsfDKJ/5eyWq2y8c1mE+aIPp+XL19m419++WXYxy9/+cts/Fe/+lWY45tvvhk8juhZStZqNOfz+Twbn06nYR819l10xhwOhzBH1CZawyV9RDmieEop7ff7weOI5qvkM4n27G9/+9swxzmqsV55nqyd5yu6S5+jkvdVC1rZd2OM4zk9aw3nMs7nyC/sAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj2A56rv+1MPoZroWUqetUaO4/E4KF7Szxh91FgbXdcNblOSYzJ5fjX/7Xabje/3+zDH4+NjNj6dTsMcq9UqG4/GmVJK19fX2firV6+y8ZcvX4Z93NzcZOPr9TrMsVgssvHZLH6V1VjvkaHnR8k4apyFh8MhzBGt42gNR/GUUnp4eBgUL+mnZD/WOHNL1iBckhpnJlyKc/leVzLOoXu75O/H+K4zxrOWiMbhLCXy/L5tAwAAAEDDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYBT6Pv+1EOoosZzlOSI2gyNl7Q5Ho9hjqhNSY7D4TAoR/T3JW1qfK5d14VtJpN8vX46nQ7OcYlev36djZfMyePj4+Acm80mG99ut2GOly9fZuOvXr3Kxr/44ouwj2i+bm5uwhzr9Tobn8/nYY7ZLP+6K9kzQ411FkZnzH6/D3NEa/Th4SEbv7u7C/uI2pTkuL+/z8Z3u12YI5qvkv1YsgbhXIxxHj4n5vPyXcp3y5TiZ4nWc8lcRDlK9szQcZbkKDF0f9eYrxpaGQf/7vl92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AC5V3/enHkJKaZxxRH2UjOF4PA6Kl7Q5HA5hjv1+n43vdrtBf1/SpmSckckkrsXPZvntP51Owxzz+bx4TJfi7du32fhyuQxzPDw8ZOMln996vc7Gr6+vwxyvXr0aFH/z5k3Yx+vXr7Pxm5ubMMdqtcrGF4tFmCNa7yVz3nVdNl7jvI1y1DjHHh8fwxzRGr29vR0UTymlT58+DYqX9BM9R0rxnJasjWiNwliic4r/y3xRQ8n3lHNhT/yPkrmI7m2XNJ/P6Vlb4hd2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD6C2vu9PPYRqxniWkj6GjqPk74/H4+Ach8MhG9/v92GO3W6XjT88PGTjd3d3YR9RjpJxRmazeGsvl8tsfLVaDR7HJfrqq6+y8aurqzDH4+NjNt51XZgj+vyur6/DHC9fvnzSeEmb7XYb5ojmNJqLlOI9MZ1OwxzR5xLFa5yF0TmXUnyGROdcSind399n47e3t9n4x48fwz5+/PHHJ88RPUdK8ZyWnKfr9TpsAzWUvB+eC3Pxf5mP0yl5N5+LyST/e56hd6GW1Li3jfG80TjOac756fzCDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGjI7NQDOEd93596CCml8cYR9TM0XuJ4PA5uczgcwhyPj4/Z+N3dXTb+448/hn18+vRp0BhKLJfLsM12u83GS+a8lb0wpl/84hfZePT5plTnM14sFtn4ZrMJc7x48SIbv7m5ycavr6/DPqIcV1dXYY7VapWNz+fzMMdsln/dTSbx/6+6rgvbDBXtqZJ9ud/vs/GHh4cwx/39fTZ+e3ubjZfsgx9++GFQPKX4zI3O7JTiOY32Wkop7Xa7sA1ExjhjWnFJz3pJz8LPU/IdI1LjTh2txZK1Go0jylHjPlUyzjH2XY0+hs5njT5q9TN0HM7Kn8cv7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7Ac9X3fRP91BhHlON4PIY5ojYlOQ6HQza+3+/DHLvdLhu/v7/Pxj9//hz28fHjx8E5IldXV4NzTKfTKm0uzVdffZWN397ehjkeHx8Hj2OxWGTjq9UqzBGtkyi+3W7DPjabTTa+Xq/DHMvlMhuP5iKllGaz/OtuMon/fxW1qXHeRudYFE8pPsdK1t/d3V02Hp1TP/74Y9hHdBZ++PAhzPHp06dsPDqzS5TMebS+IKWUuq479RBGcy7PapzUEH3HGOt7X7ROStZRdNeJcpQ8a5Tjkr5fjDFfXDa/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AH6qvu9PPYRnJ5rzGp9JlON4PIY5ojaHwyHMEbXZ7/fZ+G63C/t4fHwcnCOar+l0GuZYLpfZ+Gq1CnPMZmd3hAz29u3bbPzh4SHMEa2jEtFnvFgswhzRZ7xerwf9fUmbaB2mlNJ8Ps/GS9ZhNF+TyXn8/6rkHKtxTkXr+Pb2Nhv//Plz2MenT58G9VHSpmS+os++5P3TdV3Yhst2SWvgXJ6llXG2Mo7IuYzzHEV3+5LvSiXvmkiNu07UJuqj5FmjHCXv7pLvOkPV2DPRfIy1L1sYR8nacE79u/P4hgIAAAAAz4SCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD4D/ru/7QfFaOYYaa5zH43FQvKRNNI6u68I+ptPpoHhKw8eZUkqHwyEb3+12YY6SNpfmzZs32fjj42OYI5r7EtFam8/nYY7FYjEox3K5fPI+UkppNsu/qiaT+H9PUZuSvRspOWMiNc7CGnv74eEhG7+/v8/G7+7uwj6iHCV7KZqPkrURrdHVahXmKGnDeatxRrSihWdpYQwpjTOO5/Ssz1X0Piu5H9T4Tha980q+Y0RtontbSR+Rknd3pMY4anwmNfZdje+fnC+/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoALer7/tRDaEo0H8fjcdDfl7QpyVFjHFGOruuy8el0GvYxn8+z8cViEebY7/fZ+GQS1+KjZ436KG1zaW5ubrLxkjmJ5r6G2Sw+3qP1Gq3Vkj6iNiV7JmpTst6jvVsiOkNq9DF0DCVtDodDmCNax7vdblC8ZBwl81ljja5Wq2z8+vo6zFHShraNsX/H0MpzjDGOsZ51aD/nMk5+vs+fP2fjJXfDGneM6L5U8k6M3qvRHTb6+xIl44zmo+SuE81XKzWBVu6XzpjT8As7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw/gUvV9f+ohFGllnNE4aoxzjGedTOIa+HQ6zcZns+HbsmQckePxGLbZ7/eD+zk3m80mGy9ZZyVzO1TJGui6LhuP1mK0lkvalIwzalNjvZd8bkPPqWi+S9ucg5LPZD6fZ+Ml+ySar+VyGeaI9vT19XWY48WLF2EbTudS9lVK7TzLGOOo0UcLOVoYQ2v9XJoffvghG394eAhzHA6HbLzGe7XknbharbLx9Xqdjbfy3bLkjlrj+2cLe+Zcxlmixl360viFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1AOgfX3fn/Tva+Wooeu6QfGUUppM8nXyKF7SpmQc0Zwej8fBOS7RZrPJxmvMW8m81liLUZtW1uoY671E9NlGfdSYrxo5ptNpmGM2y18PFotFNr5arcI+DofDoD5Sip+lZBzb7XZQPKWUbm5uwjY8nZIz4By08hw1xjHGs4w1zqH9jNFHSY5W1tcl+tvf/paNf/r0Kcyx3++z8ZL3/3K5zMavrq7CHNfX19n4ixcvsvGxvj+McQ+ucb+07xjKL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPQBIKaW+7588xxh9jKXGsx4Oh2x8t9uFOfb7fdjm0iwWi2y8ZO7HWEdd11VpkzOZxP/zifoYY5wlxjgfSuYrajObxa/tqE20hlNKabVaZeObzSYbj86XlOJxHo/HwTmi50gpfpbr6+swx9XVVdgGxjjLStQYx9AcLYyhNMfQ99gYfZS2GSPHc/TnP/85G3///n2Y4/HxMRsvuUOs1+tsvOR99ubNm2w8GmfJ+7/GnSuajxp3rhp3+lbu/Jwvv7ADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANmZ16AP9b3/enHsJoxnrWS5nTGs9RkiNqE8WPx2PYx+FwyMb3+32YY7fbZeMlzxqNdTKJ6/nROC7RYrHIxmuss7F0XTco3kofJYbu7RLRnppOp2GO2Sz/Wo7iKaW0XC6z8ZJ9u9lssvHoHCt51qiPks8kmo9oLlJKab1eZ+PROEvb8POMdUYM1co4xzi3W8kx1jiHjqPkPlXjWaN+xnjW5+r9+/fZ+N///vcwx93d3eBxzOfzbLzkXfX9999n458/f87Gb29vwz6++OKLbPzVq1dhjkjJvovuKiXf60r6uRTRvcz58DSezwoDAAAAgDOgYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1APg/PV9PyheK8fxeHzyHFH8cDiEfez3+2z88fExzHF/fz94HJNJvl4fjfO5ms3yx2bJOqsh6qfruicfwxh91FLjjIlEa6NEdMbM5/MwR7T/1+v14HFE58dyuQz7iM6Yks9kOp1m44vFIswRjXW1WoU5StpADTXO3RZylPx91OZcckTnZUmbVnKc03t/TNF7t+ReHt3tHx4ewhwfPnzIxt+/fx/m2G632fjbt2+z8a+//jrs45tvvsnGf/nLX4Y5vv3222y85P0f3alK9gw8NasQAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0ZHbqAfDz9X0/So7j8Ti4n6HjKBln1KbkOaI2h8NhUDyllPb7fTb+8PAQ5vj06VM2fnt7G+aIxjqbxcfDy5cvwzaXJpqXGmv1knRdNzjHGGddSR9Dn2U6nYZtxjgLS0TPOp/Ps/H1eh32EZ1BJc8RzWnJORY9y2KxCHOUtOE/q3FGjGGMcZ7LXJSInqXkWVvJMZnkf98wNJ5SfJaVvD9qnIc1noV/V7LOonde9P0hpZS+//77bPwPf/hDmCNaJ9Hd/927d2Ef3333XTb+4cOHMMevf/3rweNo4X451rnfyjj46Zy6AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx4A/13f96ceQpFonCXPEbU5Ho9hjqjNfr8Pc0RthsZL2jw8PIQ5Pn78mI2/f/8+zPHdd99l43/5y1/CHG/fvs3Gf/e734U5zs1kMvz/HOeyt1tRY77GOKda0XXdoHhKKU2n02x8sVhk4yVn4eFwCNtEomeJniOllGaz/FUoiqeU0nw+D9vAGEr2d40cQ8+ZMfooaVPyTo/aRPGxzqGozXK5DHNEZ1nJszxH//znP7Pxv/71r2GO6G5f8v3g/v4+G1+tVmGOaA1E8ZJ9Gd0RdrtdmCP63ncudzaI+IUdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA5ev7fnCb4/EY5tjv99n44XB48hwl46zxrI+Pj9n4x48fwxx/+tOfsvHf//73YY63b9+GbS7NZPL0/+coWQPPSdd1g3NE+67GOVVD9Kwlc1Ejx2yWvx7M5/NsvOS8HWOdT6fTwW1q5HiuauzdsYwx1hp9jJGjlc+txjijd3bJOz1qE+3/6DxNKT5TF4tFmGO1Wg2Kp5TSer3OxqNxPlf/+Mc/svE//vGPYY67u7tsvGS9R2vxm2++CXNE6+Tq6iobv7m5Cft4+fJlNh6tw5TifVWyt52nnAO/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/1vXdWGbvu9HGAk/RfSZlHxmx+MxGz8cDmGO/X6fjT8+PoY5drvdoD5KTCb5OvlisQhzrFarbHy9Xoc5rq6usvF3796FOd68eRO2uTTR51ci2hPT6XRwH2No5TyOzo+U4vfLGO+fkj5qjDNqU7KGozmdzfLXh5LPpMb6qTFf0XyUzFeNcwFqKFnzY/TTyllWY39H7+ToPJzP52Ef0d0vuvelFN/9rq+vwxw3NzeDx/EcRXfmh4eHMEf0HaNkHUWfz2azCXNE6yjqY7vdhn1EbV6/fh3miNZqtC9TGu+8hCHcMAEAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49AC5f3/dhm8PhkI3v9/swx8PDQzZ+d3cX5oja7Ha7MEdksVhk45vNJszx5s2bweOYz+fZ+Ndffx3muLm5GTyOc9N1XRM5SvbVU2vlOabT6eB+jsdjmCN63skk/z+wkmeN+iiZ86HjTCmej9ksf30omc+hY0gpfpYa81WSo2ROL1GNM2AMY4yzlXdDjX7G2Dcle2aMvVny/ojaROdhSR/RnWy1WoU5rq+vs/HXr1+HOaI22+02zPEcffvtt9n4b37zmzBHtJ6Xy2WYY71eZ+Ml3zGitRb1UTLOaBwl44zWYrSnUorPh+f6bv9vzuWdf2msQgAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA+ev7Phs/Ho9hjqjNfr8Pc9zd3WXjHz58GJxjOp1m4/P5POxjsVhk47NZvC2jHMvlMsyx2Wyy8c+fP4c5VqtV2ObSTCb5/3NE+6GWrusG5xhrrDklzzHGOKO9XTKO6Byr8awlOaI20RouGcfQeGmboWrMV40cUEsra23ovil5juisKjnLauSI3g/RvS26s6UU36fW63WYY7vdZuOvXr0Kc7x79y4bf/36dZijhbvF2KJ5KfkuFK2zkjt3tE5KckTfD2p8B4najPV9aoz3vzsGQ/mFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANCQ2akH8Fx1XTc4R9/3o+SI2gyNp5TS8XjMxne7XZjj9vY2G//+++/DHHd3d9n4ZrPJxq+vr8M+lstlNj6dTsMcq9UqG4/GmVJK2+02G398fAxzLBaLsM2lqbF3GV/0uZWcU1GOGn1EZ2HJ+qsxjhrn+tA+aqgxX7X64emMMf/nsk5qrPkxnrXGOCeT+HcHUZuSO9dslv+6FMVL7kpRm/V6HeaI7nUvXrwIc3zxxRfZ+JdffhnmGONsb00099G7PaWU5vN5Nh7d/Uva1MgRrdXoOUraRHuqpE3J3o7Oh7HuEJDjF3YAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhs1MPgKfV9/3J+zgej2GOqM3hcAhz7Ha7bPzu7i7McXt7m413XZeNr1arsI/IcrkM28xm+a1bMl/X19fZeMnnFo3jOYrWCM9XydqI2pSc6dHerbFGx3i3jMWepRWtrMUaZ9XQeEopTSb53xVE8ZRSmk6ng+IpxXedKL5YLMI+ojbr9TrMcXV1lY3f3NyEOV6+fJmNv3r1KsxxSe+HUtHc11jvJd8xou8QJd8xorUYrff5fB72EeUo2Zcl+z9S45xqwbmMk5/HL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYAWdV2Xjfd9P9JInt7xeBzc5nA4DIqnlNJ+vx+co6RNJHrW+/v7bPzu7i7sY7PZZOOzWbwt1+t1Nj6ZjFOLH6uflkTnQytKzqkWnuVcxtmKkrmoMV9jvOdK3j9DjXVGWaOMJVprY50RQ8dRY5wl+ztqM51OwxzRvWyxWAyKpxTf61arVZgjul9eXV2FOa6vr7Pxm5ubMMcYZ3trttttNl6y3qO1OJ/PwxxRm5K1GI0jylGyp6I2JTmiOS3JEZ0PNc6pGmfhGIyjXc/v2zYAAAAANEzBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQkNmpB/BTdV2Xjfd9P9JI8s5lnCWOx2M2vtvtsvG7u7uwj9vb28E5DodDNj6ZxPXpqE30uUVjKOljNou35XK5zMbn8/ngcURruCTHJSqZl6FqnA9jjLNE9CytjLOGGuf+ueSooUYfrayfVsbBzzf0M7ykNVDyLFGbKF7jTlaSYzqdZuMld66oTXTnWiwWYR+r1Sob32w2YY6ozXa7DXNEba6ursIc0feGS1QyL5ForUbxlOK1WmO91xhn1KbkDKqRI3JJ5zrn6/l92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/Hxd143Sz+FwyMbv7++z8Y8fP4Z9/Pjjj9n47e1tmOPh4SEbn0zi+vRiscjGp9PpoL9PKaX5fD44x3K5HJwjmo+S+Sppw0831t6O9H0/OMe5PEuNcY7RR0mOFp61RCtro5Vx8J/5fOob4ywaGk+pzj0lurdF8ZTie1sUj+5sKaW0Wq2y8fV6Hea4uroaFE8ppc1mMzhH9L3hEpV8PpFoPZfsmdks/9W+ZL0P3TMl+zJ6lho5SuarhbOwRh9j5eA0fNsGAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGjI79QB4Wl3XPXkffd9n48fjMcxR0iYym+WX83a7DXOsVqtBfVxdXYV9RG2iMaSU0mKxGBRPKaXJJF+vj+KlbS5NC3tqLGM8a4ka89HKs1wK8wk/3Rj7pkYfY+Qo6SNqU+OeEt3rStosl8tsvORet16vs/HNZhPmiO6XJXfU6K5cMo7D4RC2uTTRGigRrdWSPTPG3b7GOKfT6eAcY5wx53Jmn4vn9Kw1Pb9v2wAAAADQMAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gtq7rwjZ935+8j5IcQ/soMZnENdvpdJqNL5fLbHy73Q4ex3q9DnPs9/ts/HA4hDmiOR06FymldH19nY2XPOtiscjGZ7N4a0fPUrJGS9YPP12N86GGGmdMDWPMR41nbWWcrawf4Kcp2bs19vcYOYbGU4rvGCV3kBo55vN5Nh7dyUruhtHdb7PZhDmurq6y8ZL7eNQm6iOllHa7Xdjm0kRroESNPRO1KVnvQ3PUGGcrOUq0cOdqYQw8Hd+2AQAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj0Afr6u6wa3mUzimu18Ps/G1+v14D42m002vt/vwxzH43FQvET0LLNZvKWi+Vwul4NzlMx5SZsxcpybkn0X6fu+wkieXo1nrWGM+TqXZ21lnMDlGuOcqXGHLckxnU6z8ZJ7W9Qmipfc61arVTYe3ZNL2kT39Vo5ojm/RCXraKgae6ZEdLevsS9byTG0j7FynIvn9Kxjen7ftgEAAACgYQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7AKXRdl433fX8RfZS0mUzimu10Os3Gl8tlNj6bxcvscDhk48fjMcxR0iYSfS5jzGcUL+1n6DhK1hc/zxhzW+OMaUUra3GMOW3lWYdqZf1dynzCuRlj77VyD57P54PiKcV36dVqFeaI2qzX68HjWCwWYY5Wzv8x1biX1+ijxr6LcgyN1xjDmP2MkaOFPmiXX9gBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADZmdegD8d13XZeN93w/OUWI6nQ7qI/r7lFI6HA4/aUz/STQfJfM1VMl8R20mk7iOHuWoMY4SNXLwNFr5bMbYd2NpYU7PZT5bmKsxPbfnhaFrfqx7yhjjiO5tJffgqE2NHCX3y8jxeAzb1LjTn5uSzydyLuv9qf/+0nK00Ect5zTWS+IXdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gRV3XZeN93z95HyX9lOSoIeqnxnxNJvnacUmOGp/LGGp8bkM/kxp9QIlW1tG5nA+RVuYTeN6GnqmXdK+rIXrW4/EY5tjv99n4/f19mOPz58/Z+HQ6DXPc3d1l45vNJsxxbqLvMTW08r3vqf++Vo5W+jmXe9u5jPM58gs7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCFd3/f9qQcBAAAAAPx/fmEHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA35f5EOqgphU/Q8AAAAAElFTkSuQmCC",
|
| 193 |
+
"text/plain": [
|
| 194 |
+
"<Figure size 1600x400 with 4 Axes>"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"output_type": "display_data"
|
| 199 |
+
}
|
| 200 |
+
],
|
| 201 |
+
"source": [
|
| 202 |
+
"import matplotlib.pyplot as plt\n",
|
| 203 |
+
"import torch\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"batch = next(iter(dataloader)) # batch shape (B, C, H, W)\n",
|
| 206 |
+
"imgs = batch[:4] # take first 4\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
|
| 209 |
+
"for i, img in enumerate(imgs):\n",
|
| 210 |
+
" img = img.detach().cpu()\n",
|
| 211 |
+
" if img.ndim == 3: # (C,H,W) -> (H,W,C)\n",
|
| 212 |
+
" img = img.permute(1, 2, 0)\n",
|
| 213 |
+
" img = (img + 1) / 2 # convert from [-1,1] to [0,1]\n",
|
| 214 |
+
" arr = img.numpy().squeeze()\n",
|
| 215 |
+
" axs[i].imshow(arr, cmap='gray' if arr.ndim == 2 else None)\n",
|
| 216 |
+
" axs[i].axis('off')\n",
|
| 217 |
+
"plt.show()"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"cell_type": "markdown",
|
| 222 |
+
"id": "06c1dff6",
|
| 223 |
+
"metadata": {},
|
| 224 |
+
"source": [
|
| 225 |
+
"**Create a Architecture**"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": 76,
|
| 231 |
+
"id": "3ee088b3",
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"outputs": [],
|
| 234 |
+
"source": [
|
| 235 |
+
"from diffusers import UNet2DModel\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"model = UNet2DModel(\n",
|
| 238 |
+
" sample_size=config.image_size, # the target image resolution\n",
|
| 239 |
+
" in_channels=1, # the number of input channels, 3 for RGB images\n",
|
| 240 |
+
" out_channels=1, # the number of output channels\n",
|
| 241 |
+
" layers_per_block=2, # how many ResNet layers to use per UNet block\n",
|
| 242 |
+
" block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block\n",
|
| 243 |
+
" down_block_types=(\n",
|
| 244 |
+
" \"DownBlock2D\", # a regular ResNet downsampling block\n",
|
| 245 |
+
" \"DownBlock2D\",\n",
|
| 246 |
+
" \"DownBlock2D\",\n",
|
| 247 |
+
" \"DownBlock2D\",\n",
|
| 248 |
+
" \"AttnDownBlock2D\", # a ResNet downsampling block with spatial self-attention\n",
|
| 249 |
+
" \"DownBlock2D\",\n",
|
| 250 |
+
" ),\n",
|
| 251 |
+
" up_block_types=(\n",
|
| 252 |
+
" \"UpBlock2D\", # a regular ResNet upsampling block\n",
|
| 253 |
+
" \"AttnUpBlock2D\", # a ResNet upsampling block with spatial self-attention\n",
|
| 254 |
+
" \"UpBlock2D\",\n",
|
| 255 |
+
" \"UpBlock2D\",\n",
|
| 256 |
+
" \"UpBlock2D\",\n",
|
| 257 |
+
" \"UpBlock2D\",\n",
|
| 258 |
+
" ),\n",
|
| 259 |
+
")"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "code",
|
| 264 |
+
"execution_count": 77,
|
| 265 |
+
"id": "8a703029",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"outputs": [
|
| 268 |
+
{
|
| 269 |
+
"name": "stdout",
|
| 270 |
+
"output_type": "stream",
|
| 271 |
+
"text": [
|
| 272 |
+
"Input shape: torch.Size([1, 1, 32, 32])\n",
|
| 273 |
+
"Output shape: torch.Size([1, 1, 32, 32])\n"
|
| 274 |
+
]
|
| 275 |
+
}
|
| 276 |
+
],
|
| 277 |
+
"source": [
|
| 278 |
+
"sample_image = dataset.__getitem__(0).unsqueeze(0) # add batch dimension\n",
|
| 279 |
+
"print(\"Input shape:\", sample_image.shape)\n",
|
| 280 |
+
"print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "markdown",
|
| 285 |
+
"id": "8d3d9af1",
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"source": [
|
| 288 |
+
"**Create a scheduler**"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": 78,
|
| 294 |
+
"id": "22c2ce21",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [
|
| 297 |
+
{
|
| 298 |
+
"data": {
|
| 299 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAGdCAYAAAAPGjobAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAO6BJREFUeJzt3XtwVfW5//FPCCQh5EZIyAUChJuAXKxRMdpalJSAHUeUcbR1RrQWRw90qvQmHS9V25PWniraUjhnWqFOS72cqXjsBdQoWG2QgnKQiylgIOGScM2V3EjW7w8P+2d0Qb5PsmOyNu/XzJohO0+e/d17rb0fvt+99rOiPM/zBAAAIka/3h4AAAAIL4o7AAARhuIOAECEobgDABBhKO4AAEQYijsAABGG4g4AQIShuAMAEGH69/YAPq29vV2HDh1SYmKioqKiens4AAAjz/NUV1en7Oxs9evXc3PIpqYmtbS0dDtPTEyM4uLiwjCivqPPFfdDhw4pJyent4cBAOimiooKDR8+vEdyNzU1KTc3V5WVld3OlZmZqbKysogq8D1W3JctW6af//znqqys1LRp0/TLX/5Sl112Wad/l5iYKEl6/fXXNWjQIKf7On78uPO4Ro4c6RwrSUePHnWOTU1NNeUeNmyYc+y//vUvU+6ysjLn2IsuusiU+8MPPzTFZ2dnO8dWV1ebco8bN845tr6+3pS7rq7OOba9vd2Uu7W11RSfl5fnHLt7925T7tjY2B7LnZaW5hw7YMAAU+62tjbn2NOnT5ty9+9ve2u0vPZjYmJMuf/5z386x06bNs2U23KMDxkyxDm2vr5eeXl5offzntDS0qLKykqVl5crKSmpy3lqa2s1YsQItbS0UNw78/zzz2vx4sVasWKFpk+frqVLl6qwsFClpaUaOnToOf/2zFL8oEGDlJCQ4HR/TU1NzmOzHmyNjY09lttyQLo+F2fEx8c7x1rH7fqfrjMsY7cWPcvYrR/zWC67YC3u1qXEnjxWLG9oluPKOpaeLO7W48o6FstxaC3uPflathzjXSnUn8dHq0lJSd0q7pGqRz4MeeKJJ7RgwQLdcccdmjRpklasWKH4+Hg988wzPXF3AIDzlOd53d4iUdiLe0tLi7Zs2aKCgoL/fyf9+qmgoEAlJSWfiW9ublZtbW2HDQAAFxR3f2Ev7seOHVNbW5syMjI63J6RkeF74kNRUZGSk5NDGyfTAQBcUdz99fr33JcsWaKamprQVlFR0dtDAgAg0MJ+Ql1aWpqio6NVVVXV4faqqiplZmZ+Jj42NtZ0ti4AAGd0d/bNzN1RTEyM8vLyVFxcHLqtvb1dxcXFys/PD/fdAQDOYyzL++uRr8ItXrxY8+fP1yWXXKLLLrtMS5cuVUNDg+64446euDsAAPAJPVLcb775Zh09elQPPfSQKisrddFFF2nt2rWfOckOAIDuYFneX491qFu0aJEWLVrU5b+3NLGxdB6zdLOTbF2Zampqeiz3iRMnTLktnaq2bdtmym3pOCfZuppZG+QcPnzYOdbaQdDS+MT6LQ/rsWLZR6NGjeqxsYwePdqU29IadMKECabcBw8edI61NrGxHuOWzorWLoyWJjbWDmulpaXOsadOnXKObWhoMI2jOyju/nr9bHkAABBefe7CMQAAuGLm7o/iDgAILIq7P5blAQCIMMzcAQCBxczdH8UdABBYFHd/FHcAQGBR3P3xmTsAABGGmTsAILCYufujuAMAAovi7q/PFvddu3Y5t12cPn26c97GxkbTOCw7PikpyZT72LFjzrGWFrvWeGtL0cTERFN8//7uh9nIkSNNufft2+cca2lXah3L6dOnTbkHDBhgio+OjnaOXbdunSm3pV1tcnKyKffQoUOdY5ubm025Bw8e7Bxrvaz09u3bTfGW53D8+PGm3Js2bXKOtbweJNtz2NbW5hwbqQUzSPpscQcAoDPM3P1R3AEAgUVx98fZ8gAARBhm7gCAwGLm7o/iDgAItEgt0N3BsjwAABGGmTsAILBYlvdHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/vpscc/KylJCQoJTrKWP+vHjx03jyM7Odo6trq425T516pRz7JgxY0y5W1panGPr6upMuV17/p9h6UXf2tpqyp2VleUce/jwYVPuDz74wDn2wgsvNOW2PoeWsY8bN86UOz093Tm2tLTUlNv1NSxJBw4cMOW2vDatvfytfe4t1y2w7nuLuLg4U/yQIUOcYy3XT7C+p3QHxd0fy/IAAESYPjtzBwCgM8zc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f322uCckJDi3rrS0LB02bJhpHJaWsj3ZUnTkyJGm3Pv27XOO7d/fdhi8+uqrpvhZs2Y5x27evNmUe/z48c6xu3btMuVOS0tzjm1razPltj7nFtZj5dChQ86xOTk5ptyWN86xY8eacldVVTnHWtvJDh8+3BTf1NTkHGttg2xpD5ybm2vKbWn5a2lt297ebhoHwq/PFncAADrDzN0fxR0AEFgUd398FQ4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACLVILdHewLA8AQIRh5g4ACCyW5f1R3AEAgUVx99eni3tUVJRT3ODBg51ztrS0mMZg6S++e/duU+7k5GTnWGtv7NTUVOdY63Ny7bXXmuI/+OAD59hLLrnElNvSp3vmzJmm3JZ+4Tt27DDltl7joKGhwTl248aNptyWnu7WN0JLz/3nn3/elHvu3LnOsTt37jTlrqmpMcVb+q5bX29XX321c6zlmhKS7biy7MtTp06ZxtEdFHd/fOYOAECECXtx/9GPfqSoqKgO24QJE8J9NwAAhGbu3dkiUY/M3C+88EIdPnw4tL399ts9cTcAgPNcbxT3t956S9ddd52ys7MVFRWlNWvWdPo369ev18UXX6zY2FiNHTtWq1atsj9Ygx4p7v3791dmZmZos3xuDQBAX9bQ0KBp06Zp2bJlTvFlZWX66le/qquvvlpbt27Vvffeq29+85tat25dj42xR06o2717t7KzsxUXF6f8/HwVFRVpxIgRvrHNzc0dTharra3tiSEBACJQb5xQN2fOHM2ZM8c5fsWKFcrNzdUvfvELSdLEiRP19ttv68knn1RhYaH5/l2EfeY+ffp0rVq1SmvXrtXy5ctVVlamL33pS6qrq/ONLyoqUnJycmjLyckJ95AAABEqXMvytbW1HTbrN5TOpaSkRAUFBR1uKywsVElJSdju49PCXtznzJmjm266SVOnTlVhYaH++te/qrq6Wi+88IJv/JIlS1RTUxPaKioqwj0kAADOKScnp8NEs6ioKGy5KysrlZGR0eG2jIwM1dbWqrGxMWz380k9/j33lJQUjR8/Xnv27PH9fWxsrGJjY3t6GACACBSuZfmKigolJSWFbg96Xerx77nX19dr7969ysrK6um7AgCcZ8K1LJ+UlNRhC2dxz8zMVFVVVYfbqqqqlJSUpIEDB4btfj4p7MX9u9/9rjZs2KB9+/bpH//4h2644QZFR0fra1/7WrjvCgCAPi8/P1/FxcUdbnvttdeUn5/fY/cZ9mX5AwcO6Gtf+5qOHz+u9PR0ffGLX9TGjRuVnp5uyrN9+3bFx8c7xebm5jrntcRK0uuvv+4cO3ToUFPu6Oho59i9e/eacl955ZXOse+//74pt7W1pOUkyU8ui7mwPC/t7e2m3Pv373eOnTx5sim3tV3t2b5t4sf6WrM858ePHzflbm1tdY6dN2+eKbflOJw0aZIp98GDB03x/fq5z5Msr3tJZz0Z2Y+lpbX08WfBrixtcy1tbburN86Wr6+v7/BRc1lZmbZu3arU1FSNGDFCS5Ys0cGDB/Xss89Kku6++2796le/0ve//3194xvf0BtvvKEXXnhBf/nLX7o87s6Evbg/99xz4U4JAICv3ijumzdv7tDzf/HixZKk+fPna9WqVTp8+LDKy8tDv8/NzdVf/vIX3XfffXrqqac0fPhw/eY3v+mxr8FJffzCMQAAnEtvFPcZM2ac8+/8us/NmDHDvEraHVw4BgCACMPMHQAQWFzy1R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/FHcAQKBFaoHujj5b3CdPnqzExESn2JMnTzrn3bRpk2kcn2wxGM5xSNLp06edY0tLS025LT3ArT3XExISTPGW/C+//LIpt6Wnu/UqT5ae64cPHzblHjdunCk+Jiamx8ZiuSZCc3OzKbelj7r1tXnBBRc4x1qPcWuPdkvfdcvrXpLzNTYknfXS2mdjef1YckdFRZnGgfDrs8UdAIDOsCzvj+IOAAgsirs/ijsAILAo7v74KhwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEmD47c29qalL//m7Da2xsdM6blpZmGscrr7ziHPuFL3zBlDs7O9s51tIiVLL9b7S2ttaU2/ocWlpz5ufnm3K3trY6x+7du9eUOyMjwzl2wIABptxNTU2m+L///e/OsdbjsKamxjn21KlTptyW53zKlCmm3A0NDc6xJ06cMOUeM2aMKd71vUqSDh48aMptaT1t3T87duxwjrW0HraOozuYufvrs8UdAIDOUNz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPujuAMAAovi7o/P3AEAiDDM3AEAgcXM3R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7++mxxb2lpce5JPnXqVOe8JSUlpnGMHDnSOTYlJcWU+8iRI86xlj7nkq3ffmJioim3lSW/te/2iBEjnGOtz6FFeXm5Kf706dOm+KysLOfYw4cPm3Jb+uJb3whzcnJ6LLfFoEGDTPHW/Wm53oK19/+uXbucY0ePHm3KbTkOLe9X0dHRpnEg/PpscQcAoDPM3P1R3AEAgRapBbo7OFseAIAIw8wdABBYLMv7o7gDAAKL4u6P4g4ACCyKuz8+cwcAIMIwcwcABBYzd38UdwBAYFHc/bEsDwBAhGHmDgAILGbu/vpscT958qRzb/n6+nrnvMOGDTONIzk52Tm2tbXVlPvYsWPOsXV1dabceXl5zrEfffSRKXd8fLwpvrq62jn2wIEDptyxsbHOsW1tbabcO3bscI619vTu39/20ouLi3OOtR6Hlv1pHbflebH2xLf0i1+3bp0pd1JSkinesn+s10+IiYlxjrW8p0hSQ0ODc6zlOg6W9+Tuorj7Y1keAIAIYy7ub731lq677jplZ2crKipKa9as6fB7z/P00EMPKSsrSwMHDlRBQYF2794drvECABByZubenS0SmYt7Q0ODpk2bpmXLlvn+/vHHH9fTTz+tFStW6N1339WgQYNUWFiopqambg8WAIBPorj7M3/mPmfOHM2ZM8f3d57naenSpXrggQd0/fXXS5KeffZZZWRkaM2aNbrlllu6N1oAANCpsH7mXlZWpsrKShUUFIRuS05O1vTp01VSUuL7N83Nzaqtre2wAQDggpm7v7AW98rKSklSRkZGh9szMjJCv/u0oqIiJScnh7acnJxwDgkAEMEo7v56/Wz5JUuWqKamJrRVVFT09pAAAAFBcfcX1uKemZkpSaqqqupwe1VVVeh3nxYbG6ukpKQOGwAA6LqwFvfc3FxlZmaquLg4dFttba3effdd5efnh/OuAABg5n4W5rPl6+vrtWfPntDPZWVl2rp1q1JTUzVixAjde++9+vGPf6xx48YpNzdXDz74oLKzszV37txwjhsAADrUnYW5uG/evFlXX3116OfFixdLkubPn69Vq1bp+9//vhoaGnTXXXepurpaX/ziF7V27VpTe0ZJio6OVnR0tFOsa5taSerXz7ZY8cEHHzjHWluQWlrh1tTUmHJbWopOnjzZlPvo0aOm+LS0NOfYQ4cOmXJHRUU5x1pfxJaTOy0tQiXbuKWPv1Xiytqq+MSJE86xQ4YMMeV+8803nWPHjh1rym1pjnXNNdeYclufQ0t8YmKiKbflNWF9f8vKyjLFIzjMxX3GjBnnfJOMiorSo48+qkcffbRbAwMAoDPM3P312QvHAADQGYq7v17/KhwAAAgvZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgRaps+/uoLgDAAKLZXl/LMsDAAKrt9rPLlu2TKNGjVJcXJymT5+uTZs2nTV21apVioqK6rBZG7tZUdwBADB4/vnntXjxYj388MN67733NG3aNBUWFurIkSNn/ZukpCQdPnw4tO3fv79Hx0hxBwAEVm/M3J944gktWLBAd9xxhyZNmqQVK1YoPj5ezzzzzFn/JioqSpmZmaEtIyOjOw+7U332M/eEhAQlJCQ4xaanpzvntfTRlmz94hsbG025XXvnS/ZxW5SWlpriY2NjTfENDQ3OsW1tbabclh7tlv7s1tzWfvvWqyRu27bNOXbw4MGm3Jbj1tqL3NLr3Pr6SU1NdY6trKw05bY6fvy4c+yAAQNMuVNSUpxje/LaGcnJyc6xltd8d4XrM/fa2toOt8fGxvq+17W0tGjLli1asmRJ6LZ+/fqpoKBAJSUlZ72f+vp6jRw5Uu3t7br44ov17//+77rwwgu7PO7OMHMHAJz3cnJylJycHNqKiop8444dO6a2trbPzLwzMjLO+p/ICy64QM8884xefvll/f73v1d7e7uuuOIKHThwIOyP44w+O3MHAKAz4Zq5V1RUKCkpKXS7dYXyXPLz8zus1l1xxRWaOHGi/vM//1OPPfZY2O7nkyjuAIDACldxT0pK6lDczyYtLU3R0dGqqqrqcHtVVZUyMzOd7nPAgAH6whe+oD179tgH7IhleQAAHMXExCgvL0/FxcWh29rb21VcXOx8Lk1bW5s++OAD8zksFszcAQCB1RtNbBYvXqz58+frkksu0WWXXaalS5eqoaFBd9xxhyTptttu07Bhw0Kf2z/66KO6/PLLNXbsWFVXV+vnP/+59u/fr29+85tdHndnKO4AgMDqjeJ+88036+jRo3rooYdUWVmpiy66SGvXrg2dZFdeXt7hmwsnT57UggULVFlZqcGDBysvL0//+Mc/NGnSpC6PuzMUdwBAYPVW+9lFixZp0aJFvr9bv359h5+ffPJJPfnkk126n67iM3cAACIMM3cAQGBx4Rh/FHcAQGBR3P312eK+c+dODRw40Cn22muvdc5rbW86btw459jt27ebch86dMg51trO8e2333aOHTZsmCn3p7/f2Zn6+voeG0tra6tzrGs74zMs7U0trTmljxtmWFha4Vrbm1r2z8mTJ025La1wa2pqTLktr2Vr++aYmBhT/KBBg5xjre9Bn26LGq5xSNKUKVOcY6urq51jLa210TP6bHEHAKAzzNz9UdwBAIFFcffH2fIAAEQYZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgcXM3R/FHQAQWBR3fxR3AECgRWqB7g4+cwcAIMIwcwcABBbL8v76bHGfMGGCc59kS991a8/ozZs3O8fGx8ebcqenpzvHpqWlmXK3tLQ4x/brZ1vASUpKMsVPnjzZOXbfvn2m3JY+3f372w73gwcPOsfu3r3blDs2NtYUb+n/bomVpKysLOdY6xvhgQMHnGOPHz9uyt3e3u4ce/nll5tyv/POO6Z4y2vfehzGxcU5xzY2NppyW45Dy7g/z97yFHd/LMsDABBh+uzMHQCAzjBz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/luUBAIgwzNwBAIHFzN0fxR0AEFgUd38UdwBAYFHc/fXZ4h4VFeXcFrW5udk5ryVWsu14a/vMMWPGOMeeOHHClNvSmtNq/Pjxpvh//etfzrGuLYfPsLbytLC0/D169Kgpd21trSl+4sSJzrGWlq+Src2utcXywIEDnWOtbZAzMjKcY63tga0tli1traOioky5U1JSnGNPnz5tym1hOWZPnTrVY+OAmz5b3AEA6Awzd38UdwBAYFHc/Zm/CvfWW2/puuuuU3Z2tqKiorRmzZoOv7/99tsVFRXVYZs9e3a4xgsAADphnrk3NDRo2rRp+sY3vqEbb7zRN2b27NlauXJl6Gfr5S0BAHDBzN2fubjPmTNHc+bMOWdMbGysMjMzuzwoAABcUNz99UiHuvXr12vo0KG64IILdM8995zzLPLm5mbV1tZ22AAAQNeFvbjPnj1bzz77rIqLi/Wzn/1MGzZs0Jw5c9TW1uYbX1RUpOTk5NCWk5MT7iEBACLUmZl7d7ZIFPaz5W+55ZbQv6dMmaKpU6dqzJgxWr9+vWbOnPmZ+CVLlmjx4sWhn2traynwAAAnLMv76/ELx4wePVppaWnas2eP7+9jY2OVlJTUYQMAAF3X499zP3DggI4fP66srKyevisAwHmGmbs/c3Gvr6/vMAsvKyvT1q1blZqaqtTUVD3yyCOaN2+eMjMztXfvXn3/+9/X2LFjVVhYGNaBAwBAcfdnLu6bN2/W1VdfHfr5zOfl8+fP1/Lly7Vt2zb97ne/U3V1tbKzszVr1iw99thj5u+6Hzt2zLk/cUNDg3Pe6Oho0zhSU1OdY62rE5Y+0OXl5abclq8iNjY2mnJbnm9JGjJkiHOstVe85Xmx5j7bSaB+rG8Q1h7gpaWlzrGVlZWm3JZrBVhfx5b9M3r0aFNuS098a9966/686KKLnGP3799vyl1WVuYcm5CQYMptOa5uuOEG59i6ujrTOLorUgt0d5iL+4wZM875RK5bt65bAwIAAN1Db3kAQGCxLO+P4g4ACCyKu78e/yocAAD4fDFzBwAEFjN3fxR3AEBgUdz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P312eI+aNAgDRo0yCk2Li7OOW98fLxpHJZ+19ae6zt37nSOtV7jfuDAgc6x1oN7wIABpnhLf/H09HRTbkuP9mHDhply/+///q9z7Fe+8hVT7rNdAvlsLH3urX29Xa/hIEl79+415bZcV2DLli2m3JZrORw/ftyUe9SoUab44uJi51jrcWh5DjMyMky5m5ubnWOPHDniHFtfX28aR3dQ3P312eIOAEBnKO7++MwdAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdX58t7vHx8c7tZysrK53zpqSkmMZhabdpbctqaed44sQJU+6JEyf2yDgkW7tfScrMzHSOtbRClaTY2Fjn2I8++siU23KsvPrqq6bcl19+uSne0pp19+7dptzjxo1zjh0/frwpt4W1tW3//u5vX8nJyabc1mM8LS3NOdZyzEpSv37un55a973lfcJyDFpfxwi/PlvcAQDoDDN3fxR3AEBgUdz9cbY8ACDQzhT4rmxdtWzZMo0aNUpxcXGaPn26Nm3adM74F198URMmTFBcXJymTJmiv/71r12+bxcUdwAADJ5//nktXrxYDz/8sN577z1NmzZNhYWFZz1/6R//+Ie+9rWv6c4779T777+vuXPnau7cudq+fXuPjZHiDgAIrO7M2rs6e3/iiSe0YMEC3XHHHZo0aZJWrFih+Ph4PfPMM77xTz31lGbPnq3vfe97mjhxoh577DFdfPHF+tWvftXdh39WFHcAQGCFq7jX1tZ22Jqbm33vr6WlRVu2bFFBQUHotn79+qmgoEAlJSW+f1NSUtIhXpIKCwvPGh8OFHcAwHkvJydHycnJoa2oqMg37tixY2pra1NGRkaH2zMyMs76tezKykpTfDhwtjwAILDCdbZ8RUWFkpKSQrdb+xH0NRR3AEBghau4JyUldSjuZ5OWlqbo6GhVVVV1uL2qquqsDbsyMzNN8eHAsjwAAI5iYmKUl5en4uLi0G3t7e0qLi5Wfn6+79/k5+d3iJek11577azx4cDMHQAQWL3RxGbx4sWaP3++LrnkEl122WVaunSpGhoadMcdd0iSbrvtNg0bNiz0uf23v/1tffnLX9YvfvELffWrX9Vzzz2nzZs367/+67+6PO7O9Nni7nme2tvbnWI/faLCuVg/R9m5c6dzrGsv/DNGjRrlHFtfX2/KbemjbuldLUk1NTWm+D179jjHWpep9u/f7xxr6f8tScOHD++x3Pv27TPFp6enO8dajivJ9pzv2rXLlHvgwIHOsRdeeKEpt+U1cfr0aVPugwcPmuJbW1udY637x3KMW44TSWpsbHSOtVyDwPp+1R29UdxvvvlmHT16VA899JAqKyt10UUXae3ataFaVF5e3uF99YorrtDq1av1wAMP6Ic//KHGjRunNWvWaPLkyV0ed2f6bHEHAKCvWrRokRYtWuT7u/Xr13/mtptuukk33XRTD4/q/6O4AwACi97y/ijuAIDAorj7o7gDAAKL4u6Pr8IBABBhmLkDAAKLmbs/ijsAILAo7v5YlgcAIMIwcwcABBYzd38UdwBAYFHc/fXZ4j527FinK/RIH7f6c2VpEylJU6ZMcY5NSEgw5baMOzs725Tb8jirq6tNuaOjo03xluewtrbWlHv69OnOsda2uaWlpaZ4iy984QumeMtYrC18LS2Wy8rKTLlTU1OdY61vsnv37nWOLSgoMOU+dOiQKd7SxtV6XFneJ5qbm025LS2z6+rqnGNPnTplGgfCr88WdwAAOsPM3Z/phLqioiJdeumlSkxM1NChQzV37tzP/C+0qalJCxcu1JAhQ5SQkKB58+Z95jq2AACEw5ni3p0tEpmK+4YNG7Rw4UJt3LhRr732mlpbWzVr1iw1NDSEYu677z698sorevHFF7VhwwYdOnRIN954Y9gHDgAA/JmW5deuXdvh51WrVmno0KHasmWLrrrqKtXU1Oi3v/2tVq9erWuuuUaStHLlSk2cOFEbN27U5ZdfHr6RAwDOeyzL++vW99zPnKB05qSZLVu2qLW1tcPJKxMmTNCIESNUUlLim6O5uVm1tbUdNgAAXLAs76/Lxb29vV333nuvrrzyytAF5ysrKxUTE6OUlJQOsRkZGaqsrPTNU1RUpOTk5NCWk5PT1SEBAM5DFPbP6nJxX7hwobZv367nnnuuWwNYsmSJampqQltFRUW38gEAcL7r0lfhFi1apD//+c966623NHz48NDtmZmZamlpUXV1dYfZe1VV1Vm/exsbG6vY2NiuDAMAcJ7jM3d/ppm753latGiRXnrpJb3xxhvKzc3t8Pu8vDwNGDBAxcXFodtKS0tVXl6u/Pz88IwYAID/w2fu/kwz94ULF2r16tV6+eWXlZiYGPocPTk5WQMHDlRycrLuvPNOLV68WKmpqUpKStK3vvUt5efnc6Y8AACfE1NxX758uSRpxowZHW5fuXKlbr/9dknSk08+qX79+mnevHlqbm5WYWGhfv3rX4dlsAAAfBLL8v5Mxd3lSYiLi9OyZcu0bNmyLg9K+rjHeE886f37204zaG9vd449duyYKbelR7u1L7q1/7uF9Tk8ffq0c2x9fb0pt6Wn94kTJ0y5ExMTnWM/2cjJxQsvvGCKz8vLc461doRMTk52jo2Li+ux3NavwVqu5XDw4EFT7g8//NAUb+nn39bWZso9ZMgQ51jr4xw3bpxzrOU9pSfffz6N4u6P67kDABBhuHAMACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj767PFfffu3Ro0aJBT7MmTJ53zDhs2zDSOuro651hLm0hJn7mq3rns37/flLulpcU5NjU11ZTb2gr3+PHjzrGTJ0825U5LS3OO/fvf/27KbWmFm5GRYcpt1dTU5BxreT1I0uDBg51jre1nq6urnWPT09NNuS2Xit60aZMp9/jx403xFpbXpmR7Di2vByvX92P0DX22uAMA0Blm7v4o7gCAwKK4++NseQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/ijsAILAo7v74zB0AgAjDzB0AEGiROvvujj5b3OPj4xUfH+8Um5yc7Jz36NGjpnFYep1HRUWZcu/Zs8c5trGx0ZTb0nf70KFDptzWXvTt7e3OsceOHTPlPnDggHPskSNHTLmHDh3qHNvW1mbKbXlOJGnDhg3OsbNmzTLl3rFjh3Osdd9b+txb9qVk6/2fkpJiyt3a2mqKj46Odo4tKSkx5b788sudY6196xMSEpxjLa9N6/tVd7As749leQAAIkyfnbkDANAZZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACi+Luj+IOAAgsiru/Plvc4+PjNWjQIKfY9PR057zDhg0zjcPScrF/f9vT6dpeV5IyMzNNuQ8ePOgcO2TIEFPuiooKU/yFF17oHLtv3z5T7tOnTzvHjhs3zpTb0mq1qqrKlNvaqnj69OnOsdb2pgMHDnSOTUxMNOWuq6tzjk1KSjLlbmpqco5taGgw5ba2qba8B02ZMsWUu7a21hRv8d577znHWlpaf57tZ+GvzxZ3AAA6w8zdH8UdABBYFHd/fBUOAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAesCJEyd06623KikpSSkpKbrzzjtVX19/zr+ZMWOGoqKiOmx33323+b6ZuQMAAqsvz9xvvfVWHT58WK+99ppaW1t1xx136K677tLq1avP+XcLFizQo48+GvrZ0s30DIo7ACCw+mpx37Vrl9auXat//vOfuuSSSyRJv/zlL3XttdfqP/7jP5SdnX3Wv42Pjze3HP+0Plvc+/fv79yrfcCAAc55d+zYYRpHdHS0c6y1R7uFpQe0JMXFxTnHWsdt7dNtGXtubq4pt6WH/u7du025R48ebYq3qKmpMcVb/uduOWYlW597a0/81tZW59gjR46Ycre0tDjHWl4Pkq1XvCSVl5c7x1p6+UvS4MGDnWN37txpyp2fn+8ca+m3b9nv3RWu4v7pHv6xsbGmfvqfVlJSopSUlFBhl6SCggL169dP7777rm644Yaz/u0f/vAH/f73v1dmZqauu+46Pfjgg+bZe58t7gAAfF5ycnI6/Pzwww/rRz/6UZfzVVZWaujQoR1u69+/v1JTU1VZWXnWv/v617+ukSNHKjs7W9u2bdMPfvADlZaW6k9/+pPp/k0n1BUVFenSSy9VYmKihg4dqrlz56q0tLRDTLhOBgAAwMWZ2XtXtjMqKipUU1MT2pYsWeJ7X/fff/9natyntw8//LDLj+Wuu+5SYWGhpkyZoltvvVXPPvusXnrpJe3du9eUxzRz37BhgxYuXKhLL71Up0+f1g9/+EPNmjVLO3fu7HB51nCcDAAAQGfCtSyflJTkdNnh73znO7r99tvPGTN69GhlZmZ+5qOm06dP68SJE6bP089c7nnPnj0aM2aM89+ZivvatWs7/Lxq1SoNHTpUW7Zs0VVXXRW6PRwnAwAA0Nekp6c7nZORn5+v6upqbdmyRXl5eZKkN954Q+3t7aGC7WLr1q2SpKysLNM4u/U99zMnBaWmpna4/Q9/+IPS0tI0efJkLVmyRKdOnTprjubmZtXW1nbYAABw0Z0l+e7O+s9l4sSJmj17thYsWKBNmzbpnXfe0aJFi3TLLbeEzpQ/ePCgJkyYoE2bNkmS9u7dq8cee0xbtmzRvn379D//8z+67bbbdNVVV2nq1Kmm++/yCXXt7e269957deWVV2ry5Mmh260nAxQVFemRRx7p6jAAAOexvvpVOOnjie6iRYs0c+ZM9evXT/PmzdPTTz8d+n1ra6tKS0tDE+CYmBi9/vrrWrp0qRoaGpSTk6N58+bpgQceMN93l4v7woULtX37dr399tsdbr/rrrtC/54yZYqysrI0c+ZM7d271/fzgiVLlmjx4sWhn2traz9z1iIAAEGTmpp6zoY1o0aN6vCfi5ycHG3YsCEs992l4r5o0SL9+c9/1ltvvaXhw4efM7azkwG6+11CAMD5qy/P3HuTqbh7nqdvfetbeumll7R+/XqnhiNdPRkAAIDOUNz9mYr7woULtXr1ar388stKTEwMfRE/OTlZAwcO1N69e7V69Wpde+21GjJkiLZt26b77ruvSycDAACArjEV9+XLl0v6uFHNJ61cuVK33357WE8GAACgM8zc/ZmX5c8lnCcDVFRUODe/sfRqvuaaa0zjeP/9951j6+rqTLktB1V7e7spt6Uv+smTJ025J0yYYIrfs2ePc6y1J/Unmyd1Ztq0aabc1dXVPRIrSU1NTab45uZm51hLD3Dp46/suDp06JApt6XnurX/+0cffeQca2n+IUnHjx83xSckJDjHVlVVmXJbji3ra9PCpcHLGa7XBQkHirs/essDAAKL4u6vW01sAABA38PMHQAQWMzc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3V+fLe7JycnOrUUbGhqc827fvt00jtOnTzvH9utn+5Rj7NixzrGWFqGStH//fufYSZMmmXJbWopKUnp6unOstYVvYmKic2xxcbEpd2ZmpnPsqVOnTLmtx+GoUaOcY62XUH7nnXecYy2vB0mKiopyjrW28LW0OLW8HiTpww8/NMVfeOGFzrFf+tKXTLnr6+udY2NiYky5LW12U1JSnGM/z/az8MceAAAEFjN3fxR3AEBgUdz9UdwBAIFFcffHV+EAAIgwzNwBAIEWqbPv7qC4AwACi2V5fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/fba4t7e3q7293Sl26NChznmTk5NN4ygtLXWObWlpMeWuqqpyjt2xY4cpd1ZWlnPsunXrTLnz8vJM8U1NTc6xgwcPNuX+6KOPnGOPHj1qyv3GG284x+bk5JhyR0dHm+Itvegt1yyQpOzsbOfYX/3qV6bco0ePdo5ta2sz5W5tbe2x3PHx8ab4kydPOsc2NzebctfW1jrHWp4TSTpy5Ihz7MiRI51jP8/e8hR3fyzLAwAQYfrszB0AgM4wc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj7Y1keAIAIw8wdABBYzNz9UdwBAIFFcffXZ4v7yJEjlZiY6BRraft64MAB0zgsbSutrW0t487MzDTlrq+vd44dNmyYKfeJEydM8YcOHXKOtbQSlqTKykrn2FdffdWU29La9pJLLjHltrTytLK0K5VsLZaHDBliyt3Q0OAce/jwYVNuS4vlgwcPmnJbWvJK0vDhw51jra2HLS2ZrW1fLfvTUgQjtWAGSZ8t7gAAdIaZuz+KOwAgsCju/jhbHgCACMPMHQAQWMzc/VHcAQCBRXH3R3EHAAQWxd0fn7kDABBhmLkDAAItUmff3UFxBwAEVncLe6T+x4BleQAAIgwzdwBAYDFz99dni3tMTIxiYmKcYi39xadOnWoax65du5xjLX20JVvPaGvfbdfnTpISEhJMuY8ePWqKHzNmjHNseXm5KXdVVZVzbGFhoSn37373O+fY6upqU+6BAwea4o8dO+Yca30OLf3Frdc4aGxsdI7Nyckx5bYct3v27DHlnjJliine8vo8fvy4Kfe0adOcY0+ePGnKPWjQIOdYy3Fl2e/dRXH3x7I8AAARxlTcly9frqlTpyopKUlJSUnKz8/X3/72t9Dvm5qatHDhQg0ZMkQJCQmaN2+eaWYFAIDFme+5d2eLRKbiPnz4cP30pz/Vli1btHnzZl1zzTW6/vrrtWPHDknSfffdp1deeUUvvviiNmzYoEOHDunGG2/skYEDAEBx92f6zP26667r8PNPfvITLV++XBs3btTw4cP129/+VqtXr9Y111wjSVq5cqUmTpyojRs36vLLLw/fqAEAwFl1+TP3trY2Pffcc2poaFB+fr62bNmi1tZWFRQUhGImTJigESNGqKSk5Kx5mpubVVtb22EDAMAFM3d/5uL+wQcfKCEhQbGxsbr77rv10ksvadKkSaqsrFRMTIxSUlI6xGdkZJzzbPaioiIlJyeHNusZswCA8xfF3Z+5uF9wwQXaunWr3n33Xd1zzz2aP3++du7c2eUBLFmyRDU1NaGtoqKiy7kAAOcXirs/8/fcY2JiNHbsWElSXl6e/vnPf+qpp57SzTffrJaWFlVXV3eYvVdVVZ3zu7GxsbGKjY21jxwAAPjq9vfc29vb1dzcrLy8PA0YMEDFxcWh35WWlqq8vFz5+fndvRsAAD6Dmbs/08x9yZIlmjNnjkaMGKG6ujqtXr1a69ev17p165ScnKw777xTixcvVmpqqpKSkvStb31L+fn5nCkPAOgRdKjzZyruR44c0W233abDhw8rOTlZU6dO1bp16/SVr3xFkvTkk0+qX79+mjdvnpqbm1VYWKhf//rXXRpYXV2dc6yl9ePevXtN45g1a5ZzrPXcA0t8VlaWKfeJEyecYy3tRyVpwIABpnhLS8y4uDhT7ry8POfYrVu3mnLfcMMNzrGWFsiSlJSUZIpPTU11jrW2E7a81nJzc025LW15refbWHLPnz/flLutrc0Uf9VVVznHWlvb7tu3zznW8nqQbO8T/fq5L/RaYtEzTMX9t7/97Tl/HxcXp2XLlmnZsmXdGhQAAC6YufvrsxeOAQCgMxR3f6ydAADQA37yk5/oiiuuUHx8/Gd6wJyN53l66KGHlJWVpYEDB6qgoEC7d+823zfFHQAQWH35bPmWlhbddNNNuueee5z/5vHHH9fTTz+tFStW6N1339WgQYNUWFiopqYm032zLA8ACKy+vCz/yCOPSJJWrVrlPJalS5fqgQce0PXXXy9JevbZZ5WRkaE1a9bolltucb5vZu4AgPPep69x0tzc/LmPoaysTJWVlR2u0ZKcnKzp06ef8xotfijuAIDACteyfE5OTofrnBQVFX3uj+XMV2ozMjI63N7ZNVr8sCwPAAiscC3LV1RUdOg/cba26Pfff79+9rOfnTPnrl27NGHChG6Nq7so7gCAwApXcU9KSnJqLvWd73xHt99++zljRo8e3aWxnLkOS1VVVYfGZVVVVbroootMuSjuAAA4Sk9PV3p6eo/kzs3NVWZmpoqLi0PFvLa2NnQVVos+V9zP/C+qvr7e+W8aGhqcYxsbG03jqa2tdY61jFmSTp061WO5Lc+Jpf1oT4/FehJLS0uLc6z1qySW3K2trabc1vj29nbnWMu4uzKWnsptbfnak8+JdSyW49b6HmQ5bi3vKdaxREVFmfN+Xg1i+mojmvLycp04cULl5eVqa2sLtcAeO3ZsqE30hAkTVFRUpBtuuEFRUVG699579eMf/1jjxo1Tbm6uHnzwQWVnZ2vu3Lm2O/f6mIqKCk8SGxsbG1vAt4qKih6rFY2NjV5mZmZYxpmZmek1NjaGfYzz58/3vb8333wzFCPJW7lyZejn9vZ278EHH/QyMjK82NhYb+bMmV5paan5vqP+L3mf0d7erkOHDikxMbHD/xRra2uVk5PzmZMeIg2PM3KcD49R4nFGmnA8Ts/zVFdXp+zs7B69iExTU5N5VcZPTEyM+aJVfV2fW5bv16+fhg8fftbfu570EHQ8zshxPjxGiccZabr7OJOTk8M4Gn9xcXERV5TDhe+5AwAQYSjuAABEmMAU99jYWD388MNnbSwQKXickeN8eIwSjzPSnC+PM9L1uRPqAABA9wRm5g4AANxQ3AEAiDAUdwAAIgzFHQCACBOY4r5s2TKNGjVKcXFxmj59ujZt2tTbQwqrH/3oR4qKiuqw9fYlA7vrrbfe0nXXXafs7GxFRUVpzZo1HX7veZ4eeughZWVlaeDAgSooKNDu3bt7Z7Dd0NnjvP322z+zb2fPnt07g+2ioqIiXXrppUpMTNTQoUM1d+5clZaWdohpamrSwoULNWTIECUkJGjevHmqqqrqpRF3jcvjnDFjxmf25913391LI+6a5cuXa+rUqaFGNfn5+frb3/4W+n0k7MvzXSCK+/PPP6/Fixfr4Ycf1nvvvadp06apsLBQR44c6e2hhdWFF16ow4cPh7a33367t4fULQ0NDZo2bZqWLVvm+/vHH39cTz/9tFasWKF3331XgwYNUmFhofkCL72ts8cpSbNnz+6wb//4xz9+jiPsvg0bNmjhwoXauHGjXnvtNbW2tmrWrFkdLgp033336ZVXXtGLL76oDRs26NChQ7rxxht7cdR2Lo9TkhYsWNBhfz7++OO9NOKuGT58uH76059qy5Yt2rx5s6655hpdf/312rFjh6TI2JfnvW51xf+cXHbZZd7ChQtDP7e1tXnZ2dleUVFRL44qvB5++GFv2rRpvT2MHiPJe+mll0I/t7e3e5mZmd7Pf/7z0G3V1dVebGys98c//rEXRhgen36cnvfxxSOuv/76XhlPTzly5IgnyduwYYPneR/vuwEDBngvvvhiKGbXrl2eJK+kpKS3htltn36cnud5X/7yl71vf/vbvTeoHjJ48GDvN7/5TcTuy/NNn5+5t7S0aMuWLSooKAjd1q9fPxUUFKikpKQXRxZ+u3fvVnZ2tkaPHq1bb71V5eXlvT2kHlNWVqbKysoO+zU5OVnTp0+PuP0qSevXr9fQoUN1wQUX6J577tHx48d7e0jdUlNTI0lKTU2VJG3ZskWtra0d9ueECRM0YsSIQO/PTz/OM/7whz8oLS1NkydP1pIlS8yXWu1L2tra9Nxzz6mhoUH5+fkRuy/PN33uwjGfduzYMbW1tSkjI6PD7RkZGfrwww97aVThN336dK1atUoXXHCBDh8+rEceeURf+tKXtH37diUmJvb28MKusrJSknz365nfRYrZs2frxhtvVG5urvbu3asf/vCHmjNnjkpKShQdHd3bwzNrb2/XvffeqyuvvFKTJ0+W9PH+jImJUUpKSofYIO9Pv8cpSV//+tc1cuRIZWdna9u2bfrBD36g0tJS/elPf+rF0dp98MEHys/PV1NTkxISEvTSSy9p0qRJ2rp1a8Tty/NRny/u54s5c+aE/j116lRNnz5dI0eO1AsvvKA777yzF0eG7rrllltC/54yZYqmTp2qMWPGaP369Zo5c2YvjqxrFi5cqO3btwf+nJDOnO1x3nXXXaF/T5kyRVlZWZo5c6b27t2rMWPGfN7D7LILLrhAW7duVU1Njf77v/9b8+fP14YNG3p7WAiTPr8sn5aWpujo6M+cqVlVVaXMzMxeGlXPS0lJ0fjx47Vnz57eHkqPOLPvzrf9KkmjR49WWlpaIPftokWL9Oc//1lvvvlmh0szZ2ZmqqWlRdXV1R3ig7o/z/Y4/UyfPl2SArc/Y2JiNHbsWOXl5amoqEjTpk3TU089FXH78nzV54t7TEyM8vLyVFxcHLqtvb1dxcXFys/P78WR9az6+nrt3btXWVlZvT2UHpGbm6vMzMwO+7W2tlbvvvtuRO9XSTpw4ICOHz8eqH3reZ4WLVqkl156SW+88YZyc3M7/D4vL08DBgzosD9LS0tVXl4eqP3Z2eP0s3XrVkkK1P70097erubm5ojZl+e93j6jz8Vzzz3nxcbGeqtWrfJ27tzp3XXXXV5KSopXWVnZ20MLm+985zve+vXrvbKyMu+dd97xCgoKvLS0NO/IkSO9PbQuq6ur895//33v/fff9yR5TzzxhPf+++97+/fv9zzP83760596KSkp3ssvv+xt27bNu/76673c3FyvsbGxl0duc67HWVdX5333u9/1SkpKvLKyMu/111/3Lr74Ym/cuHFeU1NTbw/d2T333OMlJyd769ev9w4fPhzaTp06FYq5++67vREjRnhvvPGGt3nzZi8/P9/Lz8/vxVHbdfY49+zZ4z366KPe5s2bvbKyMu/ll1/2Ro8e7V111VW9PHKb+++/39uwYYNXVlbmbdu2zbv//vu9qKgo79VXX/U8LzL25fkuEMXd8zzvl7/8pTdixAgvJibGu+yyy7yNGzf29pDC6uabb/aysrK8mJgYb9iwYd7NN9/s7dmzp7eH1S1vvvmmJ+kz2/z58z3P+/jrcA8++KCXkZHhxcbGejNnzvRKS0t7d9BdcK7HeerUKW/WrFleenq6N2DAAG/kyJHeggULAvcfU7/HJ8lbuXJlKKaxsdH7t3/7N2/w4MFefHy8d8MNN3iHDx/uvUF3QWePs7y83Lvqqqu81NRULzY21hs7dqz3ve99z6upqendgRt94xvf8EaOHOnFxMR46enp3syZM0OF3fMiY1+e77jkKwAAEabPf+YOAABsKO4AAEQYijsAABGG4g4AQIShuAMAEGEo7gAARBiKOwAAEYbiDgBAhKG4AwAQYSjuAABEGIo7AAARhuIOAECE+X9eu/BjisWv5wAAAABJRU5ErkJggg==",
|
| 300 |
+
"text/plain": [
|
| 301 |
+
"<Figure size 640x480 with 2 Axes>"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
"metadata": {},
|
| 305 |
+
"output_type": "display_data"
|
| 306 |
+
}
|
| 307 |
+
],
|
| 308 |
+
"source": [
|
| 309 |
+
"import torch\n",
|
| 310 |
+
"from PIL import Image\n",
|
| 311 |
+
"from diffusers import DDPMScheduler\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)\n",
|
| 314 |
+
"noise = torch.randn(sample_image.shape)\n",
|
| 315 |
+
"timesteps = torch.LongTensor([50]) # example timestep\n",
|
| 316 |
+
"noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"# print(\"Noisy image shape:\", noisy_image.shape)\n",
|
| 319 |
+
"# Image.fromarray(noisy_image.numpy().squeeze())\n",
|
| 320 |
+
"plt.imshow(noisy_image.detach().cpu().numpy().squeeze(), cmap='gray')\n",
|
| 321 |
+
"plt.colorbar()\n",
|
| 322 |
+
"# plt.axis('off')\n",
|
| 323 |
+
"plt.show()\n"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": 79,
|
| 329 |
+
"id": "4e878652",
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [],
|
| 332 |
+
"source": [
|
| 333 |
+
"import torch.nn.functional as F\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"noise_pred = model(noisy_image, timesteps).sample\n",
|
| 336 |
+
"loss = F.mse_loss(noise_pred, noise)"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "markdown",
|
| 341 |
+
"id": "7dffb56e",
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"source": [
|
| 344 |
+
"**Train the model**"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "code",
|
| 349 |
+
"execution_count": 80,
|
| 350 |
+
"id": "62915913",
|
| 351 |
+
"metadata": {},
|
| 352 |
+
"outputs": [],
|
| 353 |
+
"source": [
|
| 354 |
+
"from diffusers.optimization import get_cosine_schedule_with_warmup\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n",
|
| 357 |
+
"lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
| 358 |
+
" optimizer=optimizer,\n",
|
| 359 |
+
" num_warmup_steps=config.lr_warmup_steps,\n",
|
| 360 |
+
" num_training_steps=(len(train_dataloader) * config.num_epochs),\n",
|
| 361 |
+
")"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"cell_type": "code",
|
| 366 |
+
"execution_count": 81,
|
| 367 |
+
"id": "f2f3427a",
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"from diffusers import DDPMPipeline\n",
|
| 372 |
+
"from diffusers.utils import make_image_grid\n",
|
| 373 |
+
"import os\n",
|
| 374 |
+
"\n",
|
| 375 |
+
"def evaluate(config, epoch, pipeline):\n",
|
| 376 |
+
" # Sample some images from random noise (this is the backward diffusion process).\n",
|
| 377 |
+
" # The default pipeline output type is `List[PIL.Image]`\n",
|
| 378 |
+
" images = pipeline(\n",
|
| 379 |
+
" batch_size=config.eval_batch_size,\n",
|
| 380 |
+
" generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop\n",
|
| 381 |
+
" ).images\n",
|
| 382 |
+
"\n",
|
| 383 |
+
" # Make a grid out of the images\n",
|
| 384 |
+
" image_grid = make_image_grid(images, rows=4, cols=4)\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" # Save the images\n",
|
| 387 |
+
" test_dir = os.path.join(config.output_dir, \"samples\")\n",
|
| 388 |
+
" os.makedirs(test_dir, exist_ok=True)\n",
|
| 389 |
+
" image_grid.save(f\"{test_dir}/{epoch:04d}.png\")"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": 85,
|
| 395 |
+
"id": "4fe935dd",
|
| 396 |
+
"metadata": {},
|
| 397 |
+
"outputs": [],
|
| 398 |
+
"source": [
|
| 399 |
+
"from accelerate import Accelerator\n",
|
| 400 |
+
"from huggingface_hub import create_repo, upload_folder\n",
|
| 401 |
+
"from tqdm.auto import tqdm\n",
|
| 402 |
+
"from pathlib import Path\n",
|
| 403 |
+
"import os"
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"cell_type": "code",
|
| 408 |
+
"execution_count": 86,
|
| 409 |
+
"id": "07df8a4c",
|
| 410 |
+
"metadata": {},
|
| 411 |
+
"outputs": [],
|
| 412 |
+
"source": [
|
| 413 |
+
"def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n",
|
| 414 |
+
" # Initialize accelerator and tensorboard logging\n",
|
| 415 |
+
" accelerator = Accelerator(\n",
|
| 416 |
+
" mixed_precision=config.mixed_precision,\n",
|
| 417 |
+
" gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
|
| 418 |
+
" log_with=\"tensorboard\",\n",
|
| 419 |
+
" project_dir=os.path.join(config.output_dir, \"logs\"),\n",
|
| 420 |
+
" )\n",
|
| 421 |
+
" if accelerator.is_main_process:\n",
|
| 422 |
+
" if config.output_dir is not None:\n",
|
| 423 |
+
" os.makedirs(config.output_dir, exist_ok=True)\n",
|
| 424 |
+
" if config.push_to_hub:\n",
|
| 425 |
+
" repo_id = create_repo(\n",
|
| 426 |
+
" repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
|
| 427 |
+
" ).repo_id\n",
|
| 428 |
+
" accelerator.init_trackers(\"train_example\")\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" # Prepare everything\n",
|
| 431 |
+
" # There is no specific order to remember, you just need to unpack the\n",
|
| 432 |
+
" # objects in the same order you gave them to the prepare method.\n",
|
| 433 |
+
" model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n",
|
| 434 |
+
" model, optimizer, train_dataloader, lr_scheduler\n",
|
| 435 |
+
" )\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" global_step = 0\n",
|
| 438 |
+
"\n",
|
| 439 |
+
" # Now you train the model\n",
|
| 440 |
+
" for epoch in range(config.num_epochs):\n",
|
| 441 |
+
" progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n",
|
| 442 |
+
" progress_bar.set_description(f\"Epoch {epoch}\")\n",
|
| 443 |
+
"\n",
|
| 444 |
+
" for step, batch in enumerate(train_dataloader):\n",
|
| 445 |
+
" clean_images = batch[\"images\"]\n",
|
| 446 |
+
" # Sample noise to add to the images\n",
|
| 447 |
+
" noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
|
| 448 |
+
" bs = clean_images.shape[0]\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" # Sample a random timestep for each image\n",
|
| 451 |
+
" timesteps = torch.randint(\n",
|
| 452 |
+
" 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n",
|
| 453 |
+
" dtype=torch.int64\n",
|
| 454 |
+
" )\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" # Add noise to the clean images according to the noise magnitude at each timestep\n",
|
| 457 |
+
" # (this is the forward diffusion process)\n",
|
| 458 |
+
" noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n",
|
| 459 |
+
"\n",
|
| 460 |
+
" with accelerator.accumulate(model):\n",
|
| 461 |
+
" # Predict the noise residual\n",
|
| 462 |
+
" noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n",
|
| 463 |
+
" loss = F.mse_loss(noise_pred, noise)\n",
|
| 464 |
+
" accelerator.backward(loss)\n",
|
| 465 |
+
"\n",
|
| 466 |
+
" if accelerator.sync_gradients:\n",
|
| 467 |
+
" accelerator.clip_grad_norm_(model.parameters(), 1.0)\n",
|
| 468 |
+
" optimizer.step()\n",
|
| 469 |
+
" lr_scheduler.step()\n",
|
| 470 |
+
" optimizer.zero_grad()\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" progress_bar.update(1)\n",
|
| 473 |
+
" logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n",
|
| 474 |
+
" progress_bar.set_postfix(**logs)\n",
|
| 475 |
+
" accelerator.log(logs, step=global_step)\n",
|
| 476 |
+
" global_step += 1\n",
|
| 477 |
+
"\n",
|
| 478 |
+
" # After each epoch you optionally sample some demo images with evaluate() and save the model\n",
|
| 479 |
+
" if accelerator.is_main_process:\n",
|
| 480 |
+
" pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n",
|
| 481 |
+
"\n",
|
| 482 |
+
" if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n",
|
| 483 |
+
" evaluate(config, epoch, pipeline)\n",
|
| 484 |
+
"\n",
|
| 485 |
+
" if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n",
|
| 486 |
+
" if config.push_to_hub:\n",
|
| 487 |
+
" upload_folder(\n",
|
| 488 |
+
" repo_id=repo_id,\n",
|
| 489 |
+
" folder_path=config.output_dir,\n",
|
| 490 |
+
" commit_message=f\"Epoch {epoch}\",\n",
|
| 491 |
+
" ignore_patterns=[\"step_*\", \"epoch_*\"],\n",
|
| 492 |
+
" )\n",
|
| 493 |
+
" else:\n",
|
| 494 |
+
" pipeline.save_pretrained(config.output_dir)"
|
| 495 |
+
]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"cell_type": "code",
|
| 499 |
+
"execution_count": 88,
|
| 500 |
+
"id": "5d820f05",
|
| 501 |
+
"metadata": {},
|
| 502 |
+
"outputs": [
|
| 503 |
+
{
|
| 504 |
+
"name": "stdout",
|
| 505 |
+
"output_type": "stream",
|
| 506 |
+
"text": [
|
| 507 |
+
"Launching training on one GPU.\n"
|
| 508 |
+
]
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"name": "stderr",
|
| 512 |
+
"output_type": "stream",
|
| 513 |
+
"text": [
|
| 514 |
+
"/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/accelerator.py:529: UserWarning: `log_with=tensorboard` was passed but no supported trackers are currently installed.\n",
|
| 515 |
+
" warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n",
|
| 516 |
+
"Epoch 0: 0%| | 0/344 [00:00<?, ?it/s]/tmp/ipykernel_288020/2880730127.py:33: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)\n",
|
| 517 |
+
" clean_images = batch[\"images\"]\n"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"ename": "IndexError",
|
| 522 |
+
"evalue": "too many indices for tensor of dimension 4",
|
| 523 |
+
"output_type": "error",
|
| 524 |
+
"traceback": [
|
| 525 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 526 |
+
"\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
|
| 527 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[88]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01maccelerate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m notebook_launcher\n\u001b[32m 2\u001b[39m args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mnotebook_launcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
| 528 |
+
"\u001b[36mFile \u001b[39m\u001b[32m/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/launchers.py:270\u001b[39m, in \u001b[36mnotebook_launcher\u001b[39m\u001b[34m(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes, rdzv_backend, rdzv_endpoint, rdzv_conf, rdzv_id, max_restarts, monitor_interval, log_line_prefix_template)\u001b[39m\n\u001b[32m 268\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 269\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mLaunching training on CPU.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m270\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 529 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[86]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\u001b[39m\n\u001b[32m 30\u001b[39m progress_bar.set_description(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m clean_images = \u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mimages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[38;5;66;03m# Sample noise to add to the images\u001b[39;00m\n\u001b[32m 35\u001b[39m noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
|
| 530 |
+
"\u001b[31mIndexError\u001b[39m: too many indices for tensor of dimension 4"
|
| 531 |
+
]
|
| 532 |
+
}
|
| 533 |
+
],
|
| 534 |
+
"source": [
|
| 535 |
+
"from accelerate import notebook_launcher\n",
|
| 536 |
+
"args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n",
|
| 537 |
+
"notebook_launcher(train_loop, args, num_processes=1)"
|
| 538 |
+
]
|
| 539 |
+
}
|
| 540 |
+
],
|
| 541 |
+
"metadata": {
|
| 542 |
+
"kernelspec": {
|
| 543 |
+
"display_name": "my-env",
|
| 544 |
+
"language": "python",
|
| 545 |
+
"name": "python3"
|
| 546 |
+
},
|
| 547 |
+
"language_info": {
|
| 548 |
+
"codemirror_mode": {
|
| 549 |
+
"name": "ipython",
|
| 550 |
+
"version": 3
|
| 551 |
+
},
|
| 552 |
+
"file_extension": ".py",
|
| 553 |
+
"mimetype": "text/x-python",
|
| 554 |
+
"name": "python",
|
| 555 |
+
"nbconvert_exporter": "python",
|
| 556 |
+
"pygments_lexer": "ipython3",
|
| 557 |
+
"version": "3.12.3"
|
| 558 |
+
}
|
| 559 |
+
},
|
| 560 |
+
"nbformat": 4,
|
| 561 |
+
"nbformat_minor": 5
|
| 562 |
+
}
|
samples/0009.png
CHANGED
|
|
samples/0019.png
CHANGED
|
|
samples/0029.png
CHANGED
|
|
samples/0039.png
CHANGED
|
|
samples/0049.png
CHANGED
|
|