import random
import numpy as np


class BaseAgent:
    def __init__(self, env):

        self.num_states = env.NUM_STATES
        self.num_features = env.NUM_FEATURES
        self.num_actions = env.NUM_ACTIONS
        self.features = env.features

    def random_policy(self):
        return random.randint(
            0, self.num_actions - 1
        )  # randint generates from 0 to self.num_actions-1

    def fixed_behavior_policy(self):

        if np.random.rand() < 5 / 6:
            return 1  # dash action
        else:
            return 0  # solid action

    def greedy_policy(self, state, weight):

        feature = self.features[:, state]
        action_values = np.sum(feature * weight, axis=1)
        action = np.random.choice(np.where(action_values == action_values.max())[0])

        return action

    def action_value(self, state, action, weight):
        feature = self.features[action, state]
        action_value = np.sum(feature * weight)

        return action_value

    def bellman_error(self, weight, state, next_state, action, reward, done):

        phi = self.features[action, state]
        next_action = self.greedy_policy(next_state, weight)
        next_phi = self.features[next_action, next_state]
        done_mask = 0 if done else 1

        td_error = (
            reward + done_mask * self.gamma * np.sum(next_phi * weight) - np.sum(phi * weight)
        ) * phi

        return np.linalg.norm(td_error.reshape(-1, 1), ord="fro")

    def dual_gap(self, weight, state, action):
        phi = self.features[action, state]
        dual_gap = np.sum(phi * weight) ** 2
        return dual_gap

    def norm_q_value(self, weight):

        q_values = np.einsum("aij,j->ai", self.features, weight)
        sum_norm = np.linalg.norm(q_values, ord="fro")
        return sum_norm

    def norm_q_by_action(self, weight):
        q_values = np.einsum("aij,j->ai", self.features, weight)
        solid_action_norm = np.linalg.norm(q_values[0].reshape(-1, 1), ord="fro")
        dash_action_norm = np.linalg.norm(q_values[1].reshape(-1, 1), ord="fro")

        return solid_action_norm, dash_action_norm

    def infty_q_value(self, weight):
        q_values = np.einsum("aij,j->ai", self.features, weight)
        max_q_value = abs(q_values).max()
        return max_q_value