import torch 
from torch import nn
from torch.nn import functional as F
import numpy as np
import json
from numpy.random import Generator, PCG64
import argparse
import os
from tqdm import tqdm
from transformers import AutoTokenizer, BertForSequenceClassification
    

def get_train_data(path):
    data_dict = {}
    user_names = []

    with open(path+"/all_data_niid_0_keep_0_train_9.json", 'r') as f:
        data = json.load(f)
        data_dict = {**data_dict, **data["user_data"]}
        user_names = user_names + data["users"]
    return data_dict,user_names

def get_test_data(path):
    data_dict = {}
    user_names = []

    with open(path+"/all_data_niid_0_keep_0_test_9.json", 'r') as f:
        data = json.load(f)
        data_dict = {**data_dict, **data["user_data"]}
        user_names = user_names + data["users"]
    return data_dict,user_names

def test(model,test_data_dict,device):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    criterion = nn.CrossEntropyLoss()
    acc_dict = {}
    loss_dict = {}
    for key in test_data_dict: 
        inputs = tokenizer([input[4] for input in test_data_dict[key]["x"]],return_tensors="pt",padding=True)["input_ids"]
        labels = torch.tensor(test_data_dict[key]["y"])
        with torch.no_grad():
            outputs = model(inputs.to(device)).logits
        _, predicted = torch.max(outputs.data, 1)
        acc_dict[key]=(predicted == labels.to(device)).float().mean().item()
        loss_dict[key]=criterion(outputs, labels.to(device)).item()
    return acc_dict,loss_dict 


def train(data_dict,user_names,test_data_dict,test_use_names,T=10000,alpha_0=0,alpha_1=0,lr=0.06,save_name=""):
    try:
        os.mkdir("FeMNIST_Results")
    except:
        None
    try:
        os.mkdir(save_name)
    except:
        print("Folder exists, overwriting results")
    alpha_dict = {key:0 for key in user_names}
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
    model.classifier = torch.nn.Sequential(torch.nn.Linear(768,384),torch.nn.Linear(384,2)).to(device)
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    rng = Generator(PCG64(42))
    alpha_typedict = {key: int(rng.uniform()>2/3) for key in range(len(user_names))}

    losses_global = np.zeros((T,3))
    mses_global = np.zeros((T,3))
    clients_global = {}
    alphas_global = np.zeros((T,3))
    alphatypes_global = np.zeros((T,3))
    gradsizes_global = np.zeros((T,3))

    for step in tqdm(range(T),total=T):
        indexes = np.random.choice(np.arange(len(user_names)),3)
        clients = [user_names[index] for index in indexes]
        alphas = [alpha_typedict[index]*(alpha_1-alpha_0) + alpha_0 for index in indexes]
        grads = []
        grads_real = []
        losses = []
        mses = []
        sizes = []
        for i in range(3):
            inputs = tokenizer([input[4] for input in data_dict[clients[i]]["x"]],return_tensors="pt",padding=True)["input_ids"]
            labels = torch.tensor(data_dict[clients[i]]["y"])
            sizes.append(len(labels)) 
            outputs = model(inputs.to(device)).logits
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            grad = [x.grad for x in model.classifier.parameters()]
            sent = [g+alphas[i]*torch.normal(torch.zeros_like(g))/np.sqrt(len(g.flatten())) for g in grad]
            grads.append(sent)
            grads_real.append(grad)
            losses.append(loss.detach().cpu().numpy().item())

        mean = [torch.sum(torch.stack([grads[j][i]*sizes[j]/sum(sizes) for j in range(3)]),0) 
                for i in range(len(grads[0]))]

        mses = [torch.sum(torch.stack([torch.sum((grads[j][i]-mean[i])**2)
                                       for i in range(len(grads[0]))])).detach().cpu().numpy().item()  
                          for j in range(3)]

        gradsizes = [torch.sum(torch.stack([torch.sum((grads_real[j][i])**2)
                                       for i in range(len(grads[0]))])).detach().cpu().numpy().item()  
                          for j in range(3)]

        with torch.no_grad():
            for i,param in enumerate(model.classifier.parameters()):
                param -= lr * mean[i]
                param.grad = None

        losses_global[step] = losses 
        alphas_global[step] = alphas
        alphatypes_global[step] = [alpha_typedict[index] for index in indexes]
        clients_global[step] = clients
        mses_global[step] = mses
        gradsizes_global[step] = gradsizes

    final_accs,final_losses = test(model,test_data_dict,device)


    redist_global = 1.5 * mses_global - 1.5 * np.mean(mses_global,1,keepdims=True)
    
    quick_results = {"accs":np.mean([final_accs[key] for key in final_accs]),
                     "losses":np.mean([final_losses[key] for key in final_losses]),
                     "mses_a0":np.sum(mses_global*(1-alphatypes_global))/(np.sum(1-alphatypes_global)),
                     "mses_a1":np.sum(mses_global*alphatypes_global)/(np.sum(alphatypes_global)),
                     "redist_a0":np.sum(redist_global*(1-alphatypes_global))/(np.sum(1-alphatypes_global)),
                     "redist_a1":np.sum(redist_global*alphatypes_global)/(np.sum(alphatypes_global)),
                     "redist_checksum":np.mean(redist_global)
                     }

    
    np.save(save_name+"/losses_step",losses_global)
    np.save(save_name+"/mses_step",mses_global)
    np.save(save_name+"/alphas_step",alphas_global)
    np.save(save_name+"/alphatypes_step",alphatypes_global)
    np.save(save_name+"/gradsizes_step",gradsizes_global)

    json.dump(clients_global, open(save_name+"/clients_step", 'w' ) )
    json.dump(final_accs, open(save_name+"/accs_final", 'w' ) )
    json.dump(final_losses, open(save_name+"/losses_final", 'w' ) )
    json.dump(quick_results, open(save_name+"/quick_results", 'w' ) )
    return None 


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('id',type=int)
    parser.add_argument('-T',default=3550,type=int,required=False)
    parser.add_argument('-a0',default=0,type=float,required=False)
    parser.add_argument('-a1',default=0,type=float,required=False)
    parser.add_argument('-lr',default=0.06,type=float,required=False)
    args = parser.parse_args()
    

    save_name = "FeMNIST_Results/"+str(args.id)+"a0_"+str(args.a0)+"a1_"+str(args.a1)+"T_"+str(args.T) +"lr_"+str(args.lr) 

    data_dict,user_names = get_train_data("data/train_twitter")
    test_data_dict,test_user_names = get_test_data("data/test_twitter")

    subset = [user_name for user_name in user_names if len(data_dict[user_name]["y"])>15 and len(data_dict[user_name]["y"])<20]
    user_names = subset
    user_names_test = subset
    data_dict = {key:data_dict[key] for key in subset}
    test_data_dict = {key:test_data_dict[key] for key in subset}
    
    
    train(data_dict,user_names,test_data_dict,test_user_names,T=args.T,alpha_0=args.a0,alpha_1=args.a1,lr=args.lr,
            save_name=save_name)


           

