jainadarsh commited on
Commit
3819fa9
·
verified ·
1 Parent(s): 2930b34
samples/.ipynb_checkpoints/0009-checkpoint.png ADDED
samples/.ipynb_checkpoints/0019-checkpoint.png ADDED
samples/.ipynb_checkpoints/0029-checkpoint.png ADDED
samples/.ipynb_checkpoints/0039-checkpoint.png ADDED
samples/.ipynb_checkpoints/0049-checkpoint.png ADDED
samples/.ipynb_checkpoints/overview-checkpoint.ipynb ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 19,
6
+ "id": "9074d4b4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "from torch.utils.data import Dataset, DataLoader\n",
12
+ "import scipy.io\n",
13
+ "from torchvision import transforms, utils\n",
14
+ "import os\n",
15
+ "from PIL import Image"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 20,
21
+ "id": "c85ff158",
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "class GravityDataset(Dataset):\n",
26
+ " def __init__(self, folder, image_size, exts=['mat']):\n",
27
+ " super().__init__()\n",
28
+ " self.folder = folder\n",
29
+ " self.image_size = image_size\n",
30
+ " self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.mat')]\n",
31
+ "\n",
32
+ " # Define transformations that are independent of scaling\n",
33
+ " self.transform = transforms.Compose([\n",
34
+ " transforms.Resize((int(image_size * 1.12), int(image_size * 1.12))), # Resize slightly larger\n",
35
+ " transforms.RandomCrop(image_size), # Then crop to the target size\n",
36
+ " transforms.RandomHorizontalFlip(), # Random horizontal flip\n",
37
+ " transforms.ToTensor() # Convert to tensor\n",
38
+ " ])\n",
39
+ "\n",
40
+ " def scale_to_minus1_1(self, tensor):\n",
41
+ " \"\"\"Dynamically scale the tensor to the range [-1, 1] based on its own min and max values.\"\"\"\n",
42
+ " min_val = tensor.min()\n",
43
+ " max_val = tensor.max()\n",
44
+ " # Avoid division by zero if min and max are the same\n",
45
+ " if max_val > min_val:\n",
46
+ " return 2 * ((tensor - min_val) / (max_val - min_val)) - 1\n",
47
+ " else:\n",
48
+ " return tensor # If min and max are the same, return tensor as is.\n",
49
+ " # tensor = tensor\n",
50
+ " return tensor # If min and max are the same, return tensor as is.\n",
51
+ "\n",
52
+ " def __len__(self):\n",
53
+ " return len(self.paths)\n",
54
+ "\n",
55
+ " def __getitem__(self, index):\n",
56
+ " file_path = self.paths[index]\n",
57
+ " # Load the .mat file\n",
58
+ " data_loc = scipy.io.loadmat(file_path)\n",
59
+ " data_val = data_loc['d']\n",
60
+ "\n",
61
+ " data_val = data_val.reshape(32, 32)\n",
62
+ " # Convert numpy array as a PIL image\n",
63
+ " img = Image.fromarray(data_val)\n",
64
+ "\n",
65
+ " # Apply transformations\n",
66
+ " if self.transform:\n",
67
+ " img = self.transform(img)\n",
68
+ " \n",
69
+ " # Scale the image to [-1, 1] based on its own min and max values\n",
70
+ " img = self.scale_to_minus1_1(img)\n",
71
+ " return img"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 21,
77
+ "id": "c16dc0d1",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
82
+ "dataset = GravityDataset(data, image_size=32)\n",
83
+ "dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 22,
89
+ "id": "a8bd15c0",
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "name": "stderr",
94
+ "output_type": "stream",
95
+ "text": [
96
+ "Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 28.56it/s]\n"
97
+ ]
98
+ },
99
+ {
100
+ "data": {
101
+ "text/plain": [
102
+ "4"
103
+ ]
104
+ },
105
+ "execution_count": 22,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "from diffusers import StableDiffusionPipeline\n",
112
+ "\n",
113
+ "pipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n",
114
+ "pipeline.unet.config[\"in_channels\"]\n",
115
+ "4"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "id": "72b20338",
121
+ "metadata": {},
122
+ "source": [
123
+ "**Training Configuration**"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 23,
129
+ "id": "e01c3fae",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "from dataclasses import dataclass\n",
134
+ "\n",
135
+ "@dataclass\n",
136
+ "class TrainingConfig:\n",
137
+ " image_size = 32 # the generated image resolution\n",
138
+ " train_batch_size = 32\n",
139
+ " eval_batch_size = 16 # how many images to sample during evaluation\n",
140
+ " num_epochs = 50\n",
141
+ " gradient_accumulation_steps = 1\n",
142
+ " learning_rate = 1e-4\n",
143
+ " lr_warmup_steps = 500\n",
144
+ " save_image_epochs = 10\n",
145
+ " save_model_epochs = 30\n",
146
+ " mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n",
147
+ " output_dir = \"ddpm-butterflies-128\" # the model name locally and on the HF Hub\n",
148
+ "\n",
149
+ " push_to_hub = True # whether to upload the saved model to the HF Hub\n",
150
+ " hub_model_id = \"jainadarsh/trial\" # the name of the repository to create on the HF Hub\n",
151
+ " hub_private_repo = None\n",
152
+ " overwrite_output_dir = True # overwrite the old model when re-running the notebook\n",
153
+ " seed = 0\n",
154
+ "\n",
155
+ "\n",
156
+ "config = TrainingConfig()"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "9f52a27e",
162
+ "metadata": {},
163
+ "source": [
164
+ "**Load Dataset**"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 74,
170
+ "id": "b4a98567",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "# from datasets import load_dataset\n",
175
+ "\n",
176
+ "# config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n",
177
+ "# dataset = load_dataset(config.dataset_name, split=\"train\")\n",
178
+ "\n",
179
+ "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n",
180
+ "dataset = GravityDataset(data, image_size=32)\n",
181
+ "train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 75,
187
+ "id": "e3aed193",
188
+ "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "data": {
192
+ "image/png": "",
193
+ "text/plain": [
194
+ "<Figure size 1600x400 with 4 Axes>"
195
+ ]
196
+ },
197
+ "metadata": {},
198
+ "output_type": "display_data"
199
+ }
200
+ ],
201
+ "source": [
202
+ "import matplotlib.pyplot as plt\n",
203
+ "import torch\n",
204
+ "\n",
205
+ "batch = next(iter(dataloader)) # batch shape (B, C, H, W)\n",
206
+ "imgs = batch[:4] # take first 4\n",
207
+ "\n",
208
+ "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
209
+ "for i, img in enumerate(imgs):\n",
210
+ " img = img.detach().cpu()\n",
211
+ " if img.ndim == 3: # (C,H,W) -> (H,W,C)\n",
212
+ " img = img.permute(1, 2, 0)\n",
213
+ " img = (img + 1) / 2 # convert from [-1,1] to [0,1]\n",
214
+ " arr = img.numpy().squeeze()\n",
215
+ " axs[i].imshow(arr, cmap='gray' if arr.ndim == 2 else None)\n",
216
+ " axs[i].axis('off')\n",
217
+ "plt.show()"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "id": "06c1dff6",
223
+ "metadata": {},
224
+ "source": [
225
+ "**Create a Architecture**"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 76,
231
+ "id": "3ee088b3",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "from diffusers import UNet2DModel\n",
236
+ "\n",
237
+ "model = UNet2DModel(\n",
238
+ " sample_size=config.image_size, # the target image resolution\n",
239
+ " in_channels=1, # the number of input channels, 3 for RGB images\n",
240
+ " out_channels=1, # the number of output channels\n",
241
+ " layers_per_block=2, # how many ResNet layers to use per UNet block\n",
242
+ " block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block\n",
243
+ " down_block_types=(\n",
244
+ " \"DownBlock2D\", # a regular ResNet downsampling block\n",
245
+ " \"DownBlock2D\",\n",
246
+ " \"DownBlock2D\",\n",
247
+ " \"DownBlock2D\",\n",
248
+ " \"AttnDownBlock2D\", # a ResNet downsampling block with spatial self-attention\n",
249
+ " \"DownBlock2D\",\n",
250
+ " ),\n",
251
+ " up_block_types=(\n",
252
+ " \"UpBlock2D\", # a regular ResNet upsampling block\n",
253
+ " \"AttnUpBlock2D\", # a ResNet upsampling block with spatial self-attention\n",
254
+ " \"UpBlock2D\",\n",
255
+ " \"UpBlock2D\",\n",
256
+ " \"UpBlock2D\",\n",
257
+ " \"UpBlock2D\",\n",
258
+ " ),\n",
259
+ ")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 77,
265
+ "id": "8a703029",
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "name": "stdout",
270
+ "output_type": "stream",
271
+ "text": [
272
+ "Input shape: torch.Size([1, 1, 32, 32])\n",
273
+ "Output shape: torch.Size([1, 1, 32, 32])\n"
274
+ ]
275
+ }
276
+ ],
277
+ "source": [
278
+ "sample_image = dataset.__getitem__(0).unsqueeze(0) # add batch dimension\n",
279
+ "print(\"Input shape:\", sample_image.shape)\n",
280
+ "print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "8d3d9af1",
286
+ "metadata": {},
287
+ "source": [
288
+ "**Create a scheduler**"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 78,
294
+ "id": "22c2ce21",
295
+ "metadata": {},
296
+ "outputs": [
297
+ {
298
+ "data": {
299
+ "image/png": "",
300
+ "text/plain": [
301
+ "<Figure size 640x480 with 2 Axes>"
302
+ ]
303
+ },
304
+ "metadata": {},
305
+ "output_type": "display_data"
306
+ }
307
+ ],
308
+ "source": [
309
+ "import torch\n",
310
+ "from PIL import Image\n",
311
+ "from diffusers import DDPMScheduler\n",
312
+ "\n",
313
+ "noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)\n",
314
+ "noise = torch.randn(sample_image.shape)\n",
315
+ "timesteps = torch.LongTensor([50]) # example timestep\n",
316
+ "noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n",
317
+ "\n",
318
+ "# print(\"Noisy image shape:\", noisy_image.shape)\n",
319
+ "# Image.fromarray(noisy_image.numpy().squeeze())\n",
320
+ "plt.imshow(noisy_image.detach().cpu().numpy().squeeze(), cmap='gray')\n",
321
+ "plt.colorbar()\n",
322
+ "# plt.axis('off')\n",
323
+ "plt.show()\n"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": 79,
329
+ "id": "4e878652",
330
+ "metadata": {},
331
+ "outputs": [],
332
+ "source": [
333
+ "import torch.nn.functional as F\n",
334
+ "\n",
335
+ "noise_pred = model(noisy_image, timesteps).sample\n",
336
+ "loss = F.mse_loss(noise_pred, noise)"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "7dffb56e",
342
+ "metadata": {},
343
+ "source": [
344
+ "**Train the model**"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 80,
350
+ "id": "62915913",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
355
+ "\n",
356
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n",
357
+ "lr_scheduler = get_cosine_schedule_with_warmup(\n",
358
+ " optimizer=optimizer,\n",
359
+ " num_warmup_steps=config.lr_warmup_steps,\n",
360
+ " num_training_steps=(len(train_dataloader) * config.num_epochs),\n",
361
+ ")"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 81,
367
+ "id": "f2f3427a",
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "from diffusers import DDPMPipeline\n",
372
+ "from diffusers.utils import make_image_grid\n",
373
+ "import os\n",
374
+ "\n",
375
+ "def evaluate(config, epoch, pipeline):\n",
376
+ " # Sample some images from random noise (this is the backward diffusion process).\n",
377
+ " # The default pipeline output type is `List[PIL.Image]`\n",
378
+ " images = pipeline(\n",
379
+ " batch_size=config.eval_batch_size,\n",
380
+ " generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop\n",
381
+ " ).images\n",
382
+ "\n",
383
+ " # Make a grid out of the images\n",
384
+ " image_grid = make_image_grid(images, rows=4, cols=4)\n",
385
+ "\n",
386
+ " # Save the images\n",
387
+ " test_dir = os.path.join(config.output_dir, \"samples\")\n",
388
+ " os.makedirs(test_dir, exist_ok=True)\n",
389
+ " image_grid.save(f\"{test_dir}/{epoch:04d}.png\")"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 85,
395
+ "id": "4fe935dd",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "from accelerate import Accelerator\n",
400
+ "from huggingface_hub import create_repo, upload_folder\n",
401
+ "from tqdm.auto import tqdm\n",
402
+ "from pathlib import Path\n",
403
+ "import os"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 86,
409
+ "id": "07df8a4c",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n",
414
+ " # Initialize accelerator and tensorboard logging\n",
415
+ " accelerator = Accelerator(\n",
416
+ " mixed_precision=config.mixed_precision,\n",
417
+ " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
418
+ " log_with=\"tensorboard\",\n",
419
+ " project_dir=os.path.join(config.output_dir, \"logs\"),\n",
420
+ " )\n",
421
+ " if accelerator.is_main_process:\n",
422
+ " if config.output_dir is not None:\n",
423
+ " os.makedirs(config.output_dir, exist_ok=True)\n",
424
+ " if config.push_to_hub:\n",
425
+ " repo_id = create_repo(\n",
426
+ " repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
427
+ " ).repo_id\n",
428
+ " accelerator.init_trackers(\"train_example\")\n",
429
+ "\n",
430
+ " # Prepare everything\n",
431
+ " # There is no specific order to remember, you just need to unpack the\n",
432
+ " # objects in the same order you gave them to the prepare method.\n",
433
+ " model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n",
434
+ " model, optimizer, train_dataloader, lr_scheduler\n",
435
+ " )\n",
436
+ "\n",
437
+ " global_step = 0\n",
438
+ "\n",
439
+ " # Now you train the model\n",
440
+ " for epoch in range(config.num_epochs):\n",
441
+ " progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n",
442
+ " progress_bar.set_description(f\"Epoch {epoch}\")\n",
443
+ "\n",
444
+ " for step, batch in enumerate(train_dataloader):\n",
445
+ " clean_images = batch[\"images\"]\n",
446
+ " # Sample noise to add to the images\n",
447
+ " noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
448
+ " bs = clean_images.shape[0]\n",
449
+ "\n",
450
+ " # Sample a random timestep for each image\n",
451
+ " timesteps = torch.randint(\n",
452
+ " 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n",
453
+ " dtype=torch.int64\n",
454
+ " )\n",
455
+ "\n",
456
+ " # Add noise to the clean images according to the noise magnitude at each timestep\n",
457
+ " # (this is the forward diffusion process)\n",
458
+ " noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n",
459
+ "\n",
460
+ " with accelerator.accumulate(model):\n",
461
+ " # Predict the noise residual\n",
462
+ " noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n",
463
+ " loss = F.mse_loss(noise_pred, noise)\n",
464
+ " accelerator.backward(loss)\n",
465
+ "\n",
466
+ " if accelerator.sync_gradients:\n",
467
+ " accelerator.clip_grad_norm_(model.parameters(), 1.0)\n",
468
+ " optimizer.step()\n",
469
+ " lr_scheduler.step()\n",
470
+ " optimizer.zero_grad()\n",
471
+ "\n",
472
+ " progress_bar.update(1)\n",
473
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n",
474
+ " progress_bar.set_postfix(**logs)\n",
475
+ " accelerator.log(logs, step=global_step)\n",
476
+ " global_step += 1\n",
477
+ "\n",
478
+ " # After each epoch you optionally sample some demo images with evaluate() and save the model\n",
479
+ " if accelerator.is_main_process:\n",
480
+ " pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n",
481
+ "\n",
482
+ " if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n",
483
+ " evaluate(config, epoch, pipeline)\n",
484
+ "\n",
485
+ " if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n",
486
+ " if config.push_to_hub:\n",
487
+ " upload_folder(\n",
488
+ " repo_id=repo_id,\n",
489
+ " folder_path=config.output_dir,\n",
490
+ " commit_message=f\"Epoch {epoch}\",\n",
491
+ " ignore_patterns=[\"step_*\", \"epoch_*\"],\n",
492
+ " )\n",
493
+ " else:\n",
494
+ " pipeline.save_pretrained(config.output_dir)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": 88,
500
+ "id": "5d820f05",
501
+ "metadata": {},
502
+ "outputs": [
503
+ {
504
+ "name": "stdout",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Launching training on one GPU.\n"
508
+ ]
509
+ },
510
+ {
511
+ "name": "stderr",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/accelerator.py:529: UserWarning: `log_with=tensorboard` was passed but no supported trackers are currently installed.\n",
515
+ " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n",
516
+ "Epoch 0: 0%| | 0/344 [00:00<?, ?it/s]/tmp/ipykernel_288020/2880730127.py:33: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)\n",
517
+ " clean_images = batch[\"images\"]\n"
518
+ ]
519
+ },
520
+ {
521
+ "ename": "IndexError",
522
+ "evalue": "too many indices for tensor of dimension 4",
523
+ "output_type": "error",
524
+ "traceback": [
525
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
526
+ "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
527
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[88]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01maccelerate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m notebook_launcher\n\u001b[32m 2\u001b[39m args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mnotebook_launcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
528
+ "\u001b[36mFile \u001b[39m\u001b[32m/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/launchers.py:270\u001b[39m, in \u001b[36mnotebook_launcher\u001b[39m\u001b[34m(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes, rdzv_backend, rdzv_endpoint, rdzv_conf, rdzv_id, max_restarts, monitor_interval, log_line_prefix_template)\u001b[39m\n\u001b[32m 268\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 269\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mLaunching training on CPU.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m270\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
529
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[86]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\u001b[39m\n\u001b[32m 30\u001b[39m progress_bar.set_description(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m clean_images = \u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mimages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[38;5;66;03m# Sample noise to add to the images\u001b[39;00m\n\u001b[32m 35\u001b[39m noise = torch.randn(clean_images.shape, device=clean_images.device)\n",
530
+ "\u001b[31mIndexError\u001b[39m: too many indices for tensor of dimension 4"
531
+ ]
532
+ }
533
+ ],
534
+ "source": [
535
+ "from accelerate import notebook_launcher\n",
536
+ "args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n",
537
+ "notebook_launcher(train_loop, args, num_processes=1)"
538
+ ]
539
+ }
540
+ ],
541
+ "metadata": {
542
+ "kernelspec": {
543
+ "display_name": "my-env",
544
+ "language": "python",
545
+ "name": "python3"
546
+ },
547
+ "language_info": {
548
+ "codemirror_mode": {
549
+ "name": "ipython",
550
+ "version": 3
551
+ },
552
+ "file_extension": ".py",
553
+ "mimetype": "text/x-python",
554
+ "name": "python",
555
+ "nbconvert_exporter": "python",
556
+ "pygments_lexer": "ipython3",
557
+ "version": "3.12.3"
558
+ }
559
+ },
560
+ "nbformat": 4,
561
+ "nbformat_minor": 5
562
+ }
samples/0009.png CHANGED
samples/0019.png CHANGED
samples/0029.png CHANGED
samples/0039.png CHANGED
samples/0049.png CHANGED