import unittest

import torch
from torch import nn
from gpytorch.kernels import ScaleKernel, MaternKernel

from bbo.algorithms import (
    create_wrapper,
    WrapperMean,
    WrapperKernel,
)


class WrapperMeanTest(unittest.TestCase):
    def setUp(self):
        dim = 10
        wrapper = create_wrapper(
            'mlp',
            config={
                'in_features': dim,
                'hidden_features': [8, 4],
            },
        )
        final_layer = nn.Linear(4, 1)
        self.mean_fn = WrapperMean(wrapper, final_layer)

    def test_parameters(self):
        parameters_dict = dict(self.mean_fn.named_parameters())
        self.assertTrue(sum([i.startswith('wrapper.mlp') for i in parameters_dict.keys()]))

    def test_shape(self):
        bs, dim = 20, 10
        x = torch.randn((bs, bs, dim))
        m = self.mean_fn(x)
        self.assertEqual(m.shape, (bs, bs, ))


class WrapperKernelTest(unittest.TestCase):
    def setUp(self):
        dim = 10
        base_kernel = ScaleKernel(MaternKernel())

        # kumar
        kumar_wrapper = create_wrapper('kumar', dict())
        kumar_final_layer = nn.Identity()
        self.kumar_kernel = WrapperKernel(base_kernel, kumar_wrapper, kumar_final_layer)

        # mlp
        mlp_wrapper = create_wrapper(
            'mlp',
            config={
                'in_features': dim,
                'hidden_features': [8, 4],
            },
        )
        mlp_final_layer = nn.Sequential(
            nn.Linear(4, 2),
            nn.Tanh(),
        )
        self.mlp_kernel = WrapperKernel(base_kernel, mlp_wrapper, mlp_final_layer)

    def test_kumar_parameters(self):
        kernel = self.kumar_kernel
        parameters_dict = dict(kernel.named_parameters())
        self.assertTrue('wrapper.alpha' in parameters_dict.keys())
        self.assertTrue('wrapper.beta' in parameters_dict.keys())

    def test_mlp_parameters(self):
        kernel = self.mlp_kernel
        parameters_dict = dict(kernel.named_parameters())
        self.assertTrue(sum([i.startswith('wrapper.mlp') for i in parameters_dict.keys()]))

    def test_shape(self):
        bs1, bs2, dim = 30, 20, 10
        x1, x2 = torch.randn((bs1, dim)), torch.randn((bs2, dim))
        for kernel in [self.kumar_kernel, self.mlp_kernel]:
            ret = kernel(x1, x2)
            self.assertEqual(ret.shape, (bs1, bs2))