import inspect
from typing import Any, Protocol, Sequence, TypedDict

from flax.struct import dataclass
from jax import Array
from jaxtyping import PyTree, PRNGKeyArray  # noqa
import numpy as np


class SystemDefinitionDict(TypedDict):
    name: str
    config: dict[str, Any]


SystemDefinitionTuple = tuple[str, *tuple[Any, ...]]
SystemDefinition = SystemDefinitionDict | SystemDefinitionTuple
SystemDefinitions = Sequence[SystemDefinition]

Position = tuple[float, float, float] | list[float] | np.ndarray
Charges = tuple[int, ...] | np.ndarray
Spins = tuple[int, int]


@dataclass
class SystemConfigs:
    spins: tuple[Spins, ...]
    charges: tuple[Charges, ...]

    @property
    def n_elec(self):
        return tuple(sum(s) for s in self.spins)

    @property
    def n_nuc(self):
        return tuple(len(c) for c in self.charges)

    @property
    def sub_configs(self):
        return tuple(
            SystemConfigs((s,), (c,)) for s, c in zip(self.spins, self.charges)
        )

    @property
    def n_mols(self):
        return len(self.spins)

    def flat_charges(self):
        return np.concatenate(self.charges)


class MolProperty:
    """
    Base class for molecular properties.
    """

    @property
    def value(self) -> Array:
        """
        Returns the value of the property.
        """
        raise NotImplementedError()

    @property
    def key(self) -> str:
        """
        Returns the key of the property.
        """
        raise NotImplementedError()

    def update(self, **kwargs):
        """
        Updates the property with given keyword arguments.

        Args:
            **kwargs: Keyword arguments to update the property.

        Returns:
            The updated property.
        """
        params = inspect.signature(self._update).parameters
        if any(k not in kwargs for k in params.keys()):
            return None
        return self._update(**{k: v for k, v in kwargs.items() if k in params})

    def _update(self) -> Array:
        """
        Updates the property.

        Returns:
            Value of the updated property.
        """
        raise NotImplementedError()

    def state_dict(self) -> dict:
        """
        Returns the state dictionary of the property.
        """
        raise NotImplementedError()

    def load_state_dict(self, state: dict) -> None:
        """
        Loads the state dictionary to the property.

        Args:
            state: State dictionary to load.
        """
        raise NotImplementedError()


class MolPropertyConstructor(Protocol):
    def __call__(self) -> MolProperty: ...


class ChunkSizeFunction(Protocol):
    def __call__(self, s: Spins, c: Charges) -> int: ...


class OrbitalMatchingFunction(Protocol):
    def __call__(
        self,
        nn_orbitals: Sequence[PyTree],
        hf_orbitals: Sequence[tuple[Array, Array]],
        config: SystemConfigs,
        cache: Sequence[PyTree],
    ) -> (
        tuple[Sequence[Sequence[Array]], Sequence[Sequence[Array]]]
        | tuple[Sequence[Sequence[Array]], Sequence[Sequence[Array]], Sequence[PyTree]]
    ): ...


class HfOrbitalFunction(Protocol):
    def __call__(
        self,
        electrons: Array,
    ) -> tuple[Array, Array]: ...
