import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

class NonstationaryBandit:
    def __init__(self, k=10, random_walk_std=0.01):
        self.k = k
        self.random_walk_std = random_walk_std
        self.reset()
        
    def reset(self):
        # All q* start equal (mean 0)
        self.q_true = np.zeros(self.k)
        
    def step(self):
        # Random walk: add Gaussian noise to each action's true value
        self.q_true += np.random.normal(0, self.random_walk_std, self.k)
        return self.q_true
    
    def get_reward(self, action):
        # Reward = true value + noise
        return np.random.normal(self.q_true[action], 1)
    
    def optimal_action(self):
        return np.argmax(self.q_true)

def run_experiment(bandit, method, episodes=10000, epsilon=0.1, alpha=None):
    Q = np.zeros(bandit.k)  # Action value estimates
    N = np.zeros(bandit.k)  # Action counts (for sample-average method)
    rewards = np.zeros(episodes)
    optimal_actions = np.zeros(episodes)
    
    for t in range(episodes):
        # Epsilon-greedy action selection
        if np.random.random() < epsilon:
            action = np.random.randint(bandit.k)
        else:
            action = np.argmax(Q)
        
        # Environment changes (random walk)
        bandit.step()
        
        # Get reward
        reward = bandit.get_reward(action)
        rewards[t] = reward
        optimal_actions[t] = 1 if action == bandit.optimal_action() else 0
        
        # Update action value estimate
        if method == 'sample_avg':
            N[action] += 1
            Q[action] += (reward - Q[action]) / N[action]
        elif method == 'constant_step':
            Q[action] += alpha * (reward - Q[action])
    
    return rewards, optimal_actions

def run_comparison(runs=2000, episodes=10000):
    # Initialize bandit problem
    bandit = NonstationaryBandit(k=10, random_walk_std=0.01)
    
    # Initialize results storage
    methods = ['sample_avg', 'constant_step']
    results = {method: {'rewards': np.zeros((runs, episodes)),
                      'optimal': np.zeros((runs, episodes))}
              for method in methods}
    
    # Run experiments
    for run in tqdm(range(runs), desc="Running comparison"):
        bandit.reset()
        
        # Sample-average method
        rewards, optimal = run_experiment(bandit, 'sample_avg', episodes=episodes)
        results['sample_avg']['rewards'][run] = rewards
        results['sample_avg']['optimal'][run] = optimal
        
        # Constant step-size method (α=0.1)
        bandit.reset()
        rewards, optimal = run_experiment(bandit, 'constant_step', episodes=episodes, alpha=0.1)
        results['constant_step']['rewards'][run] = rewards
        results['constant_step']['optimal'][run] = optimal
    
    # Compute averages
    for method in methods:
        results[method]['avg_reward'] = np.mean(results[method]['rewards'], axis=0)
        results[method]['avg_optimal'] = np.mean(results[method]['optimal'], axis=0)
    
    return results

# Run the comparison
results = run_comparison(runs=2000, episodes=10000)

# Plot results
plt.figure(figsize=(14, 6))

# Average reward plot
plt.subplot(1, 2, 1)
plt.plot(results['sample_avg']['avg_reward'], label='Sample Average')
plt.plot(results['constant_step']['avg_reward'], label='Constant Step (α=0.1)')
plt.xlabel('Steps')
plt.ylabel('Average Reward')
plt.title('Average Reward Comparison\nNonstationary Bandit Problem')
plt.legend()

# Optimal action percentage plot
plt.subplot(1, 2, 2)
plt.plot(results['sample_avg']['avg_optimal'], label='Sample Average')
plt.plot(results['constant_step']['avg_optimal'], label='Constant Step (α=0.1)')
plt.xlabel('Steps')
plt.ylabel('% Optimal Action')
plt.title('Optimal Action Comparison\nNonstationary Bandit Problem')
plt.legend()

plt.tight_layout()
plt.show()