import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from matplotlib.table import Table

class GridWorld:
    def __init__(self, size=5):
        self.size = size
        self.actions = ['up', 'right', 'down', 'left']
        self.action_map = {a: i for i, a in enumerate(self.actions)}
        self.action_effects = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        
        # Terminal states and rewards
        self.terminal_states = [(0, 0), (0, size-1), (size-1, 0), (size-1, size-1)]
        self.rewards = np.full((size, size), -1.0)
        self.rewards[0, size-1] = 10  # High-reward terminal state
        self.rewards[size-1, 0] = 5   # Medium-reward terminal state
        
    def reset(self):
        self.state = (self.size//2, self.size//2)  # Start in center
        return self.state
    
    def step(self, action):
        if self.state in self.terminal_states:
            return self.state, 0, True, None
        
        action_idx = self.action_map[action]
        dy, dx = self.action_effects[action_idx]
        y, x = self.state
        
        # Calculate new position with boundary checks
        new_y = max(0, min(self.size-1, y + dy))
        new_x = max(0, min(self.size-1, x + dx))
        self.state = (new_y, new_x)
        
        reward = self.rewards[new_y, new_x]
        done = self.state in self.terminal_states
        return self.state, reward, done, None

def double_q_learning(env, episodes=2000, alpha=0.1, gamma=0.95, epsilon=0.1):
    Q1 = np.zeros((env.size, env.size, len(env.actions)))
    Q2 = np.zeros((env.size, env.size, len(env.actions)))
    rewards = []
    
    for _ in tqdm(range(episodes), desc="Double Q-Learning"):
        state = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            # Combine Q1 and Q2 for action selection
            combined_q = Q1[state[0], state[1]] + Q2[state[0], state[1]]
            
            # Epsilon-greedy action selection
            if np.random.random() < epsilon:
                action = np.random.choice(env.actions)
            else:
                action_idx = np.argmax(combined_q)
                action = env.actions[action_idx]
            
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            
            # Randomly select which Q-table to update
            if np.random.random() < 0.5:
                # Update Q1 using Q2's estimate
                best_action_idx = np.argmax(Q1[next_state[0], next_state[1]])
                target = reward + gamma * Q2[next_state[0], next_state[1], best_action_idx]
                Q1[state[0], state[1], env.action_map[action]] += alpha * (
                    target - Q1[state[0], state[1], env.action_map[action]]
                )
            else:
                # Update Q2 using Q1's estimate
                best_action_idx = np.argmax(Q2[next_state[0], next_state[1]])
                target = reward + gamma * Q1[next_state[0], next_state[1], best_action_idx]
                Q2[state[0], state[1], env.action_map[action]] += alpha * (
                    target - Q2[state[0], state[1], env.action_map[action]]
                )
            
            state = next_state
        
        rewards.append(episode_reward)
    
    # Return the average of both Q-tables
    return (Q1 + Q2) / 2, rewards

def visualize_policy(env, Q):
    policy = np.empty((env.size, env.size), dtype='<U5')
    arrows = {'up': '↑', 'right': '→', 'down': '↓', 'left': '←'}
    
    for y in range(env.size):
        for x in range(env.size):
            if (y, x) in env.terminal_states:
                policy[y, x] = 'T'
            else:
                action_idx = np.argmax(Q[y, x])
                policy[y, x] = arrows[env.actions[action_idx]]
    
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])
    
    for y in range(env.size):
        for x in range(env.size):
            val = policy[y, x]
            color = 'lightgray' if (y, x) in env.terminal_states else 'white'
            if val == 'T':
                # Color terminals differently based on reward
                if (y, x) == (0, env.size-1):
                    color = 'lightgreen'  # High reward
                elif (y, x) == (env.size-1, 0):
                    color = 'lightblue'   # Medium reward
                else:
                    color = 'lightgray'   # Regular terminal
            tb.add_cell(y, x, 1/env.size, 1/env.size, text=val, 
                       loc='center', facecolor=color)
    
    ax.add_table(tb)
    plt.title("Learned Policy", pad=20)
    plt.show()

# Run Double Q-Learning
env = GridWorld()
print("Training Double Q-Learning...")
Q_dql, rewards_dql = double_q_learning(env, episodes=2000)

# Visualize results
print("\nDouble Q-Learning Policy:")
visualize_policy(env, Q_dql)

# Plot learning progress
plt.figure(figsize=(10, 5))
window = 100
smoothed = np.convolve(rewards_dql, np.ones(window)/window, mode='valid')
plt.plot(smoothed)
plt.xlabel('Episodes')
plt.ylabel(f'Average Reward ({window}-episode window)')
plt.title('Double Q-Learning Performance')
plt.grid(True)
plt.show()