import itertools
from collections import defaultdict, namedtuple
from pathlib import Path
from typing import Sequence

import pandas as pd
import torch
import torch.nn.functional as F
from datasets import DatasetDict, load_from_disk
from pytorch_lightning import Trainer, seed_everything
from sklearn.discriminant_analysis import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Normalizer
from sklearn.svm import SVC
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, CosineSimilarity, MeanAbsoluteError, MeanMetric
from tqdm import tqdm

from nn_core.common import PROJECT_ROOT

from rel2abs.modules.simple_classifier import Classifier
from rel2abs.modules.translation import LatentTranslation
from rel2abs.utils import DatasetConfig, DatasetParams

DEVICE: str = "cuda"


def data_config(
    dataset_name: str,
):
    domain2encoders = {
        "vision": [
            "rexnet_100",
            "vit_base_patch16_224",
            "vit_base_patch16_384",
            "vit_base_resnet50_384",
            "vit_small_patch16_224",
            "openai/clip-vit-base-patch32",
        ],
        "text": [
            "bert-base-cased",
            "bert-base-uncased",
            "google/electra-base-discriminator",
            "roberta-base",
            "albert-base-v2",
            "xlm-roberta-base",
            "openai/clip-vit-base-patch32",
        ],
    }
    if dataset_name == "fashion_mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("fashion_mnist", None, "train", "test", perc, ("fashion_mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("mnist", None, "train", "test", perc, ("mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "cifar10":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("cifar10", None, "train", "test", perc, ("cifar10",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name.startswith("cifar100"):
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams(
            "cifar100", "fine" in dataset_name, "train", "test", perc, ("cifar100",)
        )
        label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_params.name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "trec":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("trec", False, "train", "test", perc, ("trec",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    if dataset_name == "dbpedia_14":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("dbpedia_14", None, "train", "test", perc, ("dbpedia_14",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    # if dataset_name == "amazon_reviews_multi":
    #     perc: float = 1
    #     dataset_params: DatasetParams = DatasetParams("amazon_reviews_multi", False, "train", "test", perc, ("amazon_reviews_multi", "all"))
    #     label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
    #     encoding_column_template: str = "lang_{encoder}_mean_encoding"
    #     data_key = "_".join(
    #         map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
    #     )
    #     directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

    #     assert directory.exists()
    #     return DatasetConfig(
    #         key=data_key,
    #         directory=directory,
    #         label_column=label_column,
    #         encoding_column_template=encoding_column_template,
    #         encoders=domain2encoders["text"],
    #     )

    if dataset_name == "n24news_text":
        dataset_params: DatasetParams = DatasetParams("n24news_text", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}_cls_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=[
                "bert-base-cased",
                "bert-base-uncased",
                "google/electra-base-discriminator",
                "roberta-base",
                "albert-base-v2",
                "xlm-roberta-base",
                "openai/clip-vit-base-patch32",
            ],
        )

    if dataset_name == "n24news_image":
        dataset_params: DatasetParams = DatasetParams("n24news_image", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=(
                "vit_base_patch16_224",
                "rexnet_100",
                "vit_base_patch16_384",
                "vit_small_patch16_224",
                "vit_base_resnet50_384",
                "cspdarknet53",
            ),
        )

    raise NotImplementedError


class SVCModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return F.one_hot(torch.as_tensor(self.model.predict(x.cpu().numpy()))).to(x.device)


def run(
    dataset: str,
    decoder_type: str,
    std_correction_options: Sequence[bool],
    l2_norm_options: Sequence[bool],
    anchor_num_options: Sequence[int],
    seeds: Sequence[int],
):
    dataset_config: DatasetConfig = data_config(dataset)
    result_dir = PROJECT_ROOT / "results" / "stitching" / decoder_type / dataset_config.key
    original_result = result_dir / "original.tsv"
    absolute_result = result_dir / "absolute_stitching.tsv"
    ortho_result = result_dir / "orthogonal_stitching.tsv"
    lstsq_result = result_dir / "lstsq_stitching.tsv"
    lstsq_ortho_result = result_dir / "lstsq_ortho_stitching.tsv"

    if (
        original_result.exists()
        and absolute_result.exists()
        and ortho_result.exists()
        and lstsq_result.exists()
        and lstsq_ortho_result.exists()
    ):
        print(f"Skipping {dataset} | {decoder_type}")
        return

    data: DatasetDict = load_from_disk(dataset_path=str(dataset_config.directory))
    if dataset_config.key.startswith("dbpedia_14"):
        data = DatasetDict(
            train=data["train"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)[
                "train"
            ],
            test=data["test"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)["train"],
        )

    tensor_columns = {
        column
        for column in data["train"].column_names
        if any(column.startswith(encoder) for encoder in dataset_config.encoders)
    }
    tensor_columns.add(dataset_config.label_column)
    data.set_format(columns=tensor_columns, output_all_columns=True, type="torch")

    encoder2decoder = {}

    for encoder in tqdm(dataset_config.encoders):
        seed_everything(42)

        label_column: str = dataset_config.label_column
        encoding_column: str = dataset_config.encoding_column_template.format(encoder=encoder)
        columns_to_drop = [
            column
            for column in data["train"].column_names
            if column not in {encoding_column, dataset_config.label_column, "data"}
        ]
        encoder_data = data.remove_columns(column_names=columns_to_drop)
        fit_data = encoder_data["train"].train_test_split(train_size=0.9, seed=42, stratify_by_column=label_column)
        train_data, val_data, test_data = fit_data["train"], fit_data["test"], encoder_data["test"]

        # orig_space: torch.Tensor = train_data[encoding_column]
        # orig_space_norm = orig_space.norm(dim=1, p=2)
        # train_space_std = orig_space_norm.std(dim=0)

        train_loader = DataLoader(train_data, batch_size=3000, pin_memory=False, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_data, batch_size=3000, pin_memory=False, shuffle=False, num_workers=4)
        # test_loader = DataLoader(test_data, batch_size=64, pin_memory=True, shuffle=False, num_workers=8)
        if decoder_type == "svm":
            model = make_pipeline(StandardScaler(), Normalizer(), SVC(gamma="auto", kernel="linear", random_state=42))
            for batch in train_loader:
                X, y = batch[encoding_column], batch[label_column]
                model.fit(X, y)
            model = SVCModel(model)
        elif decoder_type == "linear":
            model = Classifier(
                input_dim=train_data[encoding_column].size(1),
                num_classes=train_data.features[label_column].num_classes,
                lr=1e-4,
                deep=True,
                x_feature=encoding_column,
                y_feature=label_column,
            )

            trainer = Trainer(
                accelerator="auto",
                devices=1,
                max_epochs=5,
                logger=None,
                # callbacks=[RichProgressBar()],
                enable_progress_bar=False,
            )
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
        else:
            raise NotImplementedError

        encoder2decoder[encoder] = model

    result_dir.mkdir(exist_ok=True, parents=True)

    if not original_result.exists():
        # Original (without stitching) performance computation
        Result = namedtuple(
            "Result",
            ["space", "score", "norm_mean", "norm_std"],
        )

        original_results = []
        for (encoding_space, decoder) in encoder2decoder.items():
            decoder.to(DEVICE)

            encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
            label_column = dataset_config.label_column

            encoder_test_data = data["test"].remove_columns(
                [
                    column
                    for column in data["test"].column_names
                    if column not in {"index", label_column, encoding_space_column}
                ]
            )
            test_loader = DataLoader(encoder_test_data, batch_size=64, pin_memory=True, shuffle=False, num_workers=8)

            score = Accuracy(task="multiclass", num_classes=encoder_test_data.features[label_column].num_classes).to(
                DEVICE
            )
            mean_norm: MeanMetric = MeanMetric().to(DEVICE)
            std_norm: MeanMetric = MeanMetric().to(DEVICE)

            with torch.no_grad():
                for test_batch in test_loader:
                    y = test_batch[label_column].to(DEVICE)
                    x = test_batch[encoding_space_column].to(DEVICE)

                    logits = decoder(x)
                    preds = torch.argmax(logits, dim=1)

                    score.update(preds, y)
                    mean_norm.update(x.norm(p=2, dim=-1))
                    std_norm.update(x.std(dim=-1))

            original_results.append(
                Result(
                    space=encoding_space,
                    score=score.compute().cpu().item(),
                    norm_mean=mean_norm.compute().cpu().item(),
                    norm_std=std_norm.compute().cpu().item(),
                )
            )

            decoder.cpu()

        original_results = pd.DataFrame(original_results)
        original_results.to_csv(result_dir / "original.tsv", sep="\t", index=False)

    if not absolute_result.exists():
        Result = namedtuple(
            "Result",
            [
                "encoding_space",
                "decoding_space",
                "similarity",
                "mse",
                "norm_diff",
                "score",
                "seed",
            ],
        )
        batch_size: int = 5000

        results = []
        for ((encoding_space, decoder1), (decoding_space, decoder2)), seed in tqdm(
            itertools.product(itertools.product(list(encoder2decoder.items()), list(encoder2decoder.items())), seeds)
        ):
            seed_everything(seed)
            decoder1.to(DEVICE)
            decoder2.to(DEVICE)

            encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
            decoding_space_column = dataset_config.encoding_column_template.format(encoder=decoding_space)

            test_data = data["test"].remove_columns(
                [
                    column
                    for column in data["test"].column_names
                    if column not in {"index", label_column, encoding_space_column, decoding_space_column}
                ]
            )

            # if test_data[encoding_space_column].size(1) != test_data[decoding_space_column].size(1):
            #     continue

            test_loader = DataLoader(test_data, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0)
            score_metric = Accuracy(task="multiclass", num_classes=test_data.features[label_column].num_classes).to(
                DEVICE
            )
            reconstruction_similarity = CosineSimilarity(reduction="mean").to(DEVICE)
            reconstruction_mse = MeanAbsoluteError(reduction="mean").to(DEVICE)
            reconstruction_norm_diff = MeanMetric().to(DEVICE)

            for test_batch in test_loader:
                y = test_batch[label_column].to(DEVICE)

                encoding_x = test_batch[encoding_space_column].to(DEVICE)  # .double()
                decoding_x = test_batch[decoding_space_column].to(DEVICE)  # .double()

                if encoding_x.size(1) != decoding_x.size(1):
                    padded = torch.zeros_like(decoding_x)
                    padded[:, : encoding_x.size(1)] = encoding_x[:, : decoding_x.size(1)]
                    encoding_x = padded
                # elif encoding_x.size(1) > decoding_x.size(1):
                #     padded = torch.zeros_like(encoding_x)
                #     padded[:, : decoding_x.size(1)] = decoding_x
                #     decoding_x = padded

                rescaled_logits = decoder2(encoding_x)
                rescaled_preds = torch.argmax(rescaled_logits, dim=1)

                score_metric.update(rescaled_preds, y)
                reconstruction_similarity.update(encoding_x, decoding_x)
                reconstruction_mse.update(encoding_x, decoding_x)

                reconstruction_norm_diff.update((decoding_x.norm(p=2, dim=-1) - encoding_x.norm(p=2, dim=-1)).abs())

            results.append(
                Result(
                    encoding_space=encoding_space,
                    decoding_space=decoding_space,
                    similarity=reconstruction_similarity.compute().cpu().item(),
                    mse=reconstruction_mse.compute().cpu().item(),
                    score=score_metric.compute().cpu().item(),
                    norm_diff=reconstruction_norm_diff.compute().cpu().item(),
                    seed=seed,
                )
            )
            pd.DataFrame(results).to_csv(result_dir / "absolute_stitching.tsv", sep="\t", index=False)

            decoder1.cpu()
            decoder2.cpu()

        results = pd.DataFrame(results)
        results.to_csv(result_dir / "absolute_stitching.tsv", sep="\t", index=False)

    if not ortho_result.exists():
        # SVD
        Result = namedtuple(
            "Result",
            [
                "seed",
                "num_anchors",
                "encoding_space",
                "decoding_space",
                "similarity",
                "mse",
                "score",
                "norm_diff",
                "centering",
                "std_correction",
                "l2_norm",
                "sigma_rank",
            ],
        )
        batch_size: int = 4000

        # seed_options: Sequence[int] = (
        #     42,
        #     # 0,
        # )
        centering_options = [
            True,
        ]

        results = []
        for (
            (encoding_space, decoder1),
            (decoding_space, decoder2),
            num_anchors,
            std_correction,
            centering,
            l2_norm,
            seed,
        ) in tqdm(
            list(
                itertools.product(
                    list(encoder2decoder.items()),
                    list(encoder2decoder.items()),
                    anchor_num_options,
                    std_correction_options,
                    centering_options,
                    l2_norm_options,
                    seeds,
                )
            )
        ):
            seed_everything(seed)
            encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
            decoding_space_column = dataset_config.encoding_column_template.format(encoder=decoding_space)

            anchor_data = data["train"].shuffle(seed=seed).select(list(range(num_anchors)))

            encoder_test_data = data["test"].remove_columns(
                [
                    column
                    for column in data["test"].column_names
                    if column not in {"index", label_column, encoding_space_column, decoding_space_column}
                ]
            )
            test_loader = DataLoader(
                encoder_test_data, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0
            )

            # orig_acc = Accuracy(task="multiclass", num_classes=encoder_test_data.features["label"].num_classes).to(DEVICE)
            rescaled_acc = Accuracy(
                task="multiclass", num_classes=encoder_test_data.features[label_column].num_classes
            ).to(DEVICE)
            reconstruction_similarity = CosineSimilarity(reduction="mean").to(DEVICE)
            reconstruction_mse = MeanAbsoluteError(reduction="mean").to(DEVICE)
            reconstruction_norm_diff = MeanMetric().to(DEVICE)

            encoding_anchors: torch.Tensor = anchor_data[encoding_space_column].to(DEVICE)  # .double()
            decoding_anchors: torch.Tensor = anchor_data[decoding_space_column].to(DEVICE)  # .double()

            latent_translation = LatentTranslation(
                seed=seed, centering=centering, std_correction=std_correction, l2_norm=l2_norm, method="svd"
            )
            try:
                latent_translation.fit(encoding_anchors=encoding_anchors, decoding_anchors=decoding_anchors)
            except:  # noqa
                continue

            latent_translation.to(DEVICE)
            decoder1.to(DEVICE)
            decoder2.to(DEVICE)

            translate_info = defaultdict(list)
            with torch.no_grad():
                for test_batch in test_loader:
                    y = test_batch[label_column].to(DEVICE)

                    encoding_x = test_batch[encoding_space_column].to(DEVICE)  # .double()
                    decoding_x = test_batch[decoding_space_column].to(DEVICE)  # .double()

                    rec_decoding_x = latent_translation.transform(X=encoding_x, compute_info=False)["target"]

                    reconstruction_similarity.update(decoding_x, rec_decoding_x)
                    reconstruction_mse.update(decoding_x, rec_decoding_x)

                    rescaled_logits = decoder2(rec_decoding_x.float())
                    rescaled_preds = torch.argmax(rescaled_logits, dim=1)

                    reconstruction_norm_diff.update(
                        (decoding_x.norm(p=2, dim=-1) - rec_decoding_x.norm(p=2, dim=-1)).abs()
                    )
                    rescaled_acc.update(rescaled_preds, y)

            # TODO: translate_info merge

            results.append(
                Result(
                    num_anchors=num_anchors,
                    seed=seed,
                    encoding_space=encoding_space,
                    decoding_space=decoding_space,
                    similarity=reconstruction_similarity.compute().cpu().item(),
                    mse=reconstruction_mse.compute().cpu().item(),
                    score=rescaled_acc.compute().cpu().item(),
                    std_correction=std_correction,
                    norm_diff=reconstruction_norm_diff.compute().cpu().item(),
                    centering=centering,
                    l2_norm=l2_norm,
                    sigma_rank=latent_translation.sigma_rank,
                    **translate_info,
                )
            )
            pd.DataFrame(results).to_csv(result_dir / "orthogonal_stitching.tsv", sep="\t", index=False)

            decoder1.cpu()
            decoder2.cpu()

        results = pd.DataFrame(results)
        results.to_csv(result_dir / "orthogonal_stitching.tsv", sep="\t", index=False)

    if not lstsq_result.exists():
        # Least Square
        Result = namedtuple(
            "Result",
            [
                "seed",
                "num_anchors",
                "encoding_space",
                "decoding_space",
                "similarity",
                "mse",
                "score",
                "norm_diff",
                "centering",
                "std_correction",
                "l2_norm",
            ],
        )
        batch_size: int = 4000

        centering_options = [
            True,
        ]

        results = []
        for (
            (encoding_space, decoder1),
            (decoding_space, decoder2),
            num_anchors,
            std_correction,
            centering,
            l2_norm,
            seed,
        ) in tqdm(
            list(
                itertools.product(
                    list(encoder2decoder.items()),
                    list(encoder2decoder.items()),
                    anchor_num_options,
                    std_correction_options,
                    centering_options,
                    l2_norm_options,
                    seeds,
                )
            )
        ):
            seed_everything(seed)
            encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
            decoding_space_column = dataset_config.encoding_column_template.format(encoder=decoding_space)

            anchor_data = data["train"].shuffle(seed=seed).select(list(range(num_anchors)))

            encoder_test_data = data["test"].remove_columns(
                [
                    column
                    for column in data["test"].column_names
                    if column not in {"index", label_column, encoding_space_column, decoding_space_column}
                ]
            )
            test_loader = DataLoader(
                encoder_test_data, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0
            )

            # orig_acc = Accuracy(task="multiclass", num_classes=encoder_test_data.features["label"].num_classes).to(DEVICE)
            rescaled_acc = Accuracy(
                task="multiclass", num_classes=encoder_test_data.features[label_column].num_classes
            ).to(DEVICE)
            reconstruction_similarity = CosineSimilarity(reduction="mean").to(DEVICE)
            reconstruction_mse = MeanAbsoluteError(reduction="mean").to(DEVICE)
            reconstruction_norm_diff = MeanMetric().to(DEVICE)

            encoding_anchors: torch.Tensor = anchor_data[encoding_space_column].to(DEVICE)  # .double()
            decoding_anchors: torch.Tensor = anchor_data[decoding_space_column].to(DEVICE)  # .double()

            latent_translation = LatentTranslation(
                seed=seed, centering=centering, std_correction=std_correction, l2_norm=l2_norm, method="lstsq"
            )
            latent_translation.fit(encoding_anchors=encoding_anchors, decoding_anchors=decoding_anchors)

            latent_translation.to(DEVICE)
            decoder1.to(DEVICE)
            decoder2.to(DEVICE)

            translate_info = defaultdict(list)
            with torch.no_grad():
                for test_batch in test_loader:
                    y = test_batch[label_column].to(DEVICE)

                    encoding_x = test_batch[encoding_space_column].to(DEVICE)  # .double()
                    decoding_x = test_batch[decoding_space_column].to(DEVICE)  # .double()

                    rec_decoding_x = latent_translation.transform(X=encoding_x, compute_info=False)["target"]

                    try:
                        reconstruction_similarity.update(decoding_x, rec_decoding_x)
                        reconstruction_mse.update(decoding_x, rec_decoding_x)

                        reconstruction_norm_diff.update(
                            (decoding_x.norm(p=2, dim=-1) - rec_decoding_x.norm(p=2, dim=-1)).abs()
                        )
                        rescaled_logits = decoder2(rec_decoding_x.float())
                        rescaled_preds = torch.argmax(rescaled_logits, dim=1)
                        rescaled_acc.update(rescaled_preds, y)
                    except ValueError:
                        # TODO: error when the matrix anchor has too many linearly dependent vectors
                        continue

            results.append(
                Result(
                    num_anchors=num_anchors,
                    seed=seed,
                    encoding_space=encoding_space,
                    decoding_space=decoding_space,
                    similarity=reconstruction_similarity.compute().cpu().item(),
                    mse=reconstruction_mse.compute().cpu().item(),
                    score=rescaled_acc.compute().cpu().item(),
                    std_correction=std_correction,
                    norm_diff=reconstruction_norm_diff.compute().cpu().item(),
                    centering=centering,
                    l2_norm=l2_norm,
                    **translate_info,
                )
            )
            pd.DataFrame(results).to_csv(result_dir / "lstsq_stitching.tsv", sep="\t", index=False)

            decoder1.cpu()
            decoder2.cpu()

        results = pd.DataFrame(results)
        results.to_csv(result_dir / "lstsq_stitching.tsv", sep="\t", index=False)

    if not lstsq_ortho_result.exists():
        # Least Square
        Result = namedtuple(
            "Result",
            [
                "seed",
                "num_anchors",
                "encoding_space",
                "decoding_space",
                "similarity",
                "mse",
                "score",
                "norm_diff",
                "centering",
                "std_correction",
                "l2_norm",
            ],
        )
        batch_size: int = 4000

        centering_options = [
            True,
        ]

        results = []
        for (
            (encoding_space, decoder1),
            (decoding_space, decoder2),
            num_anchors,
            std_correction,
            centering,
            l2_norm,
            seed,
        ) in tqdm(
            list(
                itertools.product(
                    list(encoder2decoder.items()),
                    list(encoder2decoder.items()),
                    anchor_num_options,
                    std_correction_options,
                    centering_options,
                    l2_norm_options,
                    seeds,
                )
            )
        ):
            seed_everything(seed)
            encoding_space_column = dataset_config.encoding_column_template.format(encoder=encoding_space)
            decoding_space_column = dataset_config.encoding_column_template.format(encoder=decoding_space)

            anchor_data = data["train"].shuffle(seed=seed).select(list(range(num_anchors)))

            encoder_test_data = data["test"].remove_columns(
                [
                    column
                    for column in data["test"].column_names
                    if column not in {"index", label_column, encoding_space_column, decoding_space_column}
                ]
            )
            test_loader = DataLoader(
                encoder_test_data, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0
            )

            # orig_acc = Accuracy(task="multiclass", num_classes=encoder_test_data.features["label"].num_classes).to(DEVICE)
            rescaled_acc = Accuracy(
                task="multiclass", num_classes=encoder_test_data.features[label_column].num_classes
            ).to(DEVICE)
            reconstruction_similarity = CosineSimilarity(reduction="mean").to(DEVICE)
            reconstruction_mse = MeanAbsoluteError(reduction="mean").to(DEVICE)
            reconstruction_norm_diff = MeanMetric().to(DEVICE)

            encoding_anchors: torch.Tensor = anchor_data[encoding_space_column].to(DEVICE)  # .double()
            decoding_anchors: torch.Tensor = anchor_data[decoding_space_column].to(DEVICE)  # .double()

            latent_translation = LatentTranslation(
                seed=seed, centering=centering, std_correction=std_correction, l2_norm=l2_norm, method="lstsq+ortho"
            )
            try:
                latent_translation.fit(encoding_anchors=encoding_anchors, decoding_anchors=decoding_anchors)
            except:  # noqa
                continue

            latent_translation.to(DEVICE)
            decoder1.to(DEVICE)
            decoder2.to(DEVICE)

            translate_info = defaultdict(list)
            with torch.no_grad():
                for test_batch in test_loader:
                    y = test_batch[label_column].to(DEVICE)

                    encoding_x = test_batch[encoding_space_column].to(DEVICE)  # .double()
                    decoding_x = test_batch[decoding_space_column].to(DEVICE)  # .double()

                    rec_decoding_x = latent_translation.transform(X=encoding_x, compute_info=False)["target"]

                    try:
                        reconstruction_similarity.update(decoding_x, rec_decoding_x)
                        reconstruction_mse.update(decoding_x, rec_decoding_x)

                        reconstruction_norm_diff.update(
                            (decoding_x.norm(p=2, dim=-1) - rec_decoding_x.norm(p=2, dim=-1)).abs()
                        )
                        rescaled_logits = decoder2(rec_decoding_x.float())
                        rescaled_preds = torch.argmax(rescaled_logits, dim=1)
                        rescaled_acc.update(rescaled_preds, y)
                    except ValueError:
                        # TODO: error when the matrix anchor has too many linearly dependent vectors
                        continue

            results.append(
                Result(
                    num_anchors=num_anchors,
                    seed=seed,
                    encoding_space=encoding_space,
                    decoding_space=decoding_space,
                    similarity=reconstruction_similarity.compute().cpu().item(),
                    mse=reconstruction_mse.compute().cpu().item(),
                    score=rescaled_acc.compute().cpu().item(),
                    std_correction=std_correction,
                    norm_diff=reconstruction_norm_diff.compute().cpu().item(),
                    centering=centering,
                    l2_norm=l2_norm,
                    **translate_info,
                )
            )
            pd.DataFrame(results).to_csv(result_dir / "lstsq_ortho_stitching.tsv", sep="\t", index=False)

            decoder1.cpu()
            decoder2.cpu()

        results = pd.DataFrame(results)
        results.to_csv(result_dir / "lstsq_ortho_stitching.tsv", sep="\t", index=False)


if __name__ == "__main__":
    seeds = (42,) + tuple(range(4))
    # anchor_num_options: Sequence[int] = torch.arange(0, 1501, 100)[1:].tolist()
    anchor_num_options: Sequence[int] = (1500,)
    for dataset, decoder_type in (
        pbar := tqdm(
            itertools.product(
                (
                    # "cifar10",
                    # "cifar100_fine",
                    # "cifar100_coarse",
                    # "fashion_mnist",
                    # "mnist",
                    # "dbpedia_14",
                    # "trec",
                    # "n24news_image",
                    "n24news_text",
                ),
                ("svm", "linear"),
            )
        )
    ):
        pbar.set_description(f"{dataset} | {decoder_type}")
        run(
            dataset=dataset,
            decoder_type=decoder_type,
            seeds=seeds,
            std_correction_options=[
                False,
            ],
            l2_norm_options=[
                True,
            ],
            anchor_num_options=anchor_num_options,
        )
