import torch
import torch.nn as nn
from torch.nn.functional import relu
import numpy as np

import exp_utils as PQ
import rl_utils


class FLAGS(PQ.BaseFLAGS):
    class temperature(PQ.BaseFLAGS):
        max = 0.1
        min = 0.001

    class filter(PQ.BaseFLAGS):
        top_k = 10000
        pool = False

    n_steps = 1
    method = 'grad'
    lr = 0.01
    batch_size = 1000
    extend_region = 0.0
    barrier_coef = 0.
    L_neg_coef = 1
    resample = False

    n_proj_iters = 10
    precond = False


class SPool(nn.Module):
    def __init__(self, init_s):
        super().__init__()
        self.s = nn.Parameter(init_s.detach().clone(), requires_grad=False)
        self.t = 0

    @torch.no_grad()
    def add(self, s):
        m = len(self.s)
        t = self.t
        n = len(s)

        if t + n <= m:
            self.s[t:t + n] = s
            self.t = (t + n) % m
        else:
            self.s[t:] = s[:m - t]
            self.t = t = (t + n) % m
            self.s[:t] = s[-t:]


class SLangevinOptimizer(nn.Module):
    FLAGS = FLAGS

    def __init__(self, obj_eval, nmPQ):  # precond: a diagonal matrix (covariance)
        super().__init__()
        self.obj_eval = obj_eval
        self.temperature = FLAGS.temperature.max
        self.nmPQ = nmPQ

        s_init = nmPQ.mean + torch.randn(FLAGS.batch_size, *nmPQ.std.shape, device=nmPQ.mean.device) * nmPQ.std
        self.s = nn.Parameter(s_init, requires_grad=True)
        self.tau = nn.Parameter(torch.full([FLAGS.batch_size, 1], 1e-2), requires_grad=False)
        self.opt = torch.optim.Adam([self.s])
        self.pool = SPool(self.s)

        self.mask = torch.tensor([0], dtype=torch.int64)
        self.n_failure = torch.zeros(FLAGS.batch_size, dtype=torch.int64, device=nmPQ.mean.device)
        self.n_resampled = 0

        self.adam = torch.optim.Adam([self.s], betas=(0, 0.999), lr=0.001)

    def reinit(self):
        nn.init.normal_(self.s)
        nn.init.constant_(self.tau, 0.01)

    def set_temperature(self, p):
        max = FLAGS.temperature.max
        min = FLAGS.temperature.min
        self.temperature = np.exp(np.log(max) * (1 - p) + np.log(min) * p)

    @property
    def filtered_s(self):
        if FLAGS.filter.pool:
            return self.pool.s
        return self.s

    def pdf(self, s):
        result = self.obj_eval(s)
        return result['hard_obj'] / self.temperature, result

    # def project_back(self, should_print):
    #     indices = (self.L(self.s) > 1).nonzero()[:, 0]
    #
    #     n_outside = []
    #     accepts = []
    #     for i in range(10):
    #         if should_print: n_outside.append(len(indices))
    #         if len(indices) < 1000:
    #             break
    #         with torch.enable_grad():
    #             s = self.s[indices].detach().requires_grad_()
    #             L_old = self.L(s)
    #             grad = torch.autograd.grad(L_old.sum(), s)[0]
    #         with torch.no_grad():
    #             # delta_s = (L_old[:, None] - 1) * grad / grad.norm(dim=-1, keepdim=True) * 0.1
    #             delta_s = 0.01 * grad / grad.norm(dim=-1, keepdim=True)
    #             L_new = self.L(s - delta_s)
    #             # if should_print:
    #             #     for lr in [0.003, 0.01, 0.03, 0.1]:
    #             #         accepts.append((self.L(s - delta_s * lr) < L_old).sum().item())
    #             self.s[indices[L_new < L_old]] -= delta_s[L_new < L_old]
    #             indices = indices[(1 < L_new) & (L_new < L_old)]
    #
    #     if should_print:
    #         PQ.log.info(f"[project back] {n_outside}, accepts = {accepts}")

    def project_back(self, should_print=False):
        for _ in range(FLAGS.n_proj_iters):
            with torch.enable_grad():
                L = self.obj_eval.safe_invariant.L(self.s)
                loss = relu(L - 1 - 0.03)
                if (L > 1 + 0.03).sum() < 1000:
                    break
                self.adam.zero_grad()
                loss.sum().backward()
                self.adam.step()

            # with torch.enable_grad():
            #     L = self.L(self.s)
            #     loss = relu(L - 1)
            #     if (L > 1).sum() < 1000:
            #         break
            #     grad = torch.autograd.grad(loss.sum(), self.s)[0]
            # with torch.no_grad():
            #     self.s.sub_(0.01 * grad)

    @torch.no_grad()
    def resample(self, f: torch.Tensor, idx):
        if len(idx) == 0:
            return
        new_idx = f.softmax(0).multinomial(len(idx), replacement=True)
        self.s[idx] = self.s[new_idx]
        self.tau[idx] = self.tau[new_idx]
        self.n_failure[idx] = 0
        self.n_resampled += len(idx)

    def step_projection(self, s, lr):
        with torch.enable_grad():
            s = s.detach().requires_grad_()
            grad = torch.autograd.grad(self.L(s).sum(), s)[0]
            return s - lr * grad * self.nmPQ.std

    def step_langevin(self, s, tau):
        a = s
        sqrt_precond = self.nmPQ.std
        precond = sqrt_precond**2

        with torch.enable_grad():
            a = a.detach().requires_grad_()
            f_a, a_info = self.pdf(a)
            grad_a = torch.autograd.grad(f_a.sum(), a)[0]

        z = torch.randn_like(a)
        b = a + tau * precond * grad_a + (tau * 2).sqrt() * sqrt_precond * z
        with torch.enable_grad():
            b = b.detach().requires_grad_()
            f_b, b_info = self.pdf(b)
            grad_b = torch.autograd.grad(f_b.sum(), b)[0]

        with torch.no_grad():
            log_p_a_to_b = -z.norm(dim=-1)**2
            log_p_b_to_a = -((a - b - tau * precond * grad_b)**2 / precond).sum(dim=-1) / tau[:, 0] / 4
            log_ratio = (f_b + log_p_b_to_a) - (f_a + log_p_a_to_b)
            ratio = log_ratio.clamp(max=0).exp()[:, None]
            sampling = torch.rand_like(ratio) < ratio
            b = torch.where(sampling, b, a)
            new_f_b = torch.where(sampling[:, 0], f_b, f_a)
            PQ.meters['opt_s/accept'] += sampling.sum().item() / FLAGS.batch_size

            if FLAGS.filter.pool:
                self.pool.add(b[sampling[:, 0] & (f_b > 0)])
                PQ.meters['opt_s/n_s_valid'] += (new_f_b > 0).sum().item()
                PQ.meters['opt_s/n_pool_valid'] += (self.pdf(self.filtered_s)[0] > 0).sum().item()

        return b, tau * (FLAGS.lr * (ratio - 0.574) + 1), {
            'f': new_f_b,
            'optimal': torch.where(a_info['L'] < 1, a_info['U'] - 1, -a_info['L'] - 100).max().item(),
        }

    def step_(self):
        s = self.s
        tau = self.tau

        L = self.L(s)
        idx_o = (L > 1).nonzero()[:, 0]
        idx_i = (L <= 1).nonzero()[:, 0]

        with torch.no_grad():
            if len(idx_o):
                self.s[idx_o] = self.step_projection(s[idx_o], tau[idx_o])
        with torch.no_grad():
            if len(idx_i):
                self.s[idx_i], self.tau[idx_i], info = self.step_langevin(s[idx_i], tau[idx_i])
            else:
                info = {}
                info['optimal'] = -100.
        if FLAGS.resample:
            self.n_failure[idx_o] = 0
            self.n_failure += 1
            self.resample(info['f'], torch.nonzero(self.n_failure > 5)[:, 0])
        return info

    def step(self):
        self.project_back()
        tau = self.tau
        a = self.s
        if FLAGS.precond:
            sqrt_precond = self.nmPQ.std
            precond = sqrt_precond**2
        else:
            precond = 1
            sqrt_precond = 1

        f_a, a_info = self.pdf(a)
        grad_a = torch.autograd.grad(f_a.sum(), a)[0]

        z = torch.randn_like(a)
        b = a + tau * precond * grad_a + (tau * 2).sqrt() * sqrt_precond * z
        b = b.detach().requires_grad_()
        f_b, b_info = self.pdf(b)
        grad_b = torch.autograd.grad(f_b.sum(), b)[0]

        PQ.meters['opt_s/out_to_in'] += ((a_info['L'] > 1) & (b_info['L'] < 1)).sum().item() / FLAGS.batch_size

        with torch.no_grad():
            log_p_a_to_b = -z.norm(dim=-1)**2
            log_p_b_to_a = -((a - b - tau * precond * grad_b)**2 / precond).sum(dim=-1) / tau[:, 0] / 4
            log_ratio = (f_b + log_p_b_to_a) - (f_a + log_p_a_to_b)
            ratio = log_ratio.clamp(max=0).exp()[:, None]
            sampling = torch.rand_like(ratio) < ratio
            b = torch.where(sampling & (f_b[:, None] > -1000), b, a)
            new_f_b = torch.where(sampling[:, 0], f_b, f_a)
            PQ.meters['opt_s/accept'] += sampling.sum().item() / FLAGS.batch_size

            if FLAGS.filter.pool:
                self.pool.add(b[sampling[:, 0] & (f_b > 0)])
                PQ.meters['opt_s/n_s_valid'] += (new_f_b > 0).sum().item()
                PQ.meters['opt_s/n_pool_valid'] += (self.pdf(self.filtered_s)[0] > 0).sum().item()

            if FLAGS.filter.top_k != FLAGS.batch_size:
                _, self.mask = new_f_b.topk(FLAGS.filter.top_k)
            else:
                self.mask = torch.nonzero(new_f_b >= 0)[:, 0]
                if len(self.mask) == 0:
                    self.mask = torch.tensor([0], dtype=torch.int64)

            self.s.set_(b)
            self.tau.mul_(FLAGS.lr * (ratio - 0.574) + 1)  # .clamp_(max=1.0)
            if FLAGS.resample:
                self.n_failure[new_f_b >= -100] = 0
                self.n_failure += 1
                self.resample(new_f_b, torch.nonzero(self.n_failure > 1000)[:, 0])
        return {
            'optimal': a_info['hard_obj'].max().item(),
        }

    # def step(self):  # metropolis
    #     tau = self.tau
    #
    #     a = self.s
    #     f_a = self.pdf(a)
    #     b = (a + (tau * 2).sqrt() * torch.randn_like(a)).detach()
    #     f_b = self.pdf(b)
    #
    #     with torch.no_grad():
    #         log_ratio = f_b - f_a
    #         ratio = log_ratio.clamp(max=0).exp()[:, None]
    #         b = torch.where(torch.rand_like(ratio) <= ratio, b, a)
    #         self.s.set_(b)
    #
    #         self.tau.mul_(FLAGS.lr * (ratio - 0.574) + 1)

    @torch.no_grad()
    def evaluate(self, *, step):
        result = self.obj_eval(self.s)
        L_s = result['L']
        hardD_s = result['hard_obj'].max().item()
        inside = (result['constraint'] <= 0).sum().item()
        cut_size = result['mask'].sum().item()

        geo_mean_tau = self.tau.log().mean().exp().item()
        max_tau = self.tau.max().item()
        PQ.writer.add_scalar('opt_s/hardD', hardD_s, global_step=step)
        PQ.writer.add_scalar('opt_s/inside', inside / FLAGS.batch_size, global_step=step)
        PQ.writer.add_scalar('opt_s/P_accept', PQ.meters['opt_s/accept'].mean, global_step=step)

        L_inside = L_s.cpu().numpy()
        L_inside = L_inside[np.where(result['constraint'].cpu() <= 0)]
        L_dist = np.percentile(L_inside, [25, 50, 75]) if len(L_inside) else []
        PQ.log.debug(f"[S Langevin]: temperature = {self.temperature:.3f}, hardD = {hardD_s:.6f}, "
                     f"inside/cut = {inside}/{cut_size}, "
                     f"tau = [geo mean {geo_mean_tau:.3e}, max {max_tau:.3e}], "
                     # f"Pr[out => in] = {PQ.meters['opt_s/out_to_in'].mean:.6f}, "
                     f"Pr[accept] = {PQ.meters['opt_s/accept'].mean:.3f}, "
                     # f"# valid = [s = {PQ.meters['opt_s/n_s_valid'].mean:.0f}, "
                     # f"pool = {PQ.meters['opt_s/n_pool_valid'].mean:.0f}], "
                     f"L 25/50/75% = {L_dist}, "
                     f"resampled = {self.n_resampled}")
        PQ.meters.purge('opt_s/')
        self.n_resampled = 0

        return {
            'inside': inside
        }
