from typing import Optional, cast

import torch
import torch.nn.functional as F
from torch import nn

from ..encoders import Encoder, EncoderWithAction
from .base import DropContinuousQFunction
from .utility import compute_huber_loss, compute_reduce, pick_value_by_action


class DropContinuousMeanQFunction(DropContinuousQFunction, nn.Module):  # type: ignore
    _encoder: EncoderWithAction
    _action_size: int
    _embedding_size: int
    # _fc: nn.Sequential

    def __init__(self, encoder: EncoderWithAction, embedding_size: int):
        super().__init__()
        self._encoder = encoder
        self._action_size = encoder.action_size
        self._embedding_size = embedding_size
        # self._fc = nn.Linear(encoder.get_feature_size(), 1)
        self._fc = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+self._embedding_size, 512),
            # nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(512, 512),
            # nn.Dropout(0.5),
            # nn.ReLU(),
            nn.Linear(512, 1),
        )

    def forward(self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        h = self._encoder(x, action)
        return cast(torch.Tensor, self._fc(torch.cat([h, e], dim=1)))

    def compute_error(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        embeddings: torch.Tensor,
        rewards: torch.Tensor,
        target: torch.Tensor,
        terminals: torch.Tensor,
        gamma: float = 0.99,
        reduction: str = "mean",
    ) -> torch.Tensor:
        value = self.forward(observations, actions, embeddings)
        y = rewards + gamma * target * (1 - terminals)
        # loss = F.mse_loss(value, y, reduction=reduction)
        # return loss
        loss = F.mse_loss(value, y, reduction="none")
        return compute_reduce(loss, reduction)
    
    def compute_error_pre(
        self,
        next_observations: torch.Tensor,
        next_actions: torch.Tensor,
        embeddings: torch.Tensor,
        rewards: torch.Tensor,
        terminals: torch.Tensor,
        target_pre: torch.Tensor,
        Inits: torch.Tensor,
        Rs: torch.Tensor,
        gamma: float = 0.99,
        reduction: str = "mean",
    ) -> torch.Tensor:
        value = self.forward(next_observations, next_actions, embeddings)
        y = rewards + gamma * value * (1 - terminals)
        # loss = F.mse_loss(value, y, reduction=reduction)
        # return loss
        loss = F.mse_loss(target_pre*(1-Inits)+Rs, y, reduction="none")
        return compute_reduce(loss, reduction)

    def compute_target(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        return self.forward(x, action, e)

    @property
    def action_size(self) -> int:
        return self._action_size

    @property
    def encoder(self) -> EncoderWithAction:
        return self._encoder
