import asyncio
import operator as op
import threading
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import asdict, dataclass, field
from functools import partial, reduce
from typing import Generic, Optional, T, TypeVar, Union

import torch
from cupy.cuda.device import Device as CupyDevice
from torch import nn
from torch.multiprocessing import reductions
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr/
_loop = asyncio.new_event_loop()
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)


# This will block the calling thread until the coroutine is finished.
# Any exception that occurs in the coroutine is raised in the caller
def run_async(coro):  # coro is a couroutine, see example
    if not _thr.is_alive():
        _thr.start()
    future = asyncio.run_coroutine_threadsafe(coro, _loop)
    return future.result()


# From torchhist
def ravel_multi_index(coords: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    r"""Converts a tensor of coordinate vectors into a tensor of flat indices.

    This is a `torch` implementation of `numpy.ravel_multi_index`.

    Args:
        coords: A tensor of coordinate vectors, (*, D).
        shape: The source shape.

    Returns:
        The raveled indices, (*,).
    """

    shape = coords.new_tensor(shape + (1,))
    coefs = shape[1:].flipud().cumprod(dim=0).flipud()

    return (coords * coefs).sum(dim=-1)


def unravel_index(indices: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    r"""Converts a tensor of flat indices into a tensor of coordinate vectors.

    This is a `torch` implementation of `numpy.unravel_index`.

    Args:
        indices: A tensor of flat indices, (*,).
        shape: The target shape.

    Returns:
        The unraveled coordinates, (*, D).
    """

    shape = indices.new_tensor(shape + (1,))
    coefs = shape[1:].flipud().cumprod(dim=0).flipud()

    return torch.div(indices[..., None], coefs, rounding_mode="trunc") % shape[:-1]


PyTreePath = Sequence[Union[str, int]]


def tree_normalize_path(path: PyTreePath):
    def process_atom(a):
        try:
            return int(a)
        except ValueError:
            return a

    def process_molecule(m):
        if isinstance(m, str):
            return m.split(".")
        return m

    path = tree_map(process_molecule, path)
    path = tree_map(process_atom, path)
    path = tree_flatten(path)[0]
    return path


def tree_index(tree, path: PyTreePath):
    path = tree_normalize_path(path)
    subtree = tree
    for i, atom in enumerate(path):
        if hasattr(subtree, str(atom)):
            subtree = getattr(subtree, str(atom))
        else:
            subtree = subtree[atom]

    return subtree


def tree_multimap(f, *trees):
    flats, specs = zip(*(tree_flatten(tree) for tree in trees))

    def eq_checker(a, b):
        assert a == b
        return a

    reduce(eq_checker, specs)
    spec = next(iter(specs))
    mapped = [f(*t) for t in zip(*flats)]
    return tree_unflatten(mapped, spec)


def tree_reduce(f, tree):
    flat, _ = tree_flatten(tree)
    return reduce(f, flat)


def tree_linear(*terms):
    assert len(terms) > 0

    def inner(*tensors):
        return reduce(op.add, (a * t for t, (a, _) in zip(tensors, terms)))

    return tree_multimap(inner, *(t for _, t in terms))


def tree_mean(*sds):
    return tree_linear(*((1 / len(sds), sd) for sd in sds))


def tree_vdot(tree1, tree2):
    def vdot(a, b):
        return torch.sum(a * b)
        # return torch.vdot(a.ravel(), b.ravel())

    return tree_reduce(torch.add, tree_multimap(vdot, tree1, tree2))


def tree_norm(tree):
    return torch.sqrt(tree_vdot(tree, tree))


def tree_cosine_sim(tree1, tree2):
    return tree_vdot(tree1, tree2) / tree_norm(tree1) / tree_norm(tree2)


def tree_to(tree, dev=0):
    return tree_map(lambda t: t.to(dev), tree)


# https://github.com/pytorch/pytorch/blob/main/torch/multiprocessing/reductions.py
@dataclass
class TensorRef:
    device_type: str
    metadata: tuple
    device_pci_bus_id: Optional[str]

    def __init__(self, tensor: torch.Tensor):
        tensor = tensor.detach()
        self.device_type = tensor.device.type
        _, self.metadata = reductions.reduce_tensor(tensor)
        if self.device_type == "cuda":
            self.device_pci_bus_id = CupyDevice(tensor.device.index).pci_bus_id

    @classmethod
    def maybe_make_ref(cls, t):
        if isinstance(t, (torch.Tensor, torch.nn.Parameter)):
            return cls(t)
        return t

    def __call__(self) -> torch.Tensor:
        if self.device_type == "cpu":
            return reductions.rebuild_tensor(*self.metadata)
        elif self.device_type == "cuda":
            # The device index might change because this proccess has a different
            # value for CUDA_VISIBLE_DEVICES, so we patch the index here
            metadata = list(self.metadata)
            assert isinstance(metadata[6], int)
            metadata[6] = CupyDevice.from_pci_bus_id(self.device_pci_bus_id).id
            return reductions.rebuild_cuda_tensor(*metadata)


class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
    def __init__(self, device=None):
        self.device = device

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if getattr(func, "__module__", None) == "torch.nn.init":
            if "tensor" in kwargs:
                return kwargs["tensor"]
            else:
                return args[0]
        if (
            self.device is not None
            and func in torch.utils._device._device_constructors()
            and kwargs.get("device") is None
        ):
            kwargs["device"] = self.device
        return func(*args, **kwargs)
