from os.path import join

_BASE_DATA_PATH = "../data" #"/raid/NFS_SHARE/datasets/"

dataset_config = {
    'mnist': {
        'path': join(_BASE_DATA_PATH, 'mnist'),
        'normalize': ((0.1307,), (0.3081,)),
        # Use the next 3 lines to use MNIST with a 3x32x32 input
        # 'extend_channel': 3,
        # 'pad': 2,
        # 'normalize': ((0.1,), (0.2752,))    # values including padding
    },
    'svhn': {
        'path': join(_BASE_DATA_PATH, 'svhn'),
        'resize': (224, 224),
        'crop': None,
        'flip': False,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    },
    'cifar100': {
        'path': join(_BASE_DATA_PATH, 'cifar100'),
        'resize': None,
        'pad': 4,
        'crop': 32,
        'flip': True,
        'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
    },
    'cifar100_icarl': {
        'path': join(_BASE_DATA_PATH, 'cifar100'),
        'resize': None,
        'pad': 4,
        'crop': 32,
        'flip': True,
        'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
        'class_order': [
            68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
            28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
            98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
            36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
        ]
    },
    'cifar100_224': {
        'path': join(_BASE_DATA_PATH, 'cifar100'),
        'resize': 256,
        'pad': 0,
        'crop': 224,
        'flip': True,
        'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
    },
    'cifar100_icarl_224': {
        'path': join(_BASE_DATA_PATH, 'cifar100'),
        'resize': 256,
        'pad': 0,
        'crop': 224,
        'flip': True,
        'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
        'class_order': [
            68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
            28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
            98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
            36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
        ]
    },
    'vggface2': {
        'path': join(_BASE_DATA_PATH, 'VGGFace2'),
        'resize': 256,
        'crop': 224,
        'flip': True,
        'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169))
    },
    'imagenet_subset': {
        'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'),
        'resize': None,
        'crop': 224,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        'class_order': [
            68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
            28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
            98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
            36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
        ]
    },
    'tiny': {
        # wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
        'path': join(_BASE_DATA_PATH, 'tiny-imagenet-200'),
        'resize': None,
        'crop': 64,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    },
    'imagenet_subset_kaggle': {
        'path': join(_BASE_DATA_PATH, 'seed_1993_subset_100_imagenet'),
        'test_resize': 256,
        'crop': 224,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        'lbl_order': ['n03710193', 'n03089624', 'n04152593', 'n01806567', 'n02107574', 'n04409515', 'n04599235', 'n03657121', 'n03942813', 'n04026417', 'n02640242', 'n04591157', 'n01689811', 'n07614500', 'n03085013', 'n01882714', 'n02112706', 'n04266014', 'n02786058', 'n02526121', 'n03141823', 'n03775071', 'n04074963', 'n01531178', 'n04428191', 'n02096177', 'n02091467', 'n02971356', 'n02116738', 'n03017168', 'n02002556', 'n04355933', 'n02840245', 'n04371430', 'n01774384', 'n03223299', 'n04399382', 'n02088094', 'n02033041', 'n02814860', 'n04604644', 'n02669723', 'n03884397', 'n03250847', 'n04153751', 'n03016953', 'n02101388', 'n01914609', 'n02128385', 'n03075370', 'n02363005', 'n09468604', 'n02011460', 'n03785016', 'n12267677', 'n12768682', 'n12620546', 'n01537544', 'n03532672', 'n03691459', 'n02749479', 'n02105056', 'n02279972', 'n04442312', 'n02107908', 'n02229544', 'n04525305', 'n02102318', 'n15075141', 'n01514668', 'n04550184', 'n02115913', 'n02094258', 'n07892512', 'n01984695', 'n01990800', 'n02948072', 'n02112137', 'n02123597', 'n02917067', 'n03485407', 'n03759954', 'n02280649', 'n03290653', 'n01775062', 'n03527444', 'n03967562', 'n01744401', 'n02128757', 'n01729322', 'n03000247', 'n02950826', 'n03891332', 'n07831146', 'n02536864', 'n03697007', 'n02120079', 'n02951585', 'n03109150', 'n02168699'],
        'class_order': list(range(100))
    },
    'domainnet': {
        'path': join(_BASE_DATA_PATH, 'domainnet'),
        "resize": 256,
        'crop': 224,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    },
    'cub200': {
        'path': join(_BASE_DATA_PATH, 'CUB_200_2011'),
        "resize": 256,
        'crop': 224,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    },
    'aircraft': {
        'path': join(_BASE_DATA_PATH, "fgvc-aircraft-2013b"),
        "resize": 256,
        'crop': 224,
        'flip': True,
        'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    }
}

# Add missing keys:
for dset in dataset_config.keys():
    for k in ['test_resize', 'resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel']:
        if k not in dataset_config[dset].keys():
            dataset_config[dset][k] = None
    if 'flip' not in dataset_config[dset].keys():
        dataset_config[dset]['flip'] = False
