""""""
"""
Step2: aims to obtain the attention mask and part boundingbox.
It contains two phases:
Phase 1: pretrain the stacked FC with channel grouping results from STEP1
         so as to get a good initialization of channel grouping
Phase 2: train the stacked FC with L_div and L_dis loss 
"""

import json
import argparse
import pickle
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from models_MA import VGG19_Finetune_STEP1and2_2Group
from dataset import Dataset_MA_STEP2
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0', type=str, help='index of GPU to use')
parser.add_argument('--dataset', default='AWA1', type=str, help='which dataset: CUB, AWA1, AWA2, FLO')

# Hyper-Parameter
# parser.add_argument('--cls_LAMBDA',   default=1.0,   type=float, help='weight for classification loss')
parser.add_argument('--dis_LAMBDA',   default=50.0, type=float, help='weight for channel grouping loss(distance)')
parser.add_argument('--div_LAMBDA',   default=1000.0, type=float, help='weight for channel grouping loss(divergence)')
parser.add_argument('--peak_LAMBDA',   default=0.0, type=float, help='weight for channel grouping loss(distance, peak)')
parser.add_argument('--BATCH_SIZE',      default=32, type=int, help='batch size')
parser.add_argument('--BATCH_SIZE_TEST', default=64, type=int, help='batch size testing')
# Optimization
parser.add_argument('--LEARNING_RATE', default=0.01, type=float, help='base learning rate')
parser.add_argument('--LEARNING_RATE_CHG', default=0.001, type=float, help='base learning rate')
parser.add_argument('--LEARNING_RATE_CLS', default=0.001, type=float, help='base learning rate')
parser.add_argument('--MOMENTUM',      default=0.9,    type=float, help='base momentum')
parser.add_argument('--WEIGHT_DECAY',  default=0.0005, type=float, help='base weight decay')
# display and log
parser.add_argument('--disp_interval', default=20, type=int, help='display interval')
parser.add_argument('--evl_interval',  default=100000000, type=int, help='Epoch zero-shot learning evl interval')
parser.add_argument('--save_interval', default=5, type=int, help='Epoch save model interval')
# exp
parser.add_argument('--split',    default='PP', type=str, help='split mode; standard split or proposed split')
parser.add_argument('--margin', default=0.1, type=int, help='margin for channel grouping, divergence loss')
parser.add_argument('--exp_info',    default='MA_CNN_STEP2', type=str, help='sub folder name for this exp')
parser.add_argument('--resume_STEP1', default=None, type=str, help='path of model to resume, STEP1')
parser.add_argument('--channel_cluster', default=None, type=str, help='path of cluster from STEP1, '
                                                                     'Note no need if we have STEP2_Phase1')
parser.add_argument('--resume_Phase1', default=None, type=str, help='path of model to resume, STEP2_Phase1')
parser.add_argument('--resume_Phase2', default=None, type=str, help='path of model to resume, STEP2_Phase2')

opt = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
print('Running parameters:')
print(json.dumps(vars(opt), indent=4, separators=(',', ': ')))

def train():
    if not os.path.exists('models_MA'):
        os.mkdir('models_MA')
    output_dir = 'models_MA/{}'.format(opt.dataset)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)


    output_dir = os.path.join(output_dir, opt.exp_info+'_'+opt.split if opt.exp_info else opt.split)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    lr = opt.LEARNING_RATE
    fout = output_dir + '/log_{}.txt'.format(strftime("%a, %d %b %Y %H:%M:%S", gmtime()))
    with open(fout, 'w') as f:
        f.write('Training Start:')
        f.write('TRAIN.LEARNING_RATE: {}\n'.format(opt.LEARNING_RATE))
        f.write('TRAIN.MOMENTUM: {}\n'.format(opt.MOMENTUM))
        f.write('TRAIN.IMS_PER_BATCH: {}\n'.format(opt.BATCH_SIZE))

    # load dataset and net
    dataset = Dataset_MA_STEP2(opt)
    opt.train_ncls = dataset.train_ncls
    net = VGG19_Finetune_STEP1and2_2Group(opt=opt)
    criterion_BCE = nn.BCELoss().cuda()
    criterion_cls = nn.CrossEntropyLoss().cuda()

    """ First only train sig_mask branch to get a reasonable initialization for sig mask (Stacked_FC)
        so far only support single gpu mode 
    """

    """
    Start Phase1:
    We must resume the model from STEP1  
    and load the channel cluster result based on the model of STEP1
    """
    # opt.resume_STEP1 = 'models_CUB/MA_CNN_STEP2_PP/ZSL_Finetune_STEP2_Phase1_2Group_Epoch5.tar'
    resume_model(opt.resume_STEP1, net)
    result_from_STEP1 = pickle.load(open(opt.channel_cluster, 'rb'))
    channel_clustering_label_matrix = result_from_STEP1['channel_clustering_label_matrix']
    channel_clustering_label_matrix = channel_clustering_label_matrix.cuda()
    _, sort_idx = channel_clustering_label_matrix.sum(dim=1).sort()
    sort_idx = sort_idx.cpu().numpy()
    channel_clustering_label_matrix = channel_clustering_label_matrix[sort_idx]

    if torch.cuda.device_count() > 1:
        c_print("Let's use {} GPUs!".format(torch.cuda.device_count()), color='red', attrs=['bold'])
        net = nn.DataParallel(net)
    net.cuda()
    net.train()

    optimizer_Initialize_ChanClu_Layers = set_optimizer_Initialize_CHG_Layers(net, opt, lr)
    optimizer = optimizer_Initialize_ChanClu_Layers
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True, cooldown=2, min_lr=0.00001)
    earlyStopper = EarlyStopping(mode='min', patience=15)
    # training
    train_loss = 0.0
    step_cnt = 0
    t = Timer()
    t.tic()

    """ First only train sig_mask branch to get a reasonable initialization for sig mask (Stacked_FC)
    """
    for epoch in range(1, 5+1):  # 400
        print(time2str())
        with open(fout, 'a') as f:
            f.write(time2str() + '\n')
        train_loss_epoch = []
        for i_batch, sample_batched in enumerate(dataset.dataloader):
            imgs = sample_batched['imgs']
            (sig_0, sig_1) = net(imgs.cuda(), 'STEP2_Phase1')
            loss_0 = criterion_BCE(sig_0, (channel_clustering_label_matrix[0]).repeat(sig_0.shape[0], 1))
            loss_1 = criterion_BCE(sig_1, (channel_clustering_label_matrix[1]).repeat(sig_1.shape[0], 1))

            loss = (loss_0 + loss_1)/2
            train_loss += loss.item()
            train_loss_epoch.append(loss.item())
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), 10.)
            optimizer.step()

            if i_batch % opt.disp_interval == 0 and i_batch:
                duration = t.toc(average=False)
                inv_fps = duration / step_cnt

                log_text = 'Epoch:{:2d} [{:3d}/{:3d}], Tloss {:.4f}   (lr: {}, {:.2f}s per iteration)'.format(
                    epoch, i_batch, len(dataset.dataloader), train_loss / step_cnt,
                    scheduler.optimizer.param_groups[0]['lr'], inv_fps)
                c_print(log_text, color='green', attrs=['bold'])

        train_loss_epoch_mean = np.asarray(train_loss_epoch).mean()
        scheduler.step(train_loss_epoch_mean)
        if earlyStopper.step(train_loss_epoch_mean):
            break  # jump out of loop
    save_name = os.path.join(output_dir, 'ZSL_Finetune_STEP2_Phase1_2Group_Epoch{}.tar'.format(epoch))
    net2save = net.module if torch.cuda.device_count() > 1 else net
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': net2save.state_dict(),
        'log': log_text,
        'optimizer': optimizer.state_dict(),
    }, save_name)
    print('save model: {}'.format(save_name))
    print('Phase1 Done!!!!')
    print('-'*100)


    """
    Start Phase2:
    End2End training dis/div loss for channel grouping 
    """
    print('Start Phase2.')
    net.cuda()
    net.train()

    """
    Fix Conv1-5, optimize channel group layers for L_chg
    """
    lr = opt.LEARNING_RATE_CHG
    optimizer = set_optimizer_Phase2_L_chg(net, opt, lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5,
                                  verbose=True, cooldown=2, min_lr=0.000001)
    earlyStopper = EarlyStopping(mode='min', patience=10)
    t = Timer()
    t.tic()
    step_cnt = 0
    train_loss = 0.0
    train_dis_loss = 0.0
    train_div_loss = 0.0
    train_bal_loss = 0.0
    epoch = None
    for epoch in range(1, 3+1):  # checkpoint['epoch'] if opt.resume_Phase2 else 1, 50 + 1):  # 400
        print(time2str())
        with open(fout, 'a') as f:
            f.write(time2str() + '\n')
        train_loss_epoch = []
        for i_batch, sample_batched in enumerate(dataset.dataloader):
            imgs = sample_batched['imgs']
            _, _, Mask_1, Mask_2, Mask_1_norm, Mask_2_norm, sig_1_org, sig_2_org, _ = net(imgs.cuda(), 'STEP2_Phase2')

            """ get channel grouping loss
            """
            bal_loss = get_channel_balance_loss(sig_1_org, sig_2_org)
            sig_1 = (sig_1_org[0]).data.cpu().numpy()
            sig_2 = (sig_2_org[0]).data.cpu().numpy()
            dis_loss, div_loss, peak_loss = get_channel_group_loss(Mask_1_norm, Mask_2_norm)
            dis_loss = opt.dis_LAMBDA * dis_loss
            div_loss = opt.div_LAMBDA * div_loss
            bal_loss = opt.peak_LAMBDA * bal_loss
            loss = dis_loss + div_loss + bal_loss

            train_loss += loss.item()
            train_dis_loss += dis_loss.item()
            train_div_loss += div_loss.item()
            train_bal_loss += bal_loss.item()
            train_loss_epoch.append(loss.item())

            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), 10.)
            optimizer.step()

            if i_batch % opt.disp_interval == 0 and i_batch:
                duration = t.toc(average=False)
                inv_fps = duration / step_cnt
                log_text = ('CNG Epoch{:2d} [{:3d}/{:3d}], Loss T|dis|div|bal: [{:.4f}|{:.4f}|{:.4f}|{:.4f}] '
                            '(lr: {}, {:.2f}s per iteration)').format(
                     epoch, i_batch, len(dataset.dataloader), train_loss / step_cnt,
                    train_dis_loss / step_cnt, train_div_loss / step_cnt, train_bal_loss/step_cnt,
                    scheduler.optimizer.param_groups[0]['lr'], inv_fps)
                print('Sig_1: {:.2f}, Sig_2: {:.2f}, Sig_T: {:.2f}'.format(sum(sig_1) / 512, sum(sig_2) / 512,
                                                                           sum(sig_1 + sig_2) / 512))
                c_print(log_text, color='green', attrs=['bold'])
                with open(fout, 'a') as f:
                    f.write(log_text + '\n')
                # reset the counter
                train_loss = 0.0
                train_dis_loss = 0.0
                train_div_loss = 0.0
                train_bal_loss = 0.0
                step_cnt = 0
                t.tic()
            if i_batch % opt.disp_interval == 0:
                print('Max activation: [{:.4f}|{:.4f}]'.format(torch.max(Mask_1[0]).item(),
                                                               torch.max(Mask_2[0]).item()))
                print('Min activation: [{:.4f}|{:.4f}]'.format(torch.min(Mask_1[0]).item(),
                                                               torch.min(Mask_2[0]).item()))

        train_loss_epoch_mean = np.asarray(train_loss_epoch).mean()
        scheduler.step(loss)
        if earlyStopper.step(train_loss_epoch_mean):
            break  # jump out of loop
    if epoch is not None:
        save_name = os.path.join(output_dir,
                                 'ZSL_Finetune_STEP2_Phase2_CHG_2Group_Epoch{}.tar'.format(epoch))
        net2save = net.module if torch.cuda.device_count() > 1 else net
        torch.save({
            'epoch': epoch + 1,
            'mode': 'CHG',
            'model_state_dict': net2save.state_dict(),
            'log': log_text,
            'optimizer': optimizer.state_dict(),
        }, save_name)
        print('save model: {}'.format(save_name))



def visualize_test():
    # load dataset
    dataset = Dataset_MA_STEP2(opt)
    opt.train_ncls = dataset.train_ncls
    net = VGG19_Finetune_STEP1and2_2Group(opt=opt)
    if opt.resume_Phase2:
        resume_model(opt.resume_Phase2, net)
    else:
        raise ValueError('You should provide a trained model.')
    net.cuda()
    net.eval()
    import matplotlib.pyplot as plt
    with torch.no_grad():
        for i_batch, sample_batched in enumerate(dataset.dataloader_train_for_test):
            imgs = sample_batched['imgs']
            raw_imgs = sample_batched['raw_imgs']
            part_cls_0, part_cls_1, Mask_1, Mask_2, Mask_1_norm, Mask_2_norm, sig_1, sig_2, _= net(imgs.cuda(), 'STEP2_Phase2')

            fig, ax = plt.subplots(2, 2)
            Mask_1 = (Mask_1[0]).data.cpu().numpy().reshape((28,28))
            ax[0,0].imshow(Mask_1)
            Mask_2 = (Mask_2[0]).data.cpu().numpy().reshape((28,28))
            ax[0,1].imshow(Mask_2)
            ax[1,0].imshow((raw_imgs[0]).numpy())

            plt.show()
            sig_1, sig_2 = net(imgs.cuda(), 'STEP2_Phase1')
            sig_1 = (sig_1[0]).data.cpu().numpy()
            sig_2 = (sig_2[0]).data.cpu().numpy()

            fig, ax = plt.subplots(3, 1)
            ax[0].plot(sig_1)
            ax[1].plot(sig_2)

            ax[2].plot(sig_1+sig_2)
            plt.show()
            print('Sig_1: {:.2f}, Sig_2: {:.2f}, Sig_T: {:.2f}'.format(sum(sig_1)/512, sum(sig_2)/512,
                                                                       sum(sig_1+sig_2)/512))

def test_generate_rec_Img(save_image_folder=None, size='half'):

    output_dir = 'models_{}/'.format(opt.dataset)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, opt.exp_info+'_'+opt.split if opt.exp_info else opt.split)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # load dataset
    dataset = Dataset_MA_STEP2(opt)
    opt.train_ncls = dataset.train_ncls
    net = VGG19_Finetune_STEP1and2_2Group(opt=opt)
    if opt.resume_Phase2:
        resume_model(opt.resume_Phase2, net)
    else:
        raise ValueError('You should provide a trained model.')
    if torch.cuda.device_count() > 1:
        c_print("Let's use {} GPUs!".format(torch.cuda.device_count()), color='red', attrs=['bold'])
        net = nn.DataParallel(net)
    net.cuda()
    net.eval()
    pred_box = dict()
    datasets = [dataset.dataloader_train_for_test, dataset.dataloader_test_unseen]
    if hasattr(dataset, 'dataloader_test_seen'):
        datasets.append(dataset.dataloader_test_seen)

    for dataset in datasets:
        for i_batch, sample_batched in enumerate(dataset):
            imgs = sample_batched['imgs']
            img_path = sample_batched['img_path']
            with torch.no_grad():
                part_cls_0, part_cls_1, Mask_1, Mask_2, Mask_1_norm, Mask_2_norm, sig_1, sig_2, heat_map= net(imgs.cuda(),
                                                                                                              'STEP2_Phase2')
            print('[{} | {}]'.format(i_batch, len(dataset)))
            for i in range(imgs.shape[0]):
                path = img_path[i]
                img = cv2.imread(path)
                if img.ndim<3:
                    img = np.transpose(np.array([img,img,img]),(1,2,0))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                [h,w,_]=img.shape
                box  = get_box(h, w, (heat_map[i]).data.cpu().numpy())
                if size == 'fix':
                    boxes = get_part(h, w, (Mask_1[i]).data.cpu().numpy(), (Mask_2[i]).data.cpu().numpy())
                elif size == 'half':
                    boxes = get_part_half_box(h, w, (Mask_1[i]).data.cpu().numpy(), (Mask_2[i]).data.cpu().numpy(), box)
                else:
                    raise ValueError('no implemented size option')
                pred_box[os.path.join(path.split('/')[-2], path.split('/')[-1])] = {'box': box, 'part_boxes': boxes}

                if save_image_folder:
                    rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=3, edgecolor='black',
                                             facecolor='none')
                    rects = list()
                    colors = ['r', 'b', 'g', 'y']
                    for _i in range(2):
                        rects.append(patches.Rectangle((boxes[_i][0], boxes[_i][1]), boxes[_i][2] - boxes[_i][0],
                                                       boxes[_i][3] - boxes[_i][1],
                                                       linewidth=5, edgecolor=colors[_i], facecolor='none'))
                    fig, ax = plt.subplots(1)
                    ax.imshow(img)
                    ax.add_patch(rect)
                    ax.add_patch(rects[0])
                    ax.add_patch(rects[1])
                    plt.axis("off")
                    subfolder = os.path.join(save_image_folder, path.split('/')[-2])
                    if not os.path.exists(subfolder):
                        os.makedirs(subfolder)
                    path_t = os.path.join(save_image_folder, path.split('/')[-2], path.split('/')[-1])
                    fig.savefig(path_t, bbox_inches='tight')
                    plt.close('all')
                    # fig, ax = plt.subplots(2, 2)
                    # _Mask_1 = (Mask_1[i]).data.cpu().numpy().reshape((28, 28))
                    # ax[0, 0].imshow(_Mask_1)
                    # _Mask_2 = (Mask_2[i]).data.cpu().numpy().reshape((28, 28))
                    # ax[0, 1].imshow(_Mask_2)
                    # ax[1, 0].imshow((imgs[i]).data.cpu().numpy().transpose(1, 2, 0) * 255)
                    # plt.show()
    with open(output_dir + '/Pred_Boxes_{}_{}.json'.format(size,  opt.split), 'w') as outfile:
        json.dump(pred_box, outfile)
    print('Boundingbox is saved to {}...'.format(output_dir + '/Pred_Boxes_{}_{}.json'.format(size, opt.split)))
    print('Total images {}'.format(len(pred_box)))


def set_optimizer_Initialize_CHG_Layers(net, opt, lr):
    if torch.cuda.device_count() > 1:
        net = net.module
    c_print("Learning rate changed to {}".format(lr), color='red', attrs=['bold'])
    set_trainable(net.convBase, requires_grad=False)
    optimizer = torch.optim.SGD([
        {'params': net.sig_mask_1.parameters(), 'lr': lr * 10.0},
        {'params': net.sig_mask_2.parameters(), 'lr': lr * 10.0},
    ], lr=lr,  momentum=opt.MOMENTUM, weight_decay=opt.WEIGHT_DECAY)
    return optimizer

""" Fix Conv1-5, optimize channel group layers for L_chg
"""
def set_optimizer_Phase2_L_chg(net, opt, lr):
    if torch.cuda.device_count() > 1:
        net = net.module
    c_print("Learning rate changed to {}".format(lr), color='red', attrs=['bold'])
    set_trainable(net.convBase, requires_grad=False)
    optimizer = torch.optim.SGD([
        {'params': net.sig_mask_1.parameters(), 'lr': lr * 10.0},
        {'params': net.sig_mask_2.parameters(), 'lr': lr * 10.0}
    ], lr=lr,  momentum=opt.MOMENTUM, weight_decay=opt.WEIGHT_DECAY)
    return optimizer


def get_channel_balance_loss(sig_1_org, sig_2_org):
    return (torch.mean(sig_1_org) - torch.mean(sig_2_org)).pow(2)

def get_channel_group_loss(Mask_1, Mask_2):
    def get_div_loss(Mask_1, Mask_2):
        return torch.mean(Mask_1 * F.relu_(Mask_2 - opt.margin))

    def get_dis_loss(Mask, dis_map):
        peak = torch.argmax(Mask, dim=1)
        peak_row = peak / 28
        peak_col = peak % 28
        for i in range(28):
            for j in range(28):
                dis_map[:, i, j] = torch.sqrt(
                    ((peak_row - i).pow(2) + (peak_col - j).pow(2)).type(torch.FloatTensor))
        dis_map = dis_map.view(dis_map.shape[0], -1)
        dis_loss = Mask * dis_map * (dis_map > 4).float()
        # avoid the trivial solution, keep the peak pos with a high value, 4 is the dis thresh
        peak_loss = (F.relu(0.5 - Mask)).pow(2) * (dis_map <= 4).float()
        return torch.mean(dis_loss), torch.mean(peak_loss)

    def get_dis_loss_Euclidean(Mask, dis_map):
        K = 8.0
        peak = torch.argmax(Mask, dim=1)
        peak_row = peak / 28
        peak_col = peak % 28
        for i in range(28):
            for j in range(28):
                dis_map[:, i, j] = torch.sqrt(
                    ((peak_row - i).pow(2) + (peak_col - j).pow(2)).type(torch.FloatTensor))
        dis_map = dis_map.view(dis_map.shape[0], -1)
        dis_scale = K / (K + dis_map)


        dis_loss = torch.mean((Mask - dis_scale).pow(2))
        # avoid the trivial solution, keep the peak pos with a high value, 4 is the dis thresh
        peak_loss = torch.zeros(1).cuda()  # (F.relu(0.5 - Mask)).pow(2) * (dis_map <= 4).float()
        return dis_loss, peak_loss

    batch_size = Mask_1.shape[0]
    dis_map_1 = torch.zeros(batch_size, 28, 28).cuda()
    dis_map_2 = torch.zeros(batch_size, 28, 28).cuda()

    dis_loss_1, peak_loss_1 = get_dis_loss_Euclidean(Mask_1, dis_map_1)
    dis_loss_2, peak_loss_2 = get_dis_loss_Euclidean(Mask_2, dis_map_2)

    div_loss_1 = get_div_loss(Mask_1, Mask_2)
    div_loss_2 = get_div_loss(Mask_2, Mask_1)

    return (dis_loss_1 + dis_loss_2)/2,\
           (div_loss_1 + div_loss_2)/2, \
           (peak_loss_1 + peak_loss_2)/2



if __name__ == "__main__":
    # opt.channel_cluster = 'models_AWA1/MA_CNN_STEP1_PP/train_channel_cluster_STEP1_2Group.pkl'
    # opt.resume_STEP1    = 'models_AWA1/MA_CNN_STEP1_PP/ZSL_Finetune_STEP1_Epoch10.tar'
    # train()

    # opt.resume_Phase2 = 'models_AWA1/MA_CNN_STEP2_PP/ZSL_Finetune_STEP2_Phase2_CHG_2Group_Epoch3.tar'
    # visualize_test()
    """
    # this function is to generate boundingbox of parts
    # if you want to save Image with bbox, then set save_image_folder='AWA1_data_box_half_PP'
    # otherwise set save_image_folder=None
    # half mean detected part is half of detected object box
    # fix mean 64 fixed pixel in 448 by 448
    """
    opt.resume_Phase2 = 'models_AWA1/MA_CNN_STEP2_PP/ZSL_Finetune_STEP2_Phase2_CHG_2Group_Epoch3.tar'
    test_generate_rec_Img(save_image_folder='AWA1_data_box_half_PP', size='half')
