""""""
"""
Step1: aims to pretrain CNN backbone and cluster the channels by K-means
(1) first use train() to pretrain CNN
(2) then resume the net, extract the pos of max activation and do K-means clustering 
e.g.  resume = 'models_CUB/MA_CNN_STEP1_PP/ZSL_Finetune_STEP1_Epoch95.tar'
      channel_clustering_Conv5_4(ngroup=2, resume=resume)
"""

import argparse
import json
import pickle
from sklearn.cluster import KMeans
from collections import Counter

import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from models_MA import VGG19_Finetune_STEP1and2_2Group
from dataset import Dataset_MA_STEP1
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0', type=str, help='index of GPU to use')
parser.add_argument('--dataset', default='CUB', type=str, help='which dataset: CUB, AWA1, AWA2, FLO')

# Hyper-Parameter
parser.add_argument('--BATCH_SIZE', default=64, type=int, help='batch size')
parser.add_argument('--LEARNING_RATE', default=0.0001, type=float, help='base learning rate')
parser.add_argument('--MOMENTUM',      default=0.9,    type=float, help='base momentum')
parser.add_argument('--WEIGHT_DECAY',  default=0.0005, type=float, help='base weight decay')
parser.add_argument('--nepoch',        default=100, type=int, help='number of epoch, 100 for CUB, 10 for AWA1/AWA2')
# display and log
parser.add_argument('--disp_interval', default=20, type=int, help='display interval')
parser.add_argument('--save_interval', default=5, type=int, help='Epoch save model interval')

# exp
parser.add_argument('--split',    default='PP', type=str, help='split mode; standard split or proposed split, SP/PP')
parser.add_argument('--exp_info',    default='MA_CNN_STEP1', type=str, help='sub folder name for this exp')
parser.add_argument('--resume',      default=None, type=str, help='path of model to resume')
parser.add_argument('--manualseed', type=int, help='if use random seed to fix result')

opt = parser.parse_args()
print('Running parameters:')
print(json.dumps(vars(opt), indent=4, separators=(',', ': ')))

opt.start_epoch = 0
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu

def train():
    if not os.path.exists('models_MA'):
        os.mkdir('models_MA')
    output_dir = 'models_MA/{}'.format(opt.dataset)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, opt.exp_info+'_'+opt.split if opt.exp_info else opt.split)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    lr = opt.LEARNING_RATE
    fout = output_dir + '/log_{}.txt'.format(strftime("%a, %d %b %Y %H:%M:%S", gmtime()))
    with open(fout, 'w') as f:
        f.write('Training Start:\n')
        f.write('TRAIN.LEARNING_RATE: {}\n'.format(opt.LEARNING_RATE))
        f.write('TRAIN.MOMENTUM: {}\n'.format(opt.MOMENTUM))
        f.write('TRAIN.BATCH_SIZE: {}\n'.format(opt.BATCH_SIZE))

    # load dataset and net
    dataset = Dataset_MA_STEP1(opt)
    opt.train_ncls = dataset.train_ncls
    net = VGG19_Finetune_STEP1and2_2Group(opt)

    if opt.resume:
        cur_epoch = resume_model(opt.resume, net, fout)

    if torch.cuda.device_count() > 1:
        c_print("Let's use {} GPUs!".format(torch.cuda.device_count()), color='red', attrs=['bold'])
        net = nn.DataParallel(net)
    net = net.cuda()
    print(net)

    criterion_cls = nn.CrossEntropyLoss().cuda()
    optimizer = set_optimizer_CLS(net, opt, lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True, cooldown=2, min_lr=1e-6)

    # training
    train_loss = 0.0
    train_acc = 0.0
    step_cnt = 0
    t = Timer()
    t.tic()
    net.train()
    for epoch in range(cur_epoch if opt.resume else 1, opt.nepoch+1):
        print(time2str())
        with open(fout, 'a') as f:
            f.write(time2str() + '\n')
        for i_batch, sample_batched in enumerate(dataset.dataloader):

            imgs  = sample_batched['imgs']
            labels = sample_batched['labels']

            pred = net(imgs.cuda(), 'STEP1')
            loss = criterion_cls(pred, labels.cuda())

            pred_lb = np.argmax(pred.data.cpu().numpy(), axis=1)
            accuracy = (pred_lb == labels.numpy()).mean()

            train_loss += loss.item()
            train_acc += accuracy
            step_cnt += 1

            # backward
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), 10.)
            optimizer.step()

            if i_batch % opt.disp_interval == 0 and i_batch:
                duration = t.toc(average=False)
                inv_fps = duration / step_cnt
                log_text = 'Epoch:{:2d} [{:3d}/{:3d}], Tloss {:.4f}  acc: {:.2f} (lr: {}, {:.2f}s per iteration)'.format(
                    epoch, i_batch, len(dataset.dataloader), train_loss/step_cnt, train_acc*100/step_cnt,
                    scheduler.optimizer.param_groups[0]['lr'],  inv_fps)
                c_print(log_text, color='green', attrs=['bold'])
                with open(fout, 'a') as f:
                    f.write(log_text + '\n')
                # reset the counter
                train_loss = 0.0
                train_acc = 0.0
                step_cnt = 0
                t.tic()

        if (epoch % opt.save_interval == 0) and epoch:
            save_name = os.path.join(output_dir, 'ZSL_Finetune_STEP1_Epoch{}.tar'.format(epoch))
            net2save = net.module if torch.cuda.device_count() > 1 else net
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': net2save.state_dict(),
                'convBase_state_dict': net2save.convBase.state_dict(),
                'log':   log_text,
                'optimizer': optimizer.state_dict(),
            }, save_name)
            print('save model: {}'.format(save_name))
        scheduler.step(loss)


def set_optimizer_CLS(net, opt, lr):
    if torch.cuda.device_count() > 1:
        net = net.module
    c_print("Learning rate changed to {}".format(lr), color='red', attrs=['bold'])
    optimizer = torch.optim.SGD([
        {'params': net.convBase.conv3.parameters()},
        {'params': net.convBase.conv4.parameters()},
        {'params': net.convBase.conv5.parameters()},
        {'params': net.classifier.parameters(), 'lr': lr * 2.0},
        {'params': net.FC_Top.parameters(), 'lr': lr * 5.0},
    ], lr=lr,  momentum=opt.MOMENTUM, weight_decay=opt.WEIGHT_DECAY)
    return optimizer


def channel_clustering_Conv5_4(ngroup=2, resume=None):
    """"""
    """ STEP 1: extract the Conv5_4 feature tensor, 14*14*512
        and get the channel feature by the pos of max activation
        of each samples
    """

    # load dataset and net
    dataset = Dataset_MA_STEP1(opt)
    opt.train_ncls = dataset.train_ncls
    net = VGG19_Finetune_STEP1and2_2Group(opt)
    resume_model(resume, net)

    net.cuda()
    max_activation_pos = np.zeros((dataset.num_train_sample * 2, 512)).astype(int)
    cnt = 0
    print('Extracting pos for clustering...')
    for i_batch, sample_batched in enumerate(dataset.dataloader_train_for_test):
        im_data = sample_batched['imgs']
        # forward
        if i_batch%10 ==0 and i_batch:
            print("[{}/{}]".format(i_batch, len(dataset.dataloader_train_for_test)))
        with torch.no_grad():
            visual_embed = net.extract_CONV5_4_feature(im_data.cuda())
            visual_embed_flatten = visual_embed.view(visual_embed.shape[0], visual_embed.shape[1], -1).cpu().numpy()
            for feature_tensor in visual_embed_flatten:
                pos_idx = np.argmax(feature_tensor, axis=1)
                row_idx = (pos_idx / 14).astype(int)
                col_idx = pos_idx % 14
                max_activation_pos[cnt:cnt + 2] = np.vstack((row_idx, col_idx))
                cnt += 2
    max_activation_pos = max_activation_pos.transpose()
    print('max_activation_pos: ', max_activation_pos.shape)


    """ STEP 2: kmean clustering and save the label vector and the label matrix(one-hot)
    """

    kmeans = KMeans(n_clusters=ngroup, verbose=1, random_state=None, n_jobs=4).fit(max_activation_pos)
    print('Channel Clustering Stats:', Counter(kmeans.labels_))

    channel_clustering_label = kmeans.labels_
    channel_clustering_label_matrix = np.zeros((ngroup, 512))
    for idx , _label in enumerate(channel_clustering_label):
        channel_clustering_label_matrix[_label][idx] = 1
    channel_clustering_label_matrix = torch.from_numpy(channel_clustering_label_matrix.astype(np.float32)).cuda()

    out_file = os.path.join('models_'+opt.dataset, opt.exp_info+'_'+opt.split if opt.exp_info else opt.split,
                              'train_channel_cluster_STEP1_{}Group.pkl'.format(ngroup))
    with open(out_file, 'wb') as fout:
        pickle.dump({'max_activation_pos': max_activation_pos,
                    'channel_clustering_label': channel_clustering_label,
                    'channel_clustering_label_matrix': channel_clustering_label_matrix}, fout)
    cprint("Save Result of Channel Clustering STEP 1 to {}".format(out_file), color='red', attrs=['bold'])


if __name__ == "__main__":
    # train()
    resume = 'models_CUB/MA_CNN_STEP1_PP/ZSL_Finetune_STEP1_Epoch100.tar'
    channel_clustering_Conv5_4(ngroup=2, resume=resume)
