import torch
import os
import random
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from util.fragments import get_multiview_fragments



class WPCDataset(Dataset):
    def __init__(self, mode: str, fold=0, input_size=512, crop_size=224, num_view=6):
        index_root = './data/index/WPC'
        modeDict = {
            'train': 'fold_{}_train_{}view_{}.csv'.format(fold, num_view, input_size),
            'test': 'fold_{}_test_{}view_{}.csv'.format(fold, num_view, input_size),
            'total': 'total.xlsx'
        }
        index_file_name = modeDict[mode]
        index_file_path = os.path.join(index_root, index_file_name)
        self.file = pd.read_csv(index_file_path)
        self.crop_size = crop_size
        self.num_view = num_view
        if mode == 'train':
            self.transform = transforms.Compose([
                # transforms.RandomCrop(self.crop_size),
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                ## ImageNet default normalization
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
            ])
        self.transform_frags = transforms.Compose([
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.file['score'])//self.num_view

    def __getitem__(self, idx):
        imgs_path = list(self.file.loc[idx*self.num_view:(idx+1)*self.num_view-1,'imgs_path'])
        imgs = torch.zeros((self.num_view, 3, self.crop_size, self.crop_size))
        for view in range(self.num_view):
            img = Image.open(imgs_path[view]).convert('RGB')
            img = self.transform(img)
            imgs[view,...] = img
        frags = get_multiview_fragments(imgs).squeeze(0)
        frags = self.transform_frags(frags)
        mos = torch.tensor(self.file.iloc[idx]['score'], dtype=torch.float32)/5
        return imgs, frags, mos