from abc import ABC
from typing import Optional, Any, Sequence, Dict

import optuna
from torch import Tensor, nn

from adl4cv.parameters.params import HyperParameterSet, DefinitionSet, DefinitionSpace, HyperParameterSpace
from adl4cv.utils.utils import SerializableEnum


class LossType(SerializableEnum):
    """Definition of the available losses"""
    CrossEntropy = "CrossEntropy"
    CenterLoss = "CenterLoss"
    BatchCenterLoss = "BatchCenterLoss"


class LossDefinition(DefinitionSet, ABC):
    """Abstract definition of a Loss"""

    def __init__(self, type: LossType = None, hyperparams: HyperParameterSet = None):
        super().__init__(type, hyperparams)


# ----------------------------------- CrossEntropy -----------------------------------


class CrossEntropyHyperParameterSet(HyperParameterSet):
    """HyperParameterSet of the PyTorchLightningTrainer"""

    def __init__(self,
                 weight: Optional[Tensor] = None,
                 ignore_index: int = -100,
                 reduction: str = 'mean',
                 **kwargs: Any):
        """
        Creates new HyperParameterSet
        :param runtime_mode: The device to be used
        :func:`~Trainer.__init__`
        """
        super().__init__(**kwargs)

        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction


class CrossEntropyDefinition(LossDefinition):
    """Definition of the PyTorchLightningTrainer"""

    def __init__(self, hyperparams: CrossEntropyHyperParameterSet = CrossEntropyHyperParameterSet()):
        super().__init__(LossType.CrossEntropy, hyperparams)

    def instantiate(self, *args, **kwargs):
        return nn.CrossEntropyLoss(*args, **self.hyperparams, **kwargs)
