import tensorflow as tf
import matplotlib.pyplot as plt
from keras import models, layers, optimizers

from keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

from sklearn import preprocessing

import pandas as pd
import numpy as np
imdb_df = pd.read_csv('data/IMDB Dataset.csv')
train_df = imdb_df.sample(frac=0.8, random_state=42)
test_df = imdb_df.drop(train_df.index)

tokenizer = Tokenizer(num_words=10000)
tokenizer.fit_on_texts(train_df['review'].to_list())





train_seq = tokenizer.texts_to_sequences(train_df['review'].tolist())
test_seq = tokenizer.texts_to_sequences(test_df['review'].tolist())

max_len = 200
train_data = pad_sequences(train_seq, maxlen=max_len)
test_data = pad_sequences(test_seq, maxlen=max_len)


label_encoder = preprocessing.LabelEncoder()
train_labels = label_encoder.fit_transform(train_df['sentiment'])
test_labels = label_encoder.fit_transform(test_df['sentiment'])

# train_seq[:25]

model = models.Sequential([
    tf.keras.layers.Embedding(100000, 64, input_length=max_len),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

history = model.fit(train_data, train_labels, epochs=10, batch_size=512, validation_data=(test_data, test_labels))

test_loss, test_accuracy = model.evaluate(test_data, test_labels)
print(f'Test loss: {test_loss:.2f}, Test accuracy: {test_accuracy:.2%}')

predictions = model.predict(test_data)

print(predictions[21], test_labels[21])

test_labels[:10]
test_labels = np.asarray(test_labels, dtype='float')
test_labels

# predictions[:10]
predictions = np.asarray(predictions, dtype='float')
predictions

predictions = [i > 0.5 for i in predictions]
predictions = np.asarray(predictions, dtype='float')
predictions

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(test_labels, predictions)
cm

import seaborn as sns
sns.heatmap(cm, annot = True)