"""test_model.py
   Test models
   Developed as part of DeepThinking project
   Sept 2021
"""

import argparse
import os
import uuid
from collections import OrderedDict

import json
import torch
from icecream import ic

from utils.common import load_model_from_checkpoint, get_dataloaders, to_json, get_optimizer, now
from utils.testing_utils import test


# Ignore statements for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115

def get_training_args(path_to_args):
    path_to_args = os.path.join(path_to_args, "args.json")
    with open(path_to_args, 'r') as fp:
        args_dict = json.load(fp)
    return args_dict["0"]


def main():

    print("\n_________________________________________________\n")
    print(now(), "train.py main() running.")

    parser = argparse.ArgumentParser(description="Deep Thinking")

    parser.add_argument("--args_path", default=None, type=str, help="where are the args saved?")
    parser.add_argument("--model", default="recur_net_recall_1d", type=str,
                        help="model for training")
    parser.add_argument("--model_path", default=None, type=str, help="where is the model saved?")
    parser.add_argument("--output", default="output_default", type=str, help="output subdirectory")
    parser.add_argument("--problem", default="prefix_sums", type=str,
                        help="one of 'prefix_sums', 'chess', or 'mazes'")
    parser.add_argument("--quick_test", action="store_true", help="test with test data only")
    parser.add_argument("--test_batch_size", default=500, type=int, help="batch size for testing")
    parser.add_argument("--test_data", default=20, type=int, help="what size eval data")
    parser.add_argument("--test_iterations", nargs="+", default=[20], type=int,
                        help="how many, if testing with a different number iterations")
    parser.add_argument("--test_mode", default="default", type=str, help="testing mode")
    parser.add_argument("--train_batch_size", default=128, type=int,
                        help="batch size for training")
    parser.add_argument("--width", default=120, type=int, help="width of the network")

    args = parser.parse_args()
    args.run_id = uuid.uuid1().hex
    training_args = get_training_args(args.args_path)

    args.epochs = training_args["epochs"]
    args.inner_data = training_args["inner_data"]
    args.lr = training_args["lr"]
    args.lr_factor = training_args["lr_factor"]
    args.max_iters = training_args["max_iters"]
    args.min_k = training_args["min_k"]
    args.min_n = training_args["min_n"]
    args.model = training_args["model"]
    args.optimizer = training_args["optimizer"]
    args.train_data = training_args["train_data"]
    args.train_mode = training_args["train_mode"]
    args.weight_for_loss = training_args["weight_for_loss"]
    args.width = training_args["width"]

    args.train_mode, args.test_mode = args.train_mode.lower(), args.test_mode.lower()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # turn off cuda benchmarking for train_mode 'dual_loop'
    if not args.train_mode == "dual_loop":
        torch.backends.cudnn.benchmark = True

    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")

    assert 0 <= args.weight_for_loss <= 1, "Weighting for loss not in [0, 1], exiting."

    args.output = os.path.join(args.output, args.run_id)
    to_json(vars(args), args.output, "args.json")

    ####################################################
    #               Dataset and Network and Optimizer
    loaders = get_dataloaders(args)

    net, start_epoch, optimizer_state_dict = load_model_from_checkpoint(args.model,
                                                                        args.model_path,
                                                                        args.width,
                                                                        args.problem,
                                                                        args.max_iters,
                                                                        device)

    args.test_iterations.append(args.max_iters)
    args.test_iterations = list(set(args.test_iterations))
    args.test_iterations.sort()

    pytorch_total_params = sum(p.numel() for p in net.parameters())

    print(f"This {args.model} has {pytorch_total_params/1e6:0.3f} million parameters.")
    ####################################################


    ####################################################
    #        Test
    print("==> Starting testing...")

    if args.quick_test:
        test_acc = test(net, [loaders["test"]], args.test_mode, args.test_iterations,
                        args.problem, device, disable_tqdm=args.use_comet)
        test_acc = test_acc[0]
        val_acc, train_acc = None, None
    else:
        test_acc, val_acc, train_acc = test(net,
                                            [loaders["test"], loaders["val"], loaders["train"]],
                                            args.test_mode,
                                            args.test_iterations,
                                            args.problem, device)

    print(f"{now()} Training accuracy: {train_acc}")
    print(f"{now()} Val accuracy: {val_acc}")
    print(f"{now()} Testing accuracy (hard data): {test_acc}")


    model_name_str = f"{args.model}_width={args.width}"
    stats = OrderedDict([("epochs", args.epochs),
                         ("inner_data", args.inner_data),
                         ("learning rate", args.lr),
                         ("lr", args.lr),
                         ("lr_factor", args.lr_factor),
                         ("max_iters", args.max_iters),
                         ("min_k", args.min_k),
                         ("min_n", args.min_n),
                         ("model", model_name_str),
                         ("model_path", args.model_path),
                         ("num_params", pytorch_total_params),
                         ("optimizer", args.optimizer),
                         ("val_acc", val_acc),
                         ("run_id", args.run_id),
                         ("test_acc", test_acc),
                         ("test_data", args.test_data),
                         ("test_iters", args.test_iterations),
                         ("test_mode", args.test_mode),
                         ("train_data", args.train_data),
                         ("train_acc", train_acc),
                         ("train_batch_size", args.train_batch_size),
                         ("train_mode", args.train_mode),
                         ("weight", args.weight_for_loss)])
    to_json(stats, args.output, "stats.json")
    ####################################################


if __name__ == "__main__":
    main()
