import platform
# print('python_version ==', platform.python_version())
import torch
# print('torch.__version__ ==', torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import time
import argparse
import numpy as np
from renhd import *
from evaluation import *
import os
import random
from model_zoo_ext import *
import torch.multiprocessing as mp
from torch.multiprocessing import Manager, Lock
import time

#############################################################################
'''set up hyperparameters of the experiments'''
parser = argparse.ArgumentParser(description='HMCPT on CNN tested on CIFAR10 appending noise')
parser.add_argument('--train-batch-size', type=int, default=64)
parser.add_argument('--test-batch-size', type=int, default=10000)
parser.add_argument('--num-burn-in', type=int, default=30000)
parser.add_argument('--num-epochs', type=int, default=50)
parser.add_argument('--evaluation-interval', type=int, default=50)
parser.add_argument('--iter-per-epoch', type=int, default=10)#30
parser.add_argument('--eta-theta', type=float, default=1.7e-8)
parser.add_argument('--c-theta', type=float, default=0.01)
parser.add_argument('--gamma-theta', type=float, default=1)
parser.add_argument('--xi-base', type=float, default=2)
parser.add_argument('--prior-precision', type=float, default=1e-3)
parser.add_argument('--permutation', type=float, default=0.0)
parser.add_argument('--enable-cuda', type=bool,default=True)
parser.add_argument('--device-num', type=int, default=4)
parser.add_argument('--num-thread-per-gpu', type=int,default=1)
parser.add_argument('--renhd-model', type=str,default='cnn_renhd')
parser.add_argument('--check-point-path',type=str,default='check-point')
parser.add_argument('--seed', type=int, default=10)
args = parser.parse_args()
# print (args)

#############################################################################
if torch.cuda.is_available():
    # torch.cuda.set_device(args.device_num)
    num_gpu = torch.cuda.device_count()
    # torch.cuda.manual_seed(args.seed)#set up random seed for GPU
    torch.cuda.manual_seed_all(args.seed)#set up random seed for all GPU

if not os.path.exists(args.check_point_path):
    os.makedirs(os.path.abspath(args.check_point_path), exist_ok=True)

#############################################################################
'''load dataset'''
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./cifar10-dataset', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.train_batch_size, shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./cifar10-dataset', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=False, drop_last=True)

N = len(train_loader.dataset)

#############################################################################
def train(model,thread_id,xi,shared_dict,sampler):#
    # print("[{}] Checking process ID:".format(os.getpid()))
    cuda_availability = args.enable_cuda and torch.cuda.is_available()
    cuda_id = thread_id % num_gpu
    print("Running the thread_" + str(thread_id) + \
          " on the GPU_" + str(cuda_id) + \
          ' with temperature ' + str(xi))

    model.cuda(cuda_id)
    num_labels = model.outputdim

    sampler.cuda(cuda_id)
    sampler.resample_momenta(xi)

    tStart = time.time()
    estimator = FullyBayesian((len(test_loader.dataset), num_labels), \
                              model, \
                              test_loader, \
                              cuda_availability, \
                              cuda_id)
    nIter = 0
    total_loss = 0
    for iter in range(1, 1 + args.iter_per_epoch):
        if thread_id == 0:
            print("#####################################################")
            print("Train: This is the iter " + str(iter))

        for x,y in train_loader:
            nIter += 1
            batch_size = x.data.size(0)
            if args.permutation > 0.0:
                y = y.clone()
                y.data[:int(args.permutation * batch_size)] = torch.LongTensor(
                    np.random.choice(num_labels, int(args.permutation * batch_size)))
            x, y = x.cuda(device=cuda_id), y.cuda(device=cuda_id)

            model.zero_grad()
            yhat = model(x)
            loss = F.cross_entropy(yhat, y)
            # loss = 1 / xi * loss
            for param in model.parameters():
                loss += args.prior_precision * torch.sum(param**2)
            loss.backward()
            total_loss += loss.data

            '''update params'''
            sampler.update(xi)

            if thread_id == 0 and nIter % args.evaluation_interval == 0:
                print(
                    'loss:{:6.4f}; thermostats_param:{:6.3f}; tElapsed:{:6.3f}'.format(
                        loss.data.item(), \
                        sampler.get_z_theta(), \
                        time.time() - tStart))
                ## evaluation
                acc = estimator.evaluation()
                print('This is the accuracy: %{:6.2f} of thread_{}'.format(acc, thread_id))
                model.train()
                sampler.resample_momenta(xi)

    # '''send energy to the shared memory'''
    U = 1 / xi * total_loss / args.iter_per_epoch / N
    shared_dict[thread_id] = U.item()


    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    filename = str(args.renhd_model) + str(thread_id) + "_thread"+\
               str(args.permutation)+"_permutation.pt"
    output_model_file = os.path.join(args.check_point_path, filename)
    torch.save(model_to_save.state_dict(), output_model_file)
    # print("Successfully save the model "+str(filename))

    opti_to_save = sampler.module if hasattr(sampler, 'module') else sampler  # Only save the model it-self
    optiname = str(args.renhd_model) + str(thread_id) + "_optimizer"+\
               str(args.permutation)+"_permutation.pt"
    output_opti_file = os.path.join(args.check_point_path, optiname)
    torch.save(opti_to_save.state_dict(), output_opti_file)
    print("Successfully save the optimizer "+str(optiname))

    # print("Finishing thread_"+str(thread_id))
#############################################################################

def burn_in(model,thread_id,xi,shared_dict,sampler):#
    cuda_id = thread_id % num_gpu
    model.cuda(cuda_id)
    num_labels = model.outputdim

    sampler.cuda(cuda_id)
    sampler.resample_momenta(xi)

    nIter = 0
    for epoch in range(1, 1 + args.num_epochs):
        print ("##########################################################################")
        print ("Burn-in: This is the epoch {} of thread {}".format(epoch,thread_id))
        for i, (x, y) in enumerate(train_loader):
            batch_size = x.data.size(0)
            if args.permutation > 0.0:
                y = y.clone()
                y.data[:int(args.permutation * batch_size)] = torch.LongTensor(
                    np.random.choice(num_labels, int(args.permutation * batch_size)))
            x, y = x.cuda(device=cuda_id), y.cuda(device=cuda_id)

            model.zero_grad()
            yhat = model(x)
            loss = F.cross_entropy(yhat, y)
            for param in model.parameters():
                loss += args.prior_precision * torch.sum(param ** 2)
            U = 1 / xi * loss.data / args.train_batch_size # * N
            loss.backward()

            '''update params'''
            sampler.update(xi)

            nIter += 1
        if nIter >= args.num_burn_in:
            break
    print("Finishing burn-in thread_" + str(thread_id))
    # '''send energy to the shared memory'''
    shared_dict[thread_id] = U.item()

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    filename = str(args.renhd_model) + str(thread_id) + "_thread"+\
               str(args.permutation)+"_permutation.pt"
    output_model_file = os.path.join(args.check_point_path, filename)
    torch.save(model_to_save.state_dict(), output_model_file)
    print("Successfully save the model "+str(filename))

#############################################################################


if __name__ == '__main__':
    ctx = mp.get_context("spawn")
    queue = ctx.Queue()
    shared_dict = mp.Manager().dict()

    num_thread = num_gpu * args.num_thread_per_gpu
    print("{} threads will be running simultaneously!".format(num_thread))
    # define the temperature ladder
    xi_ladder = [args.xi_base ** j for j in range(num_thread)]
    model = CNN(N, args.eta_theta, args.c_theta, args.gamma_theta)
    sampler = RENHD(model, N, args.eta_theta, args.c_theta, args.gamma_theta)

    burnin = 0
    # for thread_id in range(num_thread):
    #     filename = str(args.renhd_model) + str(thread_id) + "_thread.pt"
    #     output_model_file = os.path.join(args.check_point_path, filename)
    #     if os.path.exists(output_model_file):
    #         state_dict = torch.load(output_model_file)
    #         model.load_state_dict(state_dict)
    #         burnin += 1
    #         print("Successfully load the model {} and skipping burn-in".format(str(filename)))

    if burnin < num_thread:
        ## burn-in ...
        processes = []
        for thread_id in range(num_thread):
            p = ctx.Process(target=burn_in, args=(model, thread_id, \
                                                xi_ladder[thread_id], shared_dict,sampler))  #
            p.start()
            processes.append(p)

        for p in processes:
            # To wait until a process has completed its work and exited
            p.join()

    nIter = 0
    odd_exchange = True
    suc_exchange = False
    suc_num = 0
    exc_num = 0
    xi_change = []  # record the temperature of the first replica

    tic_train = time.time()
    for epoch in range(1, 1 + args.num_epochs):
        print("###########################################################################")
        print("Train: This is the depoch " + str(epoch))

        ##### exchange temperature #####
        if bool(shared_dict):
            if odd_exchange:
                start = 1
            else:
                start = 0

            for thread_id in range(start, num_thread, 2):
                if thread_id == 0:
                    xi_change.append(xi_ladder[thread_id])
                if thread_id + 1 < num_thread:
                    exc_num += 1
                    deltaE = (shared_dict[thread_id] - shared_dict[thread_id + 1])*\
                                (1/xi_ladder[thread_id] - 1/xi_ladder[thread_id + 1])
                    acceptance = 1/(1+np.exp(-deltaE)) - random.random()
                    # acceptance = min(1,np.exp(deltaE))- random.random()
                    print("This is np.exp(-deltaE): " + str(np.exp(-deltaE)))
                    if acceptance > 0:
                        suc_num += 1
                        xi_ladder[thread_id + 1], xi_ladder[thread_id] = xi_ladder[thread_id], xi_ladder[
                            thread_id + 1]
                        print("Successfully exchange replicas between {} and {}" \
                                .format(thread_id, thread_id + 1))
                        suc_exchange = True
                    else:
                        print("Failed to exchange replicas between {} and {}" \
                                .format(thread_id, thread_id + 1))
            if suc_exchange:
                odd_exchange = not odd_exchange
            print("Successful exchange ratio: ",str(suc_num/exc_num))
        ################################################################################
        ## load checkpoints and start new threads
        ################################################################################
        processes = []
        for thread_id in range(num_thread):
            ## Load a trained model that you have fine-tuned
            filename = str(args.renhd_model) + str(thread_id) + "_thread"+\
               str(args.permutation)+"_permutation.pt"
            output_model_file = os.path.join(args.check_point_path, filename)
            if os.path.exists(output_model_file):
                state_dict = torch.load(output_model_file)
                model.load_state_dict(state_dict)
                # print("Successfully load the model " + str(filename))

            optiname = str(args.renhd_model) + str(thread_id) + "_optimizer"+\
               str(args.permutation)+"_permutation.pt"
            output_opti_file = os.path.join(args.check_point_path, optiname)
            if os.path.exists(output_opti_file):
                state_dict = torch.load(output_opti_file)
                sampler.load_state_dict(state_dict)
                print("Successfully load the optimizer " + str(optiname))

            p = ctx.Process(target=train,args=(model,thread_id,\
                                                xi_ladder[thread_id],shared_dict,sampler))#
            p.start()
            processes.append(p)

        for p in processes:
            # To wait until a process has completed its work and exited
            p.join()

    f_name = str(args.renhd_model)+ str(args.permutation)+"_permutation.npy"
    output_f_name = os.path.join(args.check_point_path, f_name)
    np.save(output_f_name,xi_change)
    print("Saving temperature of the standard replica: "+str(f_name))
    tor_train = time.time()
    print("Training time:{:2f} s".format(tor_train-tic_train))
