import logging
import os
import sys
import torch
import numpy as np

from scipy.linalg import orthogonal_procrustes
import datasets
import transformers.utils.logging as trfl
from multiprocessing import Pool
from transformers import set_seed

ROOT_PATH = "representations-large"

DEVICE = "cpu"
eval_mode = "all"  # 'all' or 'last

logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


def cca_decomp(A, B):
    """Computes CCA vectors, correlations, and transformed matrices
    requires a < n and b < n
    Args:
        A: np.array of size a x n where a is the number of neurons and n is the dataset size
        B: np.array of size b x n where b is the number of neurons and n is the dataset size
    Returns:
        u: left singular vectors for the inner SVD problem
        s: canonical correlation coefficients
        vh: right singular vectors for the inner SVD problem
        transformed_a: canonical vectors for matrix A, a x n array
        transformed_b: canonical vectors for matrix B, b x n array
    """
    assert A.shape[0] < A.shape[1]
    assert B.shape[0] < B.shape[1]

    evals_a, evecs_a = np.linalg.eigh(A @ A.T)
    evals_a = (evals_a + np.abs(evals_a)) / 2
    inv_a = np.array([1 / np.sqrt(x) if x > 0 else 0 for x in evals_a])

    evals_b, evecs_b = np.linalg.eigh(B @ B.T)
    evals_b = (evals_b + np.abs(evals_b)) / 2
    inv_b = np.array([1 / np.sqrt(x) if x > 0 else 0 for x in evals_b])

    cov_ab = A @ B.T

    temp = (
        (evecs_a @ np.diag(inv_a) @ evecs_a.T)
        @ cov_ab
        @ (evecs_b @ np.diag(inv_b) @ evecs_b.T)
    )

    try:
        u, s, vh = np.linalg.svd(temp)
    except:
        u, s, vh = np.linalg.svd(temp * 100)
        s = s / 100

    transformed_a = (u.T @ (evecs_a @ np.diag(inv_a) @ evecs_a.T) @ A).T
    transformed_b = (vh @ (evecs_b @ np.diag(inv_b) @ evecs_b.T) @ B).T
    return u, s, vh, transformed_a, transformed_b


def mean_sq_cca_corr(rho):
    """Compute mean squared CCA correlation
    :param rho: canonical correlation coefficients returned by cca_decomp(A,B)
    """
    # len(rho) is min(A.shape[0], B.shape[0])
    return np.sum(rho * rho) / len(rho)


def mean_cca_corr(rho):
    """Compute mean CCA correlation
    :param rho: canonical correlation coefficients returned by cca_decomp(A,B)
    """
    # len(rho) is min(A.shape[0], B.shape[0])
    return np.sum(rho) / len(rho)


def pwcca_dist(A, rho, transformed_a):
    """Computes projection weighted CCA distance between A and B given the correlation
    coefficients rho and the transformed matrices after running CCA
    :param A: np.array of size a x n where a is the number of neurons and n is the dataset size
    :param B: np.array of size b x n where b is the number of neurons and n is the dataset size
    :param rho: canonical correlation coefficients returned by cca_decomp(A,B)
    :param transformed_a: canonical vectors for A returned by cca_decomp(A,B)
    :param transformed_b: canonical vectors for B returned by cca_decomp(A,B)
    :return: PWCCA distance
    """
    in_prod = transformed_a.T @ A.T
    weights = np.sum(np.abs(in_prod), axis=1)
    weights = weights / np.sum(weights)
    dim = min(len(weights), len(rho))
    return 1 - np.dot(weights[:dim], rho[:dim])


## CKA
def lin_cka_dist(A, B):
    """
    Computes Linear CKA distance bewteen representations A and B
    """
    similarity = np.linalg.norm(B @ A.T, ord="fro") ** 2
    normalization = np.linalg.norm(A @ A.T, ord="fro") * np.linalg.norm(
        B @ B.T, ord="fro"
    )
    return 1 - similarity / normalization


def compute_similarity_measures(task, seedx):
    mean_squared_cca_distances = []
    mean_cca_distances = []
    pwcca_distances = []
    cka_distances = []
    op_distances = []
    logger.info(f"seedx: {seedx}, task: {task}")
    train_x = (
        torch.load(
            os.path.join(
                ROOT_PATH, f"representations-large/seed-{seedx}-task-{task}-train"
            ),
            map_location=torch.device("cpu"),
        )[-1]
        .numpy()
        .T
    )
    if task == "mnli":
        train_x = train_x
        logger.info(train_x.shape)

    for seedy in [s for s in range(25) if s != seedx]:
        logger.info(f"seedy: {seedy}")
        train_y = (
            torch.load(
                os.path.join(
                    ROOT_PATH, f"representations-large/seed-{seedy}-task-{task}-train"
                ),
                map_location=torch.device("cpu"),
            )[-1]
            .numpy()
            .T
        )  # transpose for the methods to work properly
        if task == "mnli":
            train_y = train_y
            logger.info(train_y.shape)

        _, cca_rho, _, transformed_rep1, _ = cca_decomp(train_x, train_y)
        logger.info("cca done")
        mean_squared_cca_distances.append(mean_sq_cca_corr(cca_rho))
        logger.info("mean_sq_cca_corr done")
        mean_cca_distances.append(mean_cca_corr(cca_rho))
        logger.info("mean_cca_corr done")
        pwcca_distances.append(pwcca_dist(train_x, cca_rho, transformed_rep1))
        logger.info("PWCCA done")
        cka_distances.append(lin_cka_dist(train_x, train_y))
        logger.info("lin_cka_sim done")

        R = orthogonal_procrustes(train_x, train_y)[0]
        op_distances.append(np.linalg.norm(R.T @ train_x.T - train_y.T))

    # Save NumPy arrays as .npy files
    np.save(
        os.path.join(
            ROOT_PATH,
            f"representations-large/cca_distances_task-{task}-seedx-{seedx}.npy",
        ),
        np.array(mean_squared_cca_distances),
    )
    np.save(
        os.path.join(
            ROOT_PATH,
            f"representations-large/mean_cca_distances_task-{task}-seedx-{seedx}.npy",
        ),
        np.array(mean_cca_distances),
    )
    np.save(
        os.path.join(
            ROOT_PATH,
            f"representations-large/pwcca_distances_task-{task}-seedx-{seedx}.npy",
        ),
        np.array(pwcca_distances),
    )
    np.save(
        os.path.join(
            ROOT_PATH,
            f"representations-large/cka_distances_task-{task}-seedx-{seedx}.npy",
        ),
        np.array(cka_distances),
    )
    np.save(
        os.path.join(
            ROOT_PATH,
            f"representations-large/op_distances_task-{task}-seedx-{seedx}.npy",
        ),
        np.array(op_distances),
    )


if __name__ == "__main__":
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    tasks = ["sst2", "mrpc"]
    seeds = range(25)
    with Pool(processes=16) as pool:
        pool.starmap(
            compute_similarity_measures,
            [(task, seed) for task in tasks for seed in seeds],
        )
