jainadarsh commited on
Commit
3819fa9
·
verified ·
1 Parent(s): 2930b34
samples/.ipynb_checkpoints/0009-checkpoint.png ADDED
samples/.ipynb_checkpoints/0019-checkpoint.png ADDED
samples/.ipynb_checkpoints/0029-checkpoint.png ADDED
samples/.ipynb_checkpoints/0039-checkpoint.png ADDED
samples/.ipynb_checkpoints/0049-checkpoint.png ADDED
samples/.ipynb_checkpoints/overview-checkpoint.ipynb ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 19,
6
+ "id": "9074d4b4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "from torch.utils.data import Dataset, DataLoader\n",
12
+ "import scipy.io\n",
13
+ "from torchvision import transforms, utils\n",
14
+ "import os\n",
15
+ "from PIL import Image"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 20,
21
+ "id": "c85ff158",
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "class GravityDataset(Dataset):\n",
26
+ " def __init__(self, folder, image_size, exts=['mat']):\n",
27
+ " super().__init__()\n",
28
+ " self.folder = folder\n",
29
+ " self.image_size = image_size\n",
30
+ " self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.mat')]\n",
31
+ "\n",
32
+ " # Define transformations that are independent of scaling\n",
33
+ " self.transform = transforms.Compose([\n",
34
+ " transforms.Resize((int(image_size * 1.12), int(image_size * 1.12))), # Resize slightly larger\n",
35
+ " transforms.RandomCrop(image_size), # Then crop to the target size\n",
36
+ " transforms.RandomHorizontalFlip(), # Random horizontal flip\n",
37
+ " transforms.ToTensor() # Convert to tensor\n",
38
+ " ])\n",
39
+ "\n",
40
+ " def scale_to_minus1_1(self, tensor):\n",
41
+ " \"\"\"Dynamically scale the tensor to the range [-1, 1] based on its own min and max values.\"\"\"\n",
42
+ " min_val = tensor.min()\n",
43
+ " max_val = tensor.max()\n",
44
+ " # Avoid division by zero if min and max are the same\n",
45
+ " if max_val > min_val:\n",
46
+ " return 2 * ((tensor - min_val) / (max_val - min_val)) - 1\n",
47
+ " else:\n",
48
+ " return tensor # If min and max are the same, return tensor as is.\n",
49
+ " # tensor = tensor\n",
50
+ " return tensor # If min and max are the same, return tensor as is.\n",
51
+ "\n",
52
+ " def __len__(self):\n",
53
+ " return len(self.paths)\n",
54
+ "\n",
55
+ " def __getitem__(self, index):\n",
56
+ " file_path = self.paths[index]\n",
57
+ " # Load the .mat file\n",
58
+ " data_loc = scipy.io.loadmat(file_path)\n",
59
+ " data_val = data_loc['d']\n",
60
+ "\n",
61
+ " data_val = data_val.reshape(32, 32)\n",
62
+ " # Convert numpy array as a PIL image\n",
63
+ " img = Image.fromarray(data_val)\n",
64
+ "\n",
65
+ " # Apply transformations\n",
66
+ " if self.transform:\n",
67
+ " img = self.transform(img)\n",
68
+ " \n",
69
+ " # Scale the image to [-1, 1] based on its own min and max values\n",
70
+ " img = self.scale_to_minus1_1(img)\n",
71
+ " return img"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 21,
77
+ "id": "c16dc0d1",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
82
+ "dataset = GravityDataset(data, image_size=32)\n",
83
+ "dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 22,
89
+ "id": "a8bd15c0",
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "name": "stderr",
94
+ "output_type": "stream",
95
+ "text": [
96
+ "Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 28.56it/s]\n"
97
+ ]
98
+ },
99
+ {
100
+ "data": {
101
+ "text/plain": [
102
+ "4"
103
+ ]
104
+ },
105
+ "execution_count": 22,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "from diffusers import StableDiffusionPipeline\n",
112
+ "\n",
113
+ "pipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n",
114
+ "pipeline.unet.config[\"in_channels\"]\n",
115
+ "4"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "id": "72b20338",
121
+ "metadata": {},
122
+ "source": [
123
+ "**Training Configuration**"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 23,
129
+ "id": "e01c3fae",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "from dataclasses import dataclass\n",
134
+ "\n",
135
+ "@dataclass\n",
136
+ "class TrainingConfig:\n",
137
+ " image_size = 32 # the generated image resolution\n",
138
+ " train_batch_size = 32\n",
139
+ " eval_batch_size = 16 # how many images to sample during evaluation\n",
140
+ " num_epochs = 50\n",
141
+ " gradient_accumulation_steps = 1\n",
142
+ " learning_rate = 1e-4\n",
143
+ " lr_warmup_steps = 500\n",
144
+ " save_image_epochs = 10\n",
145
+ " save_model_epochs = 30\n",
146
+ " mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n",
147
+ " output_dir = \"ddpm-butterflies-128\" # the model name locally and on the HF Hub\n",
148
+ "\n",
149
+ " push_to_hub = True # whether to upload the saved model to the HF Hub\n",
150
+ " hub_model_id = \"jainadarsh/trial\" # the name of the repository to create on the HF Hub\n",
151
+ " hub_private_repo = None\n",
152
+ " overwrite_output_dir = True # overwrite the old model when re-running the notebook\n",
153
+ " seed = 0\n",
154
+ "\n",
155
+ "\n",
156
+ "config = TrainingConfig()"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "9f52a27e",
162
+ "metadata": {},
163
+ "source": [
164
+ "**Load Dataset**"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 74,
170
+ "id": "b4a98567",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "# from datasets import load_dataset\n",
175
+ "\n",
176
+ "# config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n",
177
+ "# dataset = load_dataset(config.dataset_name, split=\"train\")\n",
178
+ "\n",
179
+ "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
180
+ "dataset = GravityDataset(data, image_size=32)\n",
181
+ "train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 75,
187
+ "id": "e3aed193",
188
+ "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "data": {
192
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAABOwAAAEhCAYAAADMCz9IAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAISNJREFUeJzt2smSI7e1MGAkZ7JY1ZNaLdmK8MZb77T2+/hF/RDeOMJ2eAi1ZXWruyYO+S/+xb03bONASlQSZH3f9qAOkCCABE+x6/u+TwAAAABAEyanHgAAAAAA8D8U7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49gBb1fX/qIaSU6oyjhRwlfx+1aSXH8Xh88j7G0nXdoHiJ7XY7OEdr3r17l43/61//CnPsdrtsfDKJ/5eyWq2y8c1mE+aIPp+XL19m419++WXYxy9/+cts/Fe/+lWY45tvvhk8juhZStZqNOfz+Twbn06nYR819l10xhwOhzBH1CZawyV9RDmieEop7ff7weOI5qvkM4n27G9/+9swxzmqsV55nqyd5yu6S5+jkvdVC1rZd2OM4zk9aw3nMs7nyC/sAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj2A56rv+1MPoZroWUqetUaO4/E4KF7Szxh91FgbXdcNblOSYzJ5fjX/7Xabje/3+zDH4+NjNj6dTsMcq9UqG4/GmVJK19fX2firV6+y8ZcvX4Z93NzcZOPr9TrMsVgssvHZLH6V1VjvkaHnR8k4apyFh8MhzBGt42gNR/GUUnp4eBgUL+mnZD/WOHNL1iBckhpnJlyKc/leVzLOoXu75O/H+K4zxrOWiMbhLCXy/L5tAwAAAEDDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYBT6Pv+1EOoosZzlOSI2gyNl7Q5Ho9hjqhNSY7D4TAoR/T3JW1qfK5d14VtJpN8vX46nQ7OcYlev36djZfMyePj4+Acm80mG99ut2GOly9fZuOvXr3Kxr/44ouwj2i+bm5uwhzr9Tobn8/nYY7ZLP+6K9kzQ411FkZnzH6/D3NEa/Th4SEbv7u7C/uI2pTkuL+/z8Z3u12YI5qvkv1YsgbhXIxxHj4n5vPyXcp3y5TiZ4nWc8lcRDlK9szQcZbkKDF0f9eYrxpaGQf/7vl92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AC5V3/enHkJKaZxxRH2UjOF4PA6Kl7Q5HA5hjv1+n43vdrtBf1/SpmSckckkrsXPZvntP51Owxzz+bx4TJfi7du32fhyuQxzPDw8ZOMln996vc7Gr6+vwxyvXr0aFH/z5k3Yx+vXr7Pxm5ubMMdqtcrGF4tFmCNa7yVz3nVdNl7jvI1y1DjHHh8fwxzRGr29vR0UTymlT58+DYqX9BM9R0rxnJasjWiNwliic4r/y3xRQ8n3lHNhT/yPkrmI7m2XNJ/P6Vlb4hd2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD6C2vu9PPYRqxniWkj6GjqPk74/H4+Ach8MhG9/v92GO3W6XjT88PGTjd3d3YR9RjpJxRmazeGsvl8tsfLVaDR7HJfrqq6+y8aurqzDH4+NjNt51XZgj+vyur6/DHC9fvnzSeEmb7XYb5ojmNJqLlOI9MZ1OwxzR5xLFa5yF0TmXUnyGROdcSind399n47e3t9n4x48fwz5+/PHHJ88RPUdK8ZyWnKfr9TpsAzWUvB+eC3Pxf5mP0yl5N5+LyST/e56hd6GW1Li3jfG80TjOac756fzCDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGjI7NQDOEd93596CCml8cYR9TM0XuJ4PA5uczgcwhyPj4/Z+N3dXTb+448/hn18+vRp0BhKLJfLsM12u83GS+a8lb0wpl/84hfZePT5plTnM14sFtn4ZrMJc7x48SIbv7m5ycavr6/DPqIcV1dXYY7VapWNz+fzMMdsln/dTSbx/6+6rgvbDBXtqZJ9ud/vs/GHh4cwx/39fTZ+e3ubjZfsgx9++GFQPKX4zI3O7JTiOY32Wkop7Xa7sA1ExjhjWnFJz3pJz8LPU/IdI1LjTh2txZK1Go0jylHjPlUyzjH2XY0+hs5njT5q9TN0HM7Kn8cv7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7Ac9X3fRP91BhHlON4PIY5ojYlOQ6HQza+3+/DHLvdLhu/v7/Pxj9//hz28fHjx8E5IldXV4NzTKfTKm0uzVdffZWN397ehjkeHx8Hj2OxWGTjq9UqzBGtkyi+3W7DPjabTTa+Xq/DHMvlMhuP5iKllGaz/OtuMon/fxW1qXHeRudYFE8pPsdK1t/d3V02Hp1TP/74Y9hHdBZ++PAhzPHp06dsPDqzS5TMebS+IKWUuq479RBGcy7PapzUEH3HGOt7X7ROStZRdNeJcpQ8a5Tjkr5fjDFfXDa/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AH6qvu9PPYRnJ5rzGp9JlON4PIY5ojaHwyHMEbXZ7/fZ+G63C/t4fHwcnCOar+l0GuZYLpfZ+Gq1CnPMZmd3hAz29u3bbPzh4SHMEa2jEtFnvFgswhzRZ7xerwf9fUmbaB2mlNJ8Ps/GS9ZhNF+TyXn8/6rkHKtxTkXr+Pb2Nhv//Plz2MenT58G9VHSpmS+os++5P3TdV3Yhst2SWvgXJ6llXG2Mo7IuYzzHEV3+5LvSiXvmkiNu07UJuqj5FmjHCXv7pLvOkPV2DPRfIy1L1sYR8nacE79u/P4hgIAAAAAz4SCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD4D/ru/7QfFaOYYaa5zH43FQvKRNNI6u68I+ptPpoHhKw8eZUkqHwyEb3+12YY6SNpfmzZs32fjj42OYI5r7EtFam8/nYY7FYjEox3K5fPI+UkppNsu/qiaT+H9PUZuSvRspOWMiNc7CGnv74eEhG7+/v8/G7+7uwj6iHCV7KZqPkrURrdHVahXmKGnDeatxRrSihWdpYQwpjTOO5/Ssz1X0Piu5H9T4Tha980q+Y0RtontbSR+Rknd3pMY4anwmNfZdje+fnC+/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoALer7/tRDaEo0H8fjcdDfl7QpyVFjHFGOruuy8el0GvYxn8+z8cViEebY7/fZ+GQS1+KjZ436KG1zaW5ubrLxkjmJ5r6G2Sw+3qP1Gq3Vkj6iNiV7JmpTst6jvVsiOkNq9DF0DCVtDodDmCNax7vdblC8ZBwl81ljja5Wq2z8+vo6zFHShraNsX/H0MpzjDGOsZ51aD/nMk5+vs+fP2fjJXfDGneM6L5U8k6M3qvRHTb6+xIl44zmo+SuE81XKzWBVu6XzpjT8As7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw/gUvV9f+ohFGllnNE4aoxzjGedTOIa+HQ6zcZns+HbsmQckePxGLbZ7/eD+zk3m80mGy9ZZyVzO1TJGui6LhuP1mK0lkvalIwzalNjvZd8bkPPqWi+S9ucg5LPZD6fZ+Ml+ySar+VyGeaI9vT19XWY48WLF2EbTudS9lVK7TzLGOOo0UcLOVoYQ2v9XJoffvghG394eAhzHA6HbLzGe7XknbharbLx9Xqdjbfy3bLkjlrj+2cLe+Zcxlmixl360viFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1AOgfX3fn/Tva+Wooeu6QfGUUppM8nXyKF7SpmQc0Zwej8fBOS7RZrPJxmvMW8m81liLUZtW1uoY671E9NlGfdSYrxo5ptNpmGM2y18PFotFNr5arcI+DofDoD5Sip+lZBzb7XZQPKWUbm5uwjY8nZIz4By08hw1xjHGs4w1zqH9jNFHSY5W1tcl+tvf/paNf/r0Kcyx3++z8ZL3/3K5zMavrq7CHNfX19n4ixcvsvGxvj+McQ+ucb+07xjKL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPQBIKaW+7588xxh9jKXGsx4Oh2x8t9uFOfb7fdjm0iwWi2y8ZO7HWEdd11VpkzOZxP/zifoYY5wlxjgfSuYrajObxa/tqE20hlNKabVaZeObzSYbj86XlOJxHo/HwTmi50gpfpbr6+swx9XVVdgGxjjLStQYx9AcLYyhNMfQ99gYfZS2GSPHc/TnP/85G3///n2Y4/HxMRsvuUOs1+tsvOR99ubNm2w8GmfJ+7/GnSuajxp3rhp3+lbu/Jwvv7ADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANmZ16AP9b3/enHsJoxnrWS5nTGs9RkiNqE8WPx2PYx+FwyMb3+32YY7fbZeMlzxqNdTKJ6/nROC7RYrHIxmuss7F0XTco3kofJYbu7RLRnppOp2GO2Sz/Wo7iKaW0XC6z8ZJ9u9lssvHoHCt51qiPks8kmo9oLlJKab1eZ+PROEvb8POMdUYM1co4xzi3W8kx1jiHjqPkPlXjWaN+xnjW5+r9+/fZ+N///vcwx93d3eBxzOfzbLzkXfX9999n458/f87Gb29vwz6++OKLbPzVq1dhjkjJvovuKiXf60r6uRTRvcz58DSezwoDAAAAgDOgYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1APg/PV9PyheK8fxeHzyHFH8cDiEfez3+2z88fExzHF/fz94HJNJvl4fjfO5ms3yx2bJOqsh6qfruicfwxh91FLjjIlEa6NEdMbM5/MwR7T/1+v14HFE58dyuQz7iM6Yks9kOp1m44vFIswRjXW1WoU5StpADTXO3RZylPx91OZcckTnZUmbVnKc03t/TNF7t+ReHt3tHx4ewhwfPnzIxt+/fx/m2G632fjbt2+z8a+//jrs45tvvsnGf/nLX4Y5vv3222y85P0f3alK9gw8NasQAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0ZHbqAfDz9X0/So7j8Ti4n6HjKBln1KbkOaI2h8NhUDyllPb7fTb+8PAQ5vj06VM2fnt7G+aIxjqbxcfDy5cvwzaXJpqXGmv1knRdNzjHGGddSR9Dn2U6nYZtxjgLS0TPOp/Ps/H1eh32EZ1BJc8RzWnJORY9y2KxCHOUtOE/q3FGjGGMcZ7LXJSInqXkWVvJMZnkf98wNJ5SfJaVvD9qnIc1noV/V7LOonde9P0hpZS+//77bPwPf/hDmCNaJ9Hd/927d2Ef3333XTb+4cOHMMevf/3rweNo4X451rnfyjj46Zy6AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx4A/13f96ceQpFonCXPEbU5Ho9hjqjNfr8Pc0RthsZL2jw8PIQ5Pn78mI2/f/8+zPHdd99l43/5y1/CHG/fvs3Gf/e734U5zs1kMvz/HOeyt1tRY77GOKda0XXdoHhKKU2n02x8sVhk4yVn4eFwCNtEomeJniOllGaz/FUoiqeU0nw+D9vAGEr2d40cQ8+ZMfooaVPyTo/aRPGxzqGozXK5DHNEZ1nJszxH//znP7Pxv/71r2GO6G5f8v3g/v4+G1+tVmGOaA1E8ZJ9Gd0RdrtdmCP63ncudzaI+IUdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA5ev7fnCb4/EY5tjv99n44XB48hwl46zxrI+Pj9n4x48fwxx/+tOfsvHf//73YY63b9+GbS7NZPL0/+coWQPPSdd1g3NE+67GOVVD9Kwlc1Ejx2yWvx7M5/NsvOS8HWOdT6fTwW1q5HiuauzdsYwx1hp9jJGjlc+txjijd3bJOz1qE+3/6DxNKT5TF4tFmGO1Wg2Kp5TSer3OxqNxPlf/+Mc/svE//vGPYY67u7tsvGS9R2vxm2++CXNE6+Tq6iobv7m5Cft4+fJlNh6tw5TifVWyt52nnAO/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/1vXdWGbvu9HGAk/RfSZlHxmx+MxGz8cDmGO/X6fjT8+PoY5drvdoD5KTCb5OvlisQhzrFarbHy9Xoc5rq6usvF3796FOd68eRO2uTTR51ci2hPT6XRwH2No5TyOzo+U4vfLGO+fkj5qjDNqU7KGozmdzfLXh5LPpMb6qTFf0XyUzFeNcwFqKFnzY/TTyllWY39H7+ToPJzP52Ef0d0vuvelFN/9rq+vwxw3NzeDx/EcRXfmh4eHMEf0HaNkHUWfz2azCXNE6yjqY7vdhn1EbV6/fh3miNZqtC9TGu+8hCHcMAEAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49AC5f3/dhm8PhkI3v9/swx8PDQzZ+d3cX5oja7Ha7MEdksVhk45vNJszx5s2bweOYz+fZ+Ndffx3muLm5GTyOc9N1XRM5SvbVU2vlOabT6eB+jsdjmCN63skk/z+wkmeN+iiZ86HjTCmej9ksf30omc+hY0gpfpYa81WSo2ROL1GNM2AMY4yzlXdDjX7G2Dcle2aMvVny/ojaROdhSR/RnWy1WoU5rq+vs/HXr1+HOaI22+02zPEcffvtt9n4b37zmzBHtJ6Xy2WYY71eZ+Ml3zGitRb1UTLOaBwl44zWYrSnUorPh+f6bv9vzuWdf2msQgAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA+ev7Phs/Ho9hjqjNfr8Pc9zd3WXjHz58GJxjOp1m4/P5POxjsVhk47NZvC2jHMvlMsyx2Wyy8c+fP4c5VqtV2ObSTCb5/3NE+6GWrusG5xhrrDklzzHGOKO9XTKO6Byr8awlOaI20RouGcfQeGmboWrMV40cUEsra23ovil5juisKjnLauSI3g/RvS26s6UU36fW63WYY7vdZuOvXr0Kc7x79y4bf/36dZijhbvF2KJ5KfkuFK2zkjt3tE5KckTfD2p8B4najPV9aoz3vzsGQ/mFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANCQ2akH8Fx1XTc4R9/3o+SI2gyNp5TS8XjMxne7XZjj9vY2G//+++/DHHd3d9n4ZrPJxq+vr8M+lstlNj6dTsMcq9UqG4/GmVJK2+02G398fAxzLBaLsM2lqbF3GV/0uZWcU1GOGn1EZ2HJ+qsxjhrn+tA+aqgxX7X64emMMf/nsk5qrPkxnrXGOCeT+HcHUZuSO9dslv+6FMVL7kpRm/V6HeaI7nUvXrwIc3zxxRfZ+JdffhnmGONsb00099G7PaWU5vN5Nh7d/Uva1MgRrdXoOUraRHuqpE3J3o7Oh7HuEJDjF3YAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhs1MPgKfV9/3J+zgej2GOqM3hcAhz7Ha7bPzu7i7McXt7m413XZeNr1arsI/IcrkM28xm+a1bMl/X19fZeMnnFo3jOYrWCM9XydqI2pSc6dHerbFGx3i3jMWepRWtrMUaZ9XQeEopTSb53xVE8ZRSmk6ng+IpxXedKL5YLMI+ojbr9TrMcXV1lY3f3NyEOV6+fJmNv3r1KsxxSe+HUtHc11jvJd8xou8QJd8xorUYrff5fB72EeUo2Zcl+z9S45xqwbmMk5/HL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYAWdV2Xjfd9P9JInt7xeBzc5nA4DIqnlNJ+vx+co6RNJHrW+/v7bPzu7i7sY7PZZOOzWbwt1+t1Nj6ZjFOLH6uflkTnQytKzqkWnuVcxtmKkrmoMV9jvOdK3j9DjXVGWaOMJVprY50RQ8dRY5wl+ztqM51OwxzRvWyxWAyKpxTf61arVZgjul9eXV2FOa6vr7Pxm5ubMMcYZ3trttttNl6y3qO1OJ/PwxxRm5K1GI0jylGyp6I2JTmiOS3JEZ0PNc6pGmfhGIyjXc/v2zYAAAAANEzBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQkNmpB/BTdV2Xjfd9P9JI8s5lnCWOx2M2vtvtsvG7u7uwj9vb28E5DodDNj6ZxPXpqE30uUVjKOljNou35XK5zMbn8/ngcURruCTHJSqZl6FqnA9jjLNE9CytjLOGGuf+ueSooUYfrayfVsbBzzf0M7ykNVDyLFGbKF7jTlaSYzqdZuMld66oTXTnWiwWYR+r1Sob32w2YY6ozXa7DXNEba6ursIc0feGS1QyL5ForUbxlOK1WmO91xhn1KbkDKqRI3JJ5zrn6/l92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/Hxd143Sz+FwyMbv7++z8Y8fP4Z9/Pjjj9n47e1tmOPh4SEbn0zi+vRiscjGp9PpoL9PKaX5fD44x3K5HJwjmo+S+Sppw0831t6O9H0/OMe5PEuNcY7RR0mOFp61RCtro5Vx8J/5fOob4ywaGk+pzj0lurdF8ZTie1sUj+5sKaW0Wq2y8fV6Hea4uroaFE8ppc1mMzhH9L3hEpV8PpFoPZfsmdks/9W+ZL0P3TMl+zJ6lho5SuarhbOwRh9j5eA0fNsGAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGjI79QB4Wl3XPXkffd9n48fjMcxR0iYym+WX83a7DXOsVqtBfVxdXYV9RG2iMaSU0mKxGBRPKaXJJF+vj+KlbS5NC3tqLGM8a4ka89HKs1wK8wk/3Rj7pkYfY+Qo6SNqU+OeEt3rStosl8tsvORet16vs/HNZhPmiO6XJXfU6K5cMo7D4RC2uTTRGigRrdWSPTPG3b7GOKfT6eAcY5wx53Jmn4vn9Kw1Pb9v2wAAAADQMAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gtq7rwjZ935+8j5IcQ/soMZnENdvpdJqNL5fLbHy73Q4ex3q9DnPs9/ts/HA4hDmiOR06FymldH19nY2XPOtiscjGZ7N4a0fPUrJGS9YPP12N86GGGmdMDWPMR41nbWWcrawf4Kcp2bs19vcYOYbGU4rvGCV3kBo55vN5Nh7dyUruhtHdb7PZhDmurq6y8ZL7eNQm6iOllHa7Xdjm0kRroESNPRO1KVnvQ3PUGGcrOUq0cOdqYQw8Hd+2AQAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj0Afr6u6wa3mUzimu18Ps/G1+v14D42m002vt/vwxzH43FQvET0LLNZvKWi+Vwul4NzlMx5SZsxcpybkn0X6fu+wkieXo1nrWGM+TqXZ21lnMDlGuOcqXGHLckxnU6z8ZJ7W9Qmipfc61arVTYe3ZNL2kT39Vo5ojm/RCXraKgae6ZEdLevsS9byTG0j7FynIvn9Kxjen7ftgEAAACgYQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7AKXRdl433fX8RfZS0mUzimu10Os3Gl8tlNj6bxcvscDhk48fjMcxR0iYSfS5jzGcUL+1n6DhK1hc/zxhzW+OMaUUra3GMOW3lWYdqZf1dynzCuRlj77VyD57P54PiKcV36dVqFeaI2qzX68HjWCwWYY5Wzv8x1biX1+ijxr6LcgyN1xjDmP2MkaOFPmiXX9gBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADZmdegD8d13XZeN93w/OUWI6nQ7qI/r7lFI6HA4/aUz/STQfJfM1VMl8R20mk7iOHuWoMY4SNXLwNFr5bMbYd2NpYU7PZT5bmKsxPbfnhaFrfqx7yhjjiO5tJffgqE2NHCX3y8jxeAzb1LjTn5uSzydyLuv9qf/+0nK00Ect5zTWS+IXdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gRV3XZeN93z95HyX9lOSoIeqnxnxNJvnacUmOGp/LGGp8bkM/kxp9QIlW1tG5nA+RVuYTeN6GnqmXdK+rIXrW4/EY5tjv99n4/f19mOPz58/Z+HQ6DXPc3d1l45vNJsxxbqLvMTW08r3vqf++Vo5W+jmXe9u5jPM58gs7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCFd3/f9qQcBAAAAAPx/fmEHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA35f5EOqgphU/Q8AAAAAElFTkSuQmCC",
193
+ "text/plain": [
194
+ "<Figure size 1600x400 with 4 Axes>"
195
+ ]
196
+ },
197
+ "metadata": {},
198
+ "output_type": "display_data"
199
+ }
200
+ ],
201
+ "source": [
202
+ "import matplotlib.pyplot as plt\n",
203
+ "import torch\n",
204
+ "\n",
205
+ "batch = next(iter(dataloader)) # batch shape (B, C, H, W)\n",
206
+ "imgs = batch[:4] # take first 4\n",
207
+ "\n",
208
+ "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
209
+ "for i, img in enumerate(imgs):\n",
210
+ " img = img.detach().cpu()\n",
211
+ " if img.ndim == 3: # (C,H,W) -> (H,W,C)\n",
212
+ " img = img.permute(1, 2, 0)\n",
213
+ " img = (img + 1) / 2 # convert from [-1,1] to [0,1]\n",
214
+ " arr = img.numpy().squeeze()\n",
215
+ " axs[i].imshow(arr, cmap='gray' if arr.ndim == 2 else None)\n",
216
+ " axs[i].axis('off')\n",
217
+ "plt.show()"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "id": "06c1dff6",
223
+ "metadata": {},
224
+ "source": [
225
+ "**Create a Architecture**"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 76,
231
+ "id": "3ee088b3",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "from diffusers import UNet2DModel\n",
236
+ "\n",
237
+ "model = UNet2DModel(\n",
238
+ " sample_size=config.image_size, # the target image resolution\n",
239
+ " in_channels=1, # the number of input channels, 3 for RGB images\n",
240
+ " out_channels=1, # the number of output channels\n",
241
+ " layers_per_block=2, # how many ResNet layers to use per UNet block\n",
242
+ " block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block\n",
243
+ " down_block_types=(\n",
244
+ " \"DownBlock2D\", # a regular ResNet downsampling block\n",
245
+ " \"DownBlock2D\",\n",
246
+ " \"DownBlock2D\",\n",
247
+ " \"DownBlock2D\",\n",
248
+ " \"AttnDownBlock2D\", # a ResNet downsampling block with spatial self-attention\n",
249
+ " \"DownBlock2D\",\n",
250
+ " ),\n",
251
+ " up_block_types=(\n",
252
+ " \"UpBlock2D\", # a regular ResNet upsampling block\n",
253
+ " \"AttnUpBlock2D\", # a ResNet upsampling block with spatial self-attention\n",
254
+ " \"UpBlock2D\",\n",
255
+ " \"UpBlock2D\",\n",
256
+ " \"UpBlock2D\",\n",
257
+ " \"UpBlock2D\",\n",
258
+ " ),\n",
259
+ ")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 77,
265
+ "id": "8a703029",
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "name": "stdout",
270
+ "output_type": "stream",
271
+ "text": [
272
+ "Input shape: torch.Size([1, 1, 32, 32])\n",
273
+ "Output shape: torch.Size([1, 1, 32, 32])\n"
274
+ ]
275
+ }
276
+ ],
277
+ "source": [
278
+ "sample_image = dataset.__getitem__(0).unsqueeze(0) # add batch dimension\n",
279
+ "print(\"Input shape:\", sample_image.shape)\n",
280
+ "print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "8d3d9af1",
286
+ "metadata": {},
287
+ "source": [
288
+ "**Create a scheduler**"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 78,
294
+ "id": "22c2ce21",
295
+ "metadata": {},
296
+ "outputs": [
297
+ {
298
+ "data": {
299
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAGdCAYAAAAPGjobAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAO6BJREFUeJzt3XtwVfW5//FPCCQh5EZIyAUChJuAXKxRMdpalJSAHUeUcbR1RrQWRw90qvQmHS9V25PWniraUjhnWqFOS72cqXjsBdQoWG2QgnKQiylgIOGScM2V3EjW7w8P+2d0Qb5PsmOyNu/XzJohO0+e/d17rb0fvt+99rOiPM/zBAAAIka/3h4AAAAIL4o7AAARhuIOAECEobgDABBhKO4AAEQYijsAABGG4g4AQIShuAMAEGH69/YAPq29vV2HDh1SYmKioqKiens4AAAjz/NUV1en7Oxs9evXc3PIpqYmtbS0dDtPTEyM4uLiwjCivqPPFfdDhw4pJyent4cBAOimiooKDR8+vEdyNzU1KTc3V5WVld3OlZmZqbKysogq8D1W3JctW6af//znqqys1LRp0/TLX/5Sl112Wad/l5iYKEl6/fXXNWjQIKf7On78uPO4Ro4c6RwrSUePHnWOTU1NNeUeNmyYc+y//vUvU+6ysjLn2IsuusiU+8MPPzTFZ2dnO8dWV1ebco8bN845tr6+3pS7rq7OOba9vd2Uu7W11RSfl5fnHLt7925T7tjY2B7LnZaW5hw7YMAAU+62tjbn2NOnT5ty9+9ve2u0vPZjYmJMuf/5z386x06bNs2U23KMDxkyxDm2vr5eeXl5offzntDS0qLKykqVl5crKSmpy3lqa2s1YsQItbS0UNw78/zzz2vx4sVasWKFpk+frqVLl6qwsFClpaUaOnToOf/2zFL8oEGDlJCQ4HR/TU1NzmOzHmyNjY09lttyQLo+F2fEx8c7x1rH7fqfrjMsY7cWPcvYrR/zWC67YC3u1qXEnjxWLG9oluPKOpaeLO7W48o6FstxaC3uPflathzjXSnUn8dHq0lJSd0q7pGqRz4MeeKJJ7RgwQLdcccdmjRpklasWKH4+Hg988wzPXF3AIDzlOd53d4iUdiLe0tLi7Zs2aKCgoL/fyf9+qmgoEAlJSWfiW9ublZtbW2HDQAAFxR3f2Ev7seOHVNbW5syMjI63J6RkeF74kNRUZGSk5NDGyfTAQBcUdz99fr33JcsWaKamprQVlFR0dtDAgAg0MJ+Ql1aWpqio6NVVVXV4faqqiplZmZ+Jj42NtZ0ti4AAGd0d/bNzN1RTEyM8vLyVFxcHLqtvb1dxcXFys/PD/fdAQDOYyzL++uRr8ItXrxY8+fP1yWXXKLLLrtMS5cuVUNDg+64446euDsAAPAJPVLcb775Zh09elQPPfSQKisrddFFF2nt2rWfOckOAIDuYFneX491qFu0aJEWLVrU5b+3NLGxdB6zdLOTbF2Zampqeiz3iRMnTLktnaq2bdtmym3pOCfZuppZG+QcPnzYOdbaQdDS+MT6LQ/rsWLZR6NGjeqxsYwePdqU29IadMKECabcBw8edI61NrGxHuOWzorWLoyWJjbWDmulpaXOsadOnXKObWhoMI2jOyju/nr9bHkAABBefe7CMQAAuGLm7o/iDgAILIq7P5blAQCIMMzcAQCBxczdH8UdABBYFHd/FHcAQGBR3P3xmTsAABGGmTsAILCYufujuAMAAovi7q/PFvddu3Y5t12cPn26c97GxkbTOCw7PikpyZT72LFjzrGWFrvWeGtL0cTERFN8//7uh9nIkSNNufft2+cca2lXah3L6dOnTbkHDBhgio+OjnaOXbdunSm3pV1tcnKyKffQoUOdY5ubm025Bw8e7Bxrvaz09u3bTfGW53D8+PGm3Js2bXKOtbweJNtz2NbW5hwbqQUzSPpscQcAoDPM3P1R3AEAgUVx98fZ8gAARBhm7gCAwGLm7o/iDgAItEgt0N3BsjwAABGGmTsAILBYlvdHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/vpscc/KylJCQoJTrKWP+vHjx03jyM7Odo6trq425T516pRz7JgxY0y5W1panGPr6upMuV17/p9h6UXf2tpqyp2VleUce/jwYVPuDz74wDn2wgsvNOW2PoeWsY8bN86UOz093Tm2tLTUlNv1NSxJBw4cMOW2vDatvfytfe4t1y2w7nuLuLg4U/yQIUOcYy3XT7C+p3QHxd0fy/IAAESYPjtzBwCgM8zc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f322uCckJDi3rrS0LB02bJhpHJaWsj3ZUnTkyJGm3Pv27XOO7d/fdhi8+uqrpvhZs2Y5x27evNmUe/z48c6xu3btMuVOS0tzjm1razPltj7nFtZj5dChQ86xOTk5ptyWN86xY8eacldVVTnHWtvJDh8+3BTf1NTkHGttg2xpD5ybm2vKbWn5a2lt297ebhoHwq/PFncAADrDzN0fxR0AEFgUd398FQ4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACLVILdHewLA8AQIRh5g4ACCyW5f1R3AEAgUVx99eni3tUVJRT3ODBg51ztrS0mMZg6S++e/duU+7k5GTnWGtv7NTUVOdY63Ny7bXXmuI/+OAD59hLLrnElNvSp3vmzJmm3JZ+4Tt27DDltl7joKGhwTl248aNptyWnu7WN0JLz/3nn3/elHvu3LnOsTt37jTlrqmpMcVb+q5bX29XX321c6zlmhKS7biy7MtTp06ZxtEdFHd/fOYOAECECXtx/9GPfqSoqKgO24QJE8J9NwAAhGbu3dkiUY/M3C+88EIdPnw4tL399ts9cTcAgPNcbxT3t956S9ddd52ys7MVFRWlNWvWdPo369ev18UXX6zY2FiNHTtWq1atsj9Ygx4p7v3791dmZmZos3xuDQBAX9bQ0KBp06Zp2bJlTvFlZWX66le/qquvvlpbt27Vvffeq29+85tat25dj42xR06o2717t7KzsxUXF6f8/HwVFRVpxIgRvrHNzc0dTharra3tiSEBACJQb5xQN2fOHM2ZM8c5fsWKFcrNzdUvfvELSdLEiRP19ttv68knn1RhYaH5/l2EfeY+ffp0rVq1SmvXrtXy5ctVVlamL33pS6qrq/ONLyoqUnJycmjLyckJ95AAABEqXMvytbW1HTbrN5TOpaSkRAUFBR1uKywsVElJSdju49PCXtznzJmjm266SVOnTlVhYaH++te/qrq6Wi+88IJv/JIlS1RTUxPaKioqwj0kAADOKScnp8NEs6ioKGy5KysrlZGR0eG2jIwM1dbWqrGxMWz380k9/j33lJQUjR8/Xnv27PH9fWxsrGJjY3t6GACACBSuZfmKigolJSWFbg96Xerx77nX19dr7969ysrK6um7AgCcZ8K1LJ+UlNRhC2dxz8zMVFVVVYfbqqqqlJSUpIEDB4btfj4p7MX9u9/9rjZs2KB9+/bpH//4h2644QZFR0fra1/7WrjvCgCAPi8/P1/FxcUdbnvttdeUn5/fY/cZ9mX5AwcO6Gtf+5qOHz+u9PR0ffGLX9TGjRuVnp5uyrN9+3bFx8c7xebm5jrntcRK0uuvv+4cO3ToUFPu6Oho59i9e/eacl955ZXOse+//74pt7W1pOUkyU8ui7mwPC/t7e2m3Pv373eOnTx5sim3tV3t2b5t4sf6WrM858ePHzflbm1tdY6dN2+eKbflOJw0aZIp98GDB03x/fq5z5Msr3tJZz0Z2Y+lpbX08WfBrixtcy1tbburN86Wr6+v7/BRc1lZmbZu3arU1FSNGDFCS5Ys0cGDB/Xss89Kku6++2796le/0ve//3194xvf0BtvvKEXXnhBf/nLX7o87s6Evbg/99xz4U4JAICv3ijumzdv7tDzf/HixZKk+fPna9WqVTp8+LDKy8tDv8/NzdVf/vIX3XfffXrqqac0fPhw/eY3v+mxr8FJffzCMQAAnEtvFPcZM2ac8+/8us/NmDHDvEraHVw4BgCACMPMHQAQWFzy1R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/FHcAQKBFaoHujj5b3CdPnqzExESn2JMnTzrn3bRpk2kcn2wxGM5xSNLp06edY0tLS025LT3ArT3XExISTPGW/C+//LIpt6Wnu/UqT5ae64cPHzblHjdunCk+Jiamx8ZiuSZCc3OzKbelj7r1tXnBBRc4x1qPcWuPdkvfdcvrXpLzNTYknfXS2mdjef1YckdFRZnGgfDrs8UdAIDOsCzvj+IOAAgsirs/ijsAILAo7v74KhwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEmD47c29qalL//m7Da2xsdM6blpZmGscrr7ziHPuFL3zBlDs7O9s51tIiVLL9b7S2ttaU2/ocWlpz5ufnm3K3trY6x+7du9eUOyMjwzl2wIABptxNTU2m+L///e/OsdbjsKamxjn21KlTptyW53zKlCmm3A0NDc6xJ06cMOUeM2aMKd71vUqSDh48aMptaT1t3T87duxwjrW0HraOozuYufvrs8UdAIDOUNz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPujuAMAAovi7o/P3AEAiDDM3AEAgcXM3R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7++mxxb2lpce5JPnXqVOe8JSUlpnGMHDnSOTYlJcWU+8iRI86xlj7nkq3ffmJioim3lSW/te/2iBEjnGOtz6FFeXm5Kf706dOm+KysLOfYw4cPm3Jb+uJb3whzcnJ6LLfFoEGDTPHW/Wm53oK19/+uXbucY0ePHm3KbTkOLe9X0dHRpnEg/PpscQcAoDPM3P1R3AEAgRapBbo7OFseAIAIw8wdABBYLMv7o7gDAAKL4u6P4g4ACCyKuz8+cwcAIMIwcwcABBYzd38UdwBAYFHc/bEsDwBAhGHmDgAILGbu/vpscT958qRzb/n6+nrnvMOGDTONIzk52Tm2tbXVlPvYsWPOsXV1dabceXl5zrEfffSRKXd8fLwpvrq62jn2wIEDptyxsbHOsW1tbabcO3bscI619vTu39/20ouLi3OOtR6Hlv1pHbflebH2xLf0i1+3bp0pd1JSkinesn+s10+IiYlxjrW8p0hSQ0ODc6zlOg6W9+Tuorj7Y1keAIAIYy7ub731lq677jplZ2crKipKa9as6fB7z/P00EMPKSsrSwMHDlRBQYF2794drvECABByZubenS0SmYt7Q0ODpk2bpmXLlvn+/vHHH9fTTz+tFStW6N1339WgQYNUWFiopqambg8WAIBPorj7M3/mPmfOHM2ZM8f3d57naenSpXrggQd0/fXXS5KeffZZZWRkaM2aNbrlllu6N1oAANCpsH7mXlZWpsrKShUUFIRuS05O1vTp01VSUuL7N83Nzaqtre2wAQDggpm7v7AW98rKSklSRkZGh9szMjJCv/u0oqIiJScnh7acnJxwDgkAEMEo7v56/Wz5JUuWqKamJrRVVFT09pAAAAFBcfcX1uKemZkpSaqqqupwe1VVVeh3nxYbG6ukpKQOGwAA6LqwFvfc3FxlZmaquLg4dFttba3effdd5efnh/OuAABg5n4W5rPl6+vrtWfPntDPZWVl2rp1q1JTUzVixAjde++9+vGPf6xx48YpNzdXDz74oLKzszV37txwjhsAADrUnYW5uG/evFlXX3116OfFixdLkubPn69Vq1bp+9//vhoaGnTXXXepurpaX/ziF7V27VpTe0ZJio6OVnR0tFOsa5taSerXz7ZY8cEHHzjHWluQWlrh1tTUmHJbWopOnjzZlPvo0aOm+LS0NOfYQ4cOmXJHRUU5x1pfxJaTOy0tQiXbuKWPv1Xiytqq+MSJE86xQ4YMMeV+8803nWPHjh1rym1pjnXNNdeYclufQ0t8YmKiKbflNWF9f8vKyjLFIzjMxX3GjBnnfJOMiorSo48+qkcffbRbAwMAoDPM3P312QvHAADQGYq7v17/KhwAAAgvZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgRaps+/uoLgDAAKLZXl/LMsDAAKrt9rPLlu2TKNGjVJcXJymT5+uTZs2nTV21apVioqK6rBZG7tZUdwBADB4/vnntXjxYj388MN67733NG3aNBUWFurIkSNn/ZukpCQdPnw4tO3fv79Hx0hxBwAEVm/M3J944gktWLBAd9xxhyZNmqQVK1YoPj5ezzzzzFn/JioqSpmZmaEtIyOjOw+7U332M/eEhAQlJCQ4xaanpzvntfTRlmz94hsbG025XXvnS/ZxW5SWlpriY2NjTfENDQ3OsW1tbabclh7tlv7s1tzWfvvWqyRu27bNOXbw4MGm3Jbj1tqL3NLr3Pr6SU1NdY6trKw05bY6fvy4c+yAAQNMuVNSUpxje/LaGcnJyc6xltd8d4XrM/fa2toOt8fGxvq+17W0tGjLli1asmRJ6LZ+/fqpoKBAJSUlZ72f+vp6jRw5Uu3t7br44ov17//+77rwwgu7PO7OMHMHAJz3cnJylJycHNqKiop8444dO6a2trbPzLwzMjLO+p/ICy64QM8884xefvll/f73v1d7e7uuuOIKHThwIOyP44w+O3MHAKAz4Zq5V1RUKCkpKXS7dYXyXPLz8zus1l1xxRWaOHGi/vM//1OPPfZY2O7nkyjuAIDACldxT0pK6lDczyYtLU3R0dGqqqrqcHtVVZUyMzOd7nPAgAH6whe+oD179tgH7IhleQAAHMXExCgvL0/FxcWh29rb21VcXOx8Lk1bW5s++OAD8zksFszcAQCB1RtNbBYvXqz58+frkksu0WWXXaalS5eqoaFBd9xxhyTptttu07Bhw0Kf2z/66KO6/PLLNXbsWFVXV+vnP/+59u/fr29+85tdHndnKO4AgMDqjeJ+88036+jRo3rooYdUWVmpiy66SGvXrg2dZFdeXt7hmwsnT57UggULVFlZqcGDBysvL0//+Mc/NGnSpC6PuzMUdwBAYPVW+9lFixZp0aJFvr9bv359h5+ffPJJPfnkk126n67iM3cAACIMM3cAQGBx4Rh/FHcAQGBR3P312eK+c+dODRw40Cn22muvdc5rbW86btw459jt27ebch86dMg51trO8e2333aOHTZsmCn3p7/f2Zn6+voeG0tra6tzrGs74zMs7U0trTmljxtmWFha4Vrbm1r2z8mTJ025La1wa2pqTLktr2Vr++aYmBhT/KBBg5xjre9Bn26LGq5xSNKUKVOcY6urq51jLa210TP6bHEHAKAzzNz9UdwBAIFFcffH2fIAAEQYZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgcXM3R/FHQAQWBR3fxR3AECgRWqB7g4+cwcAIMIwcwcABBbL8v76bHGfMGGCc59kS991a8/ozZs3O8fGx8ebcqenpzvHpqWlmXK3tLQ4x/brZ1vASUpKMsVPnjzZOXbfvn2m3JY+3f372w73gwcPOsfu3r3blDs2NtYUb+n/bomVpKysLOdY6xvhgQMHnGOPHz9uyt3e3u4ce/nll5tyv/POO6Z4y2vfehzGxcU5xzY2NppyW45Dy7g/z97yFHd/LMsDABBh+uzMHQCAzjBz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/luUBAIgwzNwBAIHFzN0fxR0AEFgUd38UdwBAYFHc/fXZ4h4VFeXcFrW5udk5ryVWsu14a/vMMWPGOMeeOHHClNvSmtNq/Pjxpvh//etfzrGuLYfPsLbytLC0/D169Kgpd21trSl+4sSJzrGWlq+Src2utcXywIEDnWOtbZAzMjKcY63tga0tli1traOioky5U1JSnGNPnz5tym1hOWZPnTrVY+OAmz5b3AEA6Awzd38UdwBAYFHc/Zm/CvfWW2/puuuuU3Z2tqKiorRmzZoOv7/99tsVFRXVYZs9e3a4xgsAADphnrk3NDRo2rRp+sY3vqEbb7zRN2b27NlauXJl6Gfr5S0BAHDBzN2fubjPmTNHc+bMOWdMbGysMjMzuzwoAABcUNz99UiHuvXr12vo0KG64IILdM8995zzLPLm5mbV1tZ22AAAQNeFvbjPnj1bzz77rIqLi/Wzn/1MGzZs0Jw5c9TW1uYbX1RUpOTk5NCWk5MT7iEBACLUmZl7d7ZIFPaz5W+55ZbQv6dMmaKpU6dqzJgxWr9+vWbOnPmZ+CVLlmjx4sWhn2traynwAAAnLMv76/ELx4wePVppaWnas2eP7+9jY2OVlJTUYQMAAF3X499zP3DggI4fP66srKyevisAwHmGmbs/c3Gvr6/vMAsvKyvT1q1blZqaqtTUVD3yyCOaN2+eMjMztXfvXn3/+9/X2LFjVVhYGNaBAwBAcfdnLu6bN2/W1VdfHfr5zOfl8+fP1/Lly7Vt2zb97ne/U3V1tbKzszVr1iw99thj5u+6Hzt2zLk/cUNDg3Pe6Oho0zhSU1OdY62rE5Y+0OXl5abclq8iNjY2mnJbnm9JGjJkiHOstVe85Xmx5j7bSaB+rG8Q1h7gpaWlzrGVlZWm3JZrBVhfx5b9M3r0aFNuS098a9966/686KKLnGP3799vyl1WVuYcm5CQYMptOa5uuOEG59i6ujrTOLorUgt0d5iL+4wZM875RK5bt65bAwIAAN1Db3kAQGCxLO+P4g4ACCyKu78e/yocAAD4fDFzBwAEFjN3fxR3AEBgUdz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P312eI+aNAgDRo0yCk2Li7OOW98fLxpHJZ+19ae6zt37nSOtV7jfuDAgc6x1oN7wIABpnhLf/H09HRTbkuP9mHDhply/+///q9z7Fe+8hVT7rNdAvlsLH3urX29Xa/hIEl79+415bZcV2DLli2m3JZrORw/ftyUe9SoUab44uJi51jrcWh5DjMyMky5m5ubnWOPHDniHFtfX28aR3dQ3P312eIOAEBnKO7++MwdAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdX58t7vHx8c7tZysrK53zpqSkmMZhabdpbctqaed44sQJU+6JEyf2yDgkW7tfScrMzHSOtbRClaTY2Fjn2I8++siU23KsvPrqq6bcl19+uSne0pp19+7dptzjxo1zjh0/frwpt4W1tW3//u5vX8nJyabc1mM8LS3NOdZyzEpSv37un55a973lfcJyDFpfxwi/PlvcAQDoDDN3fxR3AEBgUdz9cbY8ACDQzhT4rmxdtWzZMo0aNUpxcXGaPn26Nm3adM74F198URMmTFBcXJymTJmiv/71r12+bxcUdwAADJ5//nktXrxYDz/8sN577z1NmzZNhYWFZz1/6R//+Ie+9rWv6c4779T777+vuXPnau7cudq+fXuPjZHiDgAIrO7M2rs6e3/iiSe0YMEC3XHHHZo0aZJWrFih+Ph4PfPMM77xTz31lGbPnq3vfe97mjhxoh577DFdfPHF+tWvftXdh39WFHcAQGCFq7jX1tZ22Jqbm33vr6WlRVu2bFFBQUHotn79+qmgoEAlJSW+f1NSUtIhXpIKCwvPGh8OFHcAwHkvJydHycnJoa2oqMg37tixY2pra1NGRkaH2zMyMs76tezKykpTfDhwtjwAILDCdbZ8RUWFkpKSQrdb+xH0NRR3AEBghau4JyUldSjuZ5OWlqbo6GhVVVV1uL2qquqsDbsyMzNN8eHAsjwAAI5iYmKUl5en4uLi0G3t7e0qLi5Wfn6+79/k5+d3iJek11577azx4cDMHQAQWL3RxGbx4sWaP3++LrnkEl122WVaunSpGhoadMcdd0iSbrvtNg0bNiz0uf23v/1tffnLX9YvfvELffWrX9Vzzz2nzZs367/+67+6PO7O9Nni7nme2tvbnWI/faLCuVg/R9m5c6dzrGsv/DNGjRrlHFtfX2/KbemjbuldLUk1NTWm+D179jjHWpep9u/f7xxr6f8tScOHD++x3Pv27TPFp6enO8dajivJ9pzv2rXLlHvgwIHOsRdeeKEpt+U1cfr0aVPugwcPmuJbW1udY637x3KMW44TSWpsbHSOtVyDwPp+1R29UdxvvvlmHT16VA899JAqKyt10UUXae3ataFaVF5e3uF99YorrtDq1av1wAMP6Ic//KHGjRunNWvWaPLkyV0ed2f6bHEHAKCvWrRokRYtWuT7u/Xr13/mtptuukk33XRTD4/q/6O4AwACi97y/ijuAIDAorj7o7gDAAKL4u6Pr8IBABBhmLkDAAKLmbs/ijsAILAo7v5YlgcAIMIwcwcABBYzd38UdwBAYFHc/fXZ4j527FinK/RIH7f6c2VpEylJU6ZMcY5NSEgw5baMOzs725Tb8jirq6tNuaOjo03xluewtrbWlHv69OnOsda2uaWlpaZ4iy984QumeMtYrC18LS2Wy8rKTLlTU1OdY61vsnv37nWOLSgoMOU+dOiQKd7SxtV6XFneJ5qbm025LS2z6+rqnGNPnTplGgfCr88WdwAAOsPM3Z/phLqioiJdeumlSkxM1NChQzV37tzP/C+0qalJCxcu1JAhQ5SQkKB58+Z95jq2AACEw5ni3p0tEpmK+4YNG7Rw4UJt3LhRr732mlpbWzVr1iw1NDSEYu677z698sorevHFF7VhwwYdOnRIN954Y9gHDgAA/JmW5deuXdvh51WrVmno0KHasmWLrrrqKtXU1Oi3v/2tVq9erWuuuUaStHLlSk2cOFEbN27U5ZdfHr6RAwDOeyzL++vW99zPnKB05qSZLVu2qLW1tcPJKxMmTNCIESNUUlLim6O5uVm1tbUdNgAAXLAs76/Lxb29vV333nuvrrzyytAF5ysrKxUTE6OUlJQOsRkZGaqsrPTNU1RUpOTk5NCWk5PT1SEBAM5DFPbP6nJxX7hwobZv367nnnuuWwNYsmSJampqQltFRUW38gEAcL7r0lfhFi1apD//+c966623NHz48NDtmZmZamlpUXV1dYfZe1VV1Vm/exsbG6vY2NiuDAMAcJ7jM3d/ppm753latGiRXnrpJb3xxhvKzc3t8Pu8vDwNGDBAxcXFodtKS0tVXl6u/Pz88IwYAID/w2fu/kwz94ULF2r16tV6+eWXlZiYGPocPTk5WQMHDlRycrLuvPNOLV68WKmpqUpKStK3vvUt5efnc6Y8AACfE1NxX758uSRpxowZHW5fuXKlbr/9dknSk08+qX79+mnevHlqbm5WYWGhfv3rX4dlsAAAfBLL8v5Mxd3lSYiLi9OyZcu0bNmyLg9K+rjHeE886f37204zaG9vd449duyYKbelR7u1L7q1/7uF9Tk8ffq0c2x9fb0pt6Wn94kTJ0y5ExMTnWM/2cjJxQsvvGCKz8vLc461doRMTk52jo2Li+ux3NavwVqu5XDw4EFT7g8//NAUb+nn39bWZso9ZMgQ51jr4xw3bpxzrOU9pSfffz6N4u6P67kDABBhuHAMACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj767PFfffu3Ro0aJBT7MmTJ53zDhs2zDSOuro651hLm0hJn7mq3rns37/flLulpcU5NjU11ZTb2gr3+PHjzrGTJ0825U5LS3OO/fvf/27KbWmFm5GRYcpt1dTU5BxreT1I0uDBg51jre1nq6urnWPT09NNuS2Xit60aZMp9/jx403xFpbXpmR7Di2vByvX92P0DX22uAMA0Blm7v4o7gCAwKK4++NseQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/ijsAILAo7v74zB0AgAjDzB0AEGiROvvujj5b3OPj4xUfH+8Um5yc7Jz36NGjpnFYep1HRUWZcu/Zs8c5trGx0ZTb0nf70KFDptzWXvTt7e3OsceOHTPlPnDggHPskSNHTLmHDh3qHNvW1mbKbXlOJGnDhg3OsbNmzTLl3rFjh3Osdd9b+txb9qVk6/2fkpJiyt3a2mqKj46Odo4tKSkx5b788sudY6196xMSEpxjLa9N6/tVd7As749leQAAIkyfnbkDANAZZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACi+Luj+IOAAgsiru/Plvc4+PjNWjQIKfY9PR057zDhg0zjcPScrF/f9vT6dpeV5IyMzNNuQ8ePOgcO2TIEFPuiooKU/yFF17oHLtv3z5T7tOnTzvHjhs3zpTb0mq1qqrKlNvaqnj69OnOsdb2pgMHDnSOTUxMNOWuq6tzjk1KSjLlbmpqco5taGgw5ba2qba8B02ZMsWUu7a21hRv8d577znHWlpaf57tZ+GvzxZ3AAA6w8zdH8UdABBYFHd/fBUOAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAesCJEyd06623KikpSSkpKbrzzjtVX19/zr+ZMWOGoqKiOmx33323+b6ZuQMAAqsvz9xvvfVWHT58WK+99ppaW1t1xx136K677tLq1avP+XcLFizQo48+GvrZ0s30DIo7ACCw+mpx37Vrl9auXat//vOfuuSSSyRJv/zlL3XttdfqP/7jP5SdnX3Wv42Pjze3HP+0Plvc+/fv79yrfcCAAc55d+zYYRpHdHS0c6y1R7uFpQe0JMXFxTnHWsdt7dNtGXtubq4pt6WH/u7du025R48ebYq3qKmpMcVb/uduOWYlW597a0/81tZW59gjR46Ycre0tDjHWl4Pkq1XvCSVl5c7x1p6+UvS4MGDnWN37txpyp2fn+8ca+m3b9nv3RWu4v7pHv6xsbGmfvqfVlJSopSUlFBhl6SCggL169dP7777rm644Yaz/u0f/vAH/f73v1dmZqauu+46Pfjgg+bZe58t7gAAfF5ycnI6/Pzwww/rRz/6UZfzVVZWaujQoR1u69+/v1JTU1VZWXnWv/v617+ukSNHKjs7W9u2bdMPfvADlZaW6k9/+pPp/k0n1BUVFenSSy9VYmKihg4dqrlz56q0tLRDTLhOBgAAwMWZ2XtXtjMqKipUU1MT2pYsWeJ7X/fff/9natyntw8//LDLj+Wuu+5SYWGhpkyZoltvvVXPPvusXnrpJe3du9eUxzRz37BhgxYuXKhLL71Up0+f1g9/+EPNmjVLO3fu7HB51nCcDAAAQGfCtSyflJTkdNnh73znO7r99tvPGTN69GhlZmZ+5qOm06dP68SJE6bP089c7nnPnj0aM2aM89+ZivvatWs7/Lxq1SoNHTpUW7Zs0VVXXRW6PRwnAwAA0Nekp6c7nZORn5+v6upqbdmyRXl5eZKkN954Q+3t7aGC7WLr1q2SpKysLNM4u/U99zMnBaWmpna4/Q9/+IPS0tI0efJkLVmyRKdOnTprjubmZtXW1nbYAABw0Z0l+e7O+s9l4sSJmj17thYsWKBNmzbpnXfe0aJFi3TLLbeEzpQ/ePCgJkyYoE2bNkmS9u7dq8cee0xbtmzRvn379D//8z+67bbbdNVVV2nq1Kmm++/yCXXt7e269957deWVV2ry5Mmh260nAxQVFemRRx7p6jAAAOexvvpVOOnjie6iRYs0c+ZM9evXT/PmzdPTTz8d+n1ra6tKS0tDE+CYmBi9/vrrWrp0qRoaGpSTk6N58+bpgQceMN93l4v7woULtX37dr399tsdbr/rrrtC/54yZYqysrI0c+ZM7d271/fzgiVLlmjx4sWhn2traz9z1iIAAEGTmpp6zoY1o0aN6vCfi5ycHG3YsCEs992l4r5o0SL9+c9/1ltvvaXhw4efM7azkwG6+11CAMD5qy/P3HuTqbh7nqdvfetbeumll7R+/XqnhiNdPRkAAIDOUNz9mYr7woULtXr1ar388stKTEwMfRE/OTlZAwcO1N69e7V69Wpde+21GjJkiLZt26b77ruvSycDAACArjEV9+XLl0v6uFHNJ61cuVK33357WE8GAACgM8zc/ZmX5c8lnCcDVFRUODe/sfRqvuaaa0zjeP/9951j6+rqTLktB1V7e7spt6Uv+smTJ025J0yYYIrfs2ePc6y1J/Unmyd1Ztq0aabc1dXVPRIrSU1NTab45uZm51hLD3Dp46/suDp06JApt6XnurX/+0cffeQca2n+IUnHjx83xSckJDjHVlVVmXJbji3ra9PCpcHLGa7XBQkHirs/essDAAKL4u6vW01sAABA38PMHQAQWMzc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3V+fLe7JycnOrUUbGhqc827fvt00jtOnTzvH9utn+5Rj7NixzrGWFqGStH//fufYSZMmmXJbWopKUnp6unOstYVvYmKic2xxcbEpd2ZmpnPsqVOnTLmtx+GoUaOcY62XUH7nnXecYy2vB0mKiopyjrW28LW0OLW8HiTpww8/NMVfeOGFzrFf+tKXTLnr6+udY2NiYky5LW12U1JSnGM/z/az8MceAAAEFjN3fxR3AEBgUdz9UdwBAIFFcffHV+EAAIgwzNwBAIEWqbPv7qC4AwACi2V5fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/fba4t7e3q7293Sl26NChznmTk5NN4ygtLXWObWlpMeWuqqpyjt2xY4cpd1ZWlnPsunXrTLnz8vJM8U1NTc6xgwcPNuX+6KOPnGOPHj1qyv3GG284x+bk5JhyR0dHm+Itvegt1yyQpOzsbOfYX/3qV6bco0ePdo5ta2sz5W5tbe2x3PHx8ab4kydPOsc2NzebctfW1jrHWp4TSTpy5Ihz7MiRI51jP8/e8hR3fyzLAwAQYfrszB0AgM4wc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj7Y1keAIAIw8wdABBYzNz9UdwBAIFFcffXZ4v7yJEjlZiY6BRraft64MAB0zgsbSutrW0t487MzDTlrq+vd44dNmyYKfeJEydM8YcOHXKOtbQSlqTKykrn2FdffdWU29La9pJLLjHltrTytLK0K5VsLZaHDBliyt3Q0OAce/jwYVNuS4vlgwcPmnJbWvJK0vDhw51jra2HLS2ZrW1fLfvTUgQjtWAGSZ8t7gAAdIaZuz+KOwAgsCju/jhbHgCACMPMHQAQWMzc/VHcAQCBRXH3R3EHAAQWxd0fn7kDABBhmLkDAAItUmff3UFxBwAEVncLe6T+x4BleQAAIgwzdwBAYDFz99dni3tMTIxiYmKcYi39xadOnWoax65du5xjLX20JVvPaGvfbdfnTpISEhJMuY8ePWqKHzNmjHNseXm5KXdVVZVzbGFhoSn37373O+fY6upqU+6BAwea4o8dO+Yca30OLf3Frdc4aGxsdI7Nyckx5bYct3v27DHlnjJliine8vo8fvy4Kfe0adOcY0+ePGnKPWjQIOdYy3Fl2e/dRXH3x7I8AAARxlTcly9frqlTpyopKUlJSUnKz8/X3/72t9Dvm5qatHDhQg0ZMkQJCQmaN2+eaWYFAIDFme+5d2eLRKbiPnz4cP30pz/Vli1btHnzZl1zzTW6/vrrtWPHDknSfffdp1deeUUvvviiNmzYoEOHDunGG2/skYEDAEBx92f6zP26667r8PNPfvITLV++XBs3btTw4cP129/+VqtXr9Y111wjSVq5cqUmTpyojRs36vLLLw/fqAEAwFl1+TP3trY2Pffcc2poaFB+fr62bNmi1tZWFRQUhGImTJigESNGqKSk5Kx5mpubVVtb22EDAMAFM3d/5uL+wQcfKCEhQbGxsbr77rv10ksvadKkSaqsrFRMTIxSUlI6xGdkZJzzbPaioiIlJyeHNusZswCA8xfF3Z+5uF9wwQXaunWr3n33Xd1zzz2aP3++du7c2eUBLFmyRDU1NaGtoqKiy7kAAOcXirs/8/fcY2JiNHbsWElSXl6e/vnPf+qpp57SzTffrJaWFlVXV3eYvVdVVZ3zu7GxsbGKjY21jxwAAPjq9vfc29vb1dzcrLy8PA0YMEDFxcWh35WWlqq8vFz5+fndvRsAAD6Dmbs/08x9yZIlmjNnjkaMGKG6ujqtXr1a69ev17p165ScnKw777xTixcvVmpqqpKSkvStb31L+fn5nCkPAOgRdKjzZyruR44c0W233abDhw8rOTlZU6dO1bp16/SVr3xFkvTkk0+qX79+mjdvnpqbm1VYWKhf//rXXRpYXV2dc6yl9ePevXtN45g1a5ZzrPXcA0t8VlaWKfeJEyecYy3tRyVpwIABpnhLS8y4uDhT7ry8POfYrVu3mnLfcMMNzrGWFsiSlJSUZIpPTU11jrW2E7a81nJzc025LW15refbWHLPnz/flLutrc0Uf9VVVznHWlvb7tu3zznW8nqQbO8T/fq5L/RaYtEzTMX9t7/97Tl/HxcXp2XLlmnZsmXdGhQAAC6YufvrsxeOAQCgMxR3f6ydAADQA37yk5/oiiuuUHx8/Gd6wJyN53l66KGHlJWVpYEDB6qgoEC7d+823zfFHQAQWH35bPmWlhbddNNNuueee5z/5vHHH9fTTz+tFStW6N1339WgQYNUWFiopqYm032zLA8ACKy+vCz/yCOPSJJWrVrlPJalS5fqgQce0PXXXy9JevbZZ5WRkaE1a9bolltucb5vZu4AgPPep69x0tzc/LmPoaysTJWVlR2u0ZKcnKzp06ef8xotfijuAIDACteyfE5OTofrnBQVFX3uj+XMV2ozMjI63N7ZNVr8sCwPAAiscC3LV1RUdOg/cba26Pfff79+9rOfnTPnrl27NGHChG6Nq7so7gCAwApXcU9KSnJqLvWd73xHt99++zljRo8e3aWxnLkOS1VVVYfGZVVVVbroootMuSjuAAA4Sk9PV3p6eo/kzs3NVWZmpoqLi0PFvLa2NnQVVos+V9zP/C+qvr7e+W8aGhqcYxsbG03jqa2tdY61jFmSTp061WO5Lc+Jpf1oT4/FehJLS0uLc6z1qySW3K2trabc1vj29nbnWMu4uzKWnsptbfnak8+JdSyW49b6HmQ5bi3vKdaxREVFmfN+Xg1i+mojmvLycp04cULl5eVqa2sLtcAeO3ZsqE30hAkTVFRUpBtuuEFRUVG699579eMf/1jjxo1Tbm6uHnzwQWVnZ2vu3Lm2O/f6mIqKCk8SGxsbG1vAt4qKih6rFY2NjV5mZmZYxpmZmek1NjaGfYzz58/3vb8333wzFCPJW7lyZejn9vZ278EHH/QyMjK82NhYb+bMmV5paan5vqP+L3mf0d7erkOHDikxMbHD/xRra2uVk5PzmZMeIg2PM3KcD49R4nFGmnA8Ts/zVFdXp+zs7B69iExTU5N5VcZPTEyM+aJVfV2fW5bv16+fhg8fftbfu570EHQ8zshxPjxGiccZabr7OJOTk8M4Gn9xcXERV5TDhe+5AwAQYSjuAABEmMAU99jYWD388MNnbSwQKXickeN8eIwSjzPSnC+PM9L1uRPqAABA9wRm5g4AANxQ3AEAiDAUdwAAIgzFHQCACBOY4r5s2TKNGjVKcXFxmj59ujZt2tTbQwqrH/3oR4qKiuqw9fYlA7vrrbfe0nXXXafs7GxFRUVpzZo1HX7veZ4eeughZWVlaeDAgSooKNDu3bt7Z7Dd0NnjvP322z+zb2fPnt07g+2ioqIiXXrppUpMTNTQoUM1d+5clZaWdohpamrSwoULNWTIECUkJGjevHmqqqrqpRF3jcvjnDFjxmf25913391LI+6a5cuXa+rUqaFGNfn5+frb3/4W+n0k7MvzXSCK+/PPP6/Fixfr4Ycf1nvvvadp06apsLBQR44c6e2hhdWFF16ow4cPh7a33367t4fULQ0NDZo2bZqWLVvm+/vHH39cTz/9tFasWKF3331XgwYNUmFhofkCL72ts8cpSbNnz+6wb//4xz9+jiPsvg0bNmjhwoXauHGjXnvtNbW2tmrWrFkdLgp033336ZVXXtGLL76oDRs26NChQ7rxxht7cdR2Lo9TkhYsWNBhfz7++OO9NOKuGT58uH76059qy5Yt2rx5s6655hpdf/312rFjh6TI2JfnvW51xf+cXHbZZd7ChQtDP7e1tXnZ2dleUVFRL44qvB5++GFv2rRpvT2MHiPJe+mll0I/t7e3e5mZmd7Pf/7z0G3V1dVebGys98c//rEXRhgen36cnvfxxSOuv/76XhlPTzly5IgnyduwYYPneR/vuwEDBngvvvhiKGbXrl2eJK+kpKS3htltn36cnud5X/7yl71vf/vbvTeoHjJ48GDvN7/5TcTuy/NNn5+5t7S0aMuWLSooKAjd1q9fPxUUFKikpKQXRxZ+u3fvVnZ2tkaPHq1bb71V5eXlvT2kHlNWVqbKysoO+zU5OVnTp0+PuP0qSevXr9fQoUN1wQUX6J577tHx48d7e0jdUlNTI0lKTU2VJG3ZskWtra0d9ueECRM0YsSIQO/PTz/OM/7whz8oLS1NkydP1pIlS8yXWu1L2tra9Nxzz6mhoUH5+fkRuy/PN33uwjGfduzYMbW1tSkjI6PD7RkZGfrwww97aVThN336dK1atUoXXHCBDh8+rEceeURf+tKXtH37diUmJvb28MKusrJSknz365nfRYrZs2frxhtvVG5urvbu3asf/vCHmjNnjkpKShQdHd3bwzNrb2/XvffeqyuvvFKTJ0+W9PH+jImJUUpKSofYIO9Pv8cpSV//+tc1cuRIZWdna9u2bfrBD36g0tJS/elPf+rF0dp98MEHys/PV1NTkxISEvTSSy9p0qRJ2rp1a8Tty/NRny/u54s5c+aE/j116lRNnz5dI0eO1AsvvKA777yzF0eG7rrllltC/54yZYqmTp2qMWPGaP369Zo5c2YvjqxrFi5cqO3btwf+nJDOnO1x3nXXXaF/T5kyRVlZWZo5c6b27t2rMWPGfN7D7LILLrhAW7duVU1Njf77v/9b8+fP14YNG3p7WAiTPr8sn5aWpujo6M+cqVlVVaXMzMxeGlXPS0lJ0fjx47Vnz57eHkqPOLPvzrf9KkmjR49WWlpaIPftokWL9Oc//1lvvvlmh0szZ2ZmqqWlRdXV1R3ig7o/z/Y4/UyfPl2SArc/Y2JiNHbsWOXl5amoqEjTpk3TU089FXH78nzV54t7TEyM8vLyVFxcHLqtvb1dxcXFys/P78WR9az6+nrt3btXWVlZvT2UHpGbm6vMzMwO+7W2tlbvvvtuRO9XSTpw4ICOHz8eqH3reZ4WLVqkl156SW+88YZyc3M7/D4vL08DBgzosD9LS0tVXl4eqP3Z2eP0s3XrVkkK1P70097erubm5ojZl+e93j6jz8Vzzz3nxcbGeqtWrfJ27tzp3XXXXV5KSopXWVnZ20MLm+985zve+vXrvbKyMu+dd97xCgoKvLS0NO/IkSO9PbQuq6ur895//33v/fff9yR5TzzxhPf+++97+/fv9zzP83760596KSkp3ssvv+xt27bNu/76673c3FyvsbGxl0duc67HWVdX5333u9/1SkpKvLKyMu/111/3Lr74Ym/cuHFeU1NTbw/d2T333OMlJyd769ev9w4fPhzaTp06FYq5++67vREjRnhvvPGGt3nzZi8/P9/Lz8/vxVHbdfY49+zZ4z366KPe5s2bvbKyMu/ll1/2Ro8e7V111VW9PHKb+++/39uwYYNXVlbmbdu2zbv//vu9qKgo79VXX/U8LzL25fkuEMXd8zzvl7/8pTdixAgvJibGu+yyy7yNGzf29pDC6uabb/aysrK8mJgYb9iwYd7NN9/s7dmzp7eH1S1vvvmmJ+kz2/z58z3P+/jrcA8++KCXkZHhxcbGejNnzvRKS0t7d9BdcK7HeerUKW/WrFleenq6N2DAAG/kyJHeggULAvcfU7/HJ8lbuXJlKKaxsdH7t3/7N2/w4MFefHy8d8MNN3iHDx/uvUF3QWePs7y83Lvqqqu81NRULzY21hs7dqz3ve99z6upqendgRt94xvf8EaOHOnFxMR46enp3syZM0OF3fMiY1+e77jkKwAAEabPf+YOAABsKO4AAEQYijsAABGG4g4AQIShuAMAEGEo7gAARBiKOwAAEYbiDgBAhKG4AwAQYSjuAABEGIo7AAARhuIOAECE+X9eu/BjisWv5wAAAABJRU5ErkJggg==",
300
+ "text/plain": [
301
+ "<Figure size 640x480 with 2 Axes>"
302
+ ]
303
+ },
304
+ "metadata": {},
305
+ "output_type": "display_data"
306
+ }
307
+ ],
308
+ "source": [
309
+ "import torch\n",
310
+ "from PIL import Image\n",
311
+ "from diffusers import DDPMScheduler\n",
312
+ "\n",
313
+ "noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)\n",
314
+ "noise = torch.randn(sample_image.shape)\n",
315
+ "timesteps = torch.LongTensor([50]) # example timestep\n",
316
+ "noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n",
317
+ "\n",
318
+ "# print(\"Noisy image shape:\", noisy_image.shape)\n",
319
+ "# Image.fromarray(noisy_image.numpy().squeeze())\n",
320
+ "plt.imshow(noisy_image.detach().cpu().numpy().squeeze(), cmap='gray')\n",
321
+ "plt.colorbar()\n",
322
+ "# plt.axis('off')\n",
323
+ "plt.show()\n"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": 79,
329
+ "id": "4e878652",
330
+ "metadata": {},
331
+ "outputs": [],
332
+ "source": [
333
+ "import torch.nn.functional as F\n",
334
+ "\n",
335
+ "noise_pred = model(noisy_image, timesteps).sample\n",
336
+ "loss = F.mse_loss(noise_pred, noise)"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "7dffb56e",
342
+ "metadata": {},
343
+ "source": [
344
+ "**Train the model**"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 80,
350
+ "id": "62915913",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
355
+ "\n",
356
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n",
357
+ "lr_scheduler = get_cosine_schedule_with_warmup(\n",
358
+ " optimizer=optimizer,\n",
359
+ " num_warmup_steps=config.lr_warmup_steps,\n",
360
+ " num_training_steps=(len(train_dataloader) * config.num_epochs),\n",
361
+ ")"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 81,
367
+ "id": "f2f3427a",
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "from diffusers import DDPMPipeline\n",
372
+ "from diffusers.utils import make_image_grid\n",
373
+ "import os\n",
374
+ "\n",
375
+ "def evaluate(config, epoch, pipeline):\n",
376
+ " # Sample some images from random noise (this is the backward diffusion process).\n",
377
+ " # The default pipeline output type is `List[PIL.Image]`\n",
378
+ " images = pipeline(\n",
379
+ " batch_size=config.eval_batch_size,\n",
380
+ " generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop\n",
381
+ " ).images\n",
382
+ "\n",
383
+ " # Make a grid out of the images\n",
384
+ " image_grid = make_image_grid(images, rows=4, cols=4)\n",
385
+ "\n",
386
+ " # Save the images\n",
387
+ " test_dir = os.path.join(config.output_dir, \"samples\")\n",
388
+ " os.makedirs(test_dir, exist_ok=True)\n",
389
+ " image_grid.save(f\"{test_dir}/{epoch:04d}.png\")"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 85,
395
+ "id": "4fe935dd",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "from accelerate import Accelerator\n",
400
+ "from huggingface_hub import create_repo, upload_folder\n",
401
+ "from tqdm.auto import tqdm\n",
402
+ "from pathlib import Path\n",
403
+ "import os"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 86,
409
+ "id": "07df8a4c",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n",
414
+ " # Initialize accelerator and tensorboard logging\n",
415
+ " accelerator = Accelerator(\n",
416
+ " mixed_precision=config.mixed_precision,\n",
417
+ " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
418
+ " log_with=\"tensorboard\",\n",
419
+ " project_dir=os.path.join(config.output_dir, \"logs\"),\n",
420
+ " )\n",
421
+ " if accelerator.is_main_process:\n",
422
+ " if config.output_dir is not None:\n",
423
+ " os.makedirs(config.output_dir, exist_ok=True)\n",
424
+ " if config.push_to_hub:\n",
425
+ " repo_id = create_repo(\n",
426
+ " repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
427
+ " ).repo_id\n",
428
+ " accelerator.init_trackers(\"train_example\")\n",
429
+ "\n",
430
+ " # Prepare everything\n",
431
+ " # There is no specific order to remember, you just need to unpack the\n",
432
+ " # objects in the same order you gave them to the prepare method.\n",
433
+ " model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n",
434
+ " model, optimizer, train_dataloader, lr_scheduler\n",
435
+ " )\n",
436
+ "\n",
437
+ " global_step = 0\n",
438
+ "\n",
439
+ " # Now you train the model\n",
440
+ " for epoch in range(config.num_epochs):\n",
441
+ " progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n",
442
+ " progress_bar.set_description(f\"Epoch {epoch}\")\n",
443
+ "\n",
444
+ " for step, batch in enumerate(train_dataloader):\n",
445
+ " clean_images = batch[\"images\"]\n",
446
+ " # Sample noise to add to the images\n",
447
+ " noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
448
+ " bs = clean_images.shape[0]\n",
449
+ "\n",
450
+ " # Sample a random timestep for each image\n",
451
+ " timesteps = torch.randint(\n",
452
+ " 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n",
453
+ " dtype=torch.int64\n",
454
+ " )\n",
455
+ "\n",
456
+ " # Add noise to the clean images according to the noise magnitude at each timestep\n",
457
+ " # (this is the forward diffusion process)\n",
458
+ " noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n",
459
+ "\n",
460
+ " with accelerator.accumulate(model):\n",
461
+ " # Predict the noise residual\n",
462
+ " noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n",
463
+ " loss = F.mse_loss(noise_pred, noise)\n",
464
+ " accelerator.backward(loss)\n",
465
+ "\n",
466
+ " if accelerator.sync_gradients:\n",
467
+ " accelerator.clip_grad_norm_(model.parameters(), 1.0)\n",
468
+ " optimizer.step()\n",
469
+ " lr_scheduler.step()\n",
470
+ " optimizer.zero_grad()\n",
471
+ "\n",
472
+ " progress_bar.update(1)\n",
473
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n",
474
+ " progress_bar.set_postfix(**logs)\n",
475
+ " accelerator.log(logs, step=global_step)\n",
476
+ " global_step += 1\n",
477
+ "\n",
478
+ " # After each epoch you optionally sample some demo images with evaluate() and save the model\n",
479
+ " if accelerator.is_main_process:\n",
480
+ " pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n",
481
+ "\n",
482
+ " if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n",
483
+ " evaluate(config, epoch, pipeline)\n",
484
+ "\n",
485
+ " if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n",
486
+ " if config.push_to_hub:\n",
487
+ " upload_folder(\n",
488
+ " repo_id=repo_id,\n",
489
+ " folder_path=config.output_dir,\n",
490
+ " commit_message=f\"Epoch {epoch}\",\n",
491
+ " ignore_patterns=[\"step_*\", \"epoch_*\"],\n",
492
+ " )\n",
493
+ " else:\n",
494
+ " pipeline.save_pretrained(config.output_dir)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": 88,
500
+ "id": "5d820f05",
501
+ "metadata": {},
502
+ "outputs": [
503
+ {
504
+ "name": "stdout",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Launching training on one GPU.\n"
508
+ ]
509
+ },
510
+ {
511
+ "name": "stderr",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/accelerator.py:529: UserWarning: `log_with=tensorboard` was passed but no supported trackers are currently installed.\n",
515
+ " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n",
516
+ "Epoch 0: 0%| | 0/344 [00:00<?, ?it/s]/tmp/ipykernel_288020/2880730127.py:33: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)\n",
517
+ " clean_images = batch[\"images\"]\n"
518
+ ]
519
+ },
520
+ {
521
+ "ename": "IndexError",
522
+ "evalue": "too many indices for tensor of dimension 4",
523
+ "output_type": "error",
524
+ "traceback": [
525
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
526
+ "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
527
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[88]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01maccelerate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m notebook_launcher\n\u001b[32m 2\u001b[39m args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mnotebook_launcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
528
+ "\u001b[36mFile \u001b[39m\u001b[32m/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/launchers.py:270\u001b[39m, in \u001b[36mnotebook_launcher\u001b[39m\u001b[34m(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes, rdzv_backend, rdzv_endpoint, rdzv_conf, rdzv_id, max_restarts, monitor_interval, log_line_prefix_template)\u001b[39m\n\u001b[32m 268\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 269\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mLaunching training on CPU.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m270\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
529
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[86]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\u001b[39m\n\u001b[32m 30\u001b[39m progress_bar.set_description(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m clean_images = \u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mimages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[38;5;66;03m# Sample noise to add to the images\u001b[39;00m\n\u001b[32m 35\u001b[39m noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
530
+ "\u001b[31mIndexError\u001b[39m: too many indices for tensor of dimension 4"
531
+ ]
532
+ }
533
+ ],
534
+ "source": [
535
+ "from accelerate import notebook_launcher\n",
536
+ "args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n",
537
+ "notebook_launcher(train_loop, args, num_processes=1)"
538
+ ]
539
+ }
540
+ ],
541
+ "metadata": {
542
+ "kernelspec": {
543
+ "display_name": "my-env",
544
+ "language": "python",
545
+ "name": "python3"
546
+ },
547
+ "language_info": {
548
+ "codemirror_mode": {
549
+ "name": "ipython",
550
+ "version": 3
551
+ },
552
+ "file_extension": ".py",
553
+ "mimetype": "text/x-python",
554
+ "name": "python",
555
+ "nbconvert_exporter": "python",
556
+ "pygments_lexer": "ipython3",
557
+ "version": "3.12.3"
558
+ }
559
+ },
560
+ "nbformat": 4,
561
+ "nbformat_minor": 5
562
+ }
samples/0009.png CHANGED
samples/0019.png CHANGED
samples/0029.png CHANGED
samples/0039.png CHANGED
samples/0049.png CHANGED