import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import time

class GridWorld:
    def __init__(self, width=5, height=5, start_pos=(0, 0), goal_pos=(4, 4)):
        self.width = width
        self.height = height
        self.start_pos = start_pos
        self.goal_pos = goal_pos

        # Define actions: up, right, down, left
        self.actions = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        self.action_names = ['up', 'right', 'down', 'left']

        # Create grid with obstacles (optional)
        self.grid = np.zeros((height, width))
        # Add obstacles (value = -1)
        self.obstacles = [(1, 1), (2, 3), (3, 1)]
        for obs in self.obstacles:
            self.grid[obs] = -1

        # Set reward at goal position
        self.rewards = np.zeros((height, width))
        self.rewards[goal_pos] = 1.0

        # Add small negative reward for regular moves to encourage shortest path
        self.step_reward = -0.01

    def reset(self):
        return self.start_pos

    def is_valid_position(self, position):
        x, y = position
        # Check if position is within grid boundaries and not an obstacle
        return (0 <= x < self.height and
                0 <= y < self.width and
                self.grid[x, y] != -1)

    def step(self, position, action_idx):
        action = self.actions[action_idx]
        next_x = position[0] + action[0]
        next_y = position[1] + action[1]
        next_pos = (next_x, next_y)

        # Check if the move is valid
        if self.is_valid_position(next_pos):
            reward = self.rewards[next_pos] + self.step_reward
            done = next_pos == self.goal_pos
            return next_pos, reward, done
        else:
            # If invalid, stay in the same position with penalty
            reward = -0.1 + self.step_reward
            done = False
            return position, reward, done

    def get_policy_direction_map(self, policy):
        """Convert policy indices to directional arrows for visualization"""
        direction_map = np.full((self.height, self.width), '', dtype=object)
        arrow_map = {0: '↑', 1: '→', 2: '↓', 3: '←'}

        for i in range(self.height):
            for j in range(self.width):
                if (i, j) in self.obstacles:
                    direction_map[i, j] = 'X'
                elif (i, j) == self.goal_pos:
                    direction_map[i, j] = 'G'
                else:
                    direction_map[i, j] = arrow_map[policy[i, j]]

        return direction_map

class TDLearning:
    def __init__(self, env, gamma=0.9, alpha=0.1, epsilon=0.1):
        self.env = env
        self.gamma = gamma  # Discount factor
        self.alpha = alpha  # Learning rate
        self.epsilon = epsilon  # For epsilon-greedy policy

        # Initialize state-value function V(s)
        self.V = np.zeros((env.height, env.width))

        # Initialize Q-values
        self.Q = np.zeros((env.height, env.width, len(env.actions)))

        # Initialize policy (initially random)
        self.policy = np.zeros((env.height, env.width), dtype=int)

        # Metrics for tracking
        self.value_history = []
        self.policy_history = []
        self.td_errors = []
        self.episode_returns = []

    def choose_action(self, state, epsilon=None):
        """Epsilon-greedy action selection based on current policy"""
        if epsilon is None:
            epsilon = self.epsilon

        if np.random.random() < epsilon:
            # Explore: choose random action
            return np.random.choice(len(self.env.actions))
        else:
            # Exploit: choose action from policy
            return self.policy[state]

    def update_policy(self):
        """Update policy based on current Q-values"""
        for i in range(self.env.height):
            for j in range(self.env.width):
                if (i, j) not in self.env.obstacles and (i, j) != self.env.goal_pos:
                    self.policy[i, j] = np.argmax(self.Q[i, j])

    def run_episode(self, max_steps=100):
        """Run one episode using TD(0) for policy evaluation"""
        state = self.env.reset()
        total_reward = 0

        for step in range(max_steps):
            action = self.choose_action(state)
            next_state, reward, done = self.env.step(state, action)

            # TD(0) update
            td_target = reward + self.gamma * self.V[next_state]
            td_error = td_target - self.V[state]
            self.td_errors.append(abs(td_error))

            # Update state-value function
            self.V[state] += self.alpha * td_error

            # Update Q-value
            self.Q[state][action] += self.alpha * (reward + self.gamma * np.max(self.Q[next_state]) - self.Q[state][action])

            total_reward += reward
            state = next_state

            if done:
                break

        # Store episode return
        self.episode_returns.append(total_reward)

        # Update policy after episode
        self.update_policy()

        # Store value function and policy history for visualization
        self.value_history.append(self.V.copy())
        self.policy_history.append(self.policy.copy())

        return total_reward

    def train(self, num_episodes=500, log_interval=10):
        """Train TD(0) for multiple episodes"""
        start_time = time.time()

        for episode in range(1, num_episodes + 1):
            total_reward = self.run_episode()

            if episode % log_interval == 0:
                avg_td_error = np.mean(self.td_errors[-100:]) if self.td_errors else 0
                elapsed_time = time.time() - start_time
                print(f"Episode {episode}/{num_episodes} - Avg TD Error: {avg_td_error:.4f} - Return: {total_reward:.2f} - Time: {elapsed_time:.2f}s")

        print("Training completed!")
        return self.value_history, self.policy_history, self.td_errors, self.episode_returns

class Visualizer:
    def __init__(self, env, td_learner):
        self.env = env
        self.td_learner = td_learner

    def plot_value_function(self, value=None, ax=None, title="Value Function"):
        """Plot the value function as a heatmap"""
        if value is None:
            value = self.td_learner.V

        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))

        # Create a masked array to handle obstacles
        masked_value = np.ma.array(value.copy(), mask=False)
        for obs in self.env.obstacles:
            masked_value.mask[obs] = True

        # Plot heatmap
        sns.heatmap(masked_value, annot=True, fmt=".2f", cmap="YlGnBu", ax=ax, cbar=True)

        # Mark start and goal positions
        ax.add_patch(plt.Rectangle((self.env.start_pos[1], self.env.start_pos[0]),
                                  1, 1, fill=False, edgecolor='green', lw=3))
        ax.add_patch(plt.Rectangle((self.env.goal_pos[1], self.env.goal_pos[0]),
                                  1, 1, fill=False, edgecolor='red', lw=3))

        # Mark obstacles
        for obs in self.env.obstacles:
            ax.add_patch(plt.Rectangle((obs[1], obs[0]), 1, 1,
                                      fill=True, color='gray', alpha=0.7))

        ax.set_title(title)
        ax.set_xlabel("Column")
        ax.set_ylabel("Row")

        return ax

    def plot_policy(self, policy=None, ax=None, title="Policy"):
        """Plot the policy as arrows on the grid"""
        if policy is None:
            policy = self.td_learner.policy

        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))

        # Convert policy to direction map
        direction_map = self.env.get_policy_direction_map(policy)

        # Create grid with different colors for special cells
        grid = np.zeros((self.env.height, self.env.width))
        for obs in self.env.obstacles:
            grid[obs] = 1
        grid[self.env.goal_pos] = 2

        cmap = ListedColormap(['white', 'gray', 'lightgreen'])
        ax.imshow(grid, cmap=cmap)

        # Add direction arrows
        for i in range(self.env.height):
            for j in range(self.env.width):
                if direction_map[i, j] not in ['X', 'G']:
                    ax.text(j, i, direction_map[i, j], ha='center', va='center', fontsize=15)
                elif direction_map[i, j] == 'G':
                    ax.text(j, i, 'GOAL', ha='center', va='center', fontsize=8)
                else:
                    ax.text(j, i, 'X', ha='center', va='center', fontsize=15)

        ax.grid(True, color='black', linewidth=1.5)
        ax.set_xticks(np.arange(-.5, self.env.width, 1), minor=True)
        ax.set_yticks(np.arange(-.5, self.env.height, 1), minor=True)
        ax.set_title(title)
        ax.set_xlabel("Column")
        ax.set_ylabel("Row")

        return ax

    def plot_learning_metrics(self):
        """Plot TD error and episode returns over training"""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        # Plot TD error
        axes[0].plot(range(len(self.td_learner.td_errors)), self.td_learner.td_errors)
        axes[0].set_title("TD Error Over Training")
        axes[0].set_xlabel("Update Step")
        axes[0].set_ylabel("TD Error")

        # Plot episode returns
        axes[1].plot(range(len(self.td_learner.episode_returns)), self.td_learner.episode_returns)
        axes[1].set_title("Episode Returns Over Training")
        axes[1].set_xlabel("Episode")
        axes[1].set_ylabel("Total Return")

        plt.tight_layout()
        return fig, axes

    def plot_value_evolution(self, epochs_to_show=5):
        """Plot the evolution of value function over training epochs"""
        value_history = self.td_learner.value_history

        if len(value_history) == 0:
            return None

        # Choose epochs to display
        total_epochs = len(value_history)
        if total_epochs <= epochs_to_show:
            epochs = range(total_epochs)
        else:
            epochs = np.linspace(0, total_epochs-1, epochs_to_show, dtype=int)

        fig, axes = plt.subplots(1, len(epochs), figsize=(5*len(epochs), 5))
        if len(epochs) == 1:
            axes = [axes]

        for i, epoch in enumerate(epochs):
            self.plot_value_function(value_history[epoch], ax=axes[i],
                                    title=f"Value Function at Epoch {epoch+1}")

        plt.tight_layout()
        return fig, axes

    def plot_policy_evolution(self, epochs_to_show=5):
        """Plot the evolution of policy over training epochs"""
        policy_history = self.td_learner.policy_history

        if len(policy_history) == 0:
            return None

        # Choose epochs to display
        total_epochs = len(policy_history)
        if total_epochs <= epochs_to_show:
            epochs = range(total_epochs)
        else:
            epochs = np.linspace(0, total_epochs-1, epochs_to_show, dtype=int)

        fig, axes = plt.subplots(1, len(epochs), figsize=(5*len(epochs), 5))
        if len(epochs) == 1:
            axes = [axes]

        for i, epoch in enumerate(epochs):
            self.plot_policy(policy_history[epoch], ax=axes[i],
                           title=f"Policy at Epoch {epoch+1}")

        plt.tight_layout()
        return fig, axes

    def visualize_trajectory(self, max_steps=20):
        """Visualize an agent following the learned policy"""
        state = self.env.reset()
        trajectory = [state]

        for _ in range(max_steps):
            action = self.td_learner.policy[state]
            next_state, _, done = self.env.step(state, action)
            trajectory.append(next_state)
            state = next_state
            if done:
                break

        # Create a grid for visualization
        grid = np.zeros((self.env.height, self.env.width))
        for obs in self.env.obstacles:
            grid[obs] = 1
        grid[self.env.goal_pos] = 2

        # Plot the grid and trajectory
        fig, ax = plt.subplots(figsize=(8, 6))
        cmap = ListedColormap(['white', 'gray', 'lightgreen'])
        ax.imshow(grid, cmap=cmap)

        # Plot trajectory
        traj_x = [pos[1] for pos in trajectory]
        traj_y = [pos[0] for pos in trajectory]
        ax.plot(traj_x, traj_y, 'r-', linewidth=2)
        ax.plot(traj_x, traj_y, 'bo', markersize=10, alpha=0.5)

        # Add direction arrows from policy
        direction_map = self.env.get_policy_direction_map(self.td_learner.policy)
        for i in range(self.env.height):
            for j in range(self.env.width):
                if direction_map[i, j] not in ['X', 'G']:
                    ax.text(j, i, direction_map[i, j], ha='center', va='center', fontsize=15)
                elif direction_map[i, j] == 'G':
                    ax.text(j, i, 'GOAL', ha='center', va='center', fontsize=8)
                else:
                    ax.text(j, i, 'X', ha='center', va='center', fontsize=15)

        ax.grid(True, color='black', linewidth=1.5)
        ax.set_title("Agent Trajectory with Learned Policy")

        return fig, ax

def run_td_experiment(grid_size=5, num_episodes=100, alpha=0.1, gamma=0.9, epsilon=0.1):
    """Run a complete TD learning experiment"""
    print(f"Starting TD(0) learning experiment with {num_episodes} episodes")
    print(f"Parameters: α={alpha}, γ={gamma}, ε={epsilon}")

    # Create environment
    env = GridWorld(width=grid_size, height=grid_size)

    # Create TD learning agent
    td_agent = TDLearning(env, gamma=gamma, alpha=alpha, epsilon=epsilon)

    # Train the agent and track progress
    print("Training agent...")
    value_history, policy_history, td_errors, episode_returns = td_agent.train(
        num_episodes=num_episodes, log_interval=max(1, num_episodes//10))

    # Create visualizer
    visualizer = Visualizer(env, td_agent)

    # Plot final value function and policy
    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    visualizer.plot_value_function(ax=plt.gca(), title="Final Value Function")

    plt.subplot(1, 2, 2)
    visualizer.plot_policy(ax=plt.gca(), title="Final Policy")

    plt.tight_layout()
    plt.savefig("td_final_results.png")
    plt.show()

    # Plot learning metrics
    visualizer.plot_learning_metrics()
    plt.savefig("td_learning_metrics.png")
    plt.show()

    # Plot value function evolution
    visualizer.plot_value_evolution(epochs_to_show=5)
    plt.savefig("td_value_evolution.png")
    plt.show()

    # Plot policy evolution
    visualizer.plot_policy_evolution(epochs_to_show=5)
    plt.savefig("td_policy_evolution.png")
    plt.show()

    # Visualize agent trajectory
    visualizer.visualize_trajectory()
    plt.savefig("td_agent_trajectory.png")
    plt.show()

    # Save raw data for further analysis
    np.save("td_value_history.npy", value_history)
    np.save("td_policy_history.npy", policy_history)
    np.save("td_errors.npy", td_errors)
    np.save("td_returns.npy", episode_returns)

    return env, td_agent, visualizer

# Run the experiment when the script is executed
if __name__ == "__main__":
    # Set random seed for reproducibility
    np.random.seed(42)

    # Run the experiment with desired parameters
    env, td_agent, visualizer = run_td_experiment(
        grid_size=5,
        num_episodes=500,  # Try increasing for better convergence
        alpha=0.1,
        gamma=0.9,
        epsilon=0.1
    )