import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Cliff Walking Environment
class CliffWalkingEnv:
    def __init__(self):
        self.rows = 4
        self.cols = 12
        self.start = (3, 0)
        self.goal = (3, 11)
        self.current_state = self.start
        self.cliff = [(3, i) for i in range(1, 11)]
        
    def reset(self):
        self.current_state = self.start
        return self.current_state
    
    def step(self, action):
        # Action mapping: 0=up, 1=right, 2=down, 3=left
        row, col = self.current_state
        
        if action == 0:  # up
            row = max(row - 1, 0)
        elif action == 1:  # right
            col = min(col + 1, self.cols - 1)
        elif action == 2:  # down
            row = min(row + 1, self.rows - 1)
        elif action == 3:  # left
            col = max(col - 1, 0)
            
        new_state = (row, col)
        
        # Check if fell off the cliff
        if new_state in self.cliff:
            reward = -100
            new_state = self.start
        elif new_state == self.goal:
            reward = 0
        else:
            reward = -1
            
        self.current_state = new_state
        return new_state, reward, (new_state == self.goal), None

# SARSA Algorithm
def sarsa(env, episodes=500, alpha=0.1, gamma=1.0, epsilon=0.1):
    Q = np.zeros((env.rows, env.cols, 4))  # Q-table: rows x cols x actions
    
    rewards = []
    steps_per_episode = []
    
    for _ in tqdm(range(episodes), desc="SARSA"):
        state = env.reset()
        total_reward = 0
        steps = 0
        
        # Choose action using epsilon-greedy
        if np.random.random() < epsilon:
            action = np.random.randint(0, 4)
        else:
            action = np.argmax(Q[state[0], state[1]])
        
        done = False
        while not done:
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            steps += 1
            
            # Choose next action using epsilon-greedy
            if np.random.random() < epsilon:
                next_action = np.random.randint(0, 4)
            else:
                next_action = np.argmax(Q[next_state[0], next_state[1]])
            
            # SARSA update
            Q[state[0], state[1], action] += alpha * (
                reward + gamma * Q[next_state[0], next_state[1], next_action] - 
                Q[state[0], state[1], action]
            )
            
            state = next_state
            action = next_action
        
        rewards.append(total_reward)
        steps_per_episode.append(steps)
    
    return Q, rewards, steps_per_episode

# Q-Learning Algorithm
def q_learning(env, episodes=500, alpha=0.1, gamma=1.0, epsilon=0.1):
    Q = np.zeros((env.rows, env.cols, 4))  # Q-table: rows x cols x actions
    
    rewards = []
    steps_per_episode = []
    
    for _ in tqdm(range(episodes), desc="Q-Learning"):
        state = env.reset()
        total_reward = 0
        steps = 0
        done = False
        
        while not done:
            # Choose action using epsilon-greedy
            if np.random.random() < epsilon:
                action = np.random.randint(0, 4)
            else:
                action = np.argmax(Q[state[0], state[1]])
            
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            steps += 1
            
            # Q-Learning update
            best_next_action = np.argmax(Q[next_state[0], next_state[1]])
            Q[state[0], state[1], action] += alpha * (
                reward + gamma * Q[next_state[0], next_state[1], best_next_action] - 
                Q[state[0], state[1], action]
            )
            
            state = next_state
        
        rewards.append(total_reward)
        steps_per_episode.append(steps)
    
    return Q, rewards, steps_per_episode

# Run both algorithms
env = CliffWalkingEnv()
episodes = 1000

# SARSA
sarsa_Q, sarsa_rewards, sarsa_steps = sarsa(env, episodes=episodes)
# Q-Learning
q_Q, q_rewards, q_steps = q_learning(env, episodes=episodes)

# Plotting results
plt.figure(figsize=(12, 6))

# Plot rewards
plt.subplot(1, 2, 1)
plt.plot(np.convolve(sarsa_rewards, np.ones(100)/100), label='SARSA')  # Fixed missing comma
plt.plot(np.convolve(q_rewards, np.ones(100)/100), label='Q-Learning')  # Fixed missing comma
plt.xlabel('Episodes')
plt.ylabel('Average Reward (100-episode window)')
plt.title('Reward Comparison')
plt.legend()

# Plot steps
plt.subplot(1, 2, 2)
plt.plot(np.convolve(sarsa_steps, np.ones(100)/100), label='SARSA')  # Fixed missing comma
plt.plot(np.convolve(q_steps, np.ones(100)/100), label='Q-Learning')  # Fixed missing comma
plt.xlabel('Episodes')
plt.ylabel('Average Steps (100-episode window)')
plt.title('Steps Comparison')
plt.legend()

plt.tight_layout()
plt.show()

# Extract optimal policies
def extract_policy(Q):
    policy = np.zeros((env.rows, env.cols), dtype=int)
    for r in range(env.rows):
        for c in range(env.cols):
            policy[r, c] = np.argmax(Q[r, c])
    return policy

sarsa_policy = extract_policy(sarsa_Q)
q_policy = extract_policy(q_Q)

# Print policies (for the cliff row)
print("SARSA Policy (row 3):", sarsa_policy[3])
print("Q-Learning Policy (row 3):", q_policy[3])