import argparse
import os

import torch
from torch.utils.data import DataLoader

from src.data.data_partition.data_partition import (
    dirichlet_load_test,
    load_test_ood,
)
from src.models.score import Energy, MLPScore
from src.models.wideresnet import WideResNet
from src.utils.main_utils import make_save_path, set_seed

parser = argparse.ArgumentParser(description="arguments for OOD generalization and detection training")

parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--method", type=str, default="foogd")
parser.add_argument("--model_name", type=str, default="cifar10_alpha0.5_fedavg_foogd_model.pt")
parser.add_argument("--id_dataset", type=str, default="cifar10", help="the ID dataset")
parser.add_argument("--ood_dataset", type=str, default="SVHN")
parser.add_argument("--dataset_path", type=str, default="./data", help="path to dataset")
parser.add_argument("--alpha", type=float, default=0.5, help="parameter of dirichlet distribution")
parser.add_argument("--num_client", type=int, default=10, help="number of clients")
parser.add_argument("--dataset_seed", type=int, default=21, help="seed to split dataset")
parser.add_argument("--backbone", type=str, default="wideresnet", help="backbone model of task")
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--score_method", type=str, default="sm")


def test():
    args = parser.parse_args()
    set_seed(args.seed)
    save_path = make_save_path(args)

    print(f'test{args.method} on {args.id_dataset} dataset, alpha = {args.alpha}')

    corrupt_list = [
        "brightness", "fog", "glass_blur", "motion_blur", "snow", "contrast", "frost", "impulse_noise", "pixelate",
        "defocus_blur", "jpeg_compression", "elastic_transform", "gaussian_noise", "shot_noise", "zoom_blur", 'spatter',
        'gaussian_blur', 'saturate', 'speckle_noise'
    ]

    id_datasets, cor_datasets, num_class = dirichlet_load_test(
        args.dataset_path, args.id_dataset, args.num_client, args.alpha, corrupt_list, args.dataset_seed
    )
    ood_dataset = load_test_ood(args.dataset_path, args.ood_dataset, args.dataset_seed, False)

    backbone = WideResNet(depth=40, num_classes=num_class, widen_factor=2, dropRate=0.3)
    args.score_model = Energy(net=MLPScore())

    device = torch.device(args.device)

    client_id_loaders = [
        DataLoader(dataset=id_datasets[idx], batch_size=128, shuffle=True) for idx in range(args.num_client)
    ]
    ood_loader = DataLoader(dataset=ood_dataset, batch_size=128, shuffle=True)

    server_args = {}
    client_args = [
        {
            "cid": cid,
            "device": device,
            "backbone": backbone,
        }
        for cid in range(args.num_client)
    ]
    from src.utils.main_utils import get_server_and_client
    Server, Client, client_args, server_args = get_server_and_client(args, client_args, server_args)
    server = Server(server_args)
    clients = [Client(client_args[idx]) for idx in range(args.num_client)]

    server.clients.extend(clients)
    checkpoint = torch.load(os.path.join(save_path, f"{args.model_name}"))

    client_cor_loaders = dict()
    for cor_type in corrupt_list:
        client_cor_loaders[cor_type] = [
            DataLoader(
                dataset=cor_datasets[cid][cor_type],
                batch_size=128,
                shuffle=True,
            )
            for cid in range(args.num_client)
        ]
    id_accuracy, fpr95, auroc = server.test_classification_detection_ability(
        checkpoint, client_id_loaders, ood_loader, args.score_method
    )

    cor_accuracy = server.test_corrupt_accuracy(client_cor_loaders)
    print(f"test in distribution accuracy: {id_accuracy}")
    print(f"test fpr95: {fpr95}")
    print(f"test auroc: {auroc}")
    for key, value in cor_accuracy.items():
        print(f"corrupt type {key} accuracy: {value}")


if __name__ == "__main__":
    test()
