import torch
from torch import Tensor
from typing import Callable
import torch.nn.functional as F
from attack_utils import get_tokenizer, get_attack_template, get_conv_template
from util import unravel_index
import asyncio
    
class ZouAttack:
    def __init__(
        self, eval_model_name, proxy_model_name=None, eval_dpc=None, proxy_dpc=None,
        device=None
    ):
        proxy_model_name = proxy_model_name or eval_model_name
        self.eval_model_name = eval_model_name
        self.proxy_model_name = proxy_model_name

        # TODO: create DPCs if needed
        self.eval_dpc, self.proxy_dpc = eval_dpc, proxy_dpc

        self.eval_tokenizer = get_tokenizer(eval_model_name)
        self.proxy_tokenizer = get_tokenizer(proxy_model_name)
        self.eval_conv_template = get_conv_template(eval_model_name)
        self.proxy_conv_template = get_conv_template(proxy_model_name)
        self.device = device

        self.embedding_table = self.proxy_dpc.get_embedding_table().to(self.device)

    async def do_attack(self, command, response, **kwargs):
        import ray
        worker = self.eval_dpc.predictors[0].workers[0]
        # print(1, ray.get(worker.apply.remote(torch.cuda.memory_allocated)))
        embedding_table = self.embedding_table
        local_dev = self.device
        eval_tokens = (
            eval_prefix_tokens,
            eval_suffix_tokens,
            eval_target_tokens,
        ) = get_attack_template(
            command, response, self.eval_conv_template, self.eval_tokenizer
        )
        (eval_prefix_tensor, eval_suffix_tensor, eval_target_tensor) = (
            torch.tensor(tokens, device=local_dev) for tokens in eval_tokens
        )

        proxy_tokens = (
            proxy_prefix_tokens,
            proxy_suffix_tokens,
            proxy_target_tokens,
        ) = get_attack_template(
            command, response, self.proxy_conv_template, self.proxy_tokenizer
        )
        (proxy_prefix_tensor, proxy_suffix_tensor, proxy_target_tensor) = (
            torch.tensor(tokens, device=local_dev) for tokens in eval_tokens
        )

        prompt_kvref = self.eval_dpc.get_kv_cache_ref()
        _ = await self.eval_dpc.forward(
            input_ids=eval_prefix_tensor.unsqueeze(0), cache_save=prompt_kvref, device=local_dev
        )

        async def loss_with_grad(input_ids, inspect=False):
            (bs, ntok), k = input_ids.shape, len(proxy_target_tensor)
            assert bs == 1
            P = proxy_prefix_tensor.expand(bs, -1)
            S = proxy_suffix_tensor.expand(bs, -1)
            T = proxy_target_tensor.expand(bs, -1)
            out, grad = await self.proxy_dpc.forward_with_grad(
                inputs_embeds=embedding_table[torch.hstack([P, input_ids, S, T])],
                labels=torch.hstack(
                    [
                        torch.full_like(P, -100),
                        torch.full_like(input_ids, -100),
                        torch.full_like(S, -100),
                        T,
                    ]
                ),
                device=local_dev,
            )
            return (
                out.loss * k,
                (grad * k)[:, len(proxy_prefix_tensor) : len(proxy_prefix_tensor) + ntok],
            )
        
        async def eval_loss_kv_batch(input_ids):
            k = len(eval_target_tensor)
            logprobs = F.log_softmax(
                (
                    await self.eval_dpc.forward(
                        torch.hstack(
                            [
                                input_ids,
                                eval_suffix_tensor.expand(len(input_ids), -1),
                                eval_target_tensor.expand(len(input_ids), -1),
                            ]
                        ),
                        cache_load=prompt_kvref,
                        device=local_dev,
                    )
                ).logits[:, -k - 1 : -1],
                dim=-1,
            )
            labels = (
                eval_target_tensor.expand(len(input_ids), -1).unsqueeze(-1).to(logprobs.device)
            )
            cum_logprobs = logprobs.gather(-1, labels).squeeze(-1).sum(1)
        
            greedy_tokens = logprobs.argmax(-1) 
            greedy_matches = (greedy_tokens != labels.squeeze(-1))
            any_mistake, mistake_idx = ((greedy_matches.cumsum(-1) == 1) & greedy_matches).max(-1)
            return (
                -cum_logprobs, 
                mistake_idx + ~any_mistake * len(eval_target_tensor),
                greedy_tokens
            )

        @torch.inference_mode()
        async def zou_attack(
            loss,
            proxy_loss_with_grad,
            ntok,
            topk=256,
            batch_size=1024,
            max_iter=500,
        ):
            global best, best_l
            vsize, embed_dim, dev = *embedding_table.shape, embedding_table.device
            batch = torch.full((1, ntok), 1738, device=dev)
            best_l, best = float("inf"), None
            for i in range(max_iter):
                _, grad = await proxy_loss_with_grad(batch)
                scores = grad @ embedding_table.T
                scores -= (grad * embedding_table[batch]).sum(2, keepdim=True)
        
                scores_topk = (
                    scores.argsort(-1)[:, :, :topk]
                    + torch.arange(ntok, device=dev)[:, None] * vsize
                ).ravel()
        
                randperm = torch.randn(ntok * topk, device=dev).argsort()[:batch_size]
                idxs, ids = unravel_index(scores_topk[randperm], scores.shape[1:]).T
                test_batch = batch.repeat(batch_size, 1)
                test_batch.scatter_(-1, idxs[:, None], ids[:, None])
                results = await asyncio.gather(
                    *(
                        loss(tbatch)
                        for tbatch in test_batch.reshape(-1, min(1024, batch_size), *test_batch.shape[1:])
                    )
                )
                losses = torch.vstack([r[0] for r in results])
                greedy_scores = torch.vstack([r[1] for r in results])
                greedy_tokens = torch.vstack([r[2] for r in results])
        
                l, best_idx = losses.ravel().min(0)
                batch = test_batch[None, best_idx]
        
                best_in_str = self.eval_tokenizer.decode(batch[0])
                best_out_str = self.eval_tokenizer.decode(greedy_tokens[best_idx])
        
                if greedy_scores.max().item() == len(eval_target_tokens):
                    idx = greedy_scores.ravel().argmax()
                    print("success")
                    return True, i, best_l, best[0].tolist()
        
                if l < best_l:
                    best_l, best = l.item(), batch
                print(
                    f"{i}: {best_l:05f} {l.item():05f} {greedy_scores.max().item()} {best_in_str!r} -> {best_out_str!r}"
                )
                
            return False, i, best_l, best[0].tolist()


        import itertools as it
        from minmaxheap import MinMaxHeap

        @torch.inference_mode()
        async def bfs_attack(
            loss,
            proxy_loss_with_grad,
            ntok,
            batch_size=128,
            queue_size=100,
            batches_per_node=1,
            max_iter=500 * 1024//128,
        ):
            best_l, best = float("inf"), None
            bs, vsize, edim = batch_size, *embedding_table.shape
            dev = embedding_table.device
            h = MinMaxHeap()
            # xs = torch.randint(vsize - 1, (batch_size, prompt_length), device=dev)
            xs = torch.full((1, ntok), 0, device=dev)
            for l, greedy_score, greedy_tokens, x in zip(*(await loss(xs)), xs):
                h.insert((l.item(), x.tolist(), 0, greedy_score, greedy_tokens.tolist()))
        
            best_l, best = float("inf"), None
            for i in range(max_iter):
                if len(h) == 0:
                    break
        
                l, x, age, greedy_score, greedy_tokens = h.popmin()
                x = torch.tensor(x, device=dev)
                in_str = self.eval_tokenizer.decode(x)
                out_str = self.eval_tokenizer.decode(greedy_tokens)
                print(f"{i}: {best_l:05f} {l:05f} {age} {greedy_score} {in_str!r} -> {out_str!r}")
                if l < best_l:
                    best_l, best = l, x
        
                _, grad = await loss_with_grad(x[None])
                scores = grad @ embedding_table.T
                scores -= (grad * embedding_table[x[None]]).sum(2, keepdim=True)
                k, B = batch_size * 200, batch_size * batches_per_node
                _, top_swaps = (-scores.ravel()).topk(k)
                _, randperms = torch.rand(k, device=dev).topk(B)
                selected_swaps = top_swaps.gather(-1, randperms)
                idxs, ids = unravel_index(selected_swaps, scores.shape[1:]).T
        
                test_batch = x.unsqueeze(0).repeat(B, 1)
                test_batch.scatter_(-1, idxs.unsqueeze(-1), ids.unsqueeze(-1))
                test_minibatches = test_batch.reshape(-1, 128, *test_batch.shape[1:])
                tasks = [asyncio.create_task(loss(tbatch)) for tbatch in test_minibatches]
                for task, tbatch in zip(tasks, test_minibatches):
                    for l2, greedy_score, greedy_tokens, x2 in zip(*(await task), tbatch):
                        if greedy_score == len(eval_target_tokens):
                            print("success")
                            return True, i, l2.item(), x2.tolist()

                        l2 = l2.item()
                        if len(h) < queue_size:
                            h.insert((l2, x2.tolist(), age + 1, greedy_score, greedy_tokens.tolist()))
                            continue
        
                        worst, *_ = h.peekmax()
                        if l2 < worst:
                            h.popmax()
                            h.insert((l2, x2.tolist(), age + 1, greedy_score, greedy_tokens.tolist()))
                            
            return False, i, best_l, best.tolist()
        

        result = await zou_attack(eval_loss_kv_batch, loss_with_grad, 20, **kwargs)
        # result = await bfs_attack(eval_loss_kv_batch, loss_with_grad, 20, **kwargs)
        del prompt_kvref
        return result



