import os

import numpy as np
import torch
import xarray as xr
import math
from scipy import io
from torch.utils.data.dataset import Dataset

from ntldm.networks.networks import SinusoidalPosEmb
from ntldm.utils.utils import standardize_array
from sklearn.decomposition import PCA

class NER_BCI(Dataset):
    """
    Dataset of for the Inria EEG BCI data.

    Supports positional embeddings.
    """

    def __init__(
        self,
        patient_id,
        with_time_emb=False,
        cond_time_dim=32,
        filepath=None,
    ):
        super().__init__()

        self.with_time_emb = with_time_emb
        self.cond_time_dim = cond_time_dim
        self.signal_length = 260
        self.num_channels = 56

        temp_array = np.load(os.path.join(filepath, f"{patient_id}_data.npy"))
        self.data_array = standardize_array(temp_array, ax=(0, 2))

        temp_emb = SinusoidalPosEmb(cond_time_dim).forward(
            torch.arange(self.signal_length)
        )
        self.emb = torch.transpose(temp_emb, 0, 1)

    def __getitem__(self, index, cond_channel=None):
        return_dict = {}
        return_dict["signal"] = torch.from_numpy(np.float32(self.data_array[index]))
        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond
        return return_dict

    def get_cond(self):
        cond = None
        if self.with_time_emb:
            cond = self.emb
        return cond

    def __len__(self):
        return len(self.data_array)


class NER_BCI_PCA(Dataset):
    """
    Dataset of for the Inria EEG BCI data. 
    PCA is applied to the data over the channel dimension
    and only the first n_pca components are used.

    Supports positional embeddings.
    """

    def __init__(
        self,
        patient_id,
        with_time_emb=False,
        cond_time_dim=32,
        filepath=None,
        n_pca_components=10,
    ):
        super().__init__()
        self.with_time_emb = with_time_emb
        self.cond_time_dim = cond_time_dim
        self.signal_length = 260
        self.num_channels = 56
        self.n_pca_components = n_pca_components

        # shape (340, 56, 260) samples x channels x time
        temp_array = np.load(os.path.join(filepath, f"{patient_id}_data.npy"))
        self.data_array = standardize_array(temp_array, ax=(0, 2))
        
        # PCA on the data over the channel dimension
        n_samples, n_channels, n_time = self.data_array.shape
        # concatenate the samples and time dimension to perform PCA on channels
        
        reshaped_data = self.data_array.transpose(0, 2, 1).reshape(n_samples * n_time, n_channels)

        # Apply PCA
        pca = PCA(n_components=self.n_pca_components)
        pca.fit(reshaped_data)
        projected_data = pca.fit_transform(reshaped_data)

        # time series of pca scores 
        pc_scores = pca.transform(reshaped_data).reshape(n_samples, n_time, self.n_pca_components).transpose(0, 2, 1)
        self.data_array = pc_scores
        
        self.pca = pca
                
        # # Inverse transform the projected data back to the original space
        # reconstructed_data = pca.inverse_transform(projected_data)

        # # Reshape reconstructed data back to original format
        # reconstructed_data = reconstructed_data.reshape(n_samples, n_time, n_channels).transpose(0, 2, 1)

        temp_emb = SinusoidalPosEmb(cond_time_dim).forward(
            torch.arange(self.signal_length)
        )
        self.emb = torch.transpose(temp_emb, 0, 1)

    def __getitem__(self, index, cond_channel=None):
        return_dict = {}
        return_dict["signal"] = torch.from_numpy(np.float32(self.data_array[index]))
        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond
        return return_dict

    def get_cond(self):
        cond = None
        if self.with_time_emb:
            cond = self.emb
        return cond

    def __len__(self):
        return len(self.data_array)
