import torch
from torch.autograd import Variable


class PointEstimate():

    def __init__(self, size_of_marginal, model, testloader, enable_cuda,cuda_num):
        self.enable_cuda = enable_cuda
        if self.enable_cuda:
            self.counter = torch.ones(1).cuda()
            self.marginal = torch.zeros(size_of_marginal).cuda()
        else:
            self.counter = torch.ones(1)
            self.marginal = torch.zeros(size_of_marginal)
        self.testloader = testloader
        self.model = model
        self.softmax = torch.nn.Softmax(dim=-1)
        self.cuda_num = cuda_num

    def update(self, predict):
        self.marginal = self.marginal + 1 / self.counter * (predict - self.marginal)
        self.counter.add_(1)

    def get_marginal(self):
        return self.marginal

    def evaluation(self):
        correct = 0
        total = 0
        for _, (x, y) in enumerate(self.testloader):
            x, y = Variable(x), Variable(y)
            if self.enable_cuda:
                x, y = x.cuda(self.cuda_num), y.cuda(self.cuda_num)
            likelihood = self.softmax(self.model(x))
            self.update(likelihood.data)
            marginal = self.get_marginal()
            _, yhat = torch.max(marginal, 1)
            total += y.size(0)
            correct += (yhat == y.data).sum()
        accuracy = 100 * float(correct) / float(total)
        return accuracy
