import os
import torch
import random
import torch.utils.data as data
import numpy as np
import copy
from PIL import Image
import json

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class JourneyDB(data.Dataset):
    def __init__(self, configs,  transform, train=True):

        self.data_path = configs.data_path   # trainsets/JourneyDB/train
        self.json_file = os.path.join(configs.data_path, configs.json_file)   # trainsets/JourneyDB/train + re_train_anno.jsonl
        self.transform = transform
        self.data = []

        with open(self.json_file, 'r') as file:
            for line in file:
                data_item  = json.loads(line.strip())
                img_path = data_item.get('img_path')
                prompt = data_item.get('prompt')
                caption = data_item.get('Task2', {}).get('Caption', '')

                # 保存img_path和prompt
                self.data.append({'img_path': img_path, 'prompt': prompt, 'caption': caption})

      
    def __getitem__(self, index):

        # 根据索引获取数据样本
        item = self.data[index]
        img_path = item['img_path']
        # image_name = img_path.rsplit('.', 1)[0]
        image_name = img_path.split('/')[-1]
        # 获取图像路径并加载图像
        full_img_path = os.path.join(self.data_path,  img_path)
        # print(full_img_path)

        image = Image.open(full_img_path).convert('RGB')
        
        # 如果有图像变换，应用变换
        if self.transform:
            image = self.transform(image)
        prompt = item['prompt']
        caption = item['caption']

        return {'image': image, 'image_name': image_name, 'text_prompts': prompt , 'text_captions': caption}

    def __len__(self):
        return len(self.data)


if __name__ == '__main__':

    import argparse
    import torchvision
    import torch.utils.data as data

    from torchvision import transforms
    from torchvision.utils import save_image


    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="./trainsets/JourneyDB/valid/")
    parser.add_argument("--json_name", type=str, default="re_valid_anno.jsonl")
    config = parser.parse_args()

    trans = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])

    JourneyDB_dataset = JourneyDB(config, transform=trans)
    print(len(JourneyDB_dataset))
    JourneyDB_dataloader = data.DataLoader(dataset=JourneyDB_dataset, batch_size=1, shuffle=False, num_workers=1)

    for i, image_data in enumerate(JourneyDB_dataloader):
        print(image_data['image_name'])
        print(image_data['image'].shape)
        print(image_data['text_prompts'])
        print(image_data['text_captions'])
