from typing import Any, Dict, Tuple
import torch
from lightning import LightningModule
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import MaxMetric, MeanMetric, MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError, \
    MinMetric, Accuracy, Precision, Recall, F1Score, AUROC
from torchmetrics.functional import mean_squared_error, mean_absolute_error
from .components.revin import RevIN
from .PowerGPT import patch_masking, MaskMSELoss, create_patch
import torch.nn as nn

def get_model(model):
    """Return the model maybe wrapped inside `model`."""
    return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model

class PowerGPTModule(LightningModule):
    """Example of a `LightningModule` for MNIST classification.

    A `LightningModule` implements 8 key methods:

    ```python
    def __init__(self):
    # Define initialization code here.

    def setup(self, stage):
    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
    # This hook is called on every process when using DDP.

    def training_step(self, batch, batch_idx):
    # The complete training step.

    def validation_step(self, batch, batch_idx):
    # The complete validation step.

    def test_step(self, batch, batch_idx):
    # The complete test step.

    def predict_step(self, batch, batch_idx):
    # The complete predict step.

    def configure_optimizers(self):
    # Define and configure optimizers and LR schedulers.
    ```

    Docs:
        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
    """

    def __init__(
            self,
            name,
            net: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            scheduler: torch.optim.lr_scheduler,
    ) -> None:
        """Initialize a `MNISTLitModule`.

        :param net: The model to train.
        :param optimizer: The optimizer to use for training.
        :param scheduler: The learning rate scheduler to use for training.
        """
        super().__init__()
        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net
        self.patch_len = self.net.patch_len
        self.stride = self.net.stride
        self.mask_ratio = self.net.mask_ratio

        # loss function
        self.criterion = None
        if self.net.head_type == 'pretrain' or self.net.head_type == 'imputation':
            self.criterion = MaskMSELoss()
        elif self.net.head_type == 'prediction':
            self.criterion = nn.MSELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()

        # pretrain, imputation, prediction
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()


        self.train_mae = MeanAbsoluteError()
        self.val_mae = MeanAbsoluteError()
        self.test_mae = MeanAbsoluteError()

        self.train_mape = MeanAbsolutePercentageError()
        self.val_mape = MeanAbsolutePercentageError()
        self.test_mape = MeanAbsolutePercentageError()

        # detail category for prediction and imputation
        # zb
        self.train_zb_mse = MeanSquaredError()
        self.val_zb_mse = MeanSquaredError()
        self.test_zb_mse = MeanSquaredError()

        self.train_zb_mae = MeanAbsoluteError()
        self.val_zb_mae = MeanAbsoluteError()
        self.test_zb_mae = MeanAbsoluteError()

        self.train_zb_mape = MeanAbsolutePercentageError()
        self.val_zb_mape = MeanAbsolutePercentageError()
        self.test_zb_mape = MeanAbsolutePercentageError()

        # gb
        self.train_gb_mse = MeanSquaredError()
        self.val_gb_mse = MeanSquaredError()
        self.test_gb_mse = MeanSquaredError()

        self.train_gb_mae = MeanAbsoluteError()
        self.val_gb_mae = MeanAbsoluteError()
        self.test_gb_mae = MeanAbsoluteError()

        self.train_gb_mape = MeanAbsolutePercentageError()
        self.val_gb_mape = MeanAbsolutePercentageError()
        self.test_gb_mape = MeanAbsolutePercentageError()

        # industry
        self.train_industry_mse = MeanSquaredError()
        self.val_industry_mse = MeanSquaredError()
        self.test_industry_mse = MeanSquaredError()

        self.train_industry_mae = MeanAbsoluteError()
        self.val_industry_mae = MeanAbsoluteError()
        self.test_industry_mae = MeanAbsoluteError()

        self.train_industry_mape = MeanAbsolutePercentageError()
        self.val_industry_mape = MeanAbsolutePercentageError()
        self.test_industry_mape = MeanAbsolutePercentageError()

        # area
        self.train_area_mse = MeanSquaredError()
        self.val_area_mse = MeanSquaredError()
        self.test_area_mse = MeanSquaredError()

        self.train_area_mae = MeanAbsoluteError()
        self.val_area_mae = MeanAbsoluteError()
        self.test_area_mae = MeanAbsoluteError()

        self.train_area_mape = MeanAbsolutePercentageError()
        self.val_area_mape = MeanAbsolutePercentageError()
        self.test_area_mape = MeanAbsolutePercentageError()

        # city
        self.train_city_mse = MeanSquaredError()
        self.val_city_mse = MeanSquaredError()
        self.test_city_mse = MeanSquaredError()

        self.train_city_mae = MeanAbsoluteError()
        self.val_city_mae = MeanAbsoluteError()
        self.test_city_mae = MeanAbsoluteError()

        self.train_city_mape = MeanAbsolutePercentageError()
        self.val_city_mape = MeanAbsolutePercentageError()
        self.test_city_mape = MeanAbsolutePercentageError()

        # province
        self.train_province_mse = MeanSquaredError()
        self.val_province_mse = MeanSquaredError()
        self.test_province_mse = MeanSquaredError()

        self.train_province_mae = MeanAbsoluteError()
        self.val_province_mae = MeanAbsoluteError()
        self.test_province_mae = MeanAbsoluteError()

        self.train_province_mape = MeanAbsolutePercentageError()
        self.val_province_mape = MeanAbsolutePercentageError()
        self.test_province_mape = MeanAbsolutePercentageError()

        # anomaly
        self.train_acc = Accuracy(task='binary', average='macro')
        self.train_pre = Precision(task='binary', average='macro')
        self.train_rec = Recall(task='binary', average='macro')
        self.train_f1 = F1Score(task='binary', average='macro')
        self.train_auroc = AUROC(task='binary', average='macro')

        self.val_acc = Accuracy(task='binary', average='macro')
        self.val_pre = Precision(task='binary', average='macro')
        self.val_rec = Recall(task='binary', average='macro')
        self.val_f1 = F1Score(task='binary', average='macro')
        self.val_auroc = AUROC(task='binary', average='macro')

        self.test_acc = Accuracy(task='binary', average='macro')
        self.test_pre = Precision(task='binary', average='macro')
        self.test_rec = Recall(task='binary', average='macro')
        self.test_f1 = F1Score(task='binary', average='macro')
        self.test_auroc = AUROC(task='binary', average='macro')


        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_loss_best = MaxMetric() if self.net.head_type == 'anomaly' else MinMetric()

        self.val_mse_best = MinMetric()
        self.val_mae_best = MinMetric()
        self.val_mape_best = MinMetric()

        self.val_acc_best = MaxMetric()
        self.val_pre_best = MaxMetric()
        self.val_rec_best = MaxMetric()
        self.val_f1_best = MaxMetric()
        self.val_auroc_best = MaxMetric()

        self.revin = RevIN(num_features=net.n_vars, affine=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the model `self.net`.

        :param x: A tensor of images.
        :return: A tensor of logits.
        """
        return self.net(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks

        self.val_loss.reset()
        self.val_loss_best.reset()

        self.val_mse.reset()
        self.val_mse_best.reset()
        self.val_mae.reset()
        self.val_mae_best.reset()
        self.val_mape.reset()
        self.val_mape_best.reset()

        self.val_acc.reset()
        self.val_acc_best.reset()
        self.val_pre.reset()
        self.val_pre_best.reset()
        self.val_rec.reset()
        self.val_rec_best.reset()
        self.val_f1.reset()
        self.val_f1_best.reset()
        self.val_auroc.reset()
        self.val_auroc_best.reset()

    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        x = batch.x
        if self.net.head_type == "prediction":
            x = self.revin(x, 'norm')
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.node_attr = batch.node_attr[:batch.batch_size]
            batch.x = x_patch
            return batch
        elif self.net.head_type == "anomaly":
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.x = x_patch
            batch.y = batch.y.reshape(-1)
            return batch
        elif self.net.head_type == "imputation":
            x = self.revin(x, 'norm', mask=batch.val_mask)
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.x = x_patch
            batch.node_attr = batch.node_attr[:batch.batch_size]
            batch.mask = batch.val_mask.unsqueeze(-1)
            return batch
        else:
            x = self.revin(x, 'norm')
            x, y, mask = patch_masking(x, stride=self.stride, patch_len=self.patch_len, mask_ratio=self.mask_ratio)
            mask = mask.unsqueeze(-1).repeat(1, 1, 1, self.patch_len)
            batch.x = x
            batch.y = y
            batch.mask = mask
            return batch

    def model_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        if self.net.head_type == "prediction":
            batch.y = batch.y[:batch.batch_size]
            pred = self.forward(batch)
            pred = self.revin(pred, 'denorm')[:batch.batch_size]
            loss = self.criterion(pred, batch.y)
        elif self.net.head_type == "anomaly":
            pred = self.forward(batch)
            pred = pred[:batch.batch_size]
            batch.y = batch.y[:batch.batch_size]
            loss = self.criterion(pred, batch.y)
            pred = torch.max(pred, dim=-1)[1]
        elif self.net.head_type == "imputation":
            batch.y = batch.y[:batch.batch_size]
            pred = self.forward(batch)
            bs = pred.shape[0]
            pred = self.revin(pred.reshape(bs, -1, 1), 'denorm')[:batch.batch_size]
            batch.mask = batch.mask[:batch.batch_size]
            loss = self.criterion(pred, batch.y, batch.mask)
        else:
            batch.y = batch.y[:batch.batch_size]
            pred = self.forward(batch)[:batch.batch_size]
            batch.mask = batch.mask[:batch.batch_size]
            loss = self.criterion(pred, batch.y, batch.mask)
        return loss, pred, batch.y

    def training_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """

        # type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'hz':4, 'zj': 5}
        if self.net.head_type == "prediction":
            loss, preds, targets= self.model_step(batch)
            y, t = batch.y, batch.node_attr
            # update and log metrics
            self.train_loss(loss)
            self.train_mse(preds, y)
            self.train_mae(preds, y)
            self.train_mape(preds, y)

            # zb
            if torch.sum(t == 0) != 0:
                m = t == 0
                self.train_zb_mse(preds[m], y[m])
                self.train_zb_mae(preds[m], y[m])
                self.train_zb_mape(preds[m], y[m])
                self.log("train/zb_mse", self.train_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/zb_mae", self.train_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/zb_mape", self.train_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                self.train_gb_mse(preds[m], y[m])
                self.train_gb_mae(preds[m], y[m])
                self.train_gb_mape(preds[m], y[m])
                self.log("train/gb_mse", self.train_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/gb_mae", self.train_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/gb_mape", self.train_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                self.train_industry_mse(preds[m], y[m])
                self.train_industry_mae(preds[m], y[m])
                self.train_industry_mape(preds[m], y[m])
                self.log("train/industry_mse", self.train_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/industry_mae", self.train_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/industry_mape", self.train_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                self.train_area_mse(preds[m], y[m])
                self.train_area_mae(preds[m], y[m])
                self.train_area_mape(preds[m], y[m])
                self.log("train/area_mse", self.train_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/area_mae", self.train_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/area_mape", self.train_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                self.train_city_mse(preds[m], y[m])
                self.train_city_mae(preds[m], y[m])
                self.train_city_mape(preds[m], y[m])
                self.log("train/city_mse", self.train_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/city_mae", self.train_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/city_mape", self.train_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                self.train_province_mse(preds[m], y[m])
                self.train_province_mae(preds[m], y[m])
                self.train_province_mape(preds[m], y[m])
                self.log("train/province_mse", self.train_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/province_mae", self.train_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/province_mape", self.train_province_mape, on_step=False, on_epoch=True, prog_bar=True)

        elif self.net.head_type == "anomaly":
            loss, preds, targets = self.model_step(batch)
            self.train_loss(loss)
            self.train_acc(preds, targets)
            self.train_pre(preds, targets)
            self.train_rec(preds, targets)
            self.train_f1(preds, targets)
            self.train_auroc(preds, targets)
        elif self.net.head_type == "imputation":
            t = batch.node_attr
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask

            self.train_loss(loss)
            self.train_mse(preds[mask], targets[mask])
            self.train_mae(preds[mask], targets[mask])
            self.train_mape(preds[mask], targets[mask])

            if torch.sum(t == 0) != 0:
                m = t == 0
                preds_zb = preds[m][mask[m] == 1]
                targets_zb = targets[m][mask[m] == 1]
                self.train_zb_mse(preds_zb, targets_zb)
                self.train_zb_mae(preds_zb, targets_zb)
                self.train_zb_mape(preds_zb, targets_zb)
                self.log("train/zb_mse", self.train_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/zb_mae", self.train_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/zb_mape", self.train_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                preds_gb = preds[m][mask[m] == 1]
                targets_gb = targets[m][mask[m] == 1]
                self.train_gb_mse(preds_gb, targets_gb)
                self.train_gb_mae(preds_gb, targets_gb)
                self.train_gb_mape(preds_gb, targets_gb)
                self.log("train/gb_mse", self.train_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/gb_mae", self.train_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/gb_mape", self.train_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                preds_industry = preds[m][mask[m] == 1]
                targets_industry = targets[m][mask[m] == 1]
                self.train_industry_mse(preds_industry, targets_industry)
                self.train_industry_mae(preds_industry, targets_industry)
                self.train_industry_mape(preds_industry, targets_industry)
                self.log("train/industry_mse", self.train_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/industry_mae", self.train_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/industry_mape", self.train_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                preds_area = preds[m][mask[m] == 1]
                targets_area = targets[m][mask[m] == 1]
                self.train_area_mse(preds_area, targets_area)
                self.train_area_mae(preds_area, targets_area)
                self.train_area_mape(preds_area, targets_area)
                self.log("train/area_mse", self.train_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/area_mae", self.train_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/area_mape", self.train_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                preds_city = preds[m][mask[m] == 1]
                targets_city = targets[m][mask[m] == 1]
                self.train_city_mse(preds_city, targets_city)
                self.train_city_mae(preds_city, targets_city)
                self.train_city_mape(preds_city, targets_city)
                self.log("train/city_mse", self.train_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/city_mae", self.train_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/city_mape", self.train_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                preds_province = preds[m][mask[m] == 1]
                targets_province = targets[m][mask[m] == 1]
                self.train_province_mse(preds_province, targets_province)
                self.train_province_mae(preds_province, targets_province)
                self.train_province_mape(preds_province, targets_province)
                self.log("train/province_mse", self.train_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/province_mae", self.train_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("train/province_mape", self.train_province_mape, on_step=False, on_epoch=True, prog_bar=True)

        else:
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask
            # update and log metrics
            self.train_loss(loss)
            self.train_mse(preds[mask], targets[mask])
            self.train_mae(preds[mask], targets[mask])
            self.train_mape(preds[mask], targets[mask])

        if self.net.head_type != "anomaly":
            self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mse", self.train_mse, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mae", self.train_mae, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mape", self.train_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/pre", self.train_pre, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/rec", self.train_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/f1", self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        if self.net.head_type == "prediction":
            loss, preds, targets = self.model_step(batch)
            y, t = batch.y, batch.node_attr
            # update and log metrics
            self.val_loss(loss)
            self.val_mse(preds, y)
            self.val_mae(preds, y)
            self.val_mape(preds, y)

            # zb
            if torch.sum(t == 0) != 0:
                m = t == 0
                self.val_zb_mse(preds[m], y[m])
                self.val_zb_mae(preds[m], y[m])
                self.val_zb_mape(preds[m], y[m])
                self.log("val/zb_mse", self.val_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mae", self.val_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mape", self.val_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                self.val_gb_mse(preds[m], y[m])
                self.val_gb_mae(preds[m], y[m])
                self.val_gb_mape(preds[m], y[m])
                self.log("val/gb_mse", self.val_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mae", self.val_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mape", self.val_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                self.val_industry_mse(preds[m], y[m])
                self.val_industry_mae(preds[m], y[m])
                self.val_industry_mape(preds[m], y[m])
                self.log("val/industry_mse", self.val_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/industry_mae", self.val_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/industry_mape", self.val_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                self.val_area_mse(preds[m], y[m])
                self.val_area_mae(preds[m], y[m])
                self.val_area_mape(preds[m], y[m])
                self.log("val/area_mse", self.val_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mae", self.val_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mape", self.val_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                self.val_city_mse(preds[m], y[m])
                self.val_city_mae(preds[m], y[m])
                self.val_city_mape(preds[m], y[m])
                self.log("val/city_mse", self.val_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mae", self.val_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mape", self.val_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                self.val_province_mse(preds[m], y[m])
                self.val_province_mae(preds[m], y[m])
                self.val_province_mape(preds[m], y[m])
                self.log("val/province_mse", self.val_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/province_mae", self.val_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/province_mape", self.val_province_mape, on_step=False, on_epoch=True, prog_bar=True)


        elif self.net.head_type == "anomaly":
            loss, preds, targets = self.model_step(batch)
            self.val_loss(loss)
            self.val_acc(preds, targets)
            self.val_pre(preds, targets)
            self.val_rec(preds, targets)
            self.val_f1(preds, targets)
            self.val_auroc(preds, targets)

        elif self.net.head_type == "imputation":
            t = batch.node_attr
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask

            self.val_loss(loss)
            self.val_mse(preds[mask], targets[mask])
            self.val_mae(preds[mask], targets[mask])
            self.val_mape(preds[mask], targets[mask])

            if torch.sum(t == 0) != 0:
                m = t == 0
                preds_zb = preds[m][mask[m] == 1]
                targets_zb = targets[m][mask[m] == 1]
                self.val_zb_mse(preds_zb, targets_zb)
                self.val_zb_mae(preds_zb, targets_zb)
                self.val_zb_mape(preds_zb, targets_zb)
                self.log("val/zb_mse", self.val_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mae", self.val_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mape", self.val_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                preds_gb = preds[m][mask[m] == 1]
                targets_gb = targets[m][mask[m] == 1]
                self.val_gb_mse(preds_gb, targets_gb)
                self.val_gb_mae(preds_gb, targets_gb)
                self.val_gb_mape(preds_gb, targets_gb)
                self.log("val/gb_mse", self.val_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mae", self.val_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mape", self.val_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                preds_industry = preds[m][mask[m] == 1]
                targets_industry = targets[m][mask[m] == 1]
                self.val_industry_mse(preds_industry, targets_industry)
                self.val_industry_mae(preds_industry, targets_industry)
                self.val_industry_mape(preds_industry, targets_industry)
                self.log("val/industry_mse", self.val_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/industry_mae", self.val_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/industry_mape", self.val_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                preds_area = preds[m][mask[m] == 1]
                targets_area = targets[m][mask[m] == 1]
                self.val_area_mse(preds_area, targets_area)
                self.val_area_mae(preds_area, targets_area)
                self.val_area_mape(preds_area, targets_area)
                self.log("val/area_mse", self.val_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mae", self.val_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mape", self.val_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                preds_city = preds[m][mask[m] == 1]
                targets_city = targets[m][mask[m] == 1]
                self.val_city_mse(preds_city, targets_city)
                self.val_city_mae(preds_city, targets_city)
                self.val_city_mape(preds_city, targets_city)
                self.log("val/city_mse", self.val_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mae", self.val_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mape", self.val_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                preds_province = preds[m][mask[m] == 1]
                targets_province = targets[m][mask[m] == 1]
                self.val_province_mse(preds_province, targets_province)
                self.val_province_mae(preds_province, targets_province)
                self.val_province_mape(preds_province, targets_province)
                self.log("val/province_mse", self.val_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/province_mae", self.val_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/province_mape", self.val_province_mape, on_step=False, on_epoch=True, prog_bar=True)

        else:
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask
            # update and log metrics
            self.val_loss(loss)
            self.val_mse(preds[mask], targets[mask])
            self.val_mae(preds[mask], targets[mask])
            self.val_mape(preds[mask], targets[mask])

        if self.net.head_type != "anomaly":
            self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mse", self.val_mse, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mae", self.val_mae, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mape", self.val_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/pre", self.val_pre, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/rec", self.val_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/f1", self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True)


    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        loss = self.val_loss.compute()  # get current val acc
        self.val_loss_best(loss)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        if self.net.head_type == "prediction":
            loss, preds, targets = self.model_step(batch)
            y, t = batch.y, batch.node_attr
            # update and log metrics
            self.test_loss(loss)
            self.test_mse(preds, y)
            self.test_mae(preds, y)
            self.test_mape(preds, y)

            # zb
            if torch.sum(t == 0) != 0:
                m = t == 0
                self.test_zb_mse(preds[m], y[m])
                self.test_zb_mae(preds[m], y[m])
                self.test_zb_mape(preds[m], y[m])
                self.log("test/zb_mse", self.test_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/zb_mae", self.test_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/zb_mape", self.test_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                self.test_gb_mse(preds[m], y[m])
                self.test_gb_mae(preds[m], y[m])
                self.test_gb_mape(preds[m], y[m])
                self.log("test/gb_mse", self.test_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/gb_mae", self.test_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/gb_mape", self.test_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                self.test_industry_mse(preds[m], y[m])
                self.test_industry_mae(preds[m], y[m])
                self.test_industry_mape(preds[m], y[m])
                self.log("test/industry_mse", self.test_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/industry_mae", self.test_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/industry_mape", self.test_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                self.test_area_mse(preds[m], y[m])
                self.test_area_mae(preds[m], y[m])
                self.test_area_mape(preds[m], y[m])
                self.log("test/area_mse", self.test_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/area_mae", self.test_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/area_mape", self.test_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                self.test_city_mse(preds[m], y[m])
                self.test_city_mae(preds[m], y[m])
                self.test_city_mape(preds[m], y[m])
                self.log("test/city_mse", self.test_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/city_mae", self.test_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/city_mape", self.test_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                self.test_province_mse(preds[m], y[m])
                self.test_province_mae(preds[m], y[m])
                self.test_province_mape(preds[m], y[m])
                self.log("test/province_mse", self.test_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/province_mae", self.test_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/province_mape", self.test_province_mape, on_step=False, on_epoch=True, prog_bar=True)

        elif self.net.head_type == "anomaly":
            loss, preds, targets = self.model_step(batch)
            self.test_loss(loss)
            self.test_acc(preds, targets)
            self.test_pre(preds, targets)
            self.test_rec(preds, targets)
            self.test_f1(preds, targets)
            self.test_auroc(preds, targets)

        elif self.net.head_type == "imputation":
            t = batch.node_attr
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask

            self.test_loss(loss)
            self.test_mse(preds[mask], targets[mask])
            self.test_mae(preds[mask], targets[mask])
            self.test_mape(preds[mask], targets[mask])

            if torch.sum(t == 0) != 0:
                m = t == 0
                preds_zb = preds[m][mask[m] == 1]
                targets_zb = targets[m][mask[m] == 1]
                self.test_zb_mse(preds_zb, targets_zb)
                self.test_zb_mae(preds_zb, targets_zb)
                self.test_zb_mape(preds_zb, targets_zb)
                self.log("test/zb_mse", self.test_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/zb_mae", self.test_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/zb_mape", self.test_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # gb
            if torch.sum(t == 1) != 0:
                m = t == 1
                preds_gb = preds[m][mask[m] == 1]
                targets_gb = targets[m][mask[m] == 1]
                self.test_gb_mse(preds_gb, targets_gb)
                self.test_gb_mae(preds_gb, targets_gb)
                self.test_gb_mape(preds_gb, targets_gb)
                self.log("test/gb_mse", self.test_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/gb_mae", self.test_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/gb_mape", self.test_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            if torch.sum(t == 2) != 0:
                m = t == 2
                preds_industry = preds[m][mask[m] == 1]
                targets_industry = targets[m][mask[m] == 1]
                self.test_industry_mse(preds_industry, targets_industry)
                self.test_industry_mae(preds_industry, targets_industry)
                self.test_industry_mape(preds_industry, targets_industry)
                self.log("test/industry_mse", self.test_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/industry_mae", self.test_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/industry_mape", self.test_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area
            if torch.sum(t == 3) != 0:
                m = t == 3
                preds_area = preds[m][mask[m] == 1]
                targets_area = targets[m][mask[m] == 1]
                self.test_area_mse(preds_area, targets_area)
                self.test_area_mae(preds_area, targets_area)
                self.test_area_mape(preds_area, targets_area)
                self.log("test/area_mse", self.test_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/area_mae", self.test_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/area_mape", self.test_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city
            if torch.sum(t == 4) != 0:
                m = t == 4
                preds_city = preds[m][mask[m] == 1]
                targets_city = targets[m][mask[m] == 1]
                self.test_city_mse(preds_city, targets_city)
                self.test_city_mae(preds_city, targets_city)
                self.test_city_mape(preds_city, targets_city)
                self.log("test/city_mse", self.test_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/city_mae", self.test_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/city_mape", self.test_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            if torch.sum(t == 5) != 0:
                m = t == 5
                preds_province = preds[m][mask[m] == 1]
                targets_province = targets[m][mask[m] == 1]
                self.test_province_mse(preds_province, targets_province)
                self.test_province_mae(preds_province, targets_province)
                self.test_province_mape(preds_province, targets_province)
                self.log("test/province_mse", self.test_province_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/province_mae", self.test_province_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("test/province_mape", self.test_province_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask
            # update and log metrics
            self.test_loss(loss)
            self.test_mse(preds[mask], targets[mask])
            self.test_mae(preds[mask], targets[mask])
            self.test_mape(preds[mask], targets[mask])

        if self.net.head_type != "anomaly":
            self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/mse", self.test_mse, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/mae", self.test_mae, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/mape", self.test_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/pre", self.test_pre, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/rec", self.test_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True)
            self.log("test/auroc", self.test_auroc, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass


    def freeze(self):
        """
        freeze the model head
        require the model to have head attribute
        """
        if hasattr(get_model(self.net), 'head'):
            # print('model head is available')
            for param in get_model(self.net).parameters(): param.requires_grad = False
            for param in get_model(self.net).head.parameters(): param.requires_grad = True
            # print('model is frozen except the head')

    def unfreeze(self):
        for param in get_model(self.net).parameters(): param.requires_grad = True

    def configure_optimizers(self) -> Dict[str, Any]:

        """Configures optimizers and learning-rate schedulers to be used for training.

        Normally you'd need one, but in the case of GANs or similar you might need multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """

        optimizer = self.hparams.optimizer([
            {'params': self.net.backbone.parameters(), 'lr': 1e-7},
            # {'params': self.net.gnn_layers.parameters(), 'lr': 3e-4},
            {'params': self.net.head.parameters(), 'lr': 3e-4},
        ])
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


if __name__ == "__main__":
    ...
