import os
import torch
import random
import torch.utils.data as data
import numpy as np
import copy
from PIL import Image
import json

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def read_dataset(json_file, image_folder):
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    data = []
    for annotation in json_data['annotations']:
        image_id = annotation['image_id']
        caption = annotation['caption']
        file_name = f"{image_id:012}.jpg"  # Assuming the image file names are like '000000179765.jpg'
        image_path = os.path.join(image_folder, file_name)
        
        if os.path.exists(image_path):
            data.append({'img_path': image_path, 'file_name':file_name, 'caption': caption})
        else:
            print(f"Image not found for image_id: {image_id}, {image_path}")

    return data

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class COCO(data.Dataset):
    def __init__(self, configs, transform, train=True):

        self.configs = configs
        self.data_path = configs.data_path   # trainsets/coco/train2017
        self.json_file =  configs.json_file  #trainsets/coco/annotations/captions_train2017.json
        self.transform = transform
        self.data = read_dataset(json_file=self.json_file, image_folder=self.data_path)

     

      
    def __getitem__(self, index):

        # 根据索引获取数据样本
        item = self.data[index]
        img_path = item['img_path']
        image_name = item['file_name']

        image = Image.open(img_path).convert('RGB')
        
        # 如果有图像变换，应用变换
        if self.transform:
            image = self.transform(image)
 
        caption = item['caption']

        return {'image': image, 'image_name': image_name, 'text_captions': caption}

    def __len__(self):
        return len(self.data)


if __name__ == '__main__':

    import argparse
    import torchvision
    import torch.utils.data as data

    from torchvision import transforms
    from torchvision.utils import save_image


    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="./trainsets/coco/train2017/")
    parser.add_argument("--json_file", type=str, default="./trainsets/coco/annotations/captions_train2017.json")
    config = parser.parse_args()

    trans = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])

    JourneyDB_dataset = COCO(config, transform=trans)
    print(len(JourneyDB_dataset))
    JourneyDB_dataloader = data.DataLoader(dataset=JourneyDB_dataset, batch_size=1, shuffle=False, num_workers=1)

    for i, image_data in enumerate(JourneyDB_dataloader):
        print(image_data['image_name'])
        print(image_data['image'].shape)
        print(image_data['text_captions'])
