import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm

class GridWorld:
    def __init__(self, size=5):
        self.size = size
        self.actions = ['up', 'right', 'down', 'left']
        self.action_effects = {'up': (-1, 0), 'right': (0, 1),
                             'down': (1, 0), 'left': (0, -1)}
        # Terminal states with rewards and labels
        self.terminal_states = {
            (0, size-1): {'reward': 10, 'label': 'High Reward'},
            (size-1, 0): {'reward': -5, 'label': 'Penalty'}
        }
        # Special states with different properties
        self.special_states = {
            (2, 2): {'label': 'Start', 'color': 'yellow'}
        }
        
    def reset(self):
        """Start at center of grid"""
        return (self.size//2, self.size//2)
    
    def step(self, state, action):
        if state in self.terminal_states:
            return state, self.terminal_states[state]['reward'], True
        
        dy, dx = self.action_effects[action]
        new_y = max(0, min(self.size-1, state[0] + dy))
        new_x = max(0, min(self.size-1, state[1] + dx))
        new_state = (new_y, new_x)
        
        reward = -1  # Default reward for non-terminal states
        done = new_state in self.terminal_states
        return new_state, reward, done

def mc_prediction(env, episodes=1000):
    """Enhanced MC prediction with better visualization"""
    V = defaultdict(float)
    returns = defaultdict(list)
    convergence_data = []
    
    # Set up the figure
    plt.figure(figsize=(15, 6))
    
    for ep in tqdm(range(episodes), desc="MC Prediction"):
        state = env.reset()
        episode = []
        done = False
        
        # Generate episode with random policy
        while not done:
            action = np.random.choice(env.actions)
            next_state, reward, done = env.step(state, action)
            episode.append((state, reward))
            state = next_state
        
        # Update values
        G = 0
        for state, reward in reversed(episode):
            G += reward
            returns[state].append(G)
            V[state] = np.mean(returns[state])
        
        # Track convergence for center state
        convergence_data.append(V.get(env.reset(), 0))
        
        # Update plots every 100 episodes
        if ep % 100 == 0 or ep == episodes-1:
            plt.clf()
            
            # Plot 1: State Values Heatmap
            ax1 = plt.subplot(1, 2, 1)
            plot_state_values(env, V, ax1)
            
            # Plot 2: Convergence Graph
            ax2 = plt.subplot(1, 2, 2)
            plot_convergence(convergence_data, ax2)
            
            plt.suptitle(f"Episode {ep+1}/{episodes}", y=1.05)
            plt.tight_layout()
            plt.pause(0.01)
    
    plt.show(block=True)
    return V

def plot_state_values(env, V, ax):
    """Enhanced state value visualization"""
    grid = np.full((env.size, env.size), np.nan)
    for state, value in V.items():
        grid[state] = value
    
    # Create heatmap
    im = ax.imshow(grid, cmap='RdYlGn', vmin=-10, vmax=10)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label='State Value')
    
    # Add annotations
    for i in range(env.size):
        for j in range(env.size):
            state = (i, j)
            if state in env.terminal_states:
                # Terminal states
                info = env.terminal_states[state]
                text = f"{info['label']}\n({info['reward']})"
                ax.text(j, i, text, ha='center', va='center', 
                        bbox=dict(facecolor='white', alpha=0.9))
            elif state in env.special_states:
                # Special states
                info = env.special_states[state]
                ax.text(j, i, info['label'], ha='center', va='center',
                        bbox=dict(facecolor=info['color'], alpha=0.7))
            elif not np.isnan(grid[i,j]):
                # Regular states with values
                ax.text(j, i, f"{grid[i,j]:.1f}", ha='center', va='center')
    
    ax.set_title("State Value Heatmap")
    ax.set_xticks(range(env.size))
    ax.set_yticks(range(env.size))
    ax.set_xticklabels([str(i) for i in range(env.size)])
    ax.set_yticklabels([str(i) for i in range(env.size)])
    ax.set_xlabel("Column")
    ax.set_ylabel("Row")

def plot_convergence(data, ax):
    """Convergence tracking visualization"""
    ax.plot(data, color='blue')
    ax.set_title("Value Convergence (Start State)")
    ax.set_xlabel("Episodes")
    ax.set_ylabel("State Value")
    ax.grid(True)
    
    # Add smoothed version
    window_size = max(1, len(data)//20)
    smoothed = np.convolve(data, np.ones(window_size)/window_size, mode='valid')
    ax.plot(range(window_size-1, len(data)), smoothed, color='red', 
            linestyle='--', label=f'Smoothed (window={window_size})')
    ax.legend()

def print_state_values(V, env):
    """Print state values in a readable table"""
    print("\nFinal State Values:")
    print("+" + "-"*23 + "+")
    print("| {:^5} | {:^5} | {:^7} |".format("Row", "Col", "Value"))
    print("+" + "-"*23 + "+")
    
    for i in range(env.size):
        for j in range(env.size):
            state = (i, j)
            value = V.get(state, float('nan'))
            if state in env.terminal_states:
                value = env.terminal_states[state]['reward']
            print("| {:^5} | {:^5} | {:^7.2f} |".format(i, j, value))
        print("+" + "-"*23 + "+")

# Run the enhanced version
print("Starting Monte Carlo Prediction...")
env = GridWorld(size=5)
V = mc_prediction(env, episodes=1000)

# Print final results
print_state_values(V, env)

# Additional analysis
center_state = env.reset()
print(f"\nAdditional Analysis:")
print(f"- Start state value: {V.get(center_state, 0):.2f}")
print(f"- High reward state: {env.terminal_states[(0, env.size-1)]['reward']}")
print(f"- Penalty state: {env.terminal_states[(env.size-1, 0)]['reward']}")