import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

class BanditProblem:
    def __init__(self, k=10):
        self.k = k  # Number of arms
        self.reset_problem()
        
    def reset_problem(self):
        # Randomly select the true action values from a normal distribution
        self.q_true = np.random.normal(0, 1, self.k)
        
    def get_reward(self, action):
        # Reward is the true value plus some noise
        return np.random.normal(self.q_true[action], 1)
    
    def optimal_action(self):
        return np.argmax(self.q_true)

def epsilon_greedy(bandit, episodes=1000, epsilon=0.1):
    Q = np.zeros(bandit.k)  # Action value estimates
    N = np.zeros(bandit.k)  # Action counts
    rewards = np.zeros(episodes)
    optimal_actions = np.zeros(episodes)
    
    for t in range(episodes):
        # Epsilon-greedy action selection
        if np.random.random() < epsilon:
            action = np.random.randint(bandit.k)
        else:
            action = np.argmax(Q)
        
        reward = bandit.get_reward(action)
        rewards[t] = reward
        optimal_actions[t] = 1 if action == bandit.optimal_action() else 0
        
        # Update action value estimate
        N[action] += 1
        Q[action] += (reward - Q[action]) / N[action]
    
    return rewards, optimal_actions

def gradient_bandit(bandit, episodes=1000, alpha=0.1, baseline=True):
    H = np.zeros(bandit.k)  # Action preferences
    pi = np.ones(bandit.k) / bandit.k  # Action probabilities
    rewards = np.zeros(episodes)
    optimal_actions = np.zeros(episodes)
    avg_reward = 0  # For baseline
    
    for t in range(episodes):
        # Sample action according to current policy
        action = np.random.choice(bandit.k, p=pi)
        reward = bandit.get_reward(action)
        rewards[t] = reward
        optimal_actions[t] = 1 if action == bandit.optimal_action() else 0
        
        # Update average reward if using baseline
        if baseline:
            avg_reward += (reward - avg_reward) / (t + 1)
        
        # Update preferences
        one_hot = np.zeros(bandit.k)
        one_hot[action] = 1
        H += alpha * (reward - avg_reward if baseline else reward) * (one_hot - pi)
        
        # Update policy using softmax
        pi = np.exp(H - np.max(H))  # Subtract max for numerical stability
        pi /= np.sum(pi)
    
    return rewards, optimal_actions

def run_experiment(bandit, method_funcs, runs=2000, episodes=1000):
    results = {name: {'rewards': np.zeros((runs, episodes)), 
                     'optimal': np.zeros((runs, episodes))} 
               for name in method_funcs.keys()}
    
    for run in tqdm(range(runs), desc="Running experiments"):
        bandit.reset_problem()
        
        for name, func in method_funcs.items():
            rewards, optimal = func(bandit, episodes)
            results[name]['rewards'][run] = rewards
            results[name]['optimal'][run] = optimal
    
    # Compute average across runs
    for name in method_funcs.keys():
        results[name]['avg_reward'] = np.mean(results[name]['rewards'], axis=0)
        results[name]['avg_optimal'] = np.mean(results[name]['optimal'], axis=0)
    
    return results

# Set up experiment
bandit = BanditProblem(k=10)
episodes = 1000
runs = 2000

# Define method configurations
method_configs = {
    'ε=0.0 (greedy)': {'func': epsilon_greedy, 'params': {'epsilon': 0.0}},
    'ε=0.01': {'func': epsilon_greedy, 'params': {'epsilon': 0.01}},
    'ε=0.1': {'func': epsilon_greedy, 'params': {'epsilon': 0.1}},
    'Gradient Bandit': {'func': gradient_bandit, 'params': {'alpha': 0.1}}
}

# Create functions with parameters bound
method_funcs = {
    name: lambda b, e, f=config['func'], p=config['params']: f(b, episodes=e, **p)
    for name, config in method_configs.items()
}

results = run_experiment(bandit, method_funcs, runs=runs, episodes=episodes)

# Plot results
plt.figure(figsize=(14, 6))

# Average reward plot
plt.subplot(1, 2, 1)
for name in method_funcs.keys():
    plt.plot(results[name]['avg_reward'], label=name)
plt.xlabel('Steps')
plt.ylabel('Average reward')
plt.title('Average Reward Comparison')
plt.legend()

# Optimal action percentage plot
plt.subplot(1, 2, 2)
for name in method_funcs.keys():
    plt.plot(results[name]['avg_optimal'], label=name)
plt.xlabel('Steps')
plt.ylabel('% Optimal action')
plt.title('Optimal Action Selection Comparison')
plt.legend()

plt.tight_layout()
plt.show()
