import numpy as np
import matplotlib.pyplot as plt
from matplotlib.table import Table
from tqdm import tqdm
import time

class GridWorld:
    def __init__(self, size=4):
        self.size = size
        self.actions = np.array(['up', 'right', 'down', 'left'])
        self.action_effects = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
        self.terminal_states = [(0, 0), (size-1, size-1)]
        self.rewards = np.full((size, size), -1.0)
        for term in self.terminal_states:
            self.rewards[term] = 0
    
    def is_terminal(self, state):
        return (state[0], state[1]) in self.terminal_states
    
    def get_next_state(self, state, action_idx):
        if self.is_terminal(state):
            return state, 0
        new_state = np.clip(state + self.action_effects[action_idx], 0, self.size-1)
        return tuple(new_state), self.rewards[new_state[0], new_state[1]]

def policy_evaluation(env, policy, gamma=1.0, theta=1e-4, max_iter=100):
    V = np.zeros((env.size, env.size))
    pbar = tqdm(total=max_iter, desc="Policy Evaluation", leave=False)
    
    for _ in range(max_iter):
        delta = 0
        for y in range(env.size):
            for x in range(env.size):
                if env.is_terminal((y, x)):
                    continue
                action_idx = np.where(env.actions == policy[y, x])[0][0]
                (new_y, new_x), reward = env.get_next_state((y, x), action_idx)
                new_value = reward + gamma * V[new_y, new_x]
                delta = max(delta, abs(V[y, x] - new_value))
                V[y, x] = new_value
        
        pbar.set_postfix({'Δ': f"{delta:.1e}"})
        pbar.update(1)
        if delta < theta:
            break
    
    pbar.close()
    return V

def policy_improvement(env, V, gamma=1.0):
    policy = np.empty((env.size, env.size), dtype='<U5')
    pbar = tqdm(total=env.size*env.size, desc="Policy Improvement", leave=False)
    
    for y in range(env.size):
        for x in range(env.size):
            if env.is_terminal((y, x)):
                policy[y, x] = ''
                pbar.update(1)
                continue
            
            best_value = -np.inf
            for action_idx, action in enumerate(env.actions):
                (new_y, new_x), reward = env.get_next_state((y, x), action_idx)
                action_value = reward + gamma * V[new_y, new_x]
                if action_value > best_value:
                    best_value = action_value
                    best_action = action
            policy[y, x] = best_action
            pbar.update(1)
    
    pbar.close()
    return policy

def policy_iteration(env, gamma=1.0, theta=1e-4, max_iter=10):
    policy = np.random.choice(env.actions, size=(env.size, env.size))
    for term in env.terminal_states:
        policy[term[0], term[1]] = ''
    
    history = {'policy': [], 'values': []}
    
    with tqdm(total=max_iter, desc="Policy Iteration") as pbar:
        for i in range(max_iter):
            # Policy Evaluation
            V = policy_evaluation(env, policy, gamma, theta)
            history['values'].append(V.copy())
            
            # Policy Improvement
            new_policy = policy_improvement(env, V, gamma)
            history['policy'].append(policy.copy())
            
            # Visualize intermediate results
            if i < 3:  # Only show first few iterations to avoid clutter
                time.sleep(0.5)  # Pause to see the progress
                plt.figure(figsize=(12, 5))
                
                plt.subplot(1, 2, 1)
                visualize_grid(V, f"Iteration {i+1} - State Values", env)
                
                plt.subplot(1, 2, 2)
                visualize_grid(new_policy, f"Iteration {i+1} - Policy", env, is_policy=True)
                
                plt.tight_layout()
                plt.show()
            
            if np.array_equal(policy, new_policy):
                pbar.set_postfix({'status': 'Converged!'})
                break
                
            policy = new_policy
            pbar.update(1)
            pbar.set_postfix({'iteration': i+1})
    
    return policy, V, history

def visualize_grid(data, title, env, is_policy=False):
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])
    
    arrows = {'up': '↑', 'right': '→', 'down': '↓', 'left': '←', '': ''}
    
    for y in range(env.size):
        for x in range(env.size):
            val = data[y, x]
            if is_policy:
                text = arrows.get(val, val)
            else:
                text = f"{val:.1f}" if not env.is_terminal((y, x)) else "T"
            color = 'lightgray' if env.is_terminal((y, x)) else 'white'
            tb.add_cell(y, x, 1/env.size, 1/env.size, text=text, 
                       loc='center', facecolor=color)
    
    ax.add_table(tb)
    plt.title(title, pad=20)
    plt.show()

# Run the experiment with progress visualization
print("Starting Policy Iteration...\n")
env = GridWorld()
optimal_policy, optimal_values, history = policy_iteration(env)

# Show final results
print("\nFinal Results:")
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
visualize_grid(optimal_values, "Optimal State Values", env)

plt.subplot(1, 2, 2)
visualize_grid(optimal_policy, "Optimal Policy", env, is_policy=True)

plt.tight_layout()
plt.show()