#!/usr/bin/env python
# coding: utf-8

# In[1]:
import imp
import torch
from torch import nn
from torch.nn.modules.linear import Linear
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.models as models
import numpy as np
import os,sys,os.path
from tensorboardX import SummaryWriter
import pickle
from tqdm import tqdm
import copy
import gc
import torch.nn.functional as F
import time
# In[2]:


from option import args_parser
from utils import Accuracy,average_weights
from sampling import LocalDataset, LocalDataloaders , partition_data
from finch import FINCH

# In[3]:

torch.set_default_dtype(torch.float64)
print(torch.__version__)
torch.cuda.is_available()
device = torch.device("cuda:0")
print(device)
args = args_parser()
np.set_printoptions(threshold=np.inf)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'


# In[4]:
args = args_parser()
args.num_clients = 10
args.code_len = 32
args.batch_size = 64

# In[5]:

class net(nn.Module):
    def __init__(self,
                 code_length=32, 
                 num_classes = 10,
                 ):
        super(net,self).__init__()
        self.code_length = code_length
        self.num_classes = num_classes  
        self.feature_extractor = models.resnet18(num_classes=self.code_length)
        self.classifier =  nn.Sequential(
                                nn.Linear(self.code_length, self.num_classes))
    def forward(self,x): #x = [batch,time,freq]
        f = self.feature_extractor(x)
        z = self.classifier(f)
        return f, z  
global_model = net(code_length=32, num_classes = 10)
print('# model parameters:', sum(param.numel() for param in global_model.parameters()))
global_model = nn.DataParallel(global_model)
global_model.to(device)

# In[6]:
def agg_func(protos):
    """
    Returns the average of the weights.
    """

    for [label, proto_list] in protos.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            protos[label] = proto / len(proto_list)
        else:
            protos[label] = proto_list[0]

    return protos
# In[8]:
train_dataset, testset, dict_users, dict_users_test = partition_data(n_users = args.num_clients, alpha=5,rand_seed = 0, dataset='SVHN')
# In[9]:
Loaders_train = LocalDataloaders(train_dataset,dict_users,args.batch_size,ShuffleorNot = True,frac=0.1)
Major_classes = []
Counts = []
Available_labels = []
for idx in range(args.num_clients):
    available_labels = []
    counts = [0]*10
    for batch_idx,(X,y) in enumerate(Loaders_train[idx]):
        batch = len(y)
        y = np.array(y)
        for i in range(batch):
            counts[int(y[i])] += 1
    print(counts)
    Counts.append(counts)
    for i in range(10):
        if counts[i] != 0: available_labels.append(i)
    Available_labels.append(available_labels)
# In[10]:
Loaders_test = LocalDataloaders(testset, dict_users_test, args.batch_size, ShuffleorNot = True,frac=0.2)
Major_classes = []
for idx in range(args.num_clients):
    counts = [0]*10
    for batch_idx,(X,y) in enumerate(Loaders_test[idx]):
        batch = len(y)
        y = np.array(y)
        for i in range(batch):
            counts[int(y[i])] += 1
    print(counts)  
# In[11]:
logger = SummaryWriter('./logs')
checkpoint_dir = './checkpoint/'+ args.dataset + '/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
with open(checkpoint_dir+'args.pkl', 'wb') as fp:
    pickle.dump(args, fp)
print('Data and model loaded')
print('Checkpoint dir:', checkpoint_dir)
# In[12]:
for m in global_model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
# In[13]:
def calculate_infonce(f_now, f_pos, f_neg, infoNCET, device):
    f_proto = torch.cat((f_pos, f_neg), dim=0)
    l = torch.cosine_similarity(f_now, f_proto, dim=1)
    l = l / infoNCET

    exp_l = torch.exp(l)
    exp_l = exp_l.view(1, -1)
    pos_mask = [1 for _ in range(f_pos.shape[0])] + [0 for _ in range(f_neg.shape[0])]
    pos_mask = torch.tensor(pos_mask, dtype=torch.float).to(device)
    pos_mask = pos_mask.view(1, -1)
    # pos_l = torch.einsum('nc,ck->nk', [exp_l, pos_mask])
    pos_l = exp_l * pos_mask
    sum_pos_l = pos_l.sum(1)
    sum_exp_l = exp_l.sum(1)
    infonce_loss = -torch.log(sum_pos_l / sum_exp_l)
    return infonce_loss

def hierarchical_info_loss(f_now, label, all_f, mean_f, all_global_protos_keys, infoNCET, device):
    # f_pos = np.array(all_f)[all_global_protos_keys == label.item()][0].to(device)
    # f_neg = torch.cat(list(np.array(all_f)[all_global_protos_keys != label.item()])).to(device)
    # xi_info_loss = calculate_infonce(f_now, f_pos, f_neg, infoNCET, device)
    # mean_f_pos = np.array(mean_f)[all_global_protos_keys == label.item()][0].to(device)

    f_pos = torch.stack([tensor for tensor, key in zip(all_f, all_global_protos_keys) if key == label.item()])[0].to(device)
    f_neg = torch.cat([tensor for tensor, key in zip(all_f, all_global_protos_keys) if key != label.item()]).to(device)
    xi_info_loss = calculate_infonce(f_now, f_pos, f_neg, infoNCET, device)
    mean_f_pos = torch.stack([tensor for tensor, key in zip(mean_f, all_global_protos_keys) if key == label.item()])[0].to(device)

    mean_f_pos = mean_f_pos.view(1, -1)
    # mean_f_neg = torch.cat(list(np.array(mean_f)[all_global_protos_keys != label.item()]), dim=0).to(self.device)
    # mean_f_neg = mean_f_neg.view(9, -1)

    loss_mse = nn.MSELoss()
    cu_info_loss = loss_mse(f_now, mean_f_pos)

    hierar_info_loss = xi_info_loss + cu_info_loss
    return hierar_info_loss

def proto_aggregation(local_protos_list, online_clients):
    agg_protos_label = dict()
    for idx in online_clients:
        local_protos = local_protos_list[idx]
        for label in local_protos.keys():
            if label in agg_protos_label:
                agg_protos_label[label].append(local_protos[label])
            else:
                agg_protos_label[label] = [local_protos[label]]
    for [label, proto_list] in agg_protos_label.items():
        if len(proto_list) > 1:
            proto_list = [item.squeeze(0).detach().cpu().numpy().reshape(-1) for item in proto_list]
            proto_list = np.array(proto_list)

            c, num_clust, req_c = FINCH(proto_list, initial_rank=None, req_clust=None, distance='cosine',
                                        ensure_early_exit=False, verbose=True)

            m, n = c.shape
            class_cluster_list = []
            for index in range(m):
                class_cluster_list.append(c[index, -1])

            class_cluster_array = np.array(class_cluster_list)
            uniqure_cluster = np.unique(class_cluster_array).tolist()
            agg_selected_proto = []

            for _, cluster_index in enumerate(uniqure_cluster):
                selected_array = np.where(class_cluster_array == cluster_index)
                selected_proto_list = proto_list[selected_array]
                proto = np.mean(selected_proto_list, axis=0, keepdims=True)

                agg_selected_proto.append(torch.tensor(proto))
            agg_protos_label[label] = agg_selected_proto
        else:
            agg_protos_label[label] = [proto_list[0].data]

    return agg_protos_label
# In[14]:
global_protos = []
local_protos = {}
infoNCET = 0.02
class LocalUpdate(object):
    """
    This class is for train the local model with input global model(copied) and output the updated weight
    args: argument 
    Loader_train,Loader_val,Loaders_test: input for training and inference
    user: the index of local model
    idxs: the index for data of this local model
    logger: log the loss and the process
    """
    def __init__(self, index, args, Loader_train,available_labels,Loader_test,idxs, logger, code_length, num_classes, device):
        self.index = index
        self.args = args
        self.logger = logger
        self.trainloader = Loader_train
        self.testloader = Loader_test
        self.idxs = idxs
        self.ce = nn.CrossEntropyLoss() 
        self.device = device
        self.code_length = code_length
        self.mse = nn.MSELoss()
        self.model  = net(32,num_classes).to(device)
        self.model = nn.DataParallel(self.model).to(device)
        self.early_stop = 20 
        self.latent_layer_idx = -1
        self.ensemble_loss=nn.KLDivLoss(reduction="batchmean")
        self.available_labels = available_labels
        self.gen_batch_size = 32
        self.batch_size = 64
        
    def update_weights_Gen(self, global_round,regularization=True):
        self.model.to(self.device)
        self.model.train()
        epoch_loss = []
        optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)

        ##############################
        if len(global_protos) != 0:
            all_global_protos_keys = np.array(list(global_protos.keys()))
            all_f = []
            mean_f = []
            for protos_key in all_global_protos_keys:
                temp_f = global_protos[protos_key]
                temp_f = torch.cat(temp_f, dim=0).to(self.device)
                all_f.append(temp_f.cpu())
                mean_f.append(torch.mean(temp_f, dim=0).cpu())
            all_f = [item.detach() for item in all_f]
            mean_f = [item.detach() for item in mean_f]

        for iter in range(self.args.local_ep):
            agg_protos_label = {}
            for batch_idx, (X, y) in enumerate(self.trainloader):
                X = X.to(self.device).double()
                y = y.to(self.device).double()
                optimizer.zero_grad()
                f, user_output_logp = self.model(X)
                lossCE = self.ce(user_output_logp,y.long())    



                if len(global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None

                    for label in y:
                        if label.item() in global_protos.keys():
                            f_now = f[i].unsqueeze(0)
                            loss_instance = hierarchical_info_loss(f_now, label, all_f, mean_f, all_global_protos_keys, infoNCET, self.device)

                            if loss_InfoNCE is None:
                                loss_InfoNCE = loss_instance
                            else:
                                loss_InfoNCE += loss_instance
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i
                loss_InfoNCE = loss_InfoNCE

                loss = lossCE + loss_InfoNCE
                loss.backward()
                optimizer.step()

                if batch_idx % 10 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)] loss: {:.6f} lossCE: {:.6f} loss_InfoNCE: {:.6f}'.format(
                        global_round, iter, batch_idx * len(X),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item(), lossCE.item(), loss_InfoNCE.item()))
                
                if iter == self.args.local_ep - 1:
                    for i in range(len(y)):
                        if y[i].item() in agg_protos_label:
                            agg_protos_label[y[i].item()].append(f[i, :])
                        else:
                            agg_protos_label[y[i].item()] = [f[i, :]]             

                self.logger.add_scalar('loss', loss.item())

            agg_protos = agg_func(agg_protos_label)
            local_protos[self.index] = agg_protos

   
    def exp_lr_scheduler(self, epoch, decay=0.98, init_lr=0.1, lr_decay_epoch=1):
        """Decay learning rate by a factor of 0.95 every lr_decay_epoch epochs."""
        lr= max(1e-4, init_lr * (decay ** (epoch // lr_decay_epoch)))
        return lr
        
    def test_accuracy(self):
        self.model.eval()
        accuracy = 0
        cnt = 0
        for batch_idx, (X, y) in enumerate(self.testloader):
            X = X.to(self.device).double()
            y = y.to(self.device).double()
            _, p = self.model(X)
            y_pred = p.argmax(1)
            accuracy += Accuracy(y,y_pred)
            cnt += 1
        return accuracy/cnt

    def load_model(self,global_weights):
        self.model.load_state_dict(global_weights)
# In[15]:
global_weights = global_model.state_dict()
# In[16]:
# training
args.num_epochs = 50
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
LocalModels = []
for idx in range(args.num_clients):
    LocalModels.append(LocalUpdate(idx, args,Loaders_train[idx], Available_labels[idx], Loaders_test[idx], idxs=dict_users[idx], 
                                   logger=logger, code_length = args.code_len, num_classes = 10, device=device))
# In[19]:
test_loader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True)


accs_dict = {}
averaing = 'weight'
for epoch in tqdm(range(args.num_epochs)):
    test_accuracy = 0
    begin_time = time.time()

    Knowledges = []
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n')
    global_model.train()
    m = max(int(args.sampling_rate * args.num_clients), 1)
    idxs_users = np.random.choice(range(args.num_clients), m, replace=False)
    for idx in idxs_users:
        LocalModels[idx].load_model(global_weights)
        LocalModels[idx].update_weights_Gen(global_round=epoch, regularization = True)
        acc = LocalModels[idx].test_accuracy()
        test_accuracy += acc


    global_protos = proto_aggregation(local_protos, idxs_users)

    #####aggregate_nets######
    online_clients = idxs_users
    global_w = global_model.state_dict()


    if averaing == 'weight':
        online_clients_dl = [Loaders_train[online_clients_index] for online_clients_index in online_clients]
        online_clients_len = [len(dl.sampler) for dl in online_clients_dl]
        online_clients_all = np.sum(online_clients_len)
        freq = online_clients_len / online_clients_all
    else:
    # if freq == None:
        parti_num = len(online_clients)
        freq = [1 / parti_num for _ in range(parti_num)]

    first = True
    for index,net_id in enumerate(online_clients):
        cur_net = LocalModels[net_id].model
        net_para = cur_net.state_dict()
        # if net_id == 0:
        if first:
            first = False
            for key in net_para:
                global_w[key] = net_para[key] * freq[index]
        else:
            for key in net_para:
                global_w[key] += net_para[key] * freq[index]

    global_model.load_state_dict(global_w)

    for _, cur_net in enumerate(LocalModels):
        cur_net.model.load_state_dict(global_model.state_dict())
    #####################################################################


    global_model.eval()
    accuracy = 0
    cnt = 0
    for batch_idx, (X, y) in enumerate(test_loader):
        X = X.to(device).double()
        y = y.to(device).double()
        _, p = global_model(X)
        y_pred = p.argmax(1)
        accuracy += Accuracy(y,y_pred)
        cnt += 1
    print('average test accuracy:', test_accuracy / args.num_clients)

    end_time = time.time()
    training_time  = end_time - begin_time

    print('global test accuracy: ', accuracy/cnt)
    print('training time: ', training_time)