import os
import sys
import joblib
import torch
import numpy as np
from tqdm.auto import tqdm
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from model import init_model, set_weights, get_grads


def load_stats(configs, pt=True, num_ft_lrs=8, shrink_factor=5):
    cache_file = os.path.join(
        os.path.split(configs[0].savedir)[0],
        f'cached_{"pt" if pt else "ft"}_stats.pickle'
    )
    read_seeds = 0
    if os.path.isfile(cache_file):
        cached_stats = joblib.load(cache_file)

        read_seeds = cached_stats[0].shape[0]
        if read_seeds == len(configs):
            return cached_stats

    lrs = configs[0].lrs
    n_seeds = len(configs)
    if pt:
        iters = configs[0].pt_iters // (configs[0].log_iters * shrink_factor)
        ckpt_iters = configs[0].pt_iters // configs[0].ckpt_iters
        s1 = (n_seeds, len(lrs), iters + 1)
        s2 = (n_seeds, len(lrs), ckpt_iters, len(configs[0].data_protocol))
    else:
        iters = configs[0].ft_iters // (configs[0].log_iters * shrink_factor)
        ckpt_iters = configs[0].ft_iters // configs[0].ckpt_iters
        s1 = (n_seeds, len(lrs), num_ft_lrs, iters + 1)
        s2 = (n_seeds, len(lrs), num_ft_lrs, ckpt_iters, len(configs[0].data_protocol))

    all_arrays = tuple([np.zeros(s1) for _ in range(5)] + [np.zeros(s2) for _ in range(2)])
    keys = [
        'train_loss', 'test_loss', 'train_acc', 'test_acc',
        'stoch_grad_norm', 'feature_importance', 'train_feature_importance'
    ]
    if read_seeds > 0:
        for array, cached_array in zip(all_arrays, cached_stats):
            array[:read_seeds] = cached_array
        del cached_stats

    desc = f'Loading {"pre-training" if pt else "fine-tuning"} stats'
    for s, config in enumerate(tqdm(configs[read_seeds:], desc=desc), read_seeds):
        if pt:
            for j, lr in enumerate(lrs):
                ckpt = torch.load(f'{config.savedir}/pt_lr={lr:.3e}.pt')
                for array, key in zip(all_arrays, keys):
                    if key not in ckpt['trace']:
                        continue
                    if 'feature_importance' in key:
                        array[s, j] = np.array(ckpt['trace'][key])
                    else:
                        array[s, j] = np.array(ckpt['trace'][key][::shrink_factor])
        else:
            for i, pt_lr in enumerate(lrs):
                for j, ft_lr in enumerate(lrs[:num_ft_lrs]):
                    ckpt = torch.load(f'{config.ft_savedir}/pt_lr={pt_lr:.3e}-ft_lr={ft_lr:.3e}.pt')
                    for array, key in zip(all_arrays, keys):
                        if key not in ckpt['trace']:
                            continue
                        if 'feature_importance' in key:
                            array[s, i, j] = np.array(ckpt['trace'][key])
                        else:
                            array[s, i, j] = np.array(ckpt['trace'][key][::shrink_factor])

    stats_size = sum(sys.getsizeof(array) for array in all_arrays) / 10 ** 6
    print(f'{"Pre-training" if pt else "Fine-tuning"} stats size: {stats_size:.2f} Mb')
    joblib.dump(all_arrays, cache_file)

    return all_arrays


def calculate_final_grad_norms(config, ft_lr=None):
    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
    train_loader = DataLoader(TensorDataset(X_train, y_train),
                              batch_size=config.batch_size, drop_last=True, shuffle=True)
    model = init_model(config.num_layers, config.num_hidden, config.num_features)
    criterion = nn.BCEWithLogitsLoss()

    grad_norms_by_lr = []
    for j, lr in enumerate(tqdm(config.lrs)):
        path = f'{config.savedir}/pt_lr={lr:.3e}.pt' if ft_lr is None else \
            f'{config.ft_savedir}/pt_lr={lr:.3e}-ft_lr={ft_lr:.3e}.pt'
        ckpt = torch.load(path)
        model.load_state_dict(ckpt['model'])

        norms = []
        for X, y in train_loader:
            model.zero_grad()
            predictions = model(X)[:, 0]
            loss = criterion(predictions, y.to(torch.float))
            loss.backward()
            norms.append(
                torch.norm(get_grads(model)).item()
            )

        grad_norms_by_lr.append(np.mean(norms))

    return np.array(grad_norms_by_lr)


def calculate_trajectory_grad_norms(config):
    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
    train_loader = DataLoader(TensorDataset(X_train, y_train),
                              batch_size=config.batch_size, drop_last=True, shuffle=True)
    model = init_model(config.num_layers, config.num_hidden, config.num_features)
    criterion = nn.BCEWithLogitsLoss()

    grad_norms_by_lr = []
    for j, lr in enumerate(tqdm(config.lrs)):
        ckpt = torch.load(f'{config.savedir}/pt_lr={lr:.3e}.pt')
        model.load_state_dict(ckpt['model'])

        grad_norms = []
        for weight in ckpt['trace']['weight']:
            set_weights(model, weight)

            norms = []
            for X, y in train_loader:
                model.zero_grad()
                predictions = model(X)[:, 0]
                loss = criterion(predictions, y.to(torch.float))
                loss.backward()
                norms.append(
                    torch.norm(get_grads(model)).item()
                )

            grad_norms.append(np.mean(norms))

        grad_norms_by_lr.append(grad_norms)

    return np.array(grad_norms_by_lr)


def calculate_final_importance(
    config, feature_groups, ft_lr=None, squared=True, X_test=None, eps=1e-7
):
    if X_test is None:
        X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
    model = init_model(config.num_layers, config.num_hidden, config.num_features)
    X_test.requires_grad = True

    grad_norm_ratio_by_lr = []
    for j, lr in enumerate(tqdm(config.lrs)):
        path = f'{config.savedir}/pt_lr={lr:.3e}.pt' if ft_lr is None else \
            f'{config.ft_savedir}/pt_lr={lr:.3e}-ft_lr={ft_lr:.3e}.pt'
        ckpt = torch.load(path)
        model.load_state_dict(ckpt['model'])

        model.zero_grad()
        X_test.grad = None

        output = model(X_test)
        torch.autograd.backward(list(output))

        grad = X_test.grad
        norm = grad.norm(dim=1)

        mean_ratios = []
        for features in feature_groups:
            ratios = grad[:, features].norm(dim=1) / (norm + eps)
            if squared:
                ratios = ratios.square()
            mean_ratios.append(ratios.mean().item())

        grad_norm_ratio_by_lr.append(mean_ratios)

    return np.array(grad_norm_ratio_by_lr)


def calculate_trajectory_importance(
    config, feature_groups, lrs, squared=True, X_test=None, eps=1e-7
):
    if X_test is None:
        X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
    model = init_model(config.num_layers, config.num_hidden, config.num_features)
    X_test.requires_grad = True

    grad_norm_ratio_by_lr = []
    for j, lr in enumerate(tqdm(lrs)):
        if isinstance(lr, float):
            # single lr specifies, load pre-train checkpoint
            ckpt = torch.load(f'{config.savedir}/pt_lr={lr:.3e}.pt')
        else:
            # two lrs specifies, load fine-tune checkpoint
            pt_lr, ft_lr = lr
            ckpt = torch.load(f'{config.ft_savedir}/pt_lr={pt_lr:.3e}-ft_lr={ft_lr:.3e}.pt')
        model.load_state_dict(ckpt['model'])

        grad_ratios = []
        for weight in ckpt['trace']['weight']:
            set_weights(model, weight)

            model.zero_grad()
            X_test.grad = None

            output = model(X_test)
            torch.autograd.backward(list(output))

            grad = X_test.grad
            norm = grad.norm(dim=1)

            mean_ratios = []
            for features in feature_groups:
                ratios = grad[:, features].norm(dim=1) / (norm + eps)
                if squared:
                    ratios = ratios.square()
                mean_ratios.append(ratios.mean().item())

            grad_ratios.append(mean_ratios)

        grad_norm_ratio_by_lr.append(grad_ratios)

    return np.array(grad_norm_ratio_by_lr)
