import time
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from botorch.models.model import Model

from botorch.test_functions.base import BaseTestProblem
from botorch.test_functions.synthetic import Griewank, Hartmann, Rosenbrock
from torch import Tensor

from ..robust_gp.experiment_utils import (
    Bow,
    constant_outlier_generator,
    CorruptedTestProblem,
    Friedman,
    normal_outlier_corruption,
    student_t_corruption,
    SumOfSines,
    uniform_corruption,
    uniform_input_corruption,
)
from ..robust_gp.models import (
    get_power_transformed_model,
    get_robust_model,
    get_student_t_model,
    get_trimmed_mll_model,
    get_vanilla_model,
    get_winsorized_model,
)


def replication_setup(  # noqa
    method_name: str,
    function_name: str,
    outlier_generator_name: str,
    outlier_fraction: float,
    noise_std: Optional[float] = None,
    outlier_generator_kwargs: Optional[Dict[str, Any]] = None,
    dtype: torch.dtype = torch.double,
    device: Optional[torch.device] = None,
) -> Tuple[CorruptedTestProblem, Callable, int, bool, Tensor]:

    # Outlier generator
    if outlier_generator_name == "constant":
        if outlier_generator_kwargs.keys() != {"constant"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return constant_outlier_generator(
                f=f, X=X, bounds=bounds, constant=outlier_generator_kwargs["constant"]
            )

    elif outlier_generator_name == "uniform_input":
        if len(outlier_generator_kwargs) > 0:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return uniform_input_corruption(f=f, X=X, bounds=bounds)

    elif outlier_generator_name == "uniform":
        if outlier_generator_kwargs.keys() != {"lower", "upper"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return uniform_corruption(
                f=f,
                X=X,
                bounds=bounds,
                lower=outlier_generator_kwargs["lower"],
                upper=outlier_generator_kwargs["upper"],
            )

    elif outlier_generator_name == "normal":
        if outlier_generator_kwargs.keys() != {"noise_std"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return normal_outlier_corruption(
                f=f, X=X, bounds=bounds, noise_std=outlier_generator_kwargs["noise_std"]
            )

    elif outlier_generator_name == "student-t":
        if outlier_generator_kwargs.keys() != {"df", "scale"}:
            raise ValueError(
                f"Unknown outlier generator kwargs: {outlier_generator_kwargs}"
            )

        def outlier_generator(f, X, bounds):
            return student_t_corruption(
                f=f,
                X=X,
                bounds=bounds,
                df=outlier_generator_kwargs["df"],
                scale=outlier_generator_kwargs["scale"],
            )

    else:
        raise ValueError(f"Unknown outlier generator name: {outlier_generator_name}")

    base_test_problem, dim, minimize, bounds = get_base_test_problem(
        function_name=function_name, dtype=dtype, device=device, noise_std=noise_std
    )
    # Test problem
    objective_function = CorruptedTestProblem(
        base_test_problem=base_test_problem,
        outlier_generator=outlier_generator,
        outlier_fraction=outlier_fraction,
    )
    return objective_function, outlier_generator, dim, minimize, bounds


def get_base_test_problem(
    function_name: str,
    dtype: torch.dtype = torch.double,
    device: Optional[torch.device] = None,
    noise_std: Optional[float] = None,
) -> BaseTestProblem:
    # Base problem function
    if function_name == "hartmann6":
        dim = 6
        bounds = torch.cat((torch.zeros(1, dim), torch.ones(1, dim))).to(
            dtype=dtype, device=device
        )
        base_test_problem = Hartmann(
            dim=dim, noise_std=noise_std, bounds=[(0.0, 1.0) for _ in range(dim)]
        )
        minimize = True
    elif function_name == "bow":
        dim = 1
        base_test_problem = Bow(dim=dim, noise_std=noise_std, negate=True)
        bounds = base_test_problem.bounds
        minimize = False  # regression target
    elif function_name == "friedman5":
        dim = 5
        base_test_problem = Friedman(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "friedman10":
        dim = 10
        base_test_problem = Friedman(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "griewank2":
        dim = 2
        base_test_problem = Griewank(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "griewank4":
        dim = 4
        base_test_problem = Griewank(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "griewank8":
        dim = 8
        base_test_problem = Griewank(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "rosenbrock2":
        dim = 2
        base_test_problem = Rosenbrock(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "rosenbrock4":
        dim = 4
        base_test_problem = Rosenbrock(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "rosenbrock8":
        dim = 8
        base_test_problem = Rosenbrock(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True
    elif function_name == "sumofsines1":
        dim = 1
        base_test_problem = SumOfSines(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True  # regression target
    elif function_name == "sumofsines2":
        dim = 2
        base_test_problem = SumOfSines(dim=dim, noise_std=noise_std)
        bounds = base_test_problem.bounds
        minimize = True  # regression target
    else:
        raise ValueError(f"Unknown function name: {function_name}")

    bounds = bounds.to(dtype=dtype, device=device)
    return base_test_problem, dim, minimize, bounds


def model_fit_helper(method_name: str, X: Tensor, Y: Tensor, minimize: bool) -> Model:
    train_Y = Y  # standardization is done in model constructors
    if method_name == "vanilla":
        model = get_vanilla_model(X=X, Y=train_Y)
    elif method_name == "relevance_pursuit" or method_name == "relevance_pursuit_bwd":
        model = get_robust_model(
            X=X, Y=train_Y, use_forward_algorithm=False, convex_parameterization=False
        )
    elif method_name == "relevance_pursuit_fwd":
        model = get_robust_model(
            X=X, Y=train_Y, use_forward_algorithm=True, convex_parameterization=False
        )
    elif method_name == "relevance_pursuit_cvx":
        model = get_robust_model(
            X=X, Y=train_Y, use_forward_algorithm=False, convex_parameterization=True
        )
    elif method_name == "relevance_pursuit_fwd_cvx":
        model = get_robust_model(
            X=X, Y=train_Y, use_forward_algorithm=True, convex_parameterization=True
        )
    elif method_name == "relevance_pursuit_cvx_reinit":
        model = get_robust_model(
            X=X,
            Y=train_Y,
            use_forward_algorithm=False,
            convex_parameterization=True,
            reset_dense_parameters=True,
        )
    elif method_name == "relevance_pursuit_fwd_cvx_reinit":
        model = get_robust_model(
            X=X,
            Y=train_Y,
            use_forward_algorithm=True,
            convex_parameterization=True,
            reset_dense_parameters=True,
        )
    elif method_name == "relevance_pursuit_reinit":
        model = get_robust_model(
            X=X,
            Y=train_Y,
            use_forward_algorithm=False,
            convex_parameterization=False,
            reset_dense_parameters=True,
        )
    elif method_name == "relevance_pursuit_fwd_reinit":
        model = get_robust_model(
            X=X,
            Y=train_Y,
            use_forward_algorithm=True,
            convex_parameterization=False,
            reset_dense_parameters=True,
        )
    elif method_name == "student_t":
        model = get_student_t_model(X=X, Y=train_Y)
    elif method_name == "trimmed_mll":
        model = get_trimmed_mll_model(X=X, Y=train_Y)
    elif method_name == "power_transform":
        model = get_power_transformed_model(X=X, Y=train_Y)
    elif method_name == "winsorize":
        model = get_winsorized_model(X=X, Y=train_Y, winsorize_lower=not minimize)
    else:
        raise ValueError(f"Unknown method name: {method_name}")
    return model
