--- tags: - gan - mnist - pytorch - generative-model - deep-learning license: mit datasets: - mnist library_name: pytorch --- # GAN for MNIST Digit Generation This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course. ## Model Details - **Model Type**: GAN - **Dataset**: MNIST (handwritten digits) - **Generator Input**: Latent vector of size 100 - **Output**: 28x28 grayscale images - **Framework**: PyTorch ## Training Details - **Optimizer**: Adam - **Learning Rate**: 0.0002 - **Beta1**: 0.5 - **Epochs**: 50 - **Batch Size**: 64 - **Weight Decay**: 0.0001 - **Logging**: [Weights & Biases](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi) ## Usage ### Loading the Model To load the trained model, use the following code snippet: ```python from gan import Generator import torch latent_dim = 100 generator = Generator(latent_dim) generator.load_state_dict(torch.load("./gan_mnist.pth")) generator.eval() # Generate samples z = torch.randn(16, latent_dim) samples = generator(z) ``` ## Example Results ![generated images](./gan_mnist.png) ## References - [Generative AI Course Repository](https://github.com/hussamalafandi/Generative_AI) - [Weights & Biases Training Logs](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi) ## License This project is licensed under the MIT License.