import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

class BanditProblem:
    def __init__(self, k=10):
        self.k = k
        self.reset_problem()
        
    def reset_problem(self):
        # True action values from normal distribution (μ=0, σ=1)
        self.q_true = np.random.normal(0, 1, self.k)
        
    def get_reward(self, action):
        # Reward = true value + noise (σ=1)
        return np.random.normal(self.q_true[action], 1)
    
    def optimal_action(self):
        return np.argmax(self.q_true)

def epsilon_greedy(bandit, episodes=1000, epsilon=0.1, initial_value=0.0):
    Q = np.full(bandit.k, initial_value)  # Initialize action values
    N = np.zeros(bandit.k)  # Action counts
    rewards = np.zeros(episodes)
    optimal_actions = np.zeros(episodes)
    
    for t in range(episodes):
        # Epsilon-greedy action selection
        if np.random.random() < epsilon:
            action = np.random.randint(bandit.k)
        else:
            action = np.argmax(Q)
        
        reward = bandit.get_reward(action)
        rewards[t] = reward
        optimal_actions[t] = 1 if action == bandit.optimal_action() else 0
        
        # Update action value estimate
        N[action] += 1
        Q[action] += (reward - Q[action]) / N[action]
    
    return rewards, optimal_actions

def ucb(bandit, episodes=1000, c=2, initial_value=0.0):
    Q = np.full(bandit.k, initial_value)  # Initialize action values
    N = np.zeros(bandit.k)  # Action counts
    rewards = np.zeros(episodes)
    optimal_actions = np.zeros(episodes)
    total_counts = 0  # Total actions taken
    
    for t in range(episodes):
        # UCB action selection
        if total_counts < bandit.k:
            # Play each arm once first
            action = total_counts
        else:
            # UCB calculation
            ucb_values = Q + c * np.sqrt(np.log(total_counts) / (N + 1e-5))
            action = np.argmax(ucb_values)
        
        reward = bandit.get_reward(action)
        rewards[t] = reward
        optimal_actions[t] = 1 if action == bandit.optimal_action() else 0
        
        # Update action value estimate
        N[action] += 1
        total_counts += 1
        Q[action] += (reward - Q[action]) / N[action]
    
    return rewards, optimal_actions

def run_comparison(bandit, method_configs, runs=2000, episodes=1000):
    results = {name: {'rewards': np.zeros((runs, episodes)),
               'optimal': np.zeros((runs, episodes))}
              for name in method_configs.keys()}
    
    for run in tqdm(range(runs), desc="Running comparison"):
        bandit.reset_problem()
        
        for name, config in method_configs.items():
            # Unpack the configuration
            func = config['func']
            params = config['params']
            
            # Call the function with all parameters
            rewards, optimal = func(bandit, **params)
            results[name]['rewards'][run] = rewards
            results[name]['optimal'][run] = optimal
    
    # Compute averages
    for name in method_configs.keys():
        results[name]['avg_reward'] = np.mean(results[name]['rewards'], axis=0)
        results[name]['avg_optimal'] = np.mean(results[name]['optimal'], axis=0)
    
    return results

# Set up experiment
bandit = BanditProblem(k=10)
episodes = 1000
runs = 2000

# Define method configurations
method_configs = {
    'ε-greedy (Q0=0)': {
        'func': epsilon_greedy,
        'params': {'episodes': episodes, 'epsilon': 0.1, 'initial_value': 0.0}
    },
    'ε-greedy (Q0=5)': {
        'func': epsilon_greedy,
        'params': {'episodes': episodes, 'epsilon': 0.1, 'initial_value': 5.0}
    },
    'UCB (Q0=0)': {
        'func': ucb,
        'params': {'episodes': episodes, 'c': 2, 'initial_value': 0.0}
    },
    'UCB (Q0=5)': {
        'func': ucb,
        'params': {'episodes': episodes, 'c': 2, 'initial_value': 5.0}
    }
}

# Run comparison
results = run_comparison(bandit, method_configs, runs=runs, episodes=episodes)

# Plot results
plt.figure(figsize=(14, 6))

# Average reward plot
plt.subplot(1, 2, 1)
for name in method_configs.keys():
    plt.plot(results[name]['avg_reward'], label=name)
plt.xlabel('Steps')
plt.ylabel('Average Reward')
plt.title('Average Reward Comparison')
plt.legend()

# Optimal action percentage plot
plt.subplot(1, 2, 2)
for name in method_configs.keys():
    plt.plot(results[name]['avg_optimal'], label=name)
plt.xlabel('Steps')
plt.ylabel('% Optimal Action')
plt.title('Optimal Action Comparison')
plt.legend()

plt.tight_layout()
plt.show()