import torch
import os
import numpy as np
import h5py
import copy
import time
import random
from utils.utils_spectral_dataloader import *

# from utils.data_utils import read_client_data


class Server(object):
    def __init__(self, args):
        # Set up the main attributes
        self.device = args.device
        self.dataset = args.dataset
        # self.num_classes = args.num_classes
        self.global_rounds = args.global_rounds
        self.local_steps = args.local_steps
        self.batch_size = args.batch_size
        self.learning_rate = args.local_learning_rate
        self.global_model = copy.deepcopy(args.model)
        self.num_clients = args.num_clients
        self.join_ratio = args.join_ratio
        self.random_join_ratio = args.random_join_ratio
        self.join_clients = int(self.num_clients * self.join_ratio)
        self.algorithm = args.algorithm
        self.time_select = args.time_select
        # self.goal = args.goal
        self.time_threthold = args.time_threthold
        # self.save_folder_name = args.save_folder_name
        self.top_cnt = 20
        self.auto_break = args.auto_break

        self.args = args


        self.clients = []
        self.selected_clients = []
        self.train_slow_clients = []
        self.send_slow_clients = []

        self.uploaded_weights = []
        self.uploaded_ids = []
        self.uploaded_models = []

        self.rs_test_acc = []
        self.rs_test_auc = []
        self.rs_train_loss = []

        self.eval_gap = args.eval_gap
        self.client_drop_rate = args.client_drop_rate
        self.train_slow_rate = args.train_slow_rate
        self.send_slow_rate = args.send_slow_rate

        self.psnr_max = 0
        

    def set_clients(self, args, clientObj):
        for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self.send_slow_clients):
            client = clientObj(args,
                            id=i, 
                            train_slow=train_slow,
                            send_slow=send_slow)
            self.clients.append(client)

    def select_slow_clients(self, slow_rate):
        slow_clients = [False for i in range(self.num_clients)]
        idx = [i for i in range(self.num_clients)]
        idx_ = np.random.choice(idx, int(slow_rate * self.num_clients))
        for i in idx_:
            slow_clients[i] = True

        return slow_clients

    def set_slow_clients(self):
        self.train_slow_clients = self.select_slow_clients(
            self.train_slow_rate)
        self.send_slow_clients = self.select_slow_clients(
            self.send_slow_rate)

    def select_clients(self):
        if self.random_join_ratio:
            join_clients = np.random.choice(range(self.join_clients, self.num_clients+1), 1, replace=False)[0]
        else:
            join_clients = self.join_clients
        selected_clients = list(np.random.choice(self.clients, join_clients, replace=False))

        return selected_clients


    def set_pretrained_client_backbones(self):

        for client_id, client in enumerate(self.clients):
            print('>>>serverbase, client_id=%d'%client_id)
            date_time = self.args.model_save_filename_clients[client_id]
            last_train_id = self.args.last_train_clients[client_id]
            client_model_path  = self.args.model_path.replace(self.args.model_path.split('/')[-1],'')

            checkpoint_path = './' + client_model_path + date_time + '/model_epoch_{}.pth'.format(last_train_id)
            model_checkpoint = torch.load(checkpoint_path)
            client.load_pretrain(model_checkpoint)
            print('---[CLIENT %d] Successfully load the pre-trained model!---' % client_id)


    def send_models(self):
        assert (len(self.clients) > 0)

        for client in self.clients:
            start_time = time.time()
            
            client.set_parameters(self.global_model)

            client.send_time_cost['num_rounds'] += 1
            client.send_time_cost['total_cost'] += 2 * (time.time() - start_time)

    def receive_models(self):
        assert (len(self.selected_clients) > 0)

        active_clients = random.sample(
            self.selected_clients, int((1-self.client_drop_rate) * self.join_clients))
        print('>>> debug selected active_clients !')

        self.uploaded_ids = []
        self.uploaded_weights = []
        self.uploaded_models = []
        tot_samples = 0
        for client in active_clients:
            try:
                client_time_cost = client.train_time_cost['total_cost'] / client.train_time_cost['num_rounds'] + \
                        client.send_time_cost['total_cost'] / client.send_time_cost['num_rounds']
            except ZeroDivisionError:
                client_time_cost = 0
            print('>>> debug for each client time_cost  !')

            if client_time_cost <= self.time_threthold:
                tot_samples += client.train_samples
                self.uploaded_ids.append(client.id)
                self.uploaded_weights.append(client.train_samples)
                self.uploaded_models.append(client.model)
        for i, w in enumerate(self.uploaded_weights):
            self.uploaded_weights[i] = w / tot_samples
        print('>>> finish receive model')

    def aggregate_parameters(self):
        assert (len(self.uploaded_models) > 0)

        self.global_model = copy.deepcopy(self.uploaded_models[0])
        for param in self.global_model.parameters():
            param.data.zero_()
            
        for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
            self.add_parameters(w, client_model)
        print('>>> finish aggregate_parameters')

    def add_parameters(self, w, client_model):
        for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
            server_param.data += client_param.data.clone() * w

    def evaluate(self, glob_iter):
        if self.args.PTP:

            if self.args.mask_op == 'fixed256':
                raise NotImplementedError

            elif self.args.mask_op == 'rand_crop':
                PSNR_c_union, SSIM_c_union = [], []
                psnr_all, ssim_all = [], []
                for c in self.clients:
                    (_, _, _, _, psnr_c_ls_u, ssim_c_ls_u, temp_psnr, temp_ssim) = test_Mtrials(args=self.args,
                                                                epoch=glob_iter,
                                                                model_path=self.args.model_path,
                                                                net=c.model,
                                                                test_data=self.args.test_data,
                                                                mask4d_ls=self.args.mask4d_ls,
                                                                mask_source='usr_union',
                                                                id=self.args.num_clients,
                                                                stay_log=False)
                    PSNR_c_union.append(psnr_c_ls_u)
                    SSIM_c_union.append(ssim_c_ls_u)
                    psnr_all.append(temp_psnr)
                    ssim_all.append(temp_ssim)

                print('>>>psnr_all.shape', np.array(psnr_all).shape)
                print('>>>ssim_all.shape', np.array(ssim_all).shape)
                psnr_all = np.array(psnr_all).reshape(-1, 10)
                ssim_all = np.array(ssim_all).reshape(-1, 10)
                print('>>>psnr_all.shape', psnr_all.shape)
                print('>>>ssim_all.shape', ssim_all.shape)
                allscene_psnr_mean = np.mean(psnr_all, axis=0)
                allscene_psnr_std = np.std(psnr_all, axis=0)
                allscene_ssim_mean = np.mean(ssim_all, axis=0)
                allscene_ssim_std = np.std(ssim_all, axis=0)

                print('\n>>>psnr mean=', allscene_psnr_mean)
                print('\n>>>psnr std=', allscene_psnr_std)
                print('\n>>>ssim mean=', allscene_ssim_mean)
                print('\n>>>ssim std=', allscene_ssim_std)


                psnr_mean_M = np.mean(np.array(PSNR_c_union))
                msg = '===>mask:usr_union, trials:{}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}'.format(
                    self.args.trial_num,
                    glob_iter,
                    psnr_mean_M,
                    np.std(np.array(PSNR_c_union)),
                    np.mean(np.array(SSIM_c_union)),
                    np.std(np.array(SSIM_c_union)))
                gen_log(model_path=self.args.model_path, msg=msg, user_id=self.args.num_clients)

        else:
            # will evaluate upon global_model in

            if self.args.mask_op == 'fixed256':
                raise NotImplementedError

            elif self.args.mask_op == 'rand_crop':
                # usr_union evaluation
                (psnr_mean_M, _, _, _, _, _) = test_Mtrials(args=self.args,
                                                    epoch=glob_iter,
                                                    model_path=self.args.model_path,
                                                    net=self.global_model,
                                                    test_data=self.args.test_data,
                                                    mask4d_ls=self.args.mask4d_ls,
                                                    mask_source='usr_union',
                                                    id=self.args.num_clients)



        return psnr_mean_M


    def checkpoint_global(self, glob_iter, psnr_mean_M):
        if psnr_mean_M > self.psnr_max:
            self.psnr_max = psnr_mean_M
            if psnr_mean_M > self.args.psnr_set:
                checkpoint(model=self.global_model,
                           epoch=glob_iter,
                           model_path=self.args.model_path,
                           id=self.args.num_clients,
                           prompt=self.args.MP)

    def checkpoint_global_clients(self, glob_iter, psnr_mean_M):
        if psnr_mean_M > self.psnr_max:
            self.psnr_max = psnr_mean_M
            if psnr_mean_M > self.args.psnr_set:
                checkpoint(model=self.global_model,
                           epoch=glob_iter,
                           model_path=self.args.model_path,
                           id=self.args.num_clients,
                           prompt=True)
                for c in self.clients:
                    checkpoint(model=c.backbone,
                               epoch=glob_iter,
                               model_path=self.args.model_path,
                               id=c.id,
                               prompt=False)




