# Explore generative adversarial networks and its features on simple data set

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 64
channels = 3

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity.view(-1, 1)

discriminator = Discriminator().to(device)

def preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)
    return img_tensor, img

def classify_and_show_image(image_path):
    discriminator.eval()
    img_tensor, orig_img = preprocess_image(image_path)

    with torch.no_grad():
        pred = discriminator(img_tensor).item()
        classification = "Real" if pred > 0.5 else "Fake"

        plt.figure(figsize=(8, 6))
        plt.imshow(orig_img)
        plt.title(f"Discriminator Prediction: {classification} ({pred:.4f})")
        plt.axis('off')
        plt.show()

        print(f"Discriminator Output: {pred:.4f} -> {classification}")

classify_and_show_image('/content/thala.jpg')
