GAN_MNIST / README.md
hussamalafandi's picture
Upload folder using huggingface_hub
b25711e verified
metadata
tags:
  - gan
  - mnist
  - pytorch
  - generative-model
  - deep-learning
license: mit
datasets:
  - mnist
library_name: pytorch

GAN for MNIST Digit Generation

This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course.

Model Details

  • Model Type: GAN
  • Dataset: MNIST (handwritten digits)
  • Generator Input: Latent vector of size 100
  • Output: 28x28 grayscale images
  • Framework: PyTorch

Training Details

  • Optimizer: Adam
  • Learning Rate: 0.0002
  • Beta1: 0.5
  • Epochs: 50
  • Batch Size: 64
  • Weight Decay: 0.0001
  • Logging: Weights & Biases

Usage

Loading the Model

To load the trained model, use the following code snippet:

from gan import Generator
import torch

latent_dim = 100
generator = Generator(latent_dim)
generator.load_state_dict(torch.load("./gan_mnist.pth"))
generator.eval()

# Generate samples
z = torch.randn(16, latent_dim)
samples = generator(z)

Example Results

generated images

References

License

This project is licensed under the MIT License.