import asyncio
import copy
import functools
import operator as op
import uuid
import warnings
from dataclasses import asdict, dataclass, field
from typing import Callable, Hashable, Optional, T, Tuple
from weakref import ReferenceType, ref
from transformers import DynamicCache, StaticCache

import numpy as np
import ray
import torch
import torch.utils._device
from ray.train import Checkpoint
from ray.air import ScalingConfig
from ray.air.util.torch_dist import (
    TorchDistributedWorker,
    _shutdown_torch_distributed,
    init_torch_dist_process_group,
)
from ray.train.predictor import Predictor
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils._pytree import map_only, tree_map
from util import EmptyInitOnDevice, TensorRef, tree_multimap

# WARNING: Do not import deepspeed (or by extension transformers) here
# Doing so may lead to cryptic torch errors w.r.t. device ordinals/numbers
# import deepspeed
# import transformers

ray.util.register_serializer(torch.Tensor, serializer=TensorRef, deserializer=op.call)


@dataclass
class DPCConfig:
    model_name: str
    # This is how many DeepSpeed-inference replicas to run for
    # this batch inference job.
    num_worker_groups: int = 1
    # Number of DeepSpeed workers per group
    num_workers_per_group: int = 1
    dtype: str = "float16"
    model_class: str = "AutoModelForCausalLM"
    deepspeed_kwargs: dict = field(default_factory=dict)
    fsdp_compile: bool = False
    mode: str = "inference"


@ray.remote
class PredictionWorker(TorchDistributedWorker):
    """A PredictionWorker is a Ray remote actor that runs a single shard of a DeepSpeed job.

    Multiple PredictionWorkers of the same WorkerGroup form a PyTorch DDP process
    group and work together under the orchestration of DeepSpeed.
    """

    def __init__(self, config: DPCConfig, world_size: int):
        self.config = config
        self.world_size = world_size
        self.device = None
        self.dtype = getattr(torch, config.dtype)
        self.kv_cache = {}
        ray.util.register_serializer(
            torch.Tensor, serializer=TensorRef, deserializer=op.call
        )

    def init_model(self, local_rank: int):
        """Initialize the deepspeed model for inference."""
        import transformers

        transformers.logging.disable_progress_bar()

        self.local_rank = local_rank
        self.device = torch.device("cuda", index=local_rank)
        with EmptyInitOnDevice():
            model = getattr(transformers, self.config.model_class).from_pretrained(
                self.config.model_name,
                torch_dtype=self.dtype,
                trust_remote_code=True,
            )

        if self.config.mode == "inference":
            import deepspeed

            self.model = deepspeed.init_inference(
                model,
                tp={"tp_size": self.world_size},
                dtype=self.dtype,
                **self.config.deepspeed_kwargs,
            )

        elif self.config.mode == "fsdp":
            torch.cuda.set_device(self.device)
            self.embedding_table = (
                model.get_input_embeddings()._parameters["weight"].detach().clone()
            )
            model.gradient_checkpointing_enable()
            auto_wrap_policy = functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls={
                    transformers.models.llama.modeling_llama.LlamaDecoderLayer,
                },
            )

            self.model = FSDP(
                model,
                auto_wrap_policy=auto_wrap_policy,
                device_id=self.device.index,
                limit_all_gathers=True,
                mixed_precision=MixedPrecision(
                    param_dtype=self.dtype,
                    reduce_dtype=self.dtype,
                    buffer_dtype=self.dtype,
                    cast_forward_inputs=True,
                    keep_low_precision_grads=False,
                ),
                use_orig_params=True,
            )

            if self.config.fsdp_compile:
                self.model = torch.compile(self.model)

        else:
            raise NotImplementedError(f"unknown mode '{self.config.mode}'")

        self.model.eval()
            
        # Just double check that we'll never take parameter gradients
        for param in self.model.parameters():
            param.requires_grad_(False)

    def prepare_batch(self, batch):
        if not isinstance(batch, np.ndarray):
            return batch

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            batch = torch.from_numpy(batch).to(self.device)

        if torch.is_floating_point(batch):
            batch.to(self.dtype)

        return batch

    def prepare_kv_cache(self, key, batch_size):
        def kv_expander(t):
            if t.shape[0] == batch_size:
                return t
            elif t.shape[0] == 1:
                return t.expand(batch_size, *[-1] * (len(t.shape) - 1))
            else:
                raise NotImplementedError(
                    f"Only support kv_cache size 1 or {batch_size}, not {t.shape[0]}"
                )

        pkv = self.kv_cache[key]
        if isinstance(pkv, tuple):
            pkv = tree_map(kv_expander, pkv)
        if isinstance(pkv, tuple):
            pkv = DynamicCache.from_legacy_cache(pkv)
        return pkv

    @torch.inference_mode()
    def forward(
        self,
        cache_load: Optional[Hashable] = None,
        cache_save: Optional[Hashable] = None,
        full_logits=False,
        **kwargs,
    ):
        kwargs = {k: self.prepare_batch(v) for k, v in kwargs.items()}
        batch_size = (
            len(kwargs["input_ids"])
            if kwargs.get("input_ids") is not None
            else len(kwargs["inputs_embeds"])
        )

        if cache_load is not None:
            kwargs["past_key_values"] = self.prepare_kv_cache(cache_load, batch_size)

        assert "use_cache" not in kwargs
        kwargs["use_cache"] = cache_save is not None

        out = self.model(**kwargs)

        if cache_save is not None:
            self.kv_cache[cache_save] = out.past_key_values
            out.past_key_values = cache_save

        if self.local_rank == 0:
            return out

    def forward_with_grad(
        self,
        cache_load: Optional[Hashable] = None,
        cache_save: Optional[Hashable] = None,
        full_logits=False,
        **kwargs,
    ):
        kwargs = {k: self.prepare_batch(v) for k, v in kwargs.items()}
        batch_size = (
            len(kwargs["input_ids"])
            if kwargs.get("input_ids") is not None
            else len(kwargs["inputs_embeds"])
        )

        if cache_load is not None:
            kwargs["past_key_values"] = self.prepare_kv_cache(cache_load, batch_size)

        assert "use_cache" not in kwargs
        kwargs["use_cache"] = cache_save is not None

        kwargs["inputs_embeds"] = kwargs["inputs_embeds"].clone().requires_grad_()
        out = self.model(**kwargs)
        out.loss.backward()
        grad = kwargs["inputs_embeds"].grad

        if cache_save is not None:
            self.kv_cache[cache_save] = out.past_key_values
            out.past_key_values = cache_save

        if self.local_rank == 0:
            return out, grad

    @torch.inference_mode()
    def generate(self, batch, cache_load: Optional[Hashable] = None, **kwargs):
        import transformers

        batch = self.prepare_batch(batch)

        if cache_load is not None:
            kwargs["past_key_values"] = self.prepare_kv_cache(cache_load, len(batch))

        out = self.model.generate(batch, **kwargs)

        return out

    def drop_kv_cache(self, key=None):
        if key is None:
            del self.kv_cache
            self.kv_cache = {}

        else:
            self.kv_cache.pop(key, None)

    def apply(self, func: Callable[..., T], *args, **kwargs) -> T:
        return func(self, *args, **kwargs)

    @torch.inference_mode()
    def get_embedding_table(self):
        return self.embedding_table
        # with FSDP.summon_full_params(self.model, recurse=True, writeback=False):
        #     return self.model.model.embed_tokens.weight


# TODO: This Predictor should be part of Ray AIR.
class DeepSpeedPredictor:
    def __init__(self, config: DPCConfig) -> None:
        self.config = config
        self.workers = []
        self.scaling_config = ScalingConfig(
            use_gpu=True,
            num_workers=config.num_workers_per_group,
            resources_per_worker={"CPU": 1},
            trainer_resources={"CPU": 0},
        )

    @classmethod
    @functools.wraps(__init__)
    async def create(cls, *args, **kwargs):
        self = cls(*args, **kwargs)
        await self.init_worker_group(self.scaling_config)
        return self

    def __del__(self):
        # aka: shutdown_torch_dist_process_group(self.workers)
        futures = [w.execute.remote(_shutdown_torch_distributed) for w in self.workers]
        ray.get(futures)

        # politely ask workers to exit
        futures = [w.exit.remote() for w in self.workers]
        ray.wait(futures)

        # force workers to exit
        for worker in self.workers:
            ray.kill(worker)

    async def init_worker_group(self, scaling_config: ScalingConfig):
        """Create the worker group.

        Each worker in the group communicates with other workers through the
        torch distributed backend. The worker group is inelastic (a failure of
        one worker destroys the entire group). Each worker in the group
        recieves the same input data and outputs the same generated text.
        """
        config = self.config

        # Start a placement group for the workers.
        self.pg = scaling_config.as_placement_group_factory().to_placement_group()

        worker_cls = PredictionWorker.options(
            num_cpus=scaling_config.num_cpus_per_worker,
            num_gpus=scaling_config.num_gpus_per_worker,
            resources=scaling_config.additional_resources_per_worker,
            scheduling_strategy=PlacementGroupSchedulingStrategy(
                placement_group=self.pg, placement_group_capture_child_tasks=True
            ),
        )

        # Create the prediction workers.
        self.workers = [
            worker_cls.remote(config, scaling_config.num_workers)
            for i in range(scaling_config.num_workers)
        ]

        # Initialize torch distributed process group for the workers.
        self.local_ranks = init_torch_dist_process_group(self.workers, backend="nccl")

        # Initialize the model on each worker.
        # see https://github.com/ray-project/ray/issues/7815
        futures = [
            worker.init_model.remote(local_rank)
            for worker, local_rank in zip(self.workers, self.local_ranks)
        ]
        await asyncio.gather(*futures)

    async def forward(self, **kwargs):
        kwargs_ref = {
            k: (ray.put(v) if isinstance(v, np.ndarray) else v)
            for k, v in kwargs.items()
        }
        futures = [worker.forward.remote(**kwargs_ref) for worker in self.workers]

        return (await asyncio.gather(*futures))[0]

    async def forward_with_grad(self, **kwargs):
        kwargs_ref = {
            k: (ray.put(v) if isinstance(v, np.ndarray) else v)
            for k, v in kwargs.items()
        }
        futures = [
            worker.forward_with_grad.remote(**kwargs_ref) for worker in self.workers
        ]

        return (await asyncio.gather(*futures))[0]

    async def generate(self, batch, **kwargs):
        batch_ref = ray.put(batch)
        futures = [
            worker.generate.remote(batch_ref, **kwargs) for worker in self.workers
        ]

        return (await asyncio.gather(*futures))[0]

    def drop_kv_cache(self, key=None):
        # blocking so we can call from __del__
        # it's fast anyway
        futures = [worker.drop_kv_cache.remote(key) for worker in self.workers]
        ray.get(futures)


class DeepSpeedPredictorCluster:
    def __init__(self, config):
        self.config = config
        self.predictors = []

    @classmethod
    @functools.wraps(__init__)
    async def create(cls, *args, **kwargs):
        self = cls(*args, **kwargs)
        await self.init_predictors()
        return self

    async def init_predictors(self):
        self.predictors = await asyncio.gather(
            *[
                DeepSpeedPredictor.create(self.config)
                for _ in range(self.config.num_worker_groups)
            ]
        )

    def prepare_batch(self, batch):
        # pad the batch to the right size
        n = self.config.num_worker_groups
        batch = torch.vstack(
            [batch, batch[-1].expand(-len(batch) % n, *batch.shape[1:])]
        ).view(n, -1, *batch.shape[1:])

        return batch.cpu().numpy()

    def unwrap_kvcache_refs(self, kwargs):
        # unwrap DPCKVCacheRef for convenience
        if isinstance(kwargs.get("cache_load", None), DPCKVCacheRef):
            kwargs["cache_load"] = kwargs["cache_load"].key
        if isinstance(kwargs.get("cache_save", None), DPCKVCacheRef):
            kwargs["cache_save"] = kwargs["cache_save"].key
        return kwargs

    def process_outputs(self, outs, orig_batch_size, device=None):
        out0 = next(iter(outs))
        if out0 is None:
            return None

        if isinstance(out0, torch.Tensor):
            dev = out0.device if device is None else device
            return torch.vstack([out.to(dev) for out in outs])[:orig_batch_size]

    async def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        return await self.forward_inner(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            **kwargs,
        )

    async def forward_inner(self, device=None, **kwargs):
        kwargs = self.unwrap_kvcache_refs(kwargs)
        n = None

        # call the predictors
        batched_kwargs = {}
        for argname in (
            "input_ids",
            "attention_mask",
            "position_ids",
            "inputs_embeds",
            "labels",
        ):
            if kwargs.get(argname) is not None:
                arg = kwargs.pop(argname)
                n = len(arg)
                batched_kwargs[argname] = self.prepare_batch(arg)

        batched_kwargs_t = [
            dict(zip(batched_kwargs.keys(), vals))
            for vals in zip(*batched_kwargs.values())
        ]
        # print(batched_kwargs_t, kwargs)
        outputs = await asyncio.gather(
            *[
                pred.forward(**bkwargs, **kwargs)
                for pred, bkwargs in zip(self.predictors, batched_kwargs_t)
            ]
        )

        # assert all((not hasattr(out, "loss") or out.loss is None) for out in outputs)

        # stack the outputs
        out0 = next(iter(outputs))
        return out0.__class__(
            **tree_multimap(
                lambda *ts: self.process_outputs(ts, n, device), *map(dict, outputs)
            )
        )

    async def forward_with_grad(self, device=None, **kwargs):
        kwargs = self.unwrap_kvcache_refs(kwargs)
        n = None

        # call the predictors
        batched_kwargs = {}
        for argname in (
            "input_ids",
            "attention_mask",
            "position_ids",
            "inputs_embeds",
            "labels",
        ):
            if kwargs.get(argname) is not None:
                arg = kwargs.pop(argname)
                n = len(arg)
                batched_kwargs[argname] = self.prepare_batch(arg)

        batched_kwargs_t = [
            dict(zip(batched_kwargs.keys(), vals))
            for vals in zip(*batched_kwargs.values())
        ]
        outputs = await asyncio.gather(
            *[
                pred.forward_with_grad(**bkwargs, **kwargs)
                for pred, bkwargs in zip(self.predictors, batched_kwargs_t)
            ]
        )

        # stack the outputs
        out0 = next(iter(outputs))
        return out0[0].__class__(
            **tree_multimap(
                lambda *ts: self.process_outputs(ts, n, device),
                *[dict(out) for out, _ in outputs],
            ),
        ), self.process_outputs([grad for _, grad in outputs], n, device)

    async def generate(self, input_ids, device=None, **kwargs):
        import transformers

        kwargs = self.unwrap_kvcache_refs(kwargs)

        n = len(input_ids)

        subbatches = self.prepare_batch(input_ids)
        outputs = await asyncio.gather(
            *[p.generate(b, **kwargs) for p, b in zip(self.predictors, subbatches)]
        )

        out0 = next(iter(outputs))
        if isinstance(out0, transformers.utils.ModelOutput):
            return out0.__class__(
                **tree_multimap(
                    lambda *ts: self.process_outputs(ts, n, device), *map(dict, outputs)
                )
            )
        elif isinstance(out0, torch.Tensor):
            return self.process_outputs(outputs, n, device)
        else:
            raise NotImplementedError(f"Unknown type {type(out0)}")

    def drop_kv_cache(self, key=None):
        for worker in self.predictors:
            worker.drop_kv_cache(key)

    def get_kv_cache_ref(self):
        return DPCKVCacheRef(uuid.uuid4().int, ref(self))

    def get_embedding_table(self):
        return ray.get(
            [
                worker.get_embedding_table.remote()
                for worker in self.predictors[0].workers
            ]
        )[0]


@dataclass
class DPCKVCacheRef:
    key: Hashable
    dpcref: ReferenceType[DeepSpeedPredictorCluster]

    def __del__(self):
        self.dpcref().drop_kv_cache(self.key)
