import math
import time
import numpy as np
import torch
import torch.nn.functional as F
import argparse
from pathlib import Path
from utils.logger import *
from utils.config import *
from utils.misc import *
import os
from tools import builder
from torch.utils.data import Dataset
import json
from datasets.PairDataset import PairDataset
from utils.AverageMeter import AverageMeter
from pytorch3d.ops import sample_farthest_points, knn_points
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

def write_plyfile(file_name, point_cloud):
    f = open(file_name + '.ply', 'w')
    init_str = "ply\nformat ascii 1.0\ncomment VCGLIB generated\nelement vertex " + str(len(point_cloud)) + \
               "\nproperty float x\nproperty float y\nproperty float z\n" \
               "element face 0\nproperty list uchar int vertex_indices\nend_header\n"
    f.write(init_str)
    for i in range(len(point_cloud)):
        f.write(str(round(float(point_cloud[i][0]), 6)) + ' ' + str(round(float(point_cloud[i][1]), 6)) + ' ' +
                str(round(float(point_cloud[i][2]), 6)) + '\n')
    f.close()

class Loss_Metric:
    def __init__(self, loss = 0 ):
        if type(loss).__name__ == 'dict':
            self.loss = loss['loss']
        else:
            self.loss = loss

    def better_than(self, other):
        if self.loss < other.loss:
            return True
        else:
            return False

    def state_dict(self):
        _dict = dict()
        _dict['loss'] = self.loss
        return _dict

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, help='yaml config file')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--num_workers', type=int, default=8)
    # seed
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.')
    # some args
    parser.add_argument('--exp_name', type=str, default='test/PCoTTA', help='experiment name')
    parser.add_argument('--loss', type=str, default='cd2', help='loss name')
    parser.add_argument('--ckpts', type=str, default=None, help='test used ckpt path')
    parser.add_argument('--val_freq', type=int, default=1, help='test freq')
    parser.add_argument('--epoch', type=str, default='300', help='')
    parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')
    parser.add_argument('--load_prototype', type=bool, default=False, help='')

    # dataset
    parser.add_argument('--data_path', type=str, default='data', help='')
    # comment
    parser.add_argument('--comment', type=str, default='default', help='')


    parser.add_argument(
        '--resume',
        action='store_true',
        default=False,
        help = 'autoresume training (interrupted by accident)')

    args = parser.parse_args()

    args.experiment_path = args.exp_name

    args.log_name = Path(args.config).stem
    if not os.path.exists(args.experiment_path):
        os.makedirs(args.experiment_path)
        print('Create experiment path successfully at %s' % args.experiment_path)

    return args

def y_flip(pointcloud1):
    angles = [0, 0, math.pi]
    Rx = np.array([[1, 0, 0],
                    [0, np.cos(angles[0]), -np.sin(angles[0])],
                    [0, np.sin(angles[0]), np.cos(angles[0])]])
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                    [0, 1, 0],
                    [-np.sin(angles[1]), 0, np.cos(angles[1])]])
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                    [np.sin(angles[2]), np.cos(angles[2]), 0],
                    [0, 0, 1]])
    R = np.dot(Rz, np.dot(Ry, Rx))
    pointcloud1 = np.dot(pointcloud1.detach().cpu().numpy(), R)
    return pointcloud1


def find_nearest_sample(test_feature_all, test_task, simi, feature_dict, batch_size=16):
    B, G, C = test_feature_all.shape
    simi = simi.squeeze() # B, N_source
    N_source = simi.size(1)
    sample_dict = {'dataset_idx':[], 'sample_idx':[]}  #[B, ]
    nearest_idx = simi.max(dim=-1)[1] # B

    for b in range(B):
        # get the nearest domain feature
        if test_task[b] == 'reconstruction':
            test_id = 0
        elif test_task[b] == 'denoising':
            test_id = 1
        elif test_task[b] == 'registration':
            test_id = 2
        nearest_idx_source = N_source * test_id + nearest_idx[b]
        source_local_feature = feature_dict['local_feature'][nearest_idx_source] # N, G, C
        source_local_feature = F.normalize(source_local_feature, dim=-1)

        feature_dataset = TensorDataset(source_local_feature)
        feature_dataloader = DataLoader(feature_dataset, batch_size=batch_size, shuffle=False)

        # get the nearset prompt sample from the nearest domain
        # due to the GPU memory, using batch level processing
        nearest_list = {'simi':[], 'idx':[], 'batch':[]}
        test_local_feature = F.normalize(test_feature_all[b], dim=-1)
        for i, batch_data in enumerate(feature_dataloader):
            source_feature = batch_data[0]
            N = source_feature.size(0)
            test_feature = test_local_feature.unsqueeze(0).repeat(N, 1, 1)

            simi_sample = torch.matmul(test_feature, source_feature.permute(0, 2, 1))
            simi_sample = torch.diagonal(simi_sample, dim1=-2, dim2=-1)
            simi_sample = torch.mean(simi_sample, dim=-1)

            nearest_idx_sample = torch.argmax(simi_sample)
            nearset_simi_sample = simi_sample[nearest_idx_sample]
            nearest_list['simi'].append(nearset_simi_sample)
            nearest_list['idx'].append(nearest_idx_sample)
            nearest_list['batch'].append(i)
        
        nearest_simi_all = torch.stack(nearest_list['simi'])
        nearset_list_idx = torch.argmax(nearest_simi_all)
        nearest_list_sample_idx = nearest_list['idx'][nearset_list_idx]
        nearest_list_batch_idx = nearest_list['batch'][nearset_list_idx]
        sample_idx = nearest_list_batch_idx * batch_size + nearest_list_sample_idx

        sample_dict['dataset_idx'].append(nearest_idx[b])
        sample_dict['sample_idx'].append(sample_idx)

        torch.cuda.empty_cache()

    return sample_dict          


def get_index_in_class(config, train_domains):
    class_num = 7
    train_index = {'index':[], 'dataset':[], 'task':[]}
    for i in range(len(train_domains['task'])):
        train_index['task'].append(train_domains['task'][i])
        train_index['dataset'].append(train_domains['dataset'][i])

        class_index = [[] for _ in range(class_num)]  # store sample index in class-level

        dataset = train_domains['data'][i][0]
        dataloader = DataLoader(dataset, batch_size=config.total_bs, shuffle=False, drop_last=False)
        for j, (pointset, target, rotation, dataset_name, task, label) in enumerate(dataloader):
            idx_start = j * config.total_bs
            for b in range(len(pointset)):
                assert dataset_name[b] == train_index['dataset'][i]
                idx = idx_start + b
                class_index[label[b].int()].append(idx)
        
        train_index['index'].append(class_index)
    
    return train_index


def ctta(args, config, base_model, logger):
    domains = ['modelnet', 'shapenet', 'scannet', 'scanobjectnn']
    source_domains = domains.copy()
    target_domains = config.target_domain.split(',')
    for i in range(len(target_domains)):
        target_domains[i] = target_domains[i].strip()
        source_domains.remove(target_domains[i])
    tasks = ['reconstruction', 'denoising', 'registration']
    train_domains = {'data':[], 'dataset':[], 'task':[]}
    test_domains = {'data':[], 'dataset':[], 'task':[]}
    print('Source Domains:', source_domains)
    print('Target Domains:', target_domains)
    for domain in domains:
        if domain in target_domains:
            test_dataset_recon = PairDataset(config.dataset.test.others, domain, 'reconstruction', 'test')
            test_dataloader_recon = torch.utils.data.DataLoader(test_dataset_recon, batch_size=config.total_bs_test,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   num_workers=int(args.num_workers),
                                                   worker_init_fn=worker_init_fn)
            test_domains['data'].append([test_dataset_recon, test_dataloader_recon])
            test_domains['dataset'].append(domain)
            test_domains['task'].append('reconstruction')

            test_dataset_denoi = PairDataset(config.dataset.test.others, domain, 'denoising', 'test')
            test_dataloader_denoi = torch.utils.data.DataLoader(test_dataset_denoi, batch_size=config.total_bs_test,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   num_workers=int(args.num_workers),
                                                   worker_init_fn=worker_init_fn)
            test_domains['data'].append([test_dataset_denoi, test_dataloader_denoi])
            test_domains['dataset'].append(domain)
            test_domains['task'].append('denoising')
            
            test_dataset_regis = PairDataset(config.dataset.test.others, domain, 'registration', 'test')
            test_dataloader_regis = torch.utils.data.DataLoader(test_dataset_regis, batch_size=config.total_bs_test,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   num_workers=int(args.num_workers),
                                                   worker_init_fn=worker_init_fn)
            test_domains['data'].append([test_dataset_regis, test_dataloader_regis])
            test_domains['dataset'].append(domain)
            test_domains['task'].append('registration')
        else:
            train_dataset_recon = PairDataset(config.dataset.train.others, domain, 'reconstruction', 'train')
            train_dataloader_recon = torch.utils.data.DataLoader(train_dataset_recon, batch_size=config.total_bs_test,
                                             shuffle=True,
                                             drop_last=False,
                                             num_workers=int(args.num_workers),
                                             worker_init_fn=worker_init_fn)
            train_domains['data'].append([train_dataset_recon, train_dataloader_recon])
            train_domains['dataset'].append(domain)
            train_domains['task'].append('reconstruction')

            train_dataset_denoi = PairDataset(config.dataset.train.others, domain, 'denoising', 'train')
            train_dataloader_denoi = torch.utils.data.DataLoader(train_dataset_denoi, batch_size=config.total_bs_test,
                                             shuffle=True,
                                             drop_last=False,
                                             num_workers=int(args.num_workers),
                                             worker_init_fn=worker_init_fn)
            train_domains['data'].append([train_dataset_denoi, train_dataloader_denoi])
            train_domains['dataset'].append(domain)
            train_domains['task'].append('denoising')
            
            train_dataset_regis = PairDataset(config.dataset.train.others, domain, 'registration', 'train')
            train_dataloader_regis = torch.utils.data.DataLoader(train_dataset_regis, batch_size=config.total_bs_test,
                                             shuffle=True,
                                             drop_last=False,
                                             num_workers=int(args.num_workers),
                                             worker_init_fn=worker_init_fn)
            train_domains['data'].append([train_dataset_regis, train_dataloader_regis])
            train_domains['dataset'].append(domain)
            train_domains['task'].append('registration')

    print('Merging source domains...')
    source_dataset = None
    for idx, [train_dataset, train_dataloader] in enumerate(train_domains['data']):
        if source_dataset == None:
            source_dataset = train_dataset
        else:
            source_dataset = torch.utils.data.ConcatDataset([source_dataset, train_dataset])

    source_dataloader = torch.utils.data.DataLoader(source_dataset, batch_size=config.total_bs_test,
                                             shuffle=True,
                                             drop_last=False,
                                             num_workers=int(args.num_workers),
                                             worker_init_fn=worker_init_fn)

    print('Merging target domains...')
    target_dataset = None
    for idx, [test_dataset, test_dataloader] in enumerate(test_domains['data']):
        if target_dataset == None:
            target_dataset = test_dataset
        else:
            target_dataset = torch.utils.data.ConcatDataset([target_dataset, test_dataset])
                                    
    target_dataloader = torch.utils.data.DataLoader(target_dataset, batch_size=config.total_bs_test,
                                             shuffle=True,
                                             drop_last=False,
                                             num_workers=int(args.num_workers),
                                             worker_init_fn=worker_init_fn)



    # get class-level index
    train_index = get_index_in_class(config, train_domains)

    base_model.eval()

    mean_loss = 0
    feature_dict = {'task':[], 'dataset':[], 'local_feature':[], 'prototype':[]}
    load_feature_dict = args.load_prototype

    with torch.no_grad():
        # get train prototype: each dataset and each task
        if not load_feature_dict:
            for idx, [train_dataset, train_dataloader] in enumerate(train_domains['data']):
                train_dataset_name = train_domains['dataset'][idx]
                train_task = train_domains['task'][idx]
                feature_dict['task'].append(train_task)
                feature_dict['dataset'].append(train_dataset_name)
                print('Get prototype from %s of %s task...'%(train_dataset_name, train_task))
                
                global_feature_all = None
                local_feature_all = None
                for b, (pointset, target, rotation, dataset_name, task, label) in enumerate(train_dataloader):
                    batch_size = pointset.size(0)
                    assert train_task == task[b % batch_size]
                    pointset = pointset.cuda()
                    target = target.cuda()
                    # pc_neighborhood= get_patch(config, pointset)
                    
                    # only use source sample
                    local_feature_repeat = base_model(pointset, pointset, target, target, get_encoder=True)  #[B, 4G, C]
                    batch_size, seq_len, C = local_feature_repeat.size()
                    feature_len = int(seq_len/4)
                    local_feature = local_feature_repeat[:, :feature_len, :]

                    if local_feature_all == None:
                        local_feature_all = local_feature
                    else:
                        local_feature_all = torch.cat([local_feature_all, local_feature], dim=0)
                global_feature_all = torch.max(local_feature_all, dim=1)[0]
                feature_dict['local_feature'].append(local_feature_all)
                feature_dict['prototype'].append(torch.mean(local_feature_all, dim=0))
            torch.save(feature_dict, os.path.join(args.experiment_path, 'feature_dict.pth'))
            print('Trained feautre dictionaries saved in %s'%(os.path.join(args.experiment_path, 'feature_dict.pth')))
        else:
            feature_dict = torch.load(os.path.join(args.experiment_path, 'feature_dict.pth'))
    


    # test-time training
    # build model
    shift_model = builder.model_builder(config.shift_model)
    if args.use_gpu:
        shift_model.to(args.local_rank)

    # parameter setting
    start_epoch = 0
    best_metrics = Loss_Metric(100000.)
    metrics = Loss_Metric(0.)

    # resume ckpts
    if args.resume:
        start_epoch, best_metric = builder.resume_model(shift_model, args, logger = logger)
        best_metrics = Loss_Metric(best_metric)
    elif args.start_ckpts is not None:
        builder.load_model(shift_model, args.start_ckpts, logger = logger)

    print_log('Using Data parallel ...' , logger = logger)
    shift_model = nn.DataParallel(shift_model).cuda()
    # optimizer & scheduler
    optimizer, scheduler = builder.build_opti_sche(shift_model, config, test=True)
    
    if args.resume:
        builder.resume_optimizer(optimizer, args, logger = logger)

    shift_model.zero_grad()

    # pretrain on source dataset for intialization
    pretrain_epoch = config.pretrain_epoch
    for epoch in range(pretrain_epoch+1):
        epoch_start_time = time.time()
        batch_start_time = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['Loss'])

        num_iter = 0
        
        shift_model.train()  # set model to training mode
        base_model.eval()
        
        n_batches = len(source_dataloader)
        for idx, (pointset1_pc, target1, rotation1, dataset_name, task, label) in enumerate(source_dataloader):
            num_iter += 1
            n_itr = epoch * n_batches + idx
            batch_size = pointset1_pc.size(0)
            # random create pair
            datasets_choice = [source_domains.copy()] * batch_size
            pointset2_pc = torch.empty_like(pointset1_pc)
            target2 = torch.empty_like(target1)
            rotation2 = torch.empty_like(rotation1)
            for b in range(batch_size):
                datasets_choice_b = datasets_choice[b].copy()
                # datasets_choice_b.remove(dataset_name[b])
                dataset2_name = random.choice(datasets_choice_b)
                # get the prompt sample
                for i in range(len(train_domains['task'])):
                    if train_domains['task'][i] == task[b] and train_domains['dataset'][i] == dataset2_name:
                        assert train_index['task'][i] == task[b] and train_index['dataset'][i] == dataset2_name
                        train_dataset2 = train_domains['data'][i][0]
                        class_index = train_index['index'][i][label[b].int()]
                        pointset2_idx = random.choice(class_index)
                        pointset2_pc[b], target2[b], rotation2[b], _, _, _= train_dataset2[pointset2_idx]
                        if task == 'registration':
                            pointset2_origin = torch.clone(target2[b])
                            pointset2_origin = torch.from_numpy(y_flip(pointset2_origin)).float()
                            pointset2_pc[b] = torch.matmul(pointset2_origin, rotation1[b])  # align rotation with pointset1
                
            data_time.update(time.time() - batch_start_time)

            pointset1_pc = pointset1_pc.cuda()
            pointset2_pc = pointset2_pc.cuda()
            
            target1 = target1.cuda()
            target2 = target2.cuda()
                        
            # GSSM shifting
            local_feature_repeat = base_model(pointset1_pc, pointset1_pc, target1, target1, get_encoder=True)  #[B, 4G, C]
            batch_size, seq_len, C = local_feature_repeat.size()
            feature_len = int(seq_len/4)
            local_feature = local_feature_repeat[:, :feature_len, :]
            shifted_feature, loss_repul, _ = shift_model(local_feature, task)
            # print(local_feature, shifted_feature)

            _, _, loss, loss_cd = base_model(pointset2_pc, pointset1_pc, target2, target1, shifted_feature=shifted_feature)

            loss_cd = loss_cd * 1000 # 1e3
            loss = loss + loss_repul
            # print(loss_cd, loss_repul)

            try:
                loss.backward()
            except:
                loss = loss.mean()
                loss.backward()

            # forward
            if num_iter == config.step_per_update:
                num_iter = 0
                optimizer.step()
                shift_model.zero_grad()

            losses.update([loss.item()])

            batch_time.update(time.time() - batch_start_time)
            batch_start_time = time.time()

            loss_cd = loss_cd.mean()

            if idx % 20 == 0:
                print_log('[Initalizing Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s (Loss_cd = %.3f Loss_repul = %.3f) lr = %.6f' %
                            (epoch, pretrain_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
                            ['%.4f' % l for l in losses.val()], loss_cd.item(), loss_repul.item(), optimizer.param_groups[0]['lr']), logger = logger)
        if isinstance(scheduler, list):
            for item in scheduler:
                item.step(epoch)
        else:
            scheduler.step(epoch)
        epoch_end_time = time.time()

        print_log('[Initalizing] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
            (epoch,  epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],
            optimizer.param_groups[0]['lr']), logger = logger)
        

    # test-time training on target domain
    loss_dict = {'loss_cd':[], 'task':[], 'epoch':[]}

    test_epoch = config.test_epoch
    for epoch in range(test_epoch+1):
        epoch_start_time = time.time()
        batch_start_time = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['Loss'])

        num_iter = 0

        shift_model.train()  # set model to training mode
        base_model.eval()
        
        sample_num = 0
        loss_cd_mean_recon = 0.0
        loss_cd_mean_denoi = 0.0
        loss_cd_mean_regis = 0.0

        n_batches = len(target_dataloader)
        for idx, (pointset, target, rotation, dataset_name, task, label) in enumerate(target_dataloader):
            num_iter += 1
            n_itr = epoch * n_batches + idx

            batch_size = pointset.shape[0]
            pointset = pointset.cuda()
            target = target.cuda()
            rotation = rotation.cuda()
            
            # sample to 1024 points
            pointset, pos_idx = sample_farthest_points(pointset, K=1024)
            target = index_points(target, pos_idx)
            
            # GSSM shifting
            local_feature_repeat = base_model(pointset, pointset, target, target, get_encoder=True)  #[B, 4G, C]
            batch_size, seq_len, C = local_feature_repeat.size()
            feature_len = int(seq_len/4)
            local_feature = local_feature_repeat[:, :feature_len, :]
            shifted_feature, loss_repul, simi = shift_model(local_feature, task)

            # find nearest sample in nearest source domain
            sample_dict = find_nearest_sample(local_feature, task, simi, feature_dict)

            # get prompt sample
            pointset2 = torch.empty_like(pointset)
            target2 = torch.empty_like(target)
            for b in range(batch_size):
                train_dataset, train_dataloader = train_domains['data'][sample_dict['dataset_idx'][b]]
                pointset2[b], target2[b], _, _, _, _ = train_dataset[sample_dict['sample_idx'][b]]      
                if task[b] == 'registration':
                    pointset2_origin = torch.clone(target2[b])
                    pointset2_origin = torch.tensor(y_flip(pointset2_origin)).float().cuda()
                    pointset2[b] = torch.matmul(pointset2_origin, rotation[b])

            _, rebuild_points, loss, loss_cd = base_model(pointset2, pointset, target2, target, shifted_feature=shifted_feature)        

            loss_cd = loss_cd * 1000 # 1e3
            loss = loss + loss_repul

            try:
                loss.backward()
            except:
                loss = loss.mean()
                loss.backward()

            # forward
            if num_iter == config.step_per_update:
                num_iter = 0
                optimizer.step()
                shift_model.zero_grad()

            losses.update([loss.item()])

            batch_time.update(time.time() - batch_start_time)
            batch_start_time = time.time()

            for b in range(batch_size):
                if task[b] == 'reconstruction':
                    loss_cd_mean_recon += loss_cd[b].item()
                elif task[b] == 'denoising':
                    loss_cd_mean_denoi += loss_cd[b].item()
                elif task[b] == 'registration':
                    loss_cd_mean_regis += loss_cd[b].item()

            loss_cd = loss_cd.mean()
            sample_num += batch_size

            if idx % 20 == 0:
                print_log('[Testing Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s (Loss_cd = %.3f Loss_repul = %.3f) lr = %.6f' %
                            (epoch, test_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
                            ['%.4f' % l for l in losses.val()], loss_cd.item(), loss_repul.item(), optimizer.param_groups[0]['lr']), logger = logger)
        
        sample_num /= len(tasks)
        for task in tasks:
            if task == 'reconstruction':
                loss_cd_mean = loss_cd_mean_recon / sample_num
            elif task == 'denoising':
                loss_cd_mean = loss_cd_mean_denoi / sample_num
            elif task == 'registration':
                loss_cd_mean = loss_cd_mean_regis / sample_num

            loss_dict['loss_cd'].append(loss_cd_mean)
            loss_dict['task'].append(task)
            loss_dict['epoch'].append(epoch)
    
        if isinstance(scheduler, list):
            for item in scheduler:
                item.step(epoch)
        else:
            scheduler.step(epoch)
        epoch_end_time = time.time()

        print_log('[Testing] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
            (epoch,  epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],
            optimizer.param_groups[0]['lr']), logger = logger)
        

    return loss_dict

def main():
    # args
    args = get_args()
    # CUDA
    args.use_gpu = torch.cuda.is_available()
    # logger
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = os.path.join(args.experiment_path, f'{timestamp}-{args.seed}.log')
    logger = get_root_logger(log_file=log_file, name=args.log_name)
    # config
    config = get_config(args, logger = logger)
    # log
    log_args_to_file(args, 'args', logger=logger)
    log_config_to_file(config, 'config', logger=logger)

    print_log(args.comment)

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed + args.local_rank, deterministic=args.deterministic)  # seed + rank, for augmentation

    base_model = builder.model_builder(config.model)
    # load checkpoints
    builder.load_model(base_model, args.ckpts, logger=logger)
    
    if args.use_gpu:
        device = torch.device('cuda')
        base_model.to(device)
        
    print_log('Using Data parallel ...' , logger = logger)
    base_model = nn.DataParallel(base_model).cuda()

    loss_dict = ctta(args, config, base_model, logger)
    print('All tasks are done!')

    results_file = open("results.txt", "a")
    results_file.write('\nPrerained Epoch: %s\n'%(str(args.epoch)))

    for i in range(len(loss_dict['epoch'])):
        print('Epoch %d on %s task: loss_cd = %.3f'% (loss_dict['epoch'][i], loss_dict['task'][i], loss_dict['loss_cd'][i]))
        results_file.write('Epoch %d on %s task: loss_cd = %.3f\n'% (loss_dict['epoch'][i], loss_dict['task'][i], loss_dict['loss_cd'][i]))


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    main()
