import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


__all__ = ["get_transforms"]


def get_transforms(model):
    if hasattr(model, 'pretrained_cfg'):
        mean = model.pretrained_cfg['mean']
        std = model.pretrained_cfg['std']
        crop_size = model.pretrained_cfg['input_size'][-1]
        resize_size = round(crop_size / model.pretrained_cfg['crop_pct'])
    else:
        # set to ImageNet defaults
        mean = IMAGENET_DEFAULT_MEAN
        std = IMAGENET_DEFAULT_STD
        crop_size = 224
        resize_size = 256
    
    # create transform
    transform_train = T.Compose([
        T.RandomResizedCrop(crop_size, interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std)]
    )
    transform_test = T.Compose([
        T.Resize(resize_size, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(crop_size),
        T.ToTensor(),
        T.Normalize(mean, std)
    ])

    return transform_train, transform_test
