import numpy as np
import itertools
from pomdp_env import sample, POMDP, project


class AsymmetricAC:
    def __init__(self, p, Z, alpha, tau, M):
        Q = []
        self.pi = []
        self.as_space = []
        self.s_space = []
        self.p = p
        self.Z = Z
        self.alpha = alpha
        self.tau = tau
        self.M = M
        for i in range(p.H + 1):
            shape_list = []
            for _ in range(min(i, Z)):
                shape_list.append(p.o_num)
                shape_list.append(p.a_num)
            shape_list.extend([p.o_num, p.s_num])
            as_list = [list(range(shape)) for shape in shape_list]
            as_space = list(itertools.product(*as_list))
            self.as_space.append(as_space)
            shape_list = shape_list[:-1]
            s_list = [list(range(shape)) for shape in shape_list]
            s_space = list(itertools.product(*s_list))
            self.s_space.append(s_space)
            Q_dict = dict()
            pi_dict = dict()
            for s in as_space:
                Q_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            for s in s_space:
                pi_dict[tuple(s)] = np.ones(shape=(p.a_num,)) / p.a_num
            Q.append(Q_dict)
            self.pi.append(pi_dict)
        self.est = dict()
        for state, a in itertools.product(list(range(p.s_num)), list(range(p.a_num))):
            self.est[(state, a)] = np.zeros(shape=(p.s_num, p.o_num))
        self.Q_list = [Q]

    def learn(self, K):
        r_sum_list = []
        for k in range(K):
            traj_list = []
            for i in range(self.M):
                traj, r_sum = self.run_traj()
                traj_list.append(traj)
            r_sum_list.append(self.evaluate(100))
            print(r_sum_list[-1])
            self.update_Q(traj_list)
            self.update_pi(traj_list)
        return r_sum_list

    def update_Q(self, traj_list):
        self.new_Q = self.Q_list[-1].copy()
        for traj in traj_list:
            for h in range(self.p.H):
                self.est[(traj[4 * h], traj[4 * h + 2])][traj[4 * h + 4], traj[4 * h + 5]] += 1
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            for s in self.as_space[h]:
                for a in range(self.p.a_num):
                    state = s[-1]
                    expectation = 0
                    if np.sum(self.est[(state, a)]) < 0.5:
                        emp_tran = np.ones(shape=(self.p.s_num, self.p.o_num)) / (self.p.s_num * self.p.o_num)
                    else:
                        emp_tran = self.est[(state, a)] / np.sum(self.est[(state, a)])
                    for state_prime, o_prime in itertools.product(list(range(self.p.s_num)), list(range(self.p.o_num))):
                        if emp_tran[state_prime, o_prime] < 0.01:
                            continue
                        new_s_for_pi = self.truncate(s[:-1] + (a, o_prime))
                        new_s = new_s_for_pi + (state_prime,)
                        for a_prime in range(self.p.a_num):
                            if self.new_Q[h + 1][new_s][a_prime] < 0.01:
                                continue
                            expectation += emp_tran[state_prime, o_prime] * self.pi[h + 1][new_s_for_pi][a_prime] * \
                                           self.new_Q[h + 1][new_s][a_prime]
                    self.new_Q[h][s][a] = self.p.reward[h][s[-1], a] + expectation
        self.Q_list.append(self.new_Q)

    def truncate(self, his):
        if len(his) > self.Z * 2 + 1:
            s = tuple(his[-(self.Z * 2 + 1):])
        else:
            s = tuple(his)
        return s

    def update_pi(self, traj_list):
        Q = self.Q_list[-1]
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            new_pi = self.pi[h].copy()
            for traj in traj_list:
                oa_seq = []
                for j in range(4 * h + 2):
                    if j % 4 == 1 or j % 4 == 2:
                        oa_seq.append(traj[j])
                pi_s = self.truncate(oa_seq)
                s = pi_s + (traj[4 * h],)
                a = traj[4 * h + 2]
                new_pi[pi_s][a] += self.alpha * Q[h][s][a] / (len(traj_list) * self.pi[h][pi_s][a])
            for s in self.s_space[h]:
                self.pi[h][s] = project(new_pi[s])

    def run_traj(self):
        s, o = self.p.reset()
        his = [o]
        r_sum = 0
        extended_hist = [s, o]
        for h in range(self.p.H):
            a = sample(self.pi[h][self.truncate(his)])
            s, o, r, _ = self.p.step(a)
            his += [a, o]
            extended_hist += [a, r, s, o]
            r_sum += r
        return extended_hist, r_sum

    def evaluate(self, K):
        r_sum = 0
        for k in range(K):
            r_sum += self.run_traj()[1]
        return r_sum/K

