import scipy.io as sio
import os
import time
import numpy as np
import math
import torch 
import logging
from utils.ssim_torch import ssim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn as nn
import copy
from fvcore.nn import FlopCountAnalysis
import datetime


patch_size = 256


def global_model_init(model, param_init):
    for name, params in model.named_parameters():
        if param_init == 'default':
            return
        elif param_init == 'uniform':
            nn.init.uniform_(params, -0.1, 0.1)
        elif param_init == 'zero':
            nn.init.constant_(params, 0.)
        elif param_init == 'normal':
            nn.init.normal_(params, 0.0, 0.01)
        elif param_init == 'xavier_uniform':
            if len(params.shape) >= 2:
                nn.init.xavier_uniform_(params)
        elif param_init == 'kaiming_uniform':
            nn.init.kaiming_uniform_(params)
        elif param_init == 'kaiming_normal':
            nn.init.kaiming_normal_(params)
        else:
            raise NotImplementedError

def generate_masks(args):

    mask_ls = []

    for id in args.mask_ids:
        mask_dict = sio.loadmat(args.mask_path + '/mask_{:d}/'.format(id)+ 'mask.mat')
        mask = mask_dict[list(mask_dict)[-1]]
        print(list(mask_dict)[-1])
        mask4d = np.tile(mask[np.newaxis, np.newaxis:,:],(1,28,1,1))
        print('>>>generate_masks: mask4d.shape=',mask4d.shape)
        mask_ls.append(mask4d)

    return  mask_ls



def LoadTraining(path):
    imgs = []
    scene_list = os.listdir(path)
    scene_list.sort()
    print('training sences:', len(scene_list))
    max_ = 0
    for i in range(len(scene_list)):
        print('start for=',i)
        scene_path = path + scene_list[i]
        if 'mat' not in scene_path:
            continue
        img_dict = sio.loadmat(scene_path)
        if "img_expand" in img_dict:
            img = img_dict['img_expand']/65536.
        elif "img" in img_dict:
            img = img_dict['img']/65536.
        img = img.astype(np.float32)
        # print('img.max,', img.max, '||img.min', img.min)
        imgs.append(img)
        print('Sence {} is loaded. {}'.format(i, scene_list[i]))

    return imgs




def LoadTest(path_test):
    scene_list = os.listdir(path_test)
    scene_list.sort()
    test_data = np.zeros((len(scene_list), patch_size, patch_size, 28))
    for i in range(len(scene_list)):
        scene_path = path_test + scene_list[i]
        img = sio.loadmat(scene_path)['img']
        #img = img/img.max()
        test_data[i,:,:,:] = img
        print(i, img.shape, img.max(), img.min())
    test_data = torch.from_numpy(np.transpose(test_data, (0, 3, 1, 2)))
    return test_data



def time2file_name(time):
    year = time[0:4]
    month = time[5:7]
    day = time[8:10]
    hour = time[11:13]
    minute = time[14:16]
    second = time[17:19]
    time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_' + second
    return time_filename



def shuffle_crop(batch_data, argument=False):
    if argument:
        raise NotImplementedError

    else:

        bs, h, w, _ = batch_data.shape
        # print('>>>>>>>>>>>>Verify, batch_data.shape', batch_data.shape) # 4, 1024, 1024, 28
        x_index = np.random.randint(0, h - patch_size)
        y_index = np.random.randint(0, w - patch_size)
        batch_data = batch_data[:, x_index:x_index + patch_size, y_index:y_index + patch_size, :]
        gt_batch = batch_data.permute(0, 3, 1, 2)
    return gt_batch

def arguement_1(x):
    """
    :param x: c,h,w
    :return: c,h,w
    """
    rotTimes = np.random.randint(0, 3)
    vFlip = np.random.randint(0, 1)
    hFlip = np.random.randint(0, 1)
    # Random rotation
    for j in range(rotTimes):
        x = torch.rot90(x, dims=(1, 2))
    # Random vertical Flip
    for j in range(vFlip):
        x = torch.flip(x, dims=(2,))
    # Random horizontal Flip
    for j in range(hFlip):
        x = torch.flip(x, dims=(1,))
    return x

def arguement_2(generate_gt):
    c, h, w = generate_gt.shape[1],256,256
    divid_point_h = 128
    divid_point_w = 128
    output_img = torch.zeros(c,h,w).cuda()
    output_img[:, :divid_point_h, :divid_point_w] = generate_gt[0]
    output_img[:, :divid_point_h, divid_point_w:] = generate_gt[1]
    output_img[:, divid_point_h:, :divid_point_w] = generate_gt[2]
    output_img[:, divid_point_h:, divid_point_w:] = generate_gt[3]
    return output_img


def gen_meas_gt(args, data_batch, mask4d_batch):

    data_batch = Variable(data_batch).cuda().float()

    nC = data_batch.shape[1]

    temp = shift(mask4d_batch*data_batch, 2)
    meas = torch.sum(temp, 1)/nC*2
    y_temp = shift_back(meas)
    PhiTy = torch.mul(y_temp, mask4d_batch)
    if args.meas_init == 'meas*mask':
        return PhiTy, data_batch
    elif args.meas_init == 'meas':
        return y_temp, data_batch



def shift(inputs, step=2):
    [bs, nC, row, col] = inputs.shape
    output = torch.zeros(bs, nC, row, col+(nC-1)*step).cuda().float()
    for i in range(nC):
        output[:,i,:,step*i:step*i+col] = inputs[:,i,:,:]
    return output

def shift_back(inputs,step=2):          # input [bs,patch_size,310]  output [bs, 28, patch_size, patch_size]
    [bs, row, col] = inputs.shape
    nC = 28
    output = torch.zeros(bs, nC, row, col-(nC-1)*step).cuda().float()
    for i in range(nC):
        output[:,i,:,:] = inputs[:,:,step*i:step*i+col-(nC-1)*step]
    return output

def init_mask(mask_path, mask_type, batch_size):
    mask3d_batch = generate_masks(mask_path, batch_size)
    if mask_type == 'Phi':
        shift_mask3d_batch = shift(mask3d_batch)
        input_mask = shift_mask3d_batch
    elif mask_type == 'Phi_PhiPhiT':
        Phi_batch, Phi_s_batch = generate_shift_masks(mask_path, batch_size)
        input_mask = (Phi_batch, Phi_s_batch)
    elif mask_type == 'Mask':
        input_mask = mask3d_batch
    elif mask_type == None:
        input_mask = None
    return mask3d_batch, input_mask

def generate_shift_masks(mask_path, batch_size):
    mask = sio.loadmat(mask_path + '/mask_3d_shift.mat')
    mask_3d_shift = mask['mask_3d_shift']
    mask_3d_shift = np.transpose(mask_3d_shift, [2, 0, 1])
    mask_3d_shift = torch.from_numpy(mask_3d_shift)
    [nC, H, W] = mask_3d_shift.shape
    Phi_batch = mask_3d_shift.expand([batch_size, nC, H, W]).cuda().float()
    Phi_s_batch = torch.sum(Phi_batch**2,1)
    Phi_s_batch[Phi_s_batch==0] = 1
    return Phi_batch, Phi_s_batch



def shuffle_crop_mask(args, mask4d_batch):
    batch_size, nC, _, _ = mask4d_batch.shape
    mask4d_batchcrop = np.zeros((batch_size, nC, args.patch_size, args.patch_size), dtype=np.float32)

    for i in range(batch_size):
        _, h, w = mask4d_batch[i].shape
        x_index = np.random.randint(0, h - args.patch_size)
        y_index = np.random.randint(0, w - args.patch_size)
        mask4d_batchcrop[i, :, :, :] = mask4d_batch[i][:,x_index:x_index + args.patch_size, y_index:y_index + args.patch_size]
    return mask4d_batchcrop



def data_initialization(args, data, mask4d, use_batch_size_trntst=False):

    if use_batch_size_trntst:
        mask4d_batch = np.tile(mask4d, (args.batch_size_trntst, 1, 1, 1))
    else:
        mask4d_batch = np.tile(mask4d, (args.batch_size, 1,1,1))

    gt = shuffle_crop(data)

    if args.mask_op == 'rand_crop':
        assert (mask4d_batch.shape[-1] > 256) and (mask4d_batch.shape[-2] > 256), 'ERROR: mask is not large enough'
        mask4d_batch = shuffle_crop_mask(args, mask4d_batch)

        mask4d_batch = torch.from_numpy(mask4d_batch).cuda().float()
        mask4d_batch_shift = shift(mask4d_batch)

    elif args.mask_op == 'fixed256':
        assert (mask4d_batch.shape[-1] == 256) and (mask4d_batch.shape[-2] == 256), 'ERROR: mask should be 256x256'

        mask4d_batch = torch.from_numpy(mask4d_batch).cuda().float()
        mask4d_batch_shift = shift(mask4d_batch)
    else:
        raise NotImplementedError

    PhiTy_cuda, gt_cuda = gen_meas_gt(args, gt, mask4d_batch)
    return PhiTy_cuda, gt_cuda, mask4d_batch_shift, mask4d_batch


def train(args,
          ldr_train,
          mask4d,
          optimizer,
          net,
          loss_func,
          epoch_loss,
          local_iter,
          model_path,
          train_slow,
          learning_rate_decay,
          learning_rate_scheduler,
          id=None,
          glob_iter=None):
    '''
    conduct a single epoch training
    '''

    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, _ = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        model_out = net(y, m_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))
        optimizer.step()
        batch_loss.append(loss.item())

    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))

    if glob_iter is None:

        msg = "===> Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===> User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def train_prox(args,
          ldr_train,
          mask4d,
          optimizer,
          net,
          global_params,
          device,
          loss_func,
          epoch_loss,
          local_iter,
          model_path,
          train_slow,
          learning_rate_decay,
          learning_rate_scheduler,
          id=None,
          glob_iter=None):
    '''
    conduct a single epoch training
    '''

    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, _ = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        model_out = net(y, m_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))
        optimizer.step(global_params=global_params, device=device)
        batch_loss.append(loss.item())

    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))


    if glob_iter is None:

        msg = "===> Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===> User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def train_scaffold(args,
          ldr_train,
          mask4d,
          optimizer,
          net,
          loss_func,
          epoch_loss,
          local_iter,
          model_path,
          train_slow,
          learning_rate_decay,
          learning_rate_scheduler,
          global_c,
          client_c,
          id=None,
          glob_iter=None):
    '''
    conduct a single epoch training
    '''

    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, _ = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        model_out = net(y, m_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))
        optimizer.step(global_c, client_c)
        batch_loss.append(loss.item())

    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))


    if glob_iter is None:

        msg = "===> Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===> User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss



def prompt_train(args,
              ldr_train,
              mask4d,
              optimizer,
              prompt_net,
              net,
              loss_func,
              epoch_loss,
              local_iter,
              model_path,
              train_slow,
              learning_rate_decay,
              learning_rate_scheduler,
              id=None,
              glob_iter=None):
    '''
    conduct a single epoch training for prompt_model
    '''
    prompt_net.train()
    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, m_4d = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        y_yp = prompt_net(m_4d, y)
        model_out = net(y_yp, m_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))


        optimizer.step()
        batch_loss.append(loss.item())


    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))


    if glob_iter is None:

        msg = "===> [Prompt model] Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===> [Prompt model] User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def align_train(args,
              ldr_train,
              mask4d,
              optimizer,
              prompt_net,
              net,
              loss_func,
              epoch_loss,
              local_iter,
              model_path,
              train_slow,
              learning_rate_decay,
              learning_rate_scheduler,
              id=None,
              glob_iter=None):
    '''
    conduct a single epoch training for prompt_model
    '''
    prompt_net.train()
    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, m_4d = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        gt_enc = gt * m_4d
        pseudo_gt = gt_enc + prompt_net(m_4d)
        if args.KL_softmax:
            loss = loss_func(pseudo_gt.softmax(dim=-1).log(), gt.softmax(dim=-1))
        else:
            loss = loss_func(pseudo_gt.log(), gt)

        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))

        optimizer.step()
        batch_loss.append(loss.item())


    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))



    if glob_iter is None:

        msg = "===> [Prompt model] Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===> [Prompt model] User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def train_MPT_pnet(args,
              ldr_train,
              mask4d,
              optimizer,
              prompt_net,
              net,
              loss_func,
              epoch_loss,
              local_iter,
              model_path,
              train_slow,
              learning_rate_decay,
              learning_rate_scheduler,
              train_mode,
              id=None,
              glob_iter=None):
    '''
    conduct a single epoch training for prompt_model
    '''
    prompt_net.train()
    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):

        y, gt, m_shift, m_4d = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        m_4d_t = prompt_net(m_4d)
        m_4d_t_shift = shift(m_4d_t)
        y = y + m_4d_t * args.align_intensity
        model_out = net(y, m_4d_t_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))

        optimizer.step()
        batch_loss.append(loss.item())


    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))


    if glob_iter is None:

        msg = "===>Mode {} || [Prompt model] Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===>Mode {} || [Prompt model] User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def train_MPT_bnet(args,
              ldr_train,
              mask4d,
              optimizer,
              prompt_net,
              net,
              loss_func,
              epoch_loss,
              local_iter,
              model_path,
              train_slow,
              learning_rate_decay,
              learning_rate_scheduler,
              train_mode,
              id=None,
              glob_iter=None):
    '''
    conduct a single epoch training for prompt_model
    '''
    net.train()
    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):

        y, gt, m_shift, m_4d = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        m_4d_t = prompt_net(m_4d)
        m_4d_t_shift = shift(m_4d_t)
        model_out = net(y, m_4d_t_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))

        optimizer.step()
        batch_loss.append(loss.item())


    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))

    if glob_iter is None:

        msg = "===>Mode {} | [Backbone model] Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===>Mode {} | [Backbone model] User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss


def train_MPT_bnet_warmup(args,
              ldr_train,
              mask4d,
              optimizer,
              net,
              loss_func,
              epoch_loss,
              local_iter,
              model_path,
              train_slow,
              learning_rate_decay,
              learning_rate_scheduler,
              train_mode,
              id=None,
              glob_iter=None):
    '''
    conduct a single epoch training for prompt_model
    '''
    net.train()
    begin = time.time()
    batch_loss = []
    for batch_id, data in enumerate(ldr_train):


        y, gt, m_shift, m_4d = data_initialization(args, data, mask4d)
        optimizer.zero_grad()
        model_out = net(y, m_shift)
        loss = torch.sqrt(loss_func(model_out, gt))
        loss.backward()
        if train_slow:
            time.sleep(0.1 * np.abs(np.random.rand()))

        optimizer.step()
        batch_loss.append(loss.item())


    # learning rate decay
    if learning_rate_decay:
        learning_rate_scheduler.step()

    end = time.time()
    epoch_loss.append(sum(batch_loss) / len(batch_loss))

    if glob_iter is None:

        msg = "===>Mode {} | [Backbone model] Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,local_iter,
                                                                                   sum(batch_loss) / len(
                                                                                       batch_loss),
                                                                                   (end - begin) / 60.)


    else:
        msg = "===>Mode {} | [Backbone model] User: {} Epoch {} Complete: Avg. Loss: {:.6f} time: {:.5f}mins".format(train_mode,id, (local_iter + (glob_iter - 1) * args.local_steps),
                                                                                                           sum(batch_loss) / len(batch_loss),
                                                                                                            (end - begin) / 60.)
    gen_log(model_path=model_path, msg=msg, user_id=id)

    return epoch_loss



def psnr(img1, img2):
    psnr_list = []
    for i in range(img1.shape[0]):
        total_psnr = 0
        PIXEL_MAX = img2[i,:,:,:].max()
        for ch in range(28):
            mse = np.mean((img1[i,:,:,ch] - img2[i,:,:,ch])**2)
            total_psnr += 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
        psnr_list.append(total_psnr/img1.shape[3])
    return psnr_list


def torch_psnr(img, ref):  # input [28,256,256]
    img = (img*256).round()
    ref = (ref*256).round()
    nC = img.shape[0]
    psnr = 0
    for i in range(nC):
        mse = torch.mean((img[i, :, :] - ref[i, :, :]) ** 2)
        psnr += 10 * torch.log10((255*255)/mse)
    return psnr / nC

def torch_ssim(img, ref):   #input [28,patch_size,patch_size]
    return ssim(torch.unsqueeze(img,0), torch.unsqueeze(ref,0))

def _as_floats(im1, im2):
    float_type = np.result_type(im1.dtype, im2.dtype, np.float32)
    im1 = np.asarray(im1, dtype=float_type)
    im2 = np.asarray(im2, dtype=float_type)
    return im1, im2

def compare_mse(im1, im2):
    im1, im2 = _as_floats(im1, im2)
    return np.mean(np.square(im1 - im2), dtype=np.float64)

def compare_psnr(im_true, im_test, data_range=None):
    im_true, im_test = _as_floats(im_true, im_test)
    err = compare_mse(im_true, im_test)
    return 10 * np.log10((data_range ** 2) / err)


def test(args,
         epoch,
         model_path,
         net,
         test_data,
         mask4d_cube,
         stay_log=True,
         id=None):
    psnr_list_tsa, psnr_list_gsm, ssim_list = [], [], []

    mask4d_cube = torch.from_numpy(mask4d_cube).cuda().float()

    PhiTy_cuda, gt_cuda = gen_meas_gt(args, test_data, mask4d_cube)

    mask4d_cube_shift = shift(mask4d_cube)
    _, nC, Hm, Wm = mask4d_cube_shift.shape
    mask4d_cube_shift=mask4d_cube_shift.expand([gt_cuda.shape[0], nC, Hm, Wm])

    net.eval()
    begin = time.time()
    with torch.no_grad():
        model_out = net(PhiTy_cuda, mask4d_cube_shift)
    end = time.time()
    for k in range(gt_cuda.shape[0]):
        psnr_val_tsa = torch_psnr(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        psnr_list_tsa.append(psnr_val_tsa.detach().cpu().numpy())
        psnr_val_gsm = compare_psnr(gt_cuda[k, :, :, :].cpu().numpy(), model_out[k, :, :, :].cpu().numpy(),data_range=1.0)
        psnr_list_gsm.append(psnr_val_gsm)

        ssim_val = torch_ssim(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(gt_cuda.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean_tsa = np.mean(np.asarray(psnr_list_tsa))
    psnr_mean_gsm = np.mean(np.asarray(psnr_list_gsm))

    ssim_mean = np.mean(np.asarray(ssim_list))

    if stay_log:
        msg = ('===> User: {} Epoch {}: testing psnr = {:.5f}(tsa)/{:.5f}(gsm), ssim = {:.5f}, time: {:.5f}mins'.format(
                id,
                epoch,
                psnr_mean_tsa,
                psnr_mean_gsm,
                ssim_mean,
                (end - begin) / 60.))
        gen_log(model_path=model_path, msg=msg, user_id=id)

    net.train()

    return (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean)

def test_wprompt(args,
         epoch,
         model_path,
         net,
         prompt_net,
         test_data,
         mask4d_cube,
         stay_log=True,
         id=None):
    psnr_list_tsa, psnr_list_gsm, ssim_list = [], [], []

    mask4d_cube = torch.from_numpy(mask4d_cube).cuda().float()

    PhiTy_cuda, gt_cuda = gen_meas_gt(args, test_data, mask4d_cube)

    mask4d_cube_shift = shift(mask4d_cube)
    mask4d_in = mask4d_cube.expand([gt_cuda.shape[0], mask4d_cube.shape[1], mask4d_cube.shape[2], mask4d_cube.shape[3]])
    _, nC, Hm, Wm = mask4d_cube_shift.shape
    mask4d_cube_shift=mask4d_cube_shift.expand([gt_cuda.shape[0], nC, Hm, Wm])

    prompt_net.eval()
    begin = time.time()
    with torch.no_grad():
        PhiTy_p = prompt_net(mask4d_in, PhiTy_cuda)
        model_out = net(PhiTy_p, mask4d_cube_shift)
    end = time.time()
    for k in range(gt_cuda.shape[0]):
        psnr_val_tsa = torch_psnr(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        psnr_list_tsa.append(psnr_val_tsa.detach().cpu().numpy())
        psnr_val_gsm = compare_psnr(gt_cuda[k, :, :, :].cpu().numpy(), model_out[k, :, :, :].cpu().numpy(),data_range=1.0)
        psnr_list_gsm.append(psnr_val_gsm)

        ssim_val = torch_ssim(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(gt_cuda.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean_tsa = np.mean(np.asarray(psnr_list_tsa))
    psnr_mean_gsm = np.mean(np.asarray(psnr_list_gsm))

    ssim_mean = np.mean(np.asarray(ssim_list))

    if stay_log:
        msg = ('===> [w/ prompt] User: {} Epoch {}: testing psnr = {:.5f}(tsa)/{:.5f}(gsm), ssim = {:.5f}, time: {:.5f}mins'.format(
                id,
                epoch,
                psnr_mean_tsa,
                psnr_mean_gsm,
                ssim_mean,
                (end - begin) / 60.))
        gen_log(model_path=model_path, msg=msg, user_id=id)

    prompt_net.train()

    return (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean)

def align_test(args,
         epoch,
         model_path,
         net,
         prompt_net,
         test_data,
         mask4d_cube,
         stay_log=True,
         id=None):
    psnr_list_tsa, psnr_list_gsm, ssim_list = [], [], []

    mask4d_cube = torch.from_numpy(mask4d_cube).cuda().float()

    PhiTy_cuda, gt_cuda = gen_meas_gt(args, test_data, mask4d_cube)

    mask4d_cube_shift = shift(mask4d_cube)
    mask4d_in = mask4d_cube.expand([gt_cuda.shape[0], mask4d_cube.shape[1], mask4d_cube.shape[2], mask4d_cube.shape[3]])
    _, nC, Hm, Wm = mask4d_cube_shift.shape
    mask4d_cube_shift=mask4d_cube_shift.expand([gt_cuda.shape[0], nC, Hm, Wm])

    prompt_net.eval()
    begin = time.time()
    with torch.no_grad():
        PhiTy_p = prompt_net(mask4d_in, PhiTy_cuda)
        model_out = net(PhiTy_p, mask4d_cube_shift)
    end = time.time()
    for k in range(gt_cuda.shape[0]):
        psnr_val_tsa = torch_psnr(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        psnr_list_tsa.append(psnr_val_tsa.detach().cpu().numpy())
        psnr_val_gsm = compare_psnr(gt_cuda[k, :, :, :].cpu().numpy(), model_out[k, :, :, :].cpu().numpy(),data_range=1.0)
        psnr_list_gsm.append(psnr_val_gsm)

        ssim_val = torch_ssim(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(gt_cuda.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean_tsa = np.mean(np.asarray(psnr_list_tsa))
    psnr_mean_gsm = np.mean(np.asarray(psnr_list_gsm))

    ssim_mean = np.mean(np.asarray(ssim_list))

    if stay_log:
        msg = ('===> [w/ prompt] User: {} Epoch {}: testing psnr = {:.5f}(tsa)/{:.5f}(gsm), ssim = {:.5f}, time: {:.5f}mins'.format(
                id,
                epoch,
                psnr_mean_tsa,
                psnr_mean_gsm,
                ssim_mean,
                (end - begin) / 60.))
        gen_log(model_path=model_path, msg=msg, user_id=id)

    prompt_net.train()

    return (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean)


def test_MPT(args,
         epoch,
         model_path,
         net,
         prompt_net,
         test_data,
         mask4d_cube,
         stay_log=True,
         id=None):
    psnr_list_tsa, psnr_list_gsm, ssim_list = [], [], []

    mask4d_cube = torch.from_numpy(mask4d_cube).cuda().float()

    PhiTy_cuda, gt_cuda = gen_meas_gt(args, test_data, mask4d_cube)

    mask4d_cube_shift = shift(mask4d_cube)
    mask4d_in = mask4d_cube.expand([gt_cuda.shape[0], mask4d_cube.shape[1], mask4d_cube.shape[2], mask4d_cube.shape[3]])


    prompt_net.eval()
    net.eval()
    begin = time.time()
    with torch.no_grad():
        mask4d_in_t = prompt_net(mask4d_in)
        mask4d_in_t_shift = shift(mask4d_in_t)
        model_out = net(PhiTy_cuda, mask4d_in_t_shift)
    end = time.time()
    for k in range(gt_cuda.shape[0]):
        psnr_val_tsa = torch_psnr(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        psnr_list_tsa.append(psnr_val_tsa.detach().cpu().numpy())
        psnr_val_gsm = compare_psnr(gt_cuda[k, :, :, :].cpu().numpy(), model_out[k, :, :, :].cpu().numpy(),data_range=1.0)
        psnr_list_gsm.append(psnr_val_gsm)

        ssim_val = torch_ssim(model_out[k, :, :, :], gt_cuda[k, :, :, :])
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(gt_cuda.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean_tsa = np.mean(np.asarray(psnr_list_tsa))
    psnr_mean_gsm = np.mean(np.asarray(psnr_list_gsm))

    ssim_mean = np.mean(np.asarray(ssim_list))

    if stay_log:
        msg = ('===> [w/ prompt] User: {} Epoch {}: testing psnr = {:.5f}(tsa)/{:.5f}(gsm), ssim = {:.5f}, time: {:.5f}mins'.format(
                id,
                epoch,
                psnr_mean_tsa,
                psnr_mean_gsm,
                ssim_mean,
                (end - begin) / 60.))
        gen_log(model_path=model_path, msg=msg, user_id=id)

    return (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean)



def mask_determine(args, mask4d_ls, id, mask_source=None):

    if mask_source is None:

        if args.mask_op == 'fixed256':
            assert (mask4d_ls.shape[-1] == 256) and (mask4d_ls.shape[-2] == 256), 'ERROR: mask should be 256x256'
            mask4d_cube = mask4d_ls
        else: #  args.mask_op == 'rand_crop'
            assert (mask4d_ls.shape[-1] > 256) and (mask4d_ls.shape[-2] > 256), 'ERROR: mask is not large enough'
            mask4d_cube = shuffle_crop_mask(args, mask4d_ls)

    else:
        if args.mask_op == 'fixed256':

            if mask_source == 'assign_usr':
                mask4d_cube = mask4d_ls[id]
                assert (mask4d_cube.shape[-1] == 256) and (mask4d_cube.shape[-2] == 256), 'ERROR: mask should be 256x256'
            elif  mask_source == 'usr_union':
                mask4d_cube = mask4d_ls[np.random.choice(args.num_clients)]
                assert (mask4d_cube.shape[-1] == 256) and (mask4d_cube.shape[-2] == 256), 'ERROR: mask should be 256x256'
            else:
                print('ERROR: invalid mask_source!')
                raise ValueError
        else: #  args.mask_op == 'rand_crop'
            if mask_source == 'assign_usr':
                assert (mask4d_ls[id].shape[-1] > 256) and (mask4d_ls[id].shape[-2] > 256), 'ERROR: mask is not large enough'
                mask4d_cube = shuffle_crop_mask(args, mask4d_ls[id])
            elif mask_source == 'usr_union':
                mask_chose = mask4d_ls[np.random.choice(args.num_clients)]
                assert (mask_chose.shape[-1] > 256) and (mask_chose.shape[-2] > 256), 'ERROR: mask is not large enough'
                mask4d_cube = shuffle_crop_mask(args, mask_chose)
            else:
                print('ERROR: invalid mask_source!')
                raise ValueError

    return mask4d_cube


def test_Mtrials(args,
         epoch,
         model_path,
         net,
         test_data,
         mask4d_ls,
         mask_source,
         stay_log=True,
         for_client=False,
         id=None):

    psnr_trials, ssim_trials = [], []
    psnr_allscene_trials = []
    ssim_allscene_trials = []
    trials = args.trial_num if not for_client else max(args.trial_num // args.num_clients, 10)
    begin=time.time()
    for trial in range(trials):

        mask4d_cube = mask_determine(args=args,
                                     mask4d_ls=mask4d_ls,
                                     mask_source=mask_source,
                                     id=id)

        (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean) = test(args=args,
                                                                                 epoch=epoch,
                                                                                 model_path=model_path,
                                                                                 net=net,
                                                                                 test_data=test_data,
                                                                                 mask4d_cube=mask4d_cube,
                                                                                 stay_log=False)
        psnr_allscene_trials.append(psnr_list_tsa)
        ssim_allscene_trials.append(ssim_list)

        if trial==0:
            print('pred shape',pred.shape)
            sio.savemat(model_path + '/epo{}.mat'.format(args.last_train), {'pred': pred})

        psnr_10mean = np.mean(np.asarray(psnr_list_tsa))
        ssim_10mean = np.mean(np.asarray(ssim_list))
        psnr_trials.append(psnr_10mean)
        ssim_trials.append(ssim_10mean)

    allscene_psnr_mean = np.mean(np.array(psnr_allscene_trials), axis=0)
    allscene_psnr_std  = np.std(np.array(psnr_allscene_trials), axis=0)
    allscene_ssim_mean = np.mean(np.array(ssim_allscene_trials), axis=0)
    allscene_ssim_std  = np.std(np.array(ssim_allscene_trials), axis=0)

    print('\n>>>psnr mean=', allscene_psnr_mean)
    print('\n>>>psnr std=', allscene_psnr_std)
    print('\n>>>ssim mean=', allscene_ssim_mean)
    print('\n>>>ssim std=', allscene_ssim_std)

    psnr_ave = np.mean(psnr_trials)
    psnr_std = np.std(psnr_trials)
    ssim_ave = np.mean(ssim_trials)
    ssim_std = np.std(ssim_trials)
    end=time.time()

    if stay_log:
        msg = '===>mask:{}, trials:{},  User: {}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}, time: {:.5f}mins'.format(
            mask_source,
            args.trial_num,
            id,
            epoch,
            psnr_ave,
            psnr_std,
            ssim_ave,
            ssim_std,
            (end - begin) / 60.)
        gen_log(model_path=model_path, msg=msg, user_id=id)

    return psnr_ave, ssim_ave, psnr_std, ssim_std, psnr_trials, ssim_trials


def test_Mtrials_wprompt(args,
         epoch,
         model_path,
         net,
         prompt_net,
         test_data,
         mask4d_ls,
         mask_source,
         stay_log=True,
         for_client=False,
         id=None):

    psnr_trials, ssim_trials = [], []
    trials = args.trial_num if not for_client else max(args.trial_num // args.num_clients, 10)
    begin=time.time()
    for trial in range(trials):

        mask4d_cube = mask_determine(args=args,
                                     mask4d_ls=mask4d_ls,
                                     mask_source=mask_source,
                                     id=id)

        (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean) = test_wprompt(args=args,
                                                                                 epoch=epoch,
                                                                                 model_path=model_path,
                                                                                 net=net,
                                                                                 prompt_net=prompt_net,
                                                                                 test_data=test_data,
                                                                                 mask4d_cube=mask4d_cube,
                                                                                 stay_log=False)
        psnr_10mean = np.mean(np.asarray(psnr_list_tsa))
        ssim_10mean = np.mean(np.asarray(ssim_list))
        psnr_trials.append(psnr_10mean)
        ssim_trials.append(ssim_10mean)
    psnr_ave = np.mean(psnr_trials)
    psnr_std = np.std(psnr_trials)
    ssim_ave = np.mean(ssim_trials)
    ssim_std = np.std(ssim_trials)
    end=time.time()

    if stay_log:
        msg = '===>[w/ prompt] mask:{}, trials:{},  User: {}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}, time: {:.5f}mins'.format(
            mask_source,
            args.trial_num,
            id,
            epoch,
            psnr_ave,
            psnr_std,
            ssim_ave,
            ssim_std,
            (end - begin) / 60.)
        gen_log(model_path=model_path, msg=msg, user_id=id)

    return psnr_ave, ssim_ave, psnr_std, ssim_std, psnr_trials, ssim_trials


def test_Mtrials_MPT(args,
         epoch,
         model_path,
         net,
         prompt_net,
         test_data,
         mask4d_ls,
         mask_source,
         stay_log=True,
         for_client=False,
         id=None):

    psnr_trials, ssim_trials = [], []
    trials = args.trial_num if not for_client else max(args.trial_num // args.num_clients, 10)
    begin=time.time()
    for trial in range(trials):

        mask4d_cube = mask_determine(args=args,
                                     mask4d_ls=mask4d_ls,
                                     mask_source=mask_source,
                                     id=id)

        (pred, truth, psnr_list_tsa, ssim_list, psnr_mean_tsa, ssim_mean) = test_MPT(args=args,
                                                                                 epoch=epoch,
                                                                                 model_path=model_path,
                                                                                 net=net,
                                                                                 prompt_net=prompt_net,
                                                                                 test_data=test_data,
                                                                                 mask4d_cube=mask4d_cube,
                                                                                 stay_log=False)
        psnr_10mean = np.mean(np.asarray(psnr_list_tsa))
        ssim_10mean = np.mean(np.asarray(ssim_list))
        psnr_trials.append(psnr_10mean)
        ssim_trials.append(ssim_10mean)
    psnr_ave = np.mean(psnr_trials)
    psnr_std = np.std(psnr_trials)
    ssim_ave = np.mean(ssim_trials)
    ssim_std = np.std(ssim_trials)
    end=time.time()

    if stay_log:
        msg = '===>[w/ prompt] mask:{}, trials:{},  User: {}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}, time: {:.5f}mins'.format(
            mask_source,
            args.trial_num,
            id,
            epoch,
            psnr_ave,
            psnr_std,
            ssim_ave,
            ssim_std,
            (end - begin) / 60.)
        gen_log(model_path=model_path, msg=msg, user_id=id)

    return psnr_ave, ssim_ave, psnr_std, ssim_std, psnr_trials, ssim_trials


def checkpoint(model, epoch, model_path, id=None, prompt=False):
    '''
    only operates on the global model
    '''
    if prompt:
        model_out_path = './' + model_path + "/prompt_model_epoch_{}.pth".format(epoch)
    else:
        time.sleep(1)
        date_time = str(datetime.datetime.now())
        date_time = time2file_name(date_time)
        model_out_path = './' + model_path + '/' + date_time
        if not os.path.exists(model_out_path):
            os.makedirs(model_out_path)
        model_out_path = model_out_path + "/model_epoch_{}.pth".format(epoch)
    torch.save({"model_weights": model.state_dict()}, model_out_path)
    msg = "Global_iter={}, ID={}, Checkpoint saved to {}".format(epoch, id, model_out_path)
    gen_log(model_path=model_path, msg=msg, user_id=id)



def gen_log(model_path, msg, user_id=None):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s")

    log_file = model_path + '/log_user{}.txt'.format(user_id)
    fh = logging.FileHandler(log_file, mode='a')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)

    logger.addHandler(fh)
    logger.addHandler(ch)
    logger.info(msg)
    logger.removeHandler(fh)
    logger.removeHandler(ch)
    # return logger



def global_init_CA(args, strict=True):

    '''initialize the model weights by client pre-trained weighted sum'''
    model_ls = []
    # load models
    for client in range(args.num_clients):

        checkpoint_path = './' + args.model_path + '/' + args.model_save_filename_clients[client] + '/model_epoch_{}.pth'.format(args.last_train_clients[client])
        model_checkpoint = torch.load(checkpoint_path)
        args.model.load_state_dict(model_checkpoint['model_weights'], strict=strict)
        model_ls.append(copy.deepcopy(args.model))


    global_model = copy.deepcopy(model_ls[0])
    for param in global_model.parameters():
        param.data.zero_()

    for client_id, client_model in enumerate(model_ls):

        for server_param, client_param in zip(global_model.parameters(), client_model.parameters()):
            server_param.data += client_param.data.clone() * 1.0 * (args.trn_split_ratio[client_id] / sum(args.trn_split_ratio))

    return global_model






def my_summary(test_model, H = 256, W = 256, C = 28, N = 1):
    model = test_model.cuda()
    print(model)
    inputs = torch.randn((N, C, H, W)).cuda()
    flops = FlopCountAnalysis(model,inputs)
    n_param = sum([p.nelement() for p in model.parameters()])
    print(f'GMac:{flops.total()/(1024*1024*1024)}')
    print(f'Params:{n_param}')



def loadto_model(source_model, target_model):
    current_model_dict = target_model.state_dict()
    pretrained_dict = source_model.state_dict()

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in current_model_dict}

    current_model_dict.update(pretrained_dict)

    target_model.load_state_dict(current_model_dict)













