import torch
import numpy as np
import time
from flcore.clients.clientbase import Client
from utils.utils_spectral_dataloader import *



class clientMP(Client):
    def __init__(self, args, id, **kwargs):
        super().__init__(args, id, **kwargs)

        self.args = args
        self.id = id
        self.iters = 0

        if args.optimizer == 'ADAM':
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=self.args.local_learning_rate)
        else:
            raise NotImplementedError
        self.learning_rate_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer=self.optimizer,
            milestones=args.milestones,
            gamma=args.learning_rate_decay_gamma
        )
        self.learning_rate_decay = args.learning_rate_decay

        if self.args.PTP:
            adaptor_params = []
            for name, param in self.backbone.named_parameters():
                if 'ada' in name:
                    adaptor_params += [param]
            if args.base_optimizer == 'ADAM':
                self.base_optimizer = torch.optim.Adam([{'params':adaptor_params}],
                                                  lr=self.args.base_local_learning_rate)

            else:
                raise NotImplementedError

        else:
            if args.base_optimizer == 'ADAM':
                self.base_optimizer = torch.optim.Adam(self.backbone.parameters(),
                                                  lr=self.args.base_local_learning_rate)
            else:
                raise NotImplementedError
        self.base_learning_rate_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer=self.base_optimizer,
            milestones=args.base_milestones,
            gamma=args.base_learning_rate_decay_gamma
        )
        self.base_learning_rate_decay = args.base_learning_rate_decay

    def train_test(self, global_iter, trn_mode):


        trainloader = self.load_train_data(batch_size=self.args.batch_size, epoch_sum_num=self.args.epoch_sum_num)
        self.model.train()

        if trn_mode == 'WARMUP+ALT':
            max_local_steps = self.args.local_steps_warmup + self.args.local_steps_B + self.args.local_steps_P
        elif trn_mode == 'ALT':
            max_local_steps = self.args.local_steps_B + self.args.local_steps_P
        else:
            raise ValueError

        if self.train_slow:
            max_local_steps = np.random.randint(1, max_local_steps // 2)

        epoch_loss_p, epoch_loss_b = [], []
        for iter in range(max_local_steps):

            if iter < max_local_steps - self.args.local_steps_P:
                if not self.args.CF:
                    if trn_mode == 'WARMUP+ALT':
                        epoch_loss = train_MPT_bnet_warmup(args=self.args,
                                                    ldr_train=trainloader,
                                                    mask4d=self.mask4d_ls[self.id],
                                                    optimizer=self.base_optimizer,
                                                    net=self.backbone,
                                                    loss_func=self.loss,
                                                    epoch_loss=epoch_loss_b,
                                                    local_iter=self.iters,
                                                    model_path=self.args.model_path,
                                                    train_slow=self.train_slow,
                                                    learning_rate_decay=self.base_learning_rate_decay,
                                                    learning_rate_scheduler=self.base_learning_rate_scheduler,
                                                    id=self.id,
                                                    glob_iter=None,
                                                    train_mode=trn_mode)
                    else:
                        epoch_loss = train_MPT_bnet(args=self.args,
                                                    ldr_train=trainloader,
                                                    mask4d=self.mask4d_ls[self.id],
                                                    optimizer=self.base_optimizer,
                                                    prompt_net=self.model,
                                                    net=self.backbone,
                                                    loss_func=self.loss,
                                                    epoch_loss=epoch_loss_b,
                                                    local_iter=self.iters,
                                                    model_path=self.args.model_path,
                                                    train_slow=self.train_slow,
                                                    learning_rate_decay=self.base_learning_rate_decay,
                                                    learning_rate_scheduler=self.base_learning_rate_scheduler,
                                                    id=self.id,
                                                    glob_iter=None,
                                                    train_mode=trn_mode)

            else:
                epoch_loss = train_MPT_pnet(args=self.args,
                                          ldr_train=trainloader,
                                          mask4d=self.mask4d_ls[self.id],
                                          optimizer=self.optimizer,
                                          prompt_net=self.model,
                                          net=self.backbone,
                                          loss_func=self.loss,
                                          epoch_loss=epoch_loss_p,
                                          local_iter=self.iters,
                                          model_path=self.args.model_path,
                                          train_slow=self.train_slow,
                                          learning_rate_decay=self.learning_rate_decay,
                                          learning_rate_scheduler=self.learning_rate_scheduler,
                                          id=self.id,
                                          glob_iter=None,
                                          train_mode=trn_mode)

            if self.args.mask_op == 'fixed256':
                (_, _, _, _, _, _) = test_MPT(args=self.args,
                                                  epoch=self.iters,
                                                  model_path=self.args.model_path,
                                                  prompt_net=self.model,
                                                  net=self.backbone,
                                                  test_data=self.test_data,
                                                  mask4d_cube=self.mask4d_ls[self.id],
                                                  id=self.id)
                (_, _, _, _, _, _) = test_Mtrials_MPT(args=self.args,
                                                    epoch=self.iters,
                                                    model_path=self.args.model_path,
                                                    prompt_net=self.model,
                                                    net=self.backbone,
                                                    test_data=self.test_data,
                                                    mask4d_ls=self.mask4d_ls,
                                                    mask_source='usr_union',
                                                    id=self.id)


            elif self.args.mask_op == 'rand_crop':
                (_, _, _, _, _, _) = test_Mtrials_MPT(args=self.args,
                                                    epoch=self.iters,
                                                    model_path=self.args.model_path,
                                                    prompt_net=self.model,
                                                    net=self.backbone,
                                                    test_data=self.test_data,
                                                    mask4d_ls=self.mask4d_ls,
                                                    mask_source='assign_usr',
                                                    id=self.id)
                (_, _, _, _, _, _) = test_Mtrials_MPT(args=self.args,
                                            epoch=self.iters,
                                            model_path=self.args.model_path,
                                            prompt_net=self.model,
                                            net=self.backbone,
                                            test_data=self.test_data,
                                            mask4d_ls=self.mask4d_ls,
                                            mask_source='usr_union',
                                            id=self.id)
            self.iters += 1

        return self.model.state_dict()






