# implement just getting the exact solution with LR

import os
import logging
import numpy as np
import sys

import torch
import torch.nn as nn
import geoopt
from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from multiprocessing import Pool


DEVICE = "cpu"
eval_mode = "all"  # 'all' or 'last

logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


def compute_lr_exact(train_x, train_y):
    model = LinearRegression(fit_intercept=True)
    model.fit(train_x, train_y)
    y_pred = model.predict(train_x)
    return (
        model.score(train_x, train_y),  # report R^2
        mean_squared_error(y_true=train_y, y_pred=y_pred),  # report mean squared error
        np.max(np.linalg.norm(train_y - y_pred, axis=1)),  # compute max error (l2 norm)
    )


def train_model(task, seed):

    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()

    # global vars
    set_seed(42)
    ROOT_PATH = "representations-large"

    logger.info(f"Task: {task}")
    logger.info(f"Seed x: {seed}")
    train_losses_layers = []
    train_x = torch.load(
        os.path.join(ROOT_PATH, f"seed-{seed}-task-{task}-train"),
        map_location=DEVICE,
    )
    for seed_y in [s for s in range(25) if s != seed]:
        logger.info(f"Seed y: {seed_y}")
        train_y = torch.load(
            os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task}-train"),
            map_location=DEVICE,
        )
        if eval_mode == "last":
            train_x = train_x.select(0, -1)
            train_y = train_y.select(0, -1)
            train_losses_layers.append(
                compute_lr_exact(train_x=train_x, train_y=train_y)
            )
        elif eval_mode == "all":
            for i in range(train_x.size(0)):
                train_losses_layers.append(
                    compute_lr_exact(
                        train_x=train_x.select(0, i),
                        train_y=train_y.select(0, i),
                    )
                )
        else:
            raise ValueError(f"eval_mode must be 'all' or 'last', got {eval_mode=}")
    # train_losses_global.append(train_losses_tasks)
    with open(
        os.path.join(ROOT_PATH, f"v2-metrics-lr-{task}-small-seedx-{seed}.npy"), "wb"
    ) as f:
        np.save(f, np.array(train_losses_layers))


if __name__ == "__main__":
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    tasks = ["sst2", "mrpc"]
    seeds = range(25)
    with Pool(processes=16) as pool:
        pool.starmap(train_model, [(task, seed) for task in tasks for seed in seeds])
