import torch
import torch.nn as nn
from torch.nn.functional import relu
import numpy as np

import exp_utils as PQ
import rl_utils


class FLAGS(PQ.BaseFLAGS):
    class temperature(PQ.BaseFLAGS):
        max = 0.1
        min = 0.001

    class filter(PQ.BaseFLAGS):
        top_k = 10000
        pool = False

    n_steps = 1
    method = 'grad'
    lr = 0.01
    batch_size = 1000
    extend_region = 0.0
    barrier_coef = 0.
    L_neg_coef = 1
    resample = False

    n_proj_iters = 10
    precond = False


class SLangevinOptimizer(nn.Module):
    FLAGS = FLAGS

    def __init__(self, crabs, state_box):
        super().__init__()
        self.crabs = crabs
        self.temperature = FLAGS.temperature.max
        self.state_box = state_box

        device = state_box.device

        self.z = nn.Parameter(torch.zeros(FLAGS.batch_size, *state_box.shape, device=device), requires_grad=True)
        self.tau = nn.Parameter(torch.full([FLAGS.batch_size, 1], 1e-2), requires_grad=False)
        self.alpha = nn.Parameter(torch.full([FLAGS.batch_size], 3.0), requires_grad=False)
        self.opt = torch.optim.Adam([self.z])
        self.max_s = torch.zeros(state_box.shape, device=device)
        self.min_s = torch.zeros(state_box.shape, device=device)

        self.mask = torch.tensor([0], dtype=torch.int64)
        self.n_failure = torch.zeros(FLAGS.batch_size, dtype=torch.int64, device=device)
        self.n_resampled = 0

        self.adam = torch.optim.Adam([self.z], betas=(0, 0.999), lr=0.001)
        self.since_last_reset = 0
        self.reinit()

    @property
    def s(self):
        return self.state_box.decode(self.z)

    def reinit(self):
        # self.state_box.fill_(self.s)
        PQ.log.debug("[SGradOpt] reinit")
        nn.init.uniform_(self.z, -1., 1.)
        nn.init.constant_(self.tau, 0.01)
        nn.init.constant_(self.alpha, 3.0)
        self.since_last_reset = 0

    def set_temperature(self, p):
        max = FLAGS.temperature.max
        min = FLAGS.temperature.min
        self.temperature = np.exp(np.log(max) * (1 - p) + np.log(min) * p)

    def pdf(self, z):
        s = self.state_box.decode(z)
        result = self.crabs.obj_eval(s)
        # inside => L < 0, so the barrier is log(-L).
        # log_barrier = torch.where(result['L'] < 0, (1e-10 - result['L']).log() * self.alpha,
        #     torch.zeros_like(self.alpha))
        return result['hard_obj'] / self.temperature, result

    def project_back(self, should_print=False):
        for _ in range(FLAGS.n_proj_iters):
            with torch.enable_grad():
                L = self.crabs.barrier(self.s)
                loss = relu(L - 0.03)
                if (L > 0.03).sum() < 1000:
                    break
                self.adam.zero_grad()
                loss.sum().backward()
                self.adam.step()

    @torch.no_grad()
    def resample(self, f: torch.Tensor, idx):
        if len(idx) == 0:
            return
        new_idx = f.softmax(0).multinomial(len(idx), replacement=True)
        self.z[idx] = self.z[new_idx]
        self.tau[idx] = self.tau[new_idx]
        self.n_failure[idx] = 0
        self.n_resampled += len(idx)

    def step(self):
        self.since_last_reset += 1
        self.project_back()
        tau = self.tau
        a = self.z

        f_a, a_info = self.pdf(a)
        grad_a = torch.autograd.grad(f_a.sum(), a)[0]

        w = torch.randn_like(a)
        b = a + tau * grad_a + (tau * 2).sqrt() * w
        b = b.detach().requires_grad_()
        f_b, b_info = self.pdf(b)
        grad_b = torch.autograd.grad(f_b.sum(), b)[0]
        going_out = (a_info['L'] < 0) & (b_info['L'] > 0)
        PQ.meters['opt_s/going_out'] += going_out.to(torch.float32).mean()

        PQ.meters['opt_s/out_to_in'] += ((a_info['L'] > 0) & (b_info['L'] < 0)).sum().item() / FLAGS.batch_size

        with torch.no_grad():
            log_p_a_to_b = -w.norm(dim=-1)**2
            log_p_b_to_a = -((a - b - tau * grad_b)**2).sum(dim=-1) / tau[:, 0] / 4
            log_ratio = (f_b + log_p_b_to_a) - (f_a + log_p_a_to_b)
            ratio = log_ratio.clamp(max=0).exp()[:, None]
            sampling = torch.rand_like(ratio) < ratio
            b = torch.where(sampling & (b_info['L'][:, None] < 0), b, a)
            new_f_b = torch.where(sampling[:, 0], f_b, f_a)
            PQ.meters['opt_s/accept'] += sampling.sum().item() / FLAGS.batch_size

            self.mask = torch.nonzero(new_f_b >= 0)[:, 0]
            if len(self.mask) == 0:
                self.mask = torch.tensor([0], dtype=torch.int64)

            self.z.set_(b)
            # alpha should be moved slower than tau, as tau * grad will be smaller after one step.
            # self.alpha.mul_(FLAGS.lr * (going_out.to(torch.float32) - 0.5) + 1).clamp_(1e-4, 1e4)
            self.tau.mul_(FLAGS.lr * (ratio - 0.574) + 1)  # .clamp_(max=1.0)
            if FLAGS.resample:
                self.n_failure[new_f_b >= -100] = 0
                self.n_failure += 1
                self.resample(new_f_b, torch.nonzero(self.n_failure > 1000)[:, 0])
        return {
            'optimal': a_info['hard_obj'].max().item(),
        }

    @torch.no_grad()
    def evaluate(self, *, step):
        result = self.crabs.obj_eval(self.s)
        L_s = result['L']
        hardD_s = result['hard_obj'].max().item()
        inside = (result['constraint'] <= 0).sum().item()
        cut_size = result['mask'].sum().item()

        geo_mean_tau = self.tau.log().mean().exp().item()
        max_tau = self.tau.max().item()
        geo_mean_alpha = self.alpha.log().mean().exp().item()
        max_alpha = self.alpha.max().item()
        PQ.writer.add_scalar('opt_s/hardD', hardD_s, global_step=step)
        PQ.writer.add_scalar('opt_s/inside', inside / FLAGS.batch_size, global_step=step)
        PQ.writer.add_scalar('opt_s/P_accept', PQ.meters['opt_s/accept'].mean, global_step=step)

        L_inside = L_s.cpu().numpy()
        L_inside = L_inside[np.where(result['constraint'].cpu() <= 0)]
        L_dist = np.percentile(L_inside, [25, 50, 75]) if len(L_inside) else []
        PQ.log.debug(f"[S Langevin]: temperature = {self.temperature:.3f}, hardD = {hardD_s:.6f}, "
                     f"inside/cut = {inside}/{cut_size}, "
                     f"tau = [geo mean {geo_mean_tau:.3e}, max {max_tau:.3e}], "
                     # f"alpha = [geo mean {geo_mean_alpha:.3e}, max {max_alpha:.3e}], "
                     # f"Pr[out => in] = {PQ.meters['opt_s/out_to_in'].mean:.6f}, "
                     f"Pr[accept] = {PQ.meters['opt_s/accept'].mean:.3f}, "
                     f"Pr[going out] = {PQ.meters['opt_s/going_out'].mean:.3f}, "
                     # f"# valid = [s = {PQ.meters['opt_s/n_s_valid'].mean:.0f}, "
                     # f"pool = {PQ.meters['opt_s/n_pool_valid'].mean:.0f}], "
                     f"L 25/50/75% = {L_dist}, "
                     # f"resampled = {self.n_resampled}"
                     )
        PQ.meters.purge('opt_s/')
        self.n_resampled = 0

        return {
            'inside': inside
        }
