import torch
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from data import target_sln, wga
from sklearn.metrics import accuracy_score
import argparse

parser = argparse.ArgumentParser(description = 'sets the dataset')
parser.add_argument('dataset', metavar='dataset', type = str, help = 'set the dataset')
args = parser.parse_args()



global_seed = 0 
dataset = args.dataset

np.random.seed(global_seed)
torch.manual_seed(global_seed)


noise = [0, 0.1, 0.2, 0.3]
final_results = pd.DataFrame(columns=['dataset', 'noise', 'wga_mean', 'wga_std', 'C'])
C_VALUES = np.logspace(-4,0,num=10, base=10)

def sample_weights(val_df):
    df = val_df.copy()
    df.reset_index(inplace = True)
    indices = np.zeros(len(df))
    num_groups = len(np.unique(val_df['target'].values)) * len(np.unique(val_df['group'].values))
    for i, grp in df.groupby('target'):
        for j, subgrp in grp.groupby('group'):
            indices[subgrp.index.to_list()] = len(df)/(num_groups * len(df[(df['target'] == i) & (df['group'] == j)]))
    return indices

def run_exp(val_data, test_data, C):
    weights = sample_weights(val_data)
    LLR = LogisticRegression(penalty='l1', solver='liblinear', C=C, fit_intercept=True).fit(val_data.drop(['target','group', 'true_target'], axis=1), val_data['target'], sample_weight = weights)
    return wga(LLR,val_data.drop(['true_target'], axis=1)), wga(LLR,test_data)
    

    

# The base path is the directory path of the embeddings (extracted from the base model) of the required datasets. 
# In the base path, the code expects the embeddings to be in a directory named after the datasets.
# The code expects the test and validation embeddings along with the test and validation target labels and domain 
# annotations (The code refers to the domain annotations as groups) in numpy file array format (.npy). For example, 
# the name of the celebA validation embeddings would be 'celebA_val_embeddings.npy' which is in the 'celebA' directory.

base_path = './'+dataset+'/'
X = np.load(base_path+dataset+'_val_embeddings.npy')
y = np.load(base_path+dataset+'_val_labels.npy')
group = np.load(base_path+dataset+'_val_groups.npy')
test_X = np.load(base_path+dataset+'_test_embeddings.npy')
test_y = np.load(base_path+dataset+'_test_labels.npy')
test_group = np.load(base_path+dataset+'_test_groups.npy')

original_val_data = pd.DataFrame(X)
original_val_data['target'] = y
original_val_data['group'] = group

final_test_data = pd.DataFrame(test_X)
final_test_data['target'] = test_y
final_test_data['group'] = test_group

for noise_level in noise:

    print(dataset, noise_level, global_seed)
    
    
    test_data = original_val_data.sample(frac=0.5,replace=False)
    train_data = target_sln(original_val_data.drop(test_data.index).reset_index(drop=True),p=noise_level)
    full_train_data = pd.concat([target_sln(test_data.reset_index(drop=True),p=noise_level), train_data],ignore_index=True)
    
    results = pd.DataFrame(columns=['C', 'val_wga','test_wga','type'])
    
    DFR_best = -np.inf
    for c in C_VALUES:
        dfr_val, dfr_test = run_exp(train_data,test_data, c)
        results.loc[len(results)] = {'C':c, 'val_wga':dfr_val, 'test_wga':dfr_test, 'type':'GUW'}

        
    avg_param = results[results['type'] == 'GUW'].groupby(['C'])['test_wga'].mean().idxmax()
    
    print(avg_param)
    
    
    wgas = np.zeros(10)
    
    seeds = np.random.randint(200, size=(10)) 
    
    
    for i, seed in enumerate(seeds):
        print(i)
        np.random.seed(seed)
        torch.manual_seed(seed)
        full_train_data = target_sln(original_val_data.reset_index(drop=True),p=noise_level)
        _,wgas[i] = run_exp(full_train_data, final_test_data, avg_param)
        print(wgas[i])
    
    
    print("GUW (" + dataset + ")(" + str(noise_level) + "): ", wgas.mean(), wgas.std())
    final_results.loc[len(final_results)] = {'dataset': dataset, 'noise': noise_level, 'wga_mean': wgas.mean(), 'wga_std': wgas.std(), 'C': avg_param} 
    results = results[0:0] 
    np.random.seed(global_seed)
    torch.manual_seed(global_seed)
            
path = 'results/GUW/final_GUW_' + dataset  + '.csv'
final_results.to_csv(path, mode='a', header=True)
