import torch
import numpy as np

from src.dl.models.skeleton import Skeleton


class IdentityModule(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.out_features = dim

    def forward(self, xs):
        return xs


class InnerGradDescentModule(torch.nn.Module):
    def __init__(self, learning_rate, x_init, A):
        super().__init__()
        self.x = torch.nn.parameter.Parameter(x_init)
        self.lr = learning_rate
        self.A = A

    def forward(self, z, x):
        # Run the inner gradient descent.
        return z - self.lr * (z @ self.A.T - self.x) @ self.A


class BilevelQuadraticTestCase:
    def __init__(self, A, solver_fn, use_backwards_solver, dim, x_init_value, lr,
                 num_additional_unroll_steps_after_implicit_forward=0):
        self.A = A
        self.use_backwards_solver = use_backwards_solver
        self.dim = dim
        self.x_init_value = x_init_value
        self.lr = lr
        self.num_additional_unroll_steps_after_implicit_forward = num_additional_unroll_steps_after_implicit_forward
        self.solver_fn = solver_fn

    def construct_skeleton(self):
        # Prepare the cell to iterate over.
        input_preprocessor = IdentityModule(dim=self.dim)
        cell = InnerGradDescentModule(learning_rate=self.lr, x_init=self.x_init_value, A=self.A)

        backwards_solver = None if self.use_backwards_solver is False else self.solver_fn

        # Instantiate the skeleton.
        return Skeleton(
            input_preprocessor=input_preprocessor,
            cell=cell,
            classifier_layer=None,
            forward_solver=self.solver_fn,
            backward_solver=backwards_solver,
            num_additional_unroll_steps_after_implicit_forward=self.num_additional_unroll_steps_after_implicit_forward
        )

    def get_fixed_point_and_grad(self):
        # Construct the skeleton.
        skeleton = self.construct_skeleton()

        # Run the forward pass, compute the fixed point and gradients.
        x = torch.zeros((1, self.dim))  # Note that the value of tensor never gets used - only its shape is relevant.
        z, model_dict = skeleton(x)

        # Compute outer loss and backprop. Make sure that the computed gradients are correct.
        outer_loss = 0.5 * torch.sum(z ** 2)
        outer_loss.backward()
        x_grad = skeleton.f.x.grad

        return z, x_grad, model_dict

    def get_correct_grad(self, z):
        return z @ torch.linalg.inv(self.A)

    def get_inner_loss(self, z):
        return torch.sum((z @ self.A.T - self.x_init_value) ** 2)

    def get_correct_fixed_point(self):
        return self.x_init_value @ torch.linalg.inv(self.A.T)


def test_solver_correctness_using_bilevel_quadratic_problem(solver_fn, use_backward_solver=True):
    curr_lr = 0.1
    fixed_point_tolerance = 1e-4
    grad_tolerance = 1e-4

    A1 = torch.Tensor(np.array([[2., 0.], [0, 0.5]]))
    A2 = torch.Tensor(np.array([[2., 0.], [1, 0.5]]))
    A3 = torch.Tensor(np.array([[1., 0.5], [0.5, 1]]))
    As = [A1, A2, A3]
    correctness_dicts = list()  # See if you can make this a dict where  key relates to curr_A and use_backward_solver.
    for curr_A in As:
        # Instantiate the bilevel quadratic test case.
        curr_dim = curr_A.shape[0]
        x_init_value_ = torch.ones((1, curr_dim))
        test_case = BilevelQuadraticTestCase(
            A=curr_A,
            solver_fn=solver_fn,
            use_backwards_solver=use_backward_solver,
            dim=curr_dim,
            x_init_value=x_init_value_,
            lr=curr_lr,
        )

        # Get the fixed point and gradients.
        z, curr_x_grad, model_logs = test_case.get_fixed_point_and_grad()

        # ____ Check the results. ____
        correctness_checks = dict()
        # Check that the correct gradient computation means is selected.
        if not use_backward_solver:
            correctness_checks["backwards_mode_correctness"] = model_logs["inferred_backward_mode"] == "unrolled"
        else:
            correctness_checks["backwards_mode_correctness"] = model_logs["inferred_backward_mode"] == "implicit"

        # Check that the fixed point is correct.
        curr_inner_loss = test_case.get_inner_loss(z=z)
        correctness_checks["fixed_point_correctness"] = torch.allclose(curr_inner_loss,
                                                                       torch.zeros_like(curr_inner_loss),
                                                                       atol=fixed_point_tolerance)
        if not correctness_checks["fixed_point_correctness"]:
            print(f"Correct fixed point: {test_case.get_correct_fixed_point()}")
            print(f"Estimated fixed point: {z}")
            breakpoint()

        # Check that the gradients are computed correctly.
        correct_x_grad = test_case.get_correct_grad(z=z)
        correctness_checks["grad_value_correctness"] = torch.allclose(curr_x_grad, correct_x_grad,
                                                                      atol=grad_tolerance)
        if not correctness_checks["grad_value_correctness"]:
            print(f"Estimated gradients: {curr_x_grad}")
            print(f"Correct gradients: {correct_x_grad}")
            breakpoint()

        correctness_dicts.append(correctness_checks)
    return correctness_dicts
