import math

from typing import Dict

import numpy as np
import pandas as pd
import torch
from botorch.test_functions.base import BaseTestProblem
from sklearn.svm import SVR
from torch import nn, optim, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from xgboost import XGBRegressor

from .rover import ConstantOffsetFn, create_large_domain


class PyTorchCNNProblem(BaseTestProblem):
    def __init__(
        self,
        outlier_fraction: float,
        negate: bool = False,
        dataset: str = "MNIST",
    ) -> None:
        self.dim = 5
        self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
        super().__init__(negate=negate, noise_std=None)

        if dataset == "MNIST":
            dataset_fn = datasets.MNIST
        elif dataset == "FashionMINST":
            dataset_fn = datasets.FashionMNIST
        else:
            raise ValueError(f"Unknown dataset name: {dataset}")

        train_set = dataset_fn(
            root="./data", train=True, download=True, transform=transforms.ToTensor()
        )
        test_set = dataset_fn(
            root="./data", train=False, download=True, transform=transforms.ToTensor()
        )
        self.pytorch_cnn_problem = PyTorchCNNRunner(
            name=dataset, train_set=train_set, test_set=test_set
        )
        self.outlier_fraction = outlier_fraction

        # Cache (X, Y) pairs to avoid reevaluating the same thing twice
        self.cache = {}

    def forward(
        self, X: Tensor, noise: bool = False, force_no_io_failure: bool = False
    ):
        if noise:
            raise ValueError("Noise is not supported for this problem.")

        if X.ndim == 1:
            X = X.clone().unsqueeze(0)
        assert X.ndim == 2

        vals = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
        for i, x in enumerate(X):
            if str(x) in self.cache:
                vals[i] = self.cache[str(x)]
                print("In cache!")
            else:
                lr = 10 ** (-4 + 3 * x[0].item())  # Map from [0, 1] to [1e-4, 1e-1]
                momentum = x[1].item()
                weight_decay = x[2].item()
                step_size = int(1 + 99 * x[3].item())  # Map from [0, 1] to [1, 100]
                gamma = x[4].item()
                if force_no_io_failure:
                    io_failure = False
                else:
                    io_failure = torch.rand(1).item() < self.outlier_fraction
                val = self.pytorch_cnn_problem.train_and_evaluate(
                    lr=lr,
                    momentum=momentum,
                    weight_decay=weight_decay,
                    step_size=step_size,
                    gamma=gamma,
                    io_failure=io_failure,
                )
                if math.isnan(val):
                    print("Encountered NaN, setting accuracy to 0")
                    val = 0.0
                vals[i] = val
                if not io_failure:
                    # Only cache full evaluations (i.e., no IO failure)
                    self.cache[str(x)] = val
        return vals

    def evaluate_true(self, X: Tensor) -> Tensor:
        return self.forward(X, force_no_io_failure=True)


class RoverProblem(BaseTestProblem):
    def __init__(self, dim: int, outlier_fraction: float, **tkwargs) -> None:
        # domain of this function (switch lb back to being zero if you don't want to go left/down)
        assert dim % 2 == 0
        n_points = dim // 2
        self.lb = -0.5 * 4 / dim * torch.ones(dim, **tkwargs)
        self.ub = 4 / dim * torch.ones(dim, **tkwargs)
        self.domain = create_large_domain(n_points=n_points)
        self.f_max = 5
        self._objective = ConstantOffsetFn(self.domain, self.f_max)

        # Initialize base problem
        self.outlier_fraction = outlier_fraction
        self.dim = dim
        self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
        super().__init__(negate=False, noise_std=None)

    def forward(
        self, X: Tensor, noise: bool = False, force_no_crash: bool = False
    ) -> Tensor:
        if noise:
            raise ValueError("Noise is not supported for this problem.")

        if X.ndim == 1:
            X = X.clone().unsqueeze(0)
        assert X.ndim == 2

        # Map from [0, 1] to the domain of this function
        X = self.lb + (self.ub - self.lb) * X

        # Evaluate the objective function
        vals = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
        for i, x in enumerate(X):
            if force_no_crash:
                crash = False
            else:
                crash = torch.rand(1).item() < self.outlier_fraction
            vals[i] = self._objective(x.cpu().numpy(), crash=crash)
            if crash:
                print(
                    f"Crashed reward: {vals[i]:.2f}, True reward: {self._objective(x.cpu().numpy(), crash=False):.2f}"
                )
        return vals

    def evaluate_true(self, X: Tensor) -> Tensor:
        return self.forward(X, force_no_crash=True)


class SVMProblem(BaseTestProblem):
    def __init__(
        self, dataset_path: str, outlier_fraction: float, n_features=100
    ) -> None:
        # `dataset_path` here should be pointing to the "slice_localization_data.csv" file from
        # the "Relative location of CT slices on axial axis" UCI data set from
        # https://archive.ics.uci.edu/dataset/206/relative+location+of+ct+slices+on+axial+axis
        self.train_x, self.train_y, self.test_x, self.test_y = load_uci_data(
            path=dataset_path, seed=0, n_features=n_features
        )
        self.dim = 3
        self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
        super().__init__(negate=False, noise_std=None)
        self.outlier_fraction = outlier_fraction
        # Cache (X, Y) pairs to avoid reevaluating the same thing twice
        self.cache = {}

    def forward(
        self, X: Tensor, noise: bool = False, force_no_io_failure: bool = False
    ):
        if noise:
            raise ValueError("Noise is not supported for this problem.")

        if X.ndim == 1:
            X = X.clone().unsqueeze(0)
        assert X.ndim == 2

        vals = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
        for i, x in enumerate(X):
            if str(x) in self.cache:
                vals[i] = self.cache[str(x)]
                print("In cache!")
            else:
                epsilon = 0.01 * 10 ** (2 * x[-3])  # Default = 0.1
                C = 0.01 * 10 ** (4 * x[-2])  # Default = 1.0
                gamma = (
                    (1 / self.train_x.shape[-1]) * 0.1 * 10 ** (2 * x[-1])
                )  # Default = 1 / dim
                if force_no_io_failure:
                    io_failure = False
                else:
                    io_failure = torch.rand(1).item() < self.outlier_fraction

                model = SVR(
                    C=C,
                    epsilon=epsilon,
                    gamma=gamma,
                    tol=0.001,
                    cache_size=1000,
                    verbose=True,
                )
                if io_failure:
                    num_rows = torch.randint(100, 1000, size=(1,)).item()
                    inds = np.random.permutation(self.train_x.shape[0])[:num_rows]
                    model.fit(self.train_x[inds, :], self.train_y[inds].copy())
                else:
                    model.fit(self.train_x, self.train_y.copy())

                pred = model.predict(self.test_x)
                mse = ((pred - self.test_y) ** 2).mean(axis=0)
                val = math.sqrt(mse)
                if io_failure:
                    print(f"Crashed reward: {val:.2f}")
                vals[i] = val
                if not io_failure:
                    # Only cache full evaluations (i.e., no IO failure)
                    self.cache[str(x)] = val
        return vals

    def evaluate_true(self, X: Tensor) -> Tensor:
        return self.forward(X, force_no_io_failure=True)


def load_uci_data(path, seed, n_features):
    # `path` here should be pointing to the "slice_localization_data.csv" file from
    # the "Relative location of CT slices on axial axis" UCI data set from
    # https://archive.ics.uci.edu/dataset/206/relative+location+of+ct+slices+on+axial+axis
    df = pd.read_csv(path, sep=",")
    data = df.to_numpy()

    # Get the input data
    X = data[:, :-1]
    X -= X.min(axis=0)
    X = X[:, X.max(axis=0) > 1e-6]  # Throw away constant dimensions
    X = X / (X.max(axis=0) - X.min(axis=0))
    X = 2 * X - 1
    assert X.min() == -1 and X.max() == 1

    # Standardize targets
    y = data[:, -1]
    y = (y - y.mean()) / y.std()

    # Only keep 10,000 data points and n_features features
    shuffled_indices = np.random.RandomState(0).permutation(X.shape[0])[
        :10000
    ]  # Use seed 0
    X, y = X[shuffled_indices], y[shuffled_indices]

    # Use Xgboost to figure out feature importances and keep only the most important features
    xgb = XGBRegressor(max_depth=8).fit(X, y)
    inds = (-xgb.feature_importances_).argsort()
    X = X[:, inds[:n_features]]

    # Train/Test split
    train_n = int(math.floor(0.50 * X.shape[0]))
    train_x, train_y = X[:train_n], y[:train_n]
    test_x, test_y = X[train_n:], y[train_n:]

    return train_x, train_y, test_x, test_y


class PyTorchCNNRunner:
    def __init__(self, name: str, train_set: Dataset, test_set: Dataset) -> None:
        self.name = name
        self.train_loader: DataLoader = DataLoader(train_set)
        self.test_loader: DataLoader = DataLoader(test_set)
        self.results: Dict[int, float] = {}
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    class CNN(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
            self.fc1 = nn.Linear(8 * 8 * 20, 64)
            self.fc2 = nn.Linear(64, 10)

        def forward(self, x: Tensor) -> Tensor:
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 3, 3)
            x = x.view(-1, 8 * 8 * 20)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=-1)

    def train_and_evaluate(
        self,
        lr: float,
        momentum: float,
        weight_decay: float,
        step_size: int,
        gamma: float,
        io_failure: bool,
    ) -> float:
        net = self.CNN()
        net.to(device=self.device)

        # Train
        net.train()
        criterion = nn.NLLLoss(reduction="sum")
        optimizer = optim.SGD(
            net.parameters(),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )

        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=step_size, gamma=gamma
        )
        counter = 0
        max_batches = (
            float(torch.randint(100, 1000, size=(1,)).item())
            if io_failure
            else float("inf")
        )
        for inputs, labels in self.train_loader:
            counter += 1

            inputs = inputs.to(device=self.device)
            labels = labels.to(device=self.device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if counter > max_batches:
                print(f"Terminating after {counter} steps because of I/O failure")
                break

        # Evaluate
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(correct, total)
        return correct / total
