# implement just getting the exact solution with LR

import os
import logging
import numpy as np
import sys

import torch
import torch.nn as nn
import geoopt
from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from multiprocessing import Pool


DEVICE = "cpu"
eval_mode = "all"  # 'all' or 'last
ROOT_PATH = "representations-large"


logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


def get_lora_matrix(u, s, vh, rank_removed):
    s_copy = s.copy()
    if rank_removed == 0:
        return u @ np.diag(s) @ vh
    else:
        s_copy[-rank_removed:] = 0
        approximation_matrix = u[:, :-rank_removed] @ np.diag(s_copy[:-rank_removed]) @ vh[:-rank_removed, :]
        # print(np.linalg.matrix_rank(approximation_matrix))
        return approximation_matrix



def compute_lr_exact(train_x, train_y):
    model = LinearRegression(fit_intercept=True)
    model.fit(train_x, train_y)
    y_pred = model.predict(train_x)
    return (
        model.score(train_x, train_y),  # report R^2
        mean_squared_error(y_true=train_y, y_pred=y_pred),  # report mean squared error
        np.max(np.linalg.norm(train_y - y_pred, axis=1)),  # compute max error (l2 norm)
    )


def train_model(task, seed_x):
    ranks_removed = [0, 1, 8,  77, 154, 230, 307, 384, 461, 538, 614, 691, 730]  # 0, 1, 1%, 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90%, 95%
    overall_losses = []
    logger.info(f'TASK {task} seed_x: {seed_x}')

    # for task, seed in [("sst2", 0), ("sst2", 23), ("mrpc", 5), ("mrpc", 12)]:
    for seed_y in [s for s in range(25) if s != seed_x]:
        logger.info(f'seedy: {seed_y}   seedx: {seed_x}')
        losses_seedy = []
        for rank_removed_x in ranks_removed:
            train_losses = []
            logger.info(f"rank_removed_x: {rank_removed_x}")
            x = np.load(os.path.join(ROOT_PATH, f"seed-{seed_x}-task-{task}-train-lora-{rank_removed_x}.npy"))
            for rank_removed_y in ranks_removed:
                logger.info(f"rank_removed_y: {rank_removed_y}")
                y = np.load(os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task}-train-lora-{rank_removed_y}.npy"))
                losses = compute_lr_exact(x, y)
                train_losses.append(losses)
            losses_seedy.append(train_losses)
        overall_losses.append(losses_seedy)
    np.save(os.path.join(ROOT_PATH, f"seedx-{seed_x}-task-{task}-train-lora-train-losses.npy"), overall_losses)


if __name__ == "__main__":
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)

    import os
    import torch
    import numpy as np
    from scipy.linalg import svd

    # first compute rank-removed matrices
    ranks_removed = [0, 1, 8,  77, 154, 230, 307, 384, 461, 538, 614, 691, 730]  # 0, 1, 1%, 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90%, 95%


    logger.info("Computing rank-removed matrices")
    # for task, seed in [("sst2", 0), ("sst2", 23), ("mrpc", 5), ("mrpc", 12)]:
    for task in ["sst2", "mrpc"]:
        logger.info(f'TASK {task}')
        for seed in range(25):
            logger.info(f'seed: {seed}')
            representation = torch.load(os.path.join(ROOT_PATH, f"seed-{seed}-task-{task}-train"), map_location="cpu").numpy()[-1]
            u, s, vh = np.linalg.svd(representation, full_matrices=False)
            for rank_removed in ranks_removed:
                lora = get_lora_matrix(u, s, vh, rank_removed)
                # print(lora)
                np.save(os.path.join(ROOT_PATH, f"seed-{seed}-task-{task}-train-lora-{rank_removed}.npy"), lora)

    logger.info("Rank-removed matrices computed, starting training...")
    tasks = ["mrpc", "sst2"]
    # then compute intrinsic distance
    seeds_x = range(25)
    with Pool(processes=16) as pool:
        pool.starmap(
            train_model,
            [(task, seed_x) for task in tasks for seed_x in seeds_x],
        )
