import sys

import equinox as eqx
import jax
import jax.numpy as jnp
from _corrected_stepper import CorrectedStepper

sys.path.append("../apebench/apebench")

from exponax import exponax as ex  # noqa: E402
from pdequinox import pdequinox as pdeqx  # noqa: E402

TRAIN_IC_SEED = 1337
TEST_IC_SEED = 1338

STENCIL_RANGES = (-2.0, 2.0)
TRAIN_NUM_SAMPLES = 5
TEST_NUM_SAMPLES = 50

N_STEPS_IN_DATA = 200


class FOULearner(eqx.Module):
    num_points: int = 20
    difficulty: float = 0.75
    coarse_proportion: float = 0.5

    ic_config: str = "fourier;5;true;true"

    num_train_samples: int = 5
    train_max_rollout_steps: int = 200
    train_seed: int = 1337
    num_optim_iter: int = 100

    num_test_samples: int = 50
    test_seed: int = 1338
    test_rollout_steps: int = 200

    task_config: str = "predict"
    # The FOU stencil only makes sense for prediction

    def get_ic_generator(self):
        """
        Overwrite for custom initial condition generation.

        Uses the `ic_config` attribute to determine the type of initial
        condition generation.

        Allows for the following options:
            - `fourier;CUTOFF;ZERO_MEAN;MAX_ONE` for a truncated Fourier series
                with CUTOFF (int) number of modes, ZERO_MEAN (bool) for zero
                mean, and MAX_ONE (bool) for having the initial condition being
                at max in (-1, 1) but not clamped to it
            - `diffused;INTENSITY;ZERO_MEAN;MAX_ONE` for a diffused noise with
                INTENSITY (float) for the intensity, ZERO_MEAN (bool) for zero
                mean, and MAX_ONE (bool) for having the initial condition being
                at max in (-1, 1) but not clamped to it
            - `grf;POWERLAW_EXPONENT;ZERO_MEAN;MAX_ONE` for a Gaussian random
                field with POWERLAW_EXPONENT (float) for the powerlaw exponent,
                ZERO_MEAN (bool) for zero mean, and MAX_ONE (bool) for having
                the initial condition being at max in (-1, 1) but not clamped to
                it
            - `clamp;LOWER_BOUND;UPPER_BOUND;CONFIG` for clamping the
                configuration to the range of LOWER_BOUND (float) to UPPER_BOUND
                (float) and then using the configuration CONFIG for the
                generation of the initial condition
        """

        def _get_single_channel(config):
            ic_args = config.split(";")
            if ic_args[0].lower() == "fourier":
                cutoff = int(ic_args[1])
                zero_mean = ic_args[2].lower() == "true"
                max_one = ic_args[3].lower() == "true"
                if zero_mean:
                    offset_range = (0.0, 0.0)
                else:
                    offset_range = (-0.5, 0.5)
                ic_gen = ex.ic.RandomTruncatedFourierSeries(
                    num_spatial_dims=1,
                    cutoff=cutoff,
                    offset_range=offset_range,
                    max_one=max_one,
                )
            elif ic_args[0].lower() == "diffused":
                intensity = float(ic_args[1])
                zero_mean = ic_args[2].lower() == "true"
                max_one = ic_args[3].lower() == "true"
                ic_gen = ex.ic.DiffusedNoise(
                    num_spatial_dims=1,
                    intensity=intensity,
                    zero_mean=zero_mean,
                    max_one=max_one,
                )
            elif ic_args[0].lower() == "grf":
                powerlaw_exponent = float(ic_args[1])
                zero_mean = ic_args[2].lower() == "true"
                max_one = ic_args[3].lower() == "true"
                ic_gen = ex.ic.GaussianRandomField(
                    num_spatial_dims=1,
                    powerlaw_exponent=powerlaw_exponent,
                    zero_mean=zero_mean,
                    max_one=max_one,
                )
            else:
                raise ValueError("Unknown IC configuration")

            return ic_gen

        ic_args = self.ic_config.split(";")
        if ic_args[0].lower() == "clamp":
            lower_bound = float(ic_args[1])
            upper_bound = float(ic_args[2])

            ic_gen = _get_single_channel(";".join(ic_args[3:]))
            ic_gen = ex.ic.ClampingICGenerator(
                ic_gen,
                limits=(lower_bound, upper_bound),
            )
        else:
            ic_gen = _get_single_channel(self.ic_config)

        return ic_gen

    def get_ref_stepper(self):
        return ex.normalized.DiffultyLinearStepperSimple(
            1, self.num_points, difficulty=self.difficulty, order=1
        )

    def get_coarse_stepper(self):
        return ex.normalized.DiffultyLinearStepperSimple(
            1,
            self.num_points,
            difficulty=self.difficulty * self.coarse_proportion,
            order=1,
        )

    def get_fou_stencil(self):
        return jnp.array([1 - self.difficulty, self.difficulty])

    def build_network(self, left, center):
        stencil = jnp.array([[[left, center]]])
        stepper = pdeqx.conv.PhysicsConv(
            1,
            1,
            1,
            2,
            key=jax.random.PRNGKey(1337),
            use_bias=False,
            boundary_mode="periodic",
        )
        stepper = eqx.tree_at(lambda leaf: leaf.weight, stepper, stencil)
        return stepper

    def build_stepper(self, left, center):
        network = self.build_network(left, center)

        task_args = self.task_config.split(";")
        if task_args[0] == "predict":
            neural_stepper = network
        elif task_args[0] == "correct":
            coarse_stepper = self.get_coarse_stepper()
            neural_stepper = CorrectedStepper(
                coarse_stepper=coarse_stepper,
                network=network,
                mode=task_args[1],
            )
        else:
            raise ValueError("Unknown task argument")

        return neural_stepper

    def build_fou_stepper(self):
        left, center = self.get_fou_stencil()
        return self.build_stepper(left, center)

    def get_train_data(self):
        ic_distribution = self.get_ic_generator()

        ic_set = ex.build_ic_set(
            ic_distribution,
            num_points=self.num_points,
            num_samples=self.num_train_samples,
            key=jax.random.PRNGKey(self.train_seed),
        )

        analytical_stepper = self.get_ref_stepper()

        trj = jax.vmap(
            ex.rollout(
                analytical_stepper,
                self.train_max_rollout_steps,
                include_init=True,
            )
        )(ic_set)

        return trj

    def get_test_data(self):
        ic_distribution = self.get_ic_generator()

        ic_set = ex.build_ic_set(
            ic_distribution,
            num_points=self.num_points,
            num_samples=self.num_test_samples,
            key=jax.random.PRNGKey(self.test_seed),
        )

        analytical_stepper = self.get_ref_stepper()

        trj = jax.vmap(
            ex.rollout(
                analytical_stepper,
                self.test_rollout_steps,
                include_init=True,
            )
        )(ic_set)

        return trj

    def get_loss_fn(self, rollout: int = 1, use_div: bool = False):
        trj_full = self.get_train_data()
        trj_substacked = jax.vmap(ex.stack_sub_trajectories, in_axes=(0, None))(
            trj_full, self.train_max_rollout_steps + 1
        )
        trj_ref_all = jnp.concatenate(trj_substacked)

        ic = trj_ref_all[:, 0]
        ref = trj_ref_all[:, 1:]

        # Needed for diverted chain training
        ref_stepper = self.get_ref_stepper()

        if not use_div:

            def loss_fn(stepper):
                trj_pred_all = jax.vmap(
                    ex.rollout(stepper, self.train_max_rollout_steps)
                )(ic)
                diff_trj = trj_pred_all - ref
                diff_trj_reduced = jnp.where(
                    jnp.arange(1, self.train_max_rollout_steps + 1).reshape(1, -1, 1, 1)
                    < rollout + 1,
                    diff_trj,
                    0.0,
                )
                return jnp.sum(jnp.mean(jnp.square(diff_trj_reduced), axis=(0, 2)))

        else:

            def loss_fn(stepper):
                trj_pred_all = jax.vmap(
                    ex.rollout(stepper, self.train_max_rollout_steps, include_init=True)
                )(ic)
                trj_pred_all_without_last = trj_pred_all[:, :-1]
                produced_ref = jax.vmap(jax.vmap(ref_stepper))(
                    trj_pred_all_without_last
                )
                trj_pred_all_without_first = trj_pred_all[:, 1:]
                diff_trj = trj_pred_all_without_first - produced_ref
                diff_trj_reduced = jnp.where(
                    jnp.arange(1, self.train_max_rollout_steps + 1).reshape(1, -1, 1, 1)
                    < rollout + 1,
                    diff_trj,
                    0.0,
                )
                return jnp.sum(jnp.mean(jnp.square(diff_trj_reduced), axis=(0, 2)))

        return loss_fn

    def perform_newton_opt(
        self,
        rollout: int = 1,
        use_div: bool = False,
        init: jax.Array = None,
        *,
        return_history: bool = True
    ):
        loss_fn = self.get_loss_fn(rollout, use_div)

        def loss_wrapper(param):
            stepper = self.build_stepper(*param)
            return loss_fn(stepper)

        if init is None:
            init = self.get_fou_stencil()

        def newton_step(param):
            loss, grad = jax.value_and_grad(loss_wrapper)(param)
            hess = jax.hessian(loss_wrapper)(param)
            delta = jnp.linalg.solve(hess, -grad)
            return param + delta

        def scan_fn(p, _):
            p_new = newton_step(p)
            return p_new, p_new

        p_final, p_history = jax.lax.scan(
            scan_fn, init, None, length=self.num_optim_iter
        )

        if return_history:
            return p_history
        else:
            return p_final

    def get_test_rollout(self, param):
        stepper = self.build_stepper(*param)
        trj = self.get_test_data()
        ic = trj[:, 0]
        trj_pred = jax.vmap(
            ex.rollout(
                stepper,
                self.test_rollout_steps,
                include_init=True,
            )
        )(ic)

        return trj_pred

    def perform_test(self, param):
        stepper = self.build_stepper(*param)
        trj = self.get_test_data()
        ic, ref = trj[:, 0], trj[:, 1:]
        trj_pred = jax.vmap(ex.rollout(stepper, self.test_rollout_steps))(ic)
        mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(trj_pred, ref)

        return mean_nRMSE_trj
