import math
import pandas as pd
import numpy as np

class Importance_Sampling(object):
    def __init__(self, raw_data, theta, gamma, policy, gt_data):
        self.raw_data = raw_data
        self.theta = theta
        self.gamma = gamma
        self.traces = []
        self.n_action = 0
        self.n_user = 0
        self.random_prob = 0
        self.policy = policy
        self.alpha = 0.5
        
        self.gt_data = gt_data # groundtruth

    def readData(self):
        raw_data = self.raw_data

        Q_list = ['ps', 'fwe', 'we']
        beh_prob_list = ['prob_ps', 'prob_fwe', 'prob_we']
        user_list = list(raw_data['userID'].unique())
        self.n_action = len(Q_list)
        self.n_user = len(user_list)
        self.random_prob = 1.0 / self.n_action

        for user in user_list:
            user_sequence = []
            user_data = raw_data.loc[raw_data['userID'] == user,]
            row_index = user_data.index.tolist()
            
            expert_count = 0
            for i in range(0, len(row_index)):
                action = user_data.loc[row_index[i], 'real_action']
                
                reward = user_data.loc[row_index[i], 'inferred_rew']
                Qs = user_data.loc[row_index[i], Q_list].tolist()
                beh_probs = user_data.loc[row_index[i], beh_prob_list].tolist()
                
                eva_probs = []
                if self.policy.isin['DQN1','DQN2','DQN3']:
                    eva_action = Qs.index(max(Qs))
                    eva_probs = [0.8 if x == eva_action else 1e-1 for x in range(self.n_action)]
                elif self.policy == 'Expert':
                    eva_probs = [1/self.n_action for x in range(self.n_action)]

                user_sequence.append((action, reward, Qs, beh_probs, eva_probs))

            self.traces.append(user_sequence)

    def IS(self):
        IS = 0

        for each_student_data in self.traces:
            cumul_policy_prob = 1
            cumul_random_prob = 1
            cumulative_reward = 0

            for i, (action, reward, Qs) in enumerate(each_student_data):

                Q_act = Qs[action]
                prob_logP = math.exp(Q_act*self.theta) / sum(math.exp(x*self.theta) for x in Qs)

                cumul_policy_prob *= prob_logP
                cumul_random_prob *= self.random_prob
                cumulative_reward += math.pow(self.gamma, i) * reward

            weight = cumul_policy_prob / cumul_random_prob
            IS_reward = cumulative_reward * weight

            IS += IS_reward

        IS = float(IS) / self.n_user
        return IS


    def WIS(self):
        WIS = 0
        total_weight = 0

        for each_student_data in self.traces:
            cumul_policy_prob = 1
            cumul_random_prob = 1
            cumulative_reward = 0

            for i, (action, reward, Qs, beh_probs, eva_probs) in enumerate(each_student_data):

                
                cumul_policy_prob *= eva_probs[action]
                cumul_random_prob *= beh_probs[action]
                cumulative_reward += math.pow(self.gamma, i) * reward

            weight = cumul_policy_prob / cumul_random_prob
            
            total_weight += weight
            IS_reward = cumulative_reward * weight

            WIS += IS_reward

        WIS = float(WIS) / total_weight
        return WIS