import os
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import PIL
from PIL import Image
from torchvision import transforms
import torch
import pdb
import numpy as np
import clip
import logging
import random
from torchvision.transforms import InterpolationMode


class TwoAFC(Dataset):
    """2AFC dataset."""

    def __init__(self, csv_file, split, use_augmentations=False, load_size=224, interpolation=transforms.InterpolationMode.BICUBIC,
                 model_type="mae",preprocess='DEFAULT'):
        self.csv = pd.read_csv(csv_file)
        self.csv = self.csv[self.csv['votes'] > 5]
        self.preprocess = preprocess
        self.split = split
        self.model_type = model_type
        self.use_augmentations = use_augmentations

        if split == 'train':
            self.csv = self.csv[self.csv['split'] == 'train']
        elif split == 'val':
            self.csv = self.csv[self.csv['split'] == 'val']
            logging.info(f"Val set size: {len(self.csv.index)}")
            logging.info(f"Human p for full dataset: {self.csv[['left_vote', 'right_vote']].max(axis=1).mean()}")
        elif split == 'test_imagenet':
            self.csv = self.csv[self.csv['split'] == 'test']
            self.csv = self.csv[self.csv['is_imagenet'] == True]
        elif split == 'test_no_imagenet':
            self.csv = self.csv[self.csv['split'] == 'test']
            self.csv = self.csv[self.csv['is_imagenet'] == False]
        else:
            raise ValueError(f'Invalid split: {split}')

        self.load_size = load_size
        self.interpolation = interpolation

    def __len__(self):
        return len(self.csv)

    def _extractor_preprocess(self, pil_img, same_transform=None):
        pil_img = pil_img.convert('RGB')
        if same_transform is not None:
            t = same_transform
            prep_img = t.forward(pil_img)
        else:
            t = transforms.Compose([
                transforms.Resize((self.load_size, self.load_size), interpolation=self.interpolation),
                transforms.ToTensor()
            ])
            prep_img = t(pil_img)
        return prep_img

    def lpips_preprocess(self, pil_img, cent=1., factor=0.5):
        return transforms.ToTensor()(pil_img) / factor - cent

    def dists_preprocess(self, pil_img):
        pil_img = transforms.Resize((256, 256))(pil_img)
        pil_img = transforms.ToTensor()(pil_img)
        return pil_img

    def psnr_preprocess(self, pil_img):
        return transforms.ToTensor()(pil_img)

    def ssim_preprocess(self, pil_img):
        return transforms.ToTensor()(pil_img)

    def clip_ex_preprocess(self, pil_img):
        return self.clip_preprocess(pil_img)

    def gen_preprocess(self, pil_img, same_transform=None):
        if self.preprocess == 'DEFAULT':
            return self._extractor_preprocess(pil_img, same_transform=same_transform)
        elif self.preprocess == 'LPIPS':
            return self.lpips_preprocess(pil_img)
        elif self.preprocess == 'DISTS':
            return self.dists_preprocess(pil_img)
        elif self.preprocess == 'PSNR':
            return self.psnr_preprocess(pil_img)
        elif self.preprocess == 'SSIM':
            return self.ssim_preprocess(pil_img)
        elif self.preprocess == 'CLIP_EX':
            return self.clip_ex_preprocess(pil_img)
        else:
            raise ValueError("Unknown preprocessing method")


    def __getitem__(self, idx):
        same_transform = None
        image_size = 768
        if self.use_augmentations and self.split == 'train':
            same_transform = transforms_class(Image.open(self.csv.iloc[idx, 4]),image_size,self.load_size,self.interpolation)

        id = self.csv.iloc[idx, 0]
        p = self.csv.iloc[idx, 2].astype(np.float32)
        img_ref = self.gen_preprocess(Image.open(self.csv.iloc[idx, 4]),same_transform=same_transform)
        img_left = self.gen_preprocess(Image.open(self.csv.iloc[idx, 5]),same_transform=same_transform)
        img_right = self.gen_preprocess(Image.open(self.csv.iloc[idx, 6]),same_transform=same_transform)
        return img_ref, img_left, img_right, p, id


class transforms_class():
    def __init__(self,pil_img,image_size,load_size,interpolation):
        crop_size = 0.95
        i, j, h, w = transforms.RandomCrop.get_params(pil_img,output_size=(int(crop_size * image_size), int(crop_size * image_size)))
        self.i = i
        self.j = j
        self.h = h
        self.w = w
        self.load_size = load_size
        self.interpolation = interpolation
        self.is_horizontal_flip = random.random() > 0.5

    def forward(self, pil_img):
        if self.is_horizontal_flip:
            pil_img = transforms.functional.hflip(pil_img)

        pil_img = transforms.functional.crop(pil_img, self.i, self.j, self.h, self.w)

        pil_img = transforms.functional.resize(pil_img, (self.load_size, self.load_size), interpolation=self.interpolation)

        pil_img = transforms.functional.to_tensor(pil_img)
        return pil_img


class THINGS(Dataset):
    def __init__(self, txt_file, load_size=224, interpolation=transforms.InterpolationMode.BICUBIC,preprocess='DEFAULT'):
        file = open(txt_file, "r")
        self.txt = file.readlines()
        self.load_size = load_size
        self.interpolation = interpolation
        self.preprocess = preprocess

    def __len__(self):
        return len(self.txt)

    def _extractor_preprocess(self, pil_img, same_transform=None):
        pil_img = pil_img.convert('RGB')
        if same_transform is not None:
            t = same_transform
            prep_img = t.forward(pil_img)
        else:
            t = transforms.Compose([
                transforms.Resize((self.load_size, self.load_size), interpolation=self.interpolation),
                transforms.ToTensor()
            ])
            prep_img = t(pil_img)
        return prep_img

    def lpips_preprocess(self, pil_img, cent=1., factor=0.5):
        return transforms.ToTensor()(pil_img) / factor - cent

    def dists_preprocess(self, pil_img):
        pil_img = transforms.Resize((256, 256))(pil_img)
        pil_img = transforms.ToTensor()(pil_img)
        return pil_img

    def psnr_preprocess(self, pil_img):
        return transforms.ToTensor()(pil_img)

    def ssim_preprocess(self, pil_img):
        return transforms.ToTensor()(pil_img)

    def clip_ex_preprocess(self, pil_img):
        return self.clip_preprocess(pil_img)

    def gen_preprocess(self, pil_img, same_transform=None):
        if self.preprocess == 'DEFAULT':
            return self._extractor_preprocess(pil_img, same_transform=same_transform)
        elif self.preprocess == 'LPIPS':
            return self.lpips_preprocess(pil_img)
        elif self.preprocess == 'DISTS':
            return self.dists_preprocess(pil_img)
        elif self.preprocess == 'PSNR':
            return self.psnr_preprocess(pil_img)
        elif self.preprocess == 'SSIM':
            return self.ssim_preprocess(pil_img)
        elif self.preprocess == 'CLIP_EX':
            return self.clip_ex_preprocess(pil_img)
        else:
            raise ValueError("Unknown preprocessing method")

    def __getitem__(self, idx):
        im_1, im_2, im_3 = self.txt[idx].split()

        im_1 = Image.open(f'util/things_src_images/{im_1}.png')
        im_2 = Image.open(f'util/things_src_images/{im_2}.png')
        im_3 = Image.open(f'util/things_src_images/{im_3}.png')

        im_1 = self.gen_preprocess(im_1)
        im_2 = self.gen_preprocess(im_2)
        im_3 = self.gen_preprocess(im_3)

        return im_1, im_2, im_3


class BAPPS(Dataset):
    def __init__(self, root_dir, load_size=224, interpolation=transforms.InterpolationMode.LANCZOS, model_type="mae",
                 preprocess='DEFAULT', inc_distort=True, inc_real=True):
        distort = ["cnn", "traditional"]
        real = ["color", "deblur", "superres", "frameinterp"]
        data_types = []
        if inc_distort:
            data_types += distort
        if inc_real:
            data_types += real
        self.judge_paths = []
        self.p0_paths = []
        self.p1_paths = []
        self.ref_paths = []

        for dt in data_types:
            list_dir = os.path.join(os.path.join(root_dir, dt), "judge")
            for fname in os.scandir(list_dir):
                self.judge_paths.append(os.path.join(list_dir, fname.name))
                self.p0_paths.append(os.path.join(os.path.join(os.path.join(root_dir, dt), "p0"), fname.name.split(".")[0] + ".png"))
                self.p1_paths.append(
                    os.path.join(os.path.join(os.path.join(root_dir, dt), "p1"), fname.name.split(".")[0] + ".png"))
                self.ref_paths.append(
                    os.path.join(os.path.join(os.path.join(root_dir, dt), "ref"), fname.name.split(".")[0] + ".png"))


        self.load_size = load_size
        self.interpolation = interpolation
        self.preprocess = preprocess
        if 'dino' in model_type:
            self.mean = (0.485, 0.456, 0.406)
            self.std = (0.229, 0.224, 0.225)
        elif 'open_clip' in model_type:
            self.mean = (0.48145466, 0.4578275, 0.40821073)
            self.std = (0.26862954, 0.26130258, 0.27577711)
        elif 'clip' in model_type:
            self.mean = (0.48145466, 0.4578275, 0.40821073)
            self.std = (0.26862954, 0.26130258, 0.27577711)
        elif 'mae' in model_type:
            self.mean = (0.485, 0.456, 0.406)
            self.std = (0.229, 0.224, 0.225)
        else:
            self.mean = (0.5, 0.5, 0.5)
            self.std = (0.5, 0.5, 0.5)

    def __len__(self):
        return len(self.judge_paths)

    def _extractor_preprocess(self, pil_img, same_transform=None):
        pil_img = pil_img.convert('RGB')
        if same_transform is not None:
            t = same_transform
            prep_img = t.forward(pil_img)
        else:
            t = transforms.Compose([
                transforms.Resize((self.load_size, self.load_size), interpolation=self.interpolation),
                transforms.ToTensor()
            ])
            prep_img = t(pil_img)
        return prep_img

    def lpips_preprocess(self, pil_img, cent=1., factor=0.5):
        pil_img = pil_img.convert('RGB')
        pil_img = transforms.Resize((224, 224), interpolation=self.interpolation)(pil_img)
        return transforms.ToTensor()(pil_img) / factor - cent

    def dists_preprocess(self, pil_img):
        pil_img = pil_img.convert('RGB')
        pil_img = transforms.Resize((256, 256), interpolation=self.interpolation)(pil_img)
        pil_img = transforms.ToTensor()(pil_img)
        return pil_img

    def psnr_preprocess(self, pil_img):
        pil_img = pil_img.convert('RGB')
        pil_img = transforms.Resize((224, 224), interpolation=self.interpolation)(pil_img)
        return transforms.ToTensor()(pil_img)

    def ssim_preprocess(self, pil_img):
        pil_img = pil_img.convert('RGB')
        pil_img = transforms.Resize((224, 224), interpolation=self.interpolation)(pil_img)
        return transforms.ToTensor()(pil_img)

    def clip_ex_preprocess(self, pil_img):
        return self.clip_preprocess(pil_img)

    def gen_preprocess(self, pil_img):
        if self.preprocess == 'DEFAULT':
            return self._extractor_preprocess(pil_img)
        elif self.preprocess == 'LPIPS':
            return self.lpips_preprocess(pil_img)
        elif self.preprocess == 'DISTS':
            return self.dists_preprocess(pil_img)
        elif self.preprocess == 'PSNR':
            return self.psnr_preprocess(pil_img)
        elif self.preprocess == 'SSIM':
            return self.ssim_preprocess(pil_img)
        elif self.preprocess == 'CLIP_EX':
            return self.clip_ex_preprocess(pil_img)
        else:
            raise ValueError("Unknown preprocessing method")

    def __getitem__(self, idx):
        judge = np.load(self.judge_paths[idx])
        im_left = self.gen_preprocess(Image.open(self.p0_paths[idx]))
        im_right = self.gen_preprocess(Image.open(self.p1_paths[idx]))
        im_ref = self.gen_preprocess(Image.open(self.ref_paths[idx]))
        return im_ref, im_left, im_right, judge

DATASETS = {
    '2afc': TwoAFC,
    'things': THINGS,
    'bapps_real': lambda root_dir, model_type, preprocess: BAPPS(root_dir=root_dir, model_type=model_type,
                                                                 preprocess=preprocess, inc_distort=False),
    'bapps_distort': lambda root_dir, model_type, preprocess: BAPPS(root_dir=root_dir, model_type=model_type,
                                                                 preprocess=preprocess, inc_real=False),
    'bapps': BAPPS
}