import numpy as np
import torch


def MeanUpsample(x, scale):
    n, c, h, w = x.shape
    out = torch.zeros(n, c, h, scale, w, scale).to(x.device) + x.view(n,c,h,1,w,1)
    out = out.view(n, c, scale*h, scale*w)
    return out

class SuperResolution:
    def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
        self.channels=channels
        self.img_dim=img_dim
        self.ratio=ratio
        self.device=device
        self.A = torch.nn.AdaptiveAvgPool2d((256//self.ratio,256//self.ratio))
        self.Ap = lambda z: MeanUpsample(z,self.ratio)
    
    def downsampling(self, img):
        # assert img.shape[1] == 3
        # down_img = torch.zeros([img.shape[0], img.shape[1], int(img.shape[2]/self.ratio), int(img.shape[3]/self.ratio)]).to(self.device)
        # for k in range(self.ratio):
        #     for j in range(self.ratio):
        #         down_img += img[:, :, k::self.ratio, j::self.ratio]
        # # print(img.shape)
        # down_img /= self.ratio**2
        return self.A(img)

    
    def upsampling(self, img):
        return self.Ap(img)