import os
from fastai.vision.data import ImageDataLoaders
from fastai.vision.augment import CropPad, RandomCrop, Flip, setup_aug_tfms
from fastai.vision.data import Pipeline, ToTensor
from timm.data.auto_augment import rand_augment_transform
from adversarialML.biologically_inspired_models.src.utils import load_json


dataset_dir = '/home/hippo/workhorse3/tiny-imagenet-200/'

with open(os.path.join(dataset_dir, 'train_paths.txt')) as f:
    train_file_list = [os.path.join(dataset_dir, l.strip()) for l in f.readlines()]
with open(os.path.join(dataset_dir, 'val_paths.txt')) as f:
    test_file_list = [os.path.join(dataset_dir, l.strip()) for l in f.readlines()]
with open(os.path.join(dataset_dir, 'wnids.txt')) as f:
    wnid2idx = {}
    for l in f.readlines():
        wnid2idx.setdefault(l.strip(), len(wnid2idx))
        
train_tfms = [
    CropPad(72),
    RandomCrop(64),
    
    ]

train_loader = ImageDataLoaders.from_path_func(dataset_dir, test_file_list, lambda f: wnid2idx[f.split('/')[6]], valid_pct=0.05, bs=128, val_bs=128, item_tfms=setup_aug_tfms(train_tfms))
# test_loader = ImageDataLoaders.from_path_func(dataset_dir, test_file_list, lambda f: f.split('/')[1], valid_pct=0.05, bs=128, val_bs=128)
