# Experiment the model performance using the available CNN models

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import VGG16, InceptionV3, ResNet50

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess data
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# =============================================================================
# VGG16 Model
# =============================================================================

# Prepare data for VGG16 (resize to 32x32 and convert to RGB)
x_train_vgg = tf.image.grayscale_to_rgb(tf.image.resize(x_train, (32, 32)))
x_test_vgg = tf.image.grayscale_to_rgb(tf.image.resize(x_test, (32, 32)))

# Create VGG16 base model
vgg_base = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
vgg_base.trainable = False

# Build VGG16 model
vgg_model = models.Sequential([
    vgg_base,
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

# Compile and train VGG16
vgg_model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
vgg_history = vgg_model.fit(x_train_vgg, y_train,
                           epochs=3, batch_size=64,
                           validation_split=0.1, verbose=1)
vgg_acc = vgg_model.evaluate(x_test_vgg, y_test, verbose=0)[1]

# =============================================================================
# InceptionV3 Model
# =============================================================================

# Prepare data for InceptionV3 (resize to 75x75 and convert to RGB)
x_train_incept = tf.image.grayscale_to_rgb(tf.image.resize(x_train, (75, 75)))
x_test_incept = tf.image.grayscale_to_rgb(tf.image.resize(x_test, (75, 75)))

# Create InceptionV3 base model
incept_base = InceptionV3(weights='imagenet', include_top=False, input_shape=(75, 75, 3))
incept_base.trainable = False

# Build InceptionV3 model
incept_model = models.Sequential([
    incept_base,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile and train InceptionV3
incept_model.compile(optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])
incept_history = incept_model.fit(x_train_incept, y_train,
                                 epochs=3, batch_size=64,
                                 validation_split=0.1, verbose=1)
incept_acc = incept_model.evaluate(x_test_incept, y_test, verbose=0)[1]

# =============================================================================
# ResNet50 Model
# =============================================================================

# Prepare data for ResNet50 (resize to 32x32 and convert to RGB)
x_train_resnet = tf.image.grayscale_to_rgb(tf.image.resize(x_train, (32, 32)))
x_test_resnet = tf.image.grayscale_to_rgb(tf.image.resize(x_test, (32, 32)))

# Create ResNet50 base model
resnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
resnet_base.trainable = False

# Build ResNet50 model
resnet_model = models.Sequential([
    resnet_base,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile and train ResNet50
resnet_model.compile(optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])
resnet_history = resnet_model.fit(x_train_resnet, y_train,
                                 epochs=3, batch_size=64,
                                 validation_split=0.1, verbose=1)
resnet_acc = resnet_model.evaluate(x_test_resnet, y_test, verbose=0)[1]

# =============================================================================
# Compare Model Performance
# =============================================================================

# Store all model accuracies
models_acc = {
    "VGG16": vgg_acc,
    "GoogLeNet (InceptionV3)": incept_acc,
    "ResNet50": resnet_acc
}

# Print results
print("Model Performance Comparison:")
print("=" * 40)
for model_name, acc in models_acc.items():
    print(f"{model_name} Test Accuracy: {acc:.4f}")

# Find best model
best_model = max(models_acc, key=models_acc.get)
print(f"\nBest Model: {best_model}")
print(f"Best Accuracy: {models_acc[best_model]:.4f}")












