# Experiment image to image translation using GAN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

latent_dim = 100
num_classes = 10
img_size = 28
img_dim = img_size * img_size
batch_size = 64
lr = 0.0002
num_epochs = 102

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=True,
                                          transform=transform,
                                          download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_dim):
        super(Generator, self).__init__()
        input_dim = latent_dim + num_classes

        self.gen = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, img_dim),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        x = torch.cat((noise, labels), dim=1)
        img = self.gen(x)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_dim, num_classes):
        super(Discriminator, self).__init__()
        input_dim = img_dim + num_classes

        self.disc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        x = torch.cat((img, labels), dim=1)
        validity = self.disc(x)
        return validity

generator = Generator(latent_dim, num_classes, img_dim).to(device)
discriminator = Discriminator(img_dim, num_classes).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

adversarial_loss = nn.BCELoss()

def one_hot(labels, num_classes=num_classes):
    return torch.zeros(labels.size(0), num_classes, device=device).scatter_(1, labels.view(-1, 1), 1)

for epoch in range(num_epochs):
    for i, (imgs, labels) in enumerate(train_loader):
        batch_size_i = imgs.size(0)
        real_imgs = imgs.view(batch_size_i, -1).to(device)
        labels_onehot = one_hot(labels)
        valid = torch.ones(batch_size_i, 1, device=device)
        fake = torch.zeros(batch_size_i, 1, device=device)

        optimizer_G.zero_grad()
        z = torch.randn(batch_size_i, latent_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size_i,), device=device)
        gen_labels_onehot = one_hot(gen_labels)
        gen_imgs = generator(z, gen_labels_onehot)
        validity = discriminator(gen_imgs, gen_labels_onehot)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_validity = discriminator(real_imgs, labels_onehot)
        d_real_loss = adversarial_loss(real_validity, valid)
        fake_validity = discriminator(gen_imgs.detach(), gen_labels_onehot)
        d_fake_loss = adversarial_loss(fake_validity, fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] D loss: {d_loss.item():.4f} G loss: {g_loss.item():.4f}")

def generate_digit_image(digit=0):
    generator.eval()
    noise = torch.randn(1, latent_dim, device=device)
    label = torch.tensor([digit], device=device)
    label_onehot = one_hot(label)

    with torch.no_grad():
        gen_img = generator(noise, label_onehot)
        gen_img = gen_img.view(img_size, img_size).cpu()
        gen_img = (gen_img + 1) / 2

        plt.imshow(gen_img, cmap='gray')
        plt.title(f"Generated Digit: {digit}")
        plt.axis('off')
        plt.show()

generate_digit_image(7)

