from sklearn.preprocessing import StandardScaler
import torch
from copy import deepcopy
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Batch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm


type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'cs':4, 'zj': 5}


class SubDataset(Dataset):
    def __init__(self, dataset_name, step, SEQ_LEN, LABEL_LEN, PRED_LEN):
        data = torch.load('/data/tsh/PowerGPT/preprocess/pretrain' + f'/{dataset_name}.pt')
        cov = torch.load('/data/tsh/PowerGPT/preprocess/cov/cov_pretrain.pt')
        # data = torch.load('/data/tsh/PowerGPT/preprocess/'+ f'{dataset_name}.pt')
        self.step = step
        self.data = deepcopy(data)
        self.cov = cov.unsqueeze(0).repeat(self.data.total_x.shape[0], 1, 1)
        print(self.data)
        self.data.edge_index = self.data.edge_index
        normalize = True
        self.total_x = deepcopy(self.data.total_x)
        # 10697, 396
        if normalize:
            self.total_x = StandardScaler().fit_transform(self.total_x.transpose(1, 0)).transpose(1, 0)
        del self.data.total_x
        
        self.delta = 96
        self.SEQ_LEN = SEQ_LEN
        ###写活
        self.pretrain = True
        if self.pretrain == True:
            self.SEQ_LEN = SEQ_LEN + self.delta
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = (self.total_x.shape[-1] - self.WINDOW_LENGTH) // self.step + 1

    def __getitem__(self, index):
        s_begin = index * self.step
        s_end = s_begin + self.SEQ_LEN
        r_begin = s_end
        r_end = r_begin + self.PRED_LEN

        x = torch.Tensor(self.total_x[:, s_begin:s_end])
        y = torch.Tensor(self.total_x[:, r_begin:r_end])
        x_cov = torch.Tensor(self.cov[:, :, s_begin:s_end])
        # y_cov = torch.Tensor(self.cov[:, :, r_begin:r_end])
        data = deepcopy(self.data)
        data.x = x
        data.y = y
        data.x_cov = x_cov
        # data.y_cov = y_cov
        return data

    def __len__(self):
        return self.LEN
class PretrainPowerDataset(object):
    def __init__(self, SEQ_LEN=4800, LABEL_LEN=48, PRED_LEN=7):

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.datasets = []

    def add_dataset(self, dataset_name, step=1):
        dataset = SubDataset(dataset_name=dataset_name, step=step, SEQ_LEN=self.SEQ_LEN, LABEL_LEN=self.LABEL_LEN,
                             PRED_LEN=self.PRED_LEN)
        self.datasets.append(dataset)

    def get_dataset(self):
        batch = []
        for dataset in self.datasets:
            dl = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=lambda x: x)
            for data in tqdm(dl):
                batch += data
        return collate_fn(batch)


def collate_fn(batch):
    batch = Batch.from_data_list(batch)
    node_attr = []
    # for n in batch.node_attr:
    #     node_attr += n
    # batch.node_attr = torch.LongTensor([int(type2index[i]) for i in list(np.array(node_attr).reshape(-1))])
    batch.x = batch.x.unsqueeze(-1)
    batch.y = batch.y.unsqueeze(-1)
    # batch.x_cov = batch.x_cov.reshape(-1, batch.x_cov.shape[0], batch.x_cov.shape[-1])
    # batch.y_cov = batch.y_cov.reshape(-1, batch.y_cov.shape[0], batch.y_cov.shape[-1])
    return batch