import argparse
import yaml
import cv2
import numpy as np
import os
from lpips import LPIPS
from pytorch_fid import fid_score

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

from utils.datasets import Img2ImgDataset
from nets.Unets import ConditionalUNet, UnconditionalUNet
from models.IRSDE import IRSDE
from models.CM import KarrasDenoiser as KD
from metric import cal_metric


def tensor_to_ndarray(tensor):
    tensor = tensor.cpu()
    tensor = torch.squeeze(tensor)
    if tensor.size(0)==3:
        ndarray = tensor.permute(1,2,0).detach().numpy()
    else:
        ndarray = tensor.detach().numpy()
    ndarray = ndarray*255
    ndarray[ndarray>255] = 255
    ndarray[ndarray<0] = 0
    ndarray = np.uint8(ndarray)
    #ndarray = cv2.cvtColor(ndarray, cv2.COLOR_BGR2RGB)
    return ndarray 

def load_img(img_path, data_transforms=None):
    if data_transforms == None:
        data_transforms = transforms.Compose([transforms.ToTensor()])
    img = Image.open(img_path)
    img = data_transforms(img)
    img = torch.unsqueeze(img,0).to(device)
    return img


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("-opt", type=str, help="Path to option YMAL file.")
    parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK',-1), type=int)
    args = parser.parse_args()

    with open(args.opt, mode="r") as f:
        opt = yaml.load(f, Loader=yaml.FullLoader)
    
    # if args.local_rank != -1:
    #     torch.cuda.set_device(args.local_rank)
    #     device=torch.device("cuda", args.local_rank)
    #     print("using: ",device, args.local_rank)
    #     dist.init_process_group(backend="nccl", init_method='env://')

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print("using: ",device)

    model_path = opt['path']['checkpoint']['main_model']

    model = ConditionalUNet(**opt["network"]['setting'])
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint) # strict=True
    print('main model has been loaded!')

    loss_fn = LPIPS(net='alex')
    loss_fn.to(device)

    model = model.to(device)
    inception_model = torchvision.models.inception_v3(pretrained=True)

    model = nn.DataParallel(model, device_ids=[0,1])

    data_transform = transforms.Compose([transforms.Resize((128,128)), 
                                          transforms.ToTensor()])

    sde = IRSDE(**opt['sde'], device=device)
    sde.set_model(model)

    cm = KD(**opt['cm'], device=device)
        
    input_path = opt['path']['data']['test']['input_path'] 
    GT_path = opt['path']['data']['test']['GT_path'] 
    save_path = opt['path']['save']['test'] + '4/'

    lpips_all = 0
    time = 0
    M = 4

    model.eval()
    for i in tqdm(sorted(os.listdir(input_path))):

        time += 1
        
        with torch.no_grad():
            input = load_img(input_path + i, data_transforms=data_transform).to(device)
            GT = load_img(GT_path + i, data_transforms=data_transform).to(device)
            sde.set_mu(input)

            once = 1
            step = opt['sde']['T']
            x0 = GT
            for m in range(M):
                if once == 1:
                    xt = sde.noise_state(input).to(device)
                    once = 0
                else:
                    xt = sde.generate_noise_states(x0,step)
                # xt = sde.generate_noise_states(GT,50) #opt['sde']['T']
                t = torch.tensor(step).view(1,1,1,1).to(device) #opt['sde']['T']
                x0_pred = sde.predict_x0(xt, t)
                x0 = cm.denoise(model, xt, input, t, x0_pred)
                step = int(step/2)
            output = x0

            # xt = sde.generate_noise_states(x0,50) #opt['sde']['T']
            # t = torch.tensor(50).view(1,1,1,1).to(device) #opt['sde']['T']
            # x0_pred = sde.predict_x0(xt, t)
            # output = cm.denoise(model, xt, input, t, x0_pred)

        GT = load_img(GT_path + i, data_transforms=data_transform).to(device)
        lpips = loss_fn(output,GT)
        lpips_all += lpips.cpu().item()

        output = tensor_to_ndarray(output.cpu())
        image = Image.fromarray(output)
        image.save(save_path + i)
    
    lpips_ave = lpips_all / time
    print("lpips is ",lpips_ave)
    
    psnr, ssim = cal_metric(save_path, GT_path)
    print("psnr is ",psnr)
    print("ssim is ",ssim)

    fid_value = fid_score.calculate_fid_given_paths([GT_path,save_path], batch_size=10, dims=2048, device='cuda:1')
    print('FID is ',fid_value)






        