from collections import defaultdict
from copy import copy 
import numpy as np
import os
import pandas as pd
from sklearn.linear_model import LinearRegression
from math import prod

from collections import defaultdict
import cvxpy as cp

races = ['Race_A',
         'Race_B',
         'Race_C']

def race_name(r):
    return r.split('_')[-1].split()[0]

def get_race(x):
    for r in races:
        if x[r] == 1:
            return race_name(r)
    return 'Unknown'

def get_instance(group_feature='race'):
    if group_feature == 'age':
        group_feature = 'Age_discrete'
    df = pd.read_csv('warfarin_processed_dataset.csv')
    df['race'] = df.apply(get_race, axis=1)
    df = df[df.race != 'Unknown']

    # less than
    weight_cutoff = 76
    age_cutoff = 7 
    K = 3

    df['Age_discrete'] = (df['Age'] < age_cutoff).astype(int)
    df['Weight_discrete'] = (df['Weight (kg)'] < weight_cutoff).astype(int)
    df['race_age'] = df['race'] + df['Age_discrete'].astype(str)

    groups = df[group_feature].unique()
    group_size = dict(pd.value_counts(df[group_feature]))
    counts = defaultdict(int)
    counts_by_group = defaultdict(lambda : defaultdict(int))

    features = ['Age_discrete', 'Weight_discrete', 'amiodarone', 'Cyp2C9 genotypes_*1/*1', 'VKORC1_rs9923231_G/G']

    for index, row in df.iterrows():
        f = tuple(np.append(row[features], 1))
        group = row[group_feature]
        counts[f] += 1
        counts_by_group[f][group] += 1

        
    # only keep features in which more than 1% of the group is involved in.
    new_counts_by_group = dict()
    counts = dict()
    threshold = 0.01
    hard_threshold = 10
    for feature, d in counts_by_group.items():
        new_d = dict((g, c) for g, c in d.items() if c > threshold*group_size[g] and c > hard_threshold)
        if new_d:
            new_counts_by_group[feature] = new_d
            counts[feature] = sum(new_d.values())
    counts_by_group = new_counts_by_group
        

    all_features = []
    all_features_expanded = []
    # race -> possible features
    features_for_race = defaultdict(list)
    for c in sorted(list(counts.keys())):
        for r in counts_by_group[c].keys():
            features_for_race[r].append(c)
        all_features.append(c)
        for _ in range(3):
            all_features_expanded.append(c)
        
    #     if len(all_features) >= 10:
    #         break

    group_to_unavail_actions = defaultdict(list)
    for i, c in enumerate(all_features):
        action_idx = i*K
        groups_for_action = counts_by_group[c].keys()
        for g in groups:
            if g not in groups_for_action:
                group_to_unavail_actions[g].extend([action_idx, action_idx+1, action_idx+2])

    true_mus = []
    for arm in range(3):
        lr = LinearRegression()
        X = df[features]
        Y = df['Warfarin_arm_%d' % arm]
        lr.fit(X, Y)
        
        model_Y = lr.predict(X) 
    #     model_Y += np.random.normal(0, 0.1, len(model_Y))
        df['model_Warfarin_arm_%d' % arm] = model_Y
        mu = lr.coef_
        true_mus.append(np.append(lr.coef_, lr.intercept_))


    bandit = LinearOptimalBandit()
    bandit.mu = true_mus
    bandit.initialize_with_contexts(3, all_features)
    bandit.compute_deltas()

    return bandit, groups, group_to_unavail_actions, counts_by_group, all_features_expanded, features_for_race


def get_disagreement_point(groups, f_to_deltas, features_for_race, multiple):
    features_for_race_reduced = dict()
    for g in groups:
        fs = np.array(features_for_race[g])
        d = len(features_for_race[g][0])
        to_include = []
        for i in range(d):
            if len(np.unique(fs[:, i])) > 1:
                to_include.append(i)
        to_include.append(-1)
        features_for_race_reduced[g] = [tuple(np.array(x)[to_include]) for x in features_for_race[g]]

    # optimize for each race
    disagreement_point = dict()
    for race in groups:
        bandit = LinearOptimalBandit()
        # bandit.mu = orig_bandit.mu
        bandit.initialize_with_contexts(3, features_for_race_reduced[race])
        deltas = []
        for f in features_for_race[race]:
            deltas.extend(f_to_deltas[f])
        bandit.deltas = np.array(deltas)
        sol = minimize_regret(bandit.all_actions, bandit.deltas*multiple)
        prob, H, alpha, H_for_action = sol
        group_alpha = alpha * multiple**2
        g_regret_disagreement = group_alpha @ bandit.deltas
        disagreement_point[race] = g_regret_disagreement
    return disagreement_point


def common_constraints(constraints, actions, deltas, d, K, 
                       warm_start=None, num_samples=500):
    H = cp.Variable((d,d), symmetric=True)
    # percent of pulls for each arm k
    alpha = cp.Variable(K)
    constraints += [
        H == cp.sum([alpha[k] * (x.reshape(-1, 1) @ x.reshape(1, -1)) for k, x in enumerate(actions)]),
        alpha >= 0,
    ]

    H_for_action = dict()
    for k in range(K):
        if deltas[k] == 0:
            continue
        H_extended = cp.Variable((d+1,d+1), symmetric=True)
        # https://en.wikipedia.org/wiki/Schur_complement
        constraints += [
            H_extended[:d, :d] == H,
            H_extended[d, d] == deltas[k]**2/2,
            H_extended[d, :d] == actions[k],
            H_extended >> 0,
        ]
        H_for_action[k] = H_extended
            
    if warm_start:
        warm_alpha, warm_H, warm_H_for_action = warm_start
        H.value = warm_H
        alpha.value = warm_alpha
        for k in range(K):
            if k in H_for_action:
                H_for_action[k].value = warm_H_for_action[k]

    return H, alpha, H_for_action


def max_fairness_frank_wolfe_subroutine(actions, groups, group_to_unavail_actions, 
        deltas, disagreement_point, initial_group_alphas):
    K = len(actions)
    d = len(actions[0])

    constraints = []
    H, alpha, H_for_action = common_constraints(
        constraints, actions, deltas, d, K)

    num_groups = len(groups)
    group_alpha = cp.Variable((K, num_groups), nonneg=True)
    regret_decrease = cp.Variable(num_groups, nonneg=True)

    for k in range(K):
        constraints += [
            cp.sum(group_alpha[k]) == alpha[k]
        ]

    for i, g in enumerate(groups):
        constraints  += [
            regret_decrease[i] == disagreement_point[g] - group_alpha[:, i] @ deltas,
        ]
        # actions that are unavailable should be 0.
        if group_to_unavail_actions[g]:
            constraints += [
                group_alpha[group_to_unavail_actions[g], i] == 0,
            ]
            
    # frank-wolfe step: calculate gradient
    gradients = np.zeros((K, num_groups))
    for i, g in enumerate(groups): 
        initial_utility_gain = disagreement_point[g] - initial_group_alphas[:, i] @ deltas
        if initial_utility_gain < 0:
            initial_utility_gain = 1
        for k in range(K):
            gradients[k, i] = deltas[k] / initial_utility_gain

    objective = cp.sum(cp.multiply(gradients, group_alpha))
    prob = cp.Problem(cp.Minimize(objective), constraints)

    prob.solve(solver=cp.MOSEK, verbose=False)

    return prob, alpha, group_alpha, H, H_for_action


# group_to_action_idx is a dict from group -> list of indices of unavailable actions
def maximize_fairness(actions, groups, group_to_unavail_actions, deltas, 
                      disagreement_point, warm_start=None):
    K = len(actions)
    d = len(actions[0])
    
    constraints = []
    H, alpha, H_for_action = common_constraints(
        constraints, actions, deltas, d, K, warm_start)

    num_groups = len(groups)
    group_alpha = cp.Variable((K, num_groups), nonneg=True)
    regret_decrease = cp.Variable(num_groups, nonneg=True)
    
    if warm_start:
        warm_alpha = warm_start[0]
    
    for k in range(K):
        constraints += [
            cp.sum(group_alpha[k]) == alpha[k]
        ]

    for i, g in enumerate(groups):
        constraints  += [
            # Regret decrease for each group
            regret_decrease[i] == disagreement_point[g] - group_alpha[:, i] @ deltas,
        ]
        # actions that are unavailable should be 0.
        if group_to_unavail_actions[g]:
            constraints += [
                group_alpha[group_to_unavail_actions[g], i] == 0,
            ]

    prob = cp.Problem(cp.Maximize(cp.sum([cp.log(regret_decrease[i]) for i in range(num_groups)])), constraints)
    prob.solve(solver=cp.MOSEK, verbose=False)

    return prob, alpha.value, group_alpha.value, regret_decrease.value


# The only thing that changes throughout time is the deltas, which affect 
# only the objective.
def minimize_regret(actions, deltas, warm_start=None):
    K = len(actions)
    d = len(actions[0])


    constraints = []
    H, alpha, H_for_action = common_constraints(
        constraints, actions, deltas, d, K, warm_start)

    prob = cp.Problem(cp.Minimize(alpha @ deltas), constraints)
    prob.solve(solver=cp.MOSEK, verbose=False)
    alpha_values = alpha.value

    H_for_action_vals = dict()
    for k, H_val in H_for_action.items():
        H_for_action_vals[k] = H_val.value

    return prob, H.value, alpha_values, H_for_action_vals


class LinearOptimalBandit():

    # all_possible_contexts is a list of tuples of the form (0, 1, 1, 0, 0, 1)
    def initialize_with_contexts(self, K, all_possible_contexts):
        self.all_actions = []
        self.all_contexts = []
        d = len(all_possible_contexts[0]) 
        for c in all_possible_contexts:
            context_with_intercept = c
            self.all_contexts.append(context_with_intercept)
            for i in range(K):
                action = np.zeros(d*K)
                action[i*d:i*d+d] = context_with_intercept
                self.all_actions.append(action)

    # need mu
    def compute_deltas(self):
        self.deltas = []
        self.context_to_deltas = dict()
        for context in self.all_contexts:
            expected_rewards = np.array([np.dot(self.mu[i], context) for i in range(3)]).ravel()
            expected_opt = np.max(expected_rewards)
            self.deltas = np.hstack([self.deltas, expected_opt - expected_rewards])
            self.context_to_deltas[tuple(context)] = expected_opt - expected_rewards
    
    def do_minimize_regret(self):
        self.solution = minimize_regret(self.all_actions, self.deltas)
        return self.solution
