import json
import os
import sys
from collections import defaultdict
from datetime import datetime

sys.path.append("../")

from causal_discovery.utils import set_cluster, str2bool
from causal_graphs.graph_generation import (generate_categorical_graph,
                                            get_graph_func)
from causal_graphs.graph_visualization import visualize_graph

from experiments.utils import (get_basic_parser, save_dataset, set_seed,
                               test_graph)


def parse_args_generated_graphs(args=None):
    parser = get_generated_graphs_parser()
    return parser.parse_args(args=args)


def get_generated_graphs_parser():
    parser = get_basic_parser()
    parser.add_argument(
        "--graph_type",
        type=str,
        default="random",
        help="Which graph type to test on. Currently supported are: "
        "chain, bidiag, collider, jungle, tree, full, regular, random, "
        "random_max_#N where #N is to be replaced with an integer. "
        "random_max_10 is random with max. 10 parents per node.",
    )
    parser.add_argument(
        "--num_graphs",
        type=int,
        default=1,
        help="Number of graphs to generate and sequentially test on.",
    )
    parser.add_argument(
        "--num_vars",
        type=int,
        default=8,
        help="Number of variables that the graphs should have.",
    )
    parser.add_argument(
        "--num_categs",
        type=int,
        default=10,
        help="Number of categories/different values each variable can take.",
    )
    parser.add_argument(
        "--edge_prob",
        type=float,
        default=0.2,
        help="For random graphs, the probability of two arbitrary nodes to be connected.",
    )
    return parser


def main(args, logger_prefix=""):

    import causal_discovery.logger as lg

    lg.create_neptune_logger(args.use_neptune_logger)

    # Basic checkpoint directory creation
    current_date = datetime.now()
    if args.checkpoint_dir is None or len(args.checkpoint_dir) == 0:
        checkpoint_dir = "checkpoints/%02d_%02d_%02d__%02d_%02d_%02d/" % (
            current_date.year,
            current_date.month,
            current_date.day,
            current_date.hour,
            current_date.minute,
            current_date.second,
        )
    else:
        checkpoint_dir = args.checkpoint_dir
    os.makedirs(checkpoint_dir, exist_ok=True)
    with open(os.path.join(checkpoint_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, indent=4)

    set_cluster(args.cluster)

    dataset_to_save = None
    if args.save_dataset_path is not None:
        assert args.interventions_policy == "ce_shd_reduction"
        assert args.interventions_check_max == -1
        dataset_to_save = defaultdict(list)

    for gindex in range(args.num_graphs):
        if lg.NEPTUNE_LOGGER is not None:
            lg.NEPTUNE_LOGGER.prefix = f"logs/{logger_prefix}g{gindex}/"
            lg.NEPTUNE_LOGGER.reset()
        # Seed setting for reproducibility
        set_seed(
            args.seed + gindex
        )  # Need to increase seed, otherwise we might same graphs
        # Generate graph
        print(
            "Generating %s graph with %i variables..."
            % (args.graph_type, args.num_vars)
        )
        graph = generate_categorical_graph(
            num_vars=args.num_vars,
            min_categs=args.num_categs,
            max_categs=args.num_categs,
            edge_prob=args.edge_prob,
            connected=True,
            use_nn=True,
            graph_func=get_graph_func(args.graph_type),
            seed=args.seed + gindex,
            embed_dim=args.embed_dim,
        )
        file_id = "%s_%s" % (str(gindex + 1).zfill(3), args.graph_type)
        # Save graph
        graph.save_to_file(os.path.join(checkpoint_dir, "graph_%s.pt" % (file_id)))
        # Visualize graph
        if graph.num_vars <= 100 and args.visualize:
            print("Visualizing graph...")
            figsize = max(3, graph.num_vars**0.7)
            visualize_graph(
                graph,
                filename=os.path.join(checkpoint_dir, "graph_%s.pdf" % (file_id)),
                figsize=(figsize, figsize),
                layout="circular" if graph.num_vars < 40 else "graphviz",
            )

        # Start structure learning
        test_graph(
            graph, args, checkpoint_dir, file_id, dataset_to_save=dataset_to_save
        )

    if args.save_dataset_path is not None:
        save_dataset(dataset_to_save, args.save_dataset_path)


if __name__ == "__main__":
    args = parse_args_generated_graphs()
    main(args)
