import torch
#from retina_preproc import RetinaBlurFilter
from retina_blur2 import RetinaBlurFilter
import numpy as np
from PIL import Image

def convert_image_ndarray_to_tensor(img):
    return torch.from_numpy(img).transpose(2,1).transpose(1,0)

def convert_image_tensor_to_ndarray(img):
    return img.transpose(0,1).transpose(1,2).numpy()

img =Image.open('animals-cats-cute-45170-min-1024x569.jpg')
s = 224/min(img.size)
img = img.resize((int(img.size[0]*s), int(img.size[1]*s)))
img = img.crop((0, 0, 224, 224))
img = np.asarray(img) / 255

p = RetinaBlurFilter.ModelParams(RetinaBlurFilter, input_shape=[3, 1600, 1600], cone_std=0.12, rod_std=.09, max_rod_density=.12, scale=8, view_scale=6, loc_mode='random_uniform_2', min_res=33, max_res=360)
filter = p.cls(p).cuda()
print(filter)
ptimg = convert_image_ndarray_to_tensor(img).float().cuda()
ptimg = torch.nn.functional.interpolate(ptimg.unsqueeze(0), scale_factor=1)
stacked_ptimg = torch.repeat_interleave(ptimg, 128, 0)
print(stacked_ptimg.shape)
fimg = filter.forward((stacked_ptimg))
