# Exp 6: viterbi sequence
----------------------------

emission_probs = {
    "CP": {"cola": 0.6, "ice_tea": 0.1, "lem": 0.3}, 
    "IP": {"cola": 0.1, "ice_tea": 0.7, "lem": 0.2}
    }
alpha_a = 1
alpha_b = 0

for _ in range(3):
    state = input("Enter the state:")
    alpha_a = max(alpha_a * 0.7 * emission_probs["CP"][state], 
                           alpha_b * 0.5 * emission_probs["IP"][state])
                           
    alpha_b = max(alpha_a * 0.3 * emission_probs["CP"][state], 
                           alpha_b * 0.5 * emission_probs["IP"][state])
    print(alpha_a, alpha_b)
    print("CP" if alpha_a > alpha_b else "IP")

--------------------
#HHM-Trellis
#forward procedure

emission_probs = {'A': {'K': 0.4, 
                        'T': 0.5}, 
                  'B': {'K': 0.3, 
                        'T': 0.3}
                        }
alpha_a = 1
alpha_b = 0
alpha_A = [alpha_a]
alpha_B = [alpha_b]

visible_states = ['K', 'T','K']  # Update with the actual visible states

for state in visible_states:
    old_alpha = alpha_a
    alpha_a = (alpha_a * 0.2 * emission_probs["A"][state]) + (alpha_b * 0.6 * emission_probs["B"][state])
    alpha_b = (old_alpha * 0.8 * emission_probs["A"][state]) + (alpha_b * 0.4 * emission_probs["B"][state])
    alpha_A.append(alpha_a)
    alpha_B.append(alpha_b)

print(alpha_A)
print(alpha_B)

# B. BACKWARD PROCEDURE
# Credit: Ahmed Baari
# Backward
emission_probs = {
    'A': {'K': 0.4, 'T': 0.5}, 
    'B': {'K': 0.3, 'T': 0.3}
    }

b_A = 1
b_B = 1
beta_A = [b_A]
beta_B = [b_B]

for state in reversed(visible_states):
    old_bA = b_A

    b_A = (
        b_A * 0.2 * emission_probs["A"][state]
    ) + (  
        b_B * 0.8 * emission_probs["A"][state]
    )

    b_B = (
        old_bA * 0.6 * emission_probs["B"][state]
    ) + (
        b_B * 0.4 * emission_probs["B"][state]
    )         

    beta_A.append(b_A)
    beta_B.append(b_B)

# Reverse the Beta list for correct order
beta_A.reverse()
beta_B.reverse()


# C. BEST STATE SEQUENCE 
# Credit: Ahmed Baari 
gamma_A = []
gamma_B = []

# alpha * beta of A / that of A + that of B

for i in range(3):
    g_A = (
        alpha_A[i] * beta_A[i]
    ) / (
        alpha_A[i]*beta_A[i] + alpha_B[i]*beta_B[i]
    )
    g_B = (
        alpha_B[i] * beta_B[i] 
    ) / (
        alpha_B[i] * beta_B[i] + alpha_A[i] * beta_A[i]
    )

    gamma_A.append(g_A)
    gamma_B.append(g_B)

for i in range(3):
    print( 
        "A" if gamma_A[i] > gamma_B[i] else "B", 
        end=" "
    )
# trellis forward 
-----------------
import numpy as np

obs = ['K2', 'K1', 'K2', 'K3', 'K1']
obs_map = {'K1': 0, 'K2': 1, 'K3': 2}
obs_idx = [obs_map[o] for o in obs]

# States: S1 = 0, S2 = 1
pi = np.array([0.5, 0.5])
A = np.array([[0.4, 0.6],   # From S1
              [0.5, 0.5]])  # From S2
B = np.array([[0.4, 0.4, 0.2],  # S1
              [0.3, 0.4, 0.3]]) # S2

T = len(obs)
N = 2

# Forward procedure
alpha = np.zeros((T, N))
alpha[0] = pi * B[:, obs_idx[0]]

for t in range(1, T):
    for j in range(N):
        alpha[t][j] = np.sum(alpha[t-1] * A[:, j]) * B[j][obs_idx[t]]

forward_prob = np.sum(alpha[-1])
print("Forward Probability:", forward_prob)

# Backward procedure
beta = np.zeros((T, N))
beta[-1] = 1

for t in reversed(range(T-1)):
    for i in range(N):
        beta[t][i] = np.sum(A[i, :] * B[:, obs_idx[t+1]] * beta[t+1])

backward_prob = np.sum(pi * B[:, obs_idx[0]] * beta[0])
print("Backward Probability:", backward_prob)
