import random
from typing import List, Dict

import metrics
import numpy as np
import torch
from lightning.pytorch import LightningModule
from sklearn.preprocessing import StandardScaler
import utils

from dct import LinearDCT
import torch.nn as nn


class AbstractLitModule(LightningModule):
    def __init__(
        self,
        state_dim: int,
        n_step: int,
        n_iv_steps: int,
        param_dim: int,
        n_views: int,
        learning_rate: float = 1e-5,
        eval_metrics=[],
        code_sharing: Dict[int,List[int]] = None,
        factor_type="discrete",
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.state_dim = state_dim
        self.param_dim = param_dim
        self.n_views = n_views
        self.learning_rate = learning_rate
        self.n_iv_steps = n_iv_steps
        self.n_step = n_step

        self.misc = {
            "pred_params": [],
            "pred_states": [],
            "gbt": [],
            "gt_params": [],
            "r2_linear": [],
            "r2_nonlinear": [],
        }

        self.eval_metrics = eval_metrics
        self.code_sharing = code_sharing # {factor_id: List[view ids]} # len = GT param dim
        
        if self.code_sharing is None:
            print("Code sharing map is None, latents will not be averaged, make sure this is intended.")
            if 'r2' in self.eval_metrics:
                self.eval_metrics.remove("r2")
            self.shared_encoding = [range(self.param_dim)]
        else:
            self.shared_encoding = list(code_sharing.keys())


        self.feature_sharing_fn = lambda params, **kwargs: utils.feature_sharing_fn(
            params, num_views=self.n_views, code_sharing=self.code_sharing, **kwargs
        )
        self.factor_type = factor_type
        
        
        # add dct transform if necessary:
        self.dct_layer: bool = kwargs.get("dct_layer", True)
        self.freq_frac_to_keep: float = kwargs.get("freq_frac_to_keep", 0.5)
        if self.dct_layer:
            self.dct: nn.Module = LinearDCT(self.n_step, "dct", norm="ortho").double()
            self.idct: nn.Module = LinearDCT(self.n_step, "idct", norm="ortho").double()
            self.input_dim = int(self.freq_frac_to_keep * self.n_step) * self.state_dim
        else:
            self.input_dim = self.n_step * self.state_dim
    
    
    def state_transform(self, states: torch.Tensor):
        # states: [n_views, bs, n_step, state_dim]
        freqs: torch.Tensor = self.dct(states.swapaxes(-1, -2)).swapaxes(-1, -2)
        return freqs[..., : int(self.freq_frac_to_keep * self.n_step), :]

    def state_inverse_transform(self, freqs: torch.Tensor):
        # freqs: [bs, n_freqs_to_keep, state_dim]
        # fill the high-frequency that we droped before with zero
        freqs: torch.Tensor = torch.cat(
            [freqs, torch.zeros(*freqs.shape[:2], self.n_step - freqs.shape[-2], freqs.shape[-1]).type_as(freqs)],
            dim=-2,
        )
        return self.idct(freqs.swapaxes(-1, -2)).swapaxes(-1, -2)

    def forward(self, states: torch.Tensor):
        raise NotImplementedError

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError

    def on_validation_epoch_end(self):
        if len(self.eval_metrics) == 0:
            pass
        else:
            pred_params = np.concatenate(self.misc["pred_params"], 0).squeeze()  # (ds, param_dim)
            gt_params = np.concatenate(self.misc["gt_params"], 0)  # (ds, param_dim)
            # compute DCI scores
            X = pred_params # [..., :4]
            n_train = int(len(X) * 0.8)
            if self.factor_type == "discrete":
                y = gt_params.astype("int8")  # bs, 4
            else:
                y = gt_params

            # shuffle validation data
            zipped = list(zip(X, y))
            random.shuffle(zipped)
            X, y = zip(*zipped)

            X, y = np.asanyarray(X), np.asanyarray(y)


            if "dci" in self.eval_metrics:
                
                # compute DCI scores
                dci_score, gbt = metrics._compute_dci(
                    X[:n_train].T,
                    y[:n_train].T,
                    X[n_train:].T,
                    y[n_train:].T,
                    factor_types=[self.factor_type] * y.shape[-1],
                )
                self.misc["gbt"].append(gbt)
                self.log("disent_score", round(dci_score["disentanglement"], 4))
                
                
                
            if "r2" in self.eval_metrics:
                # compute r2 scores
                # let the first partition to encode the sharing part
                Xs: List[np.ndarray] = [X[:, self.shared_encoding[0]]]  #[X[:, s] for s in self.shared_encoding]
                r2_linear, r2_nonlinear = metrics._compute_r2(
                    Xs,
                    y.T,
                    factor_types=[self.factor_type] * y.shape[-1],
                )
                self.misc["r2_linear"].append(r2_linear)
                self.misc["r2_nonlinear"].append(r2_nonlinear)
                self.log("r2_linear", round(r2_linear.max(-1).mean(), 4), on_epoch=True, batch_size=X.shape[0])
                self.log("r2_nonlinear", round(r2_nonlinear.max(-1).mean(), 4), on_epoch=True, batch_size=X.shape[0])

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
