import numpy as np
import matplotlib.pyplot as plt
from matplotlib.table import Table
from tqdm import tqdm
import time

class GridWorld:
    def __init__(self, size=4):
        self.size = size
        self.actions = np.array(['up', 'right', 'down', 'left'])
        self.action_effects = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
        self.terminal_states = [(0, 0), (size-1, size-1)]
        self.rewards = np.full((size, size), -1.0)
        for term in self.terminal_states:
            self.rewards[term] = 0
    
    def is_terminal(self, state):
        return (state[0], state[1]) in self.terminal_states
    
    def get_next_state(self, state, action_idx):
        if self.is_terminal(state):
            return state, 0
        new_state = np.clip(state + self.action_effects[action_idx], 0, self.size-1)
        return tuple(new_state), self.rewards[new_state[0], new_state[1]]

def value_iteration(env, gamma=1.0, theta=1e-4, max_iter=100):
    V = np.zeros((env.size, env.size))
    history = {'values': [], 'deltas': []}
    
    with tqdm(total=max_iter, desc="Value Iteration") as pbar:
        for i in range(max_iter):
            delta = 0
            new_V = np.zeros_like(V)
            
            for y in range(env.size):
                for x in range(env.size):
                    if env.is_terminal((y, x)):
                        new_V[y, x] = 0
                        continue
                    
                    # Find maximum value across all possible actions
                    max_value = -np.inf
                    for action_idx in range(len(env.actions)):
                        (new_y, new_x), reward = env.get_next_state((y, x), action_idx)
                        action_value = reward + gamma * V[new_y, new_x]
                        if action_value > max_value:
                            max_value = action_value
                    
                    new_V[y, x] = max_value
                    delta = max(delta, abs(new_V[y, x] - V[y, x]))
            
            history['values'].append(V.copy())
            history['deltas'].append(delta)
            V = new_V
            
            pbar.set_postfix({'Δ': f"{delta:.1e}"})
            pbar.update(1)
            
            # Show progress for first few iterations
            if i < 3:
                time.sleep(0.5)
                visualize_grid(V, f"Iteration {i+1} - State Values", env)
            
            if delta < theta:
                pbar.set_postfix({'status': 'Converged!', 'Δ': f"{delta:.1e}"})
                break
    
    # Extract optimal policy from value function
    policy = extract_policy(env, V, gamma)
    return V, policy, history

def extract_policy(env, V, gamma=1.0):
    policy = np.empty((env.size, env.size), dtype='<U5')
    for y in range(env.size):
        for x in range(env.size):
            if env.is_terminal((y, x)):
                policy[y, x] = ''
                continue
            
            best_action = None
            best_value = -np.inf
            for action_idx, action in enumerate(env.actions):
                (new_y, new_x), reward = env.get_next_state((y, x), action_idx)
                action_value = reward + gamma * V[new_y, new_x]
                if action_value > best_value:
                    best_value = action_value
                    best_action = action
            policy[y, x] = best_action
    return policy

def visualize_grid(data, title, env, is_policy=False):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])
    
    arrows = {'up': '↑', 'right': '→', 'down': '↓', 'left': '←', '': ''}
    
    for y in range(env.size):
        for x in range(env.size):
            val = data[y, x]
            if is_policy:
                text = arrows.get(val, val)
            else:
                text = f"{val:.1f}" if not env.is_terminal((y, x)) else "T"
            color = 'lightgray' if env.is_terminal((y, x)) else 'white'
            tb.add_cell(y, x, 1/env.size, 1/env.size, text=text, 
                       loc='center', facecolor=color)
    
    ax.add_table(tb)
    plt.title(title, pad=20)
    plt.show()

# Run value iteration
print("Running Value Iteration...\n")
env = GridWorld()
optimal_values, optimal_policy, history = value_iteration(env)

# Display results
print("\nFinal Optimal Policy:")
visualize_grid(optimal_policy, "Optimal Policy", env, is_policy=True)

print("\nFinal State Values:")
visualize_grid(optimal_values, "Optimal State Values", env)

# Plot convergence
plt.figure(figsize=(10, 4))
plt.plot(history['deltas'], 'o-')
plt.xlabel('Iteration')
plt.ylabel('Max Value Change (Δ)')
plt.title('Value Iteration Convergence')
plt.grid(True)
plt.show()