"""
Script to split a file with units into multiple different ones.
"""

import json
from argparse import ArgumentParser

from stimuli_generation.sample_units import (init_layer_sizes, layer_sizes,
                                             model_types)
from stimuli_generation.utils import (accuracies, load_model, read_units_file,
                                      split_unit)


def get_fair_unit_lists(model_name, units, n):
    """
    Takes a list of units and distributes them into n lists, so that the expected work
    for every list is about the same.

    :param model_name: the name of the model
    :param units: list of units
    :param n: desired number of lists
    """

    # Sort units by number of channels in their layer, which is the main factor for
    # the GPU time it will take to visualize the unit.
    # Load the model and populate layer_sizes-dictionary to know possible number of
    # units per layer
    model = load_model(model_name)
    if "clip" in model_name:
        model = model.visual
    init_layer_sizes(model_name, model)

    def get_key(unit):
        """Custom sorting key: value is the unit to be sorted."""
        layer, _channel = split_unit(unit)
        if "clip" in model_name:
            layer = layer[len("visual_") :]
        layer_size = layer_sizes[model_types[model_name]][layer]
        return layer_size

    units.sort(key=get_key)
    print(units)
    units_lists = [units[i::n] for i in range(n)]

    assert sum(len(ln) for ln in units_lists) == len(
        units
    ), "Not all units could be assigned to lists!"

    return units_lists


def split_units_list(args):
    """
    Takes the CLI-arguments, splits the units into fair lists and assigns them to
    different GPUs.

    :param args: the CLI-arguments
    """

    unit_lists = get_fair_unit_lists(args.model_name, args.units, args.num_gpus)

    filename = args.units_file.split(".")[0]

    for i, unit_list in enumerate(unit_lists):
        # store the chosen units
        data = {"units": unit_list}
        with open(filename + "_" + str(i) + ".json", "w", encoding="utf-8") as f:
            json.dump(data, f)

    print("Done with parallel processing.")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--model_name",
        choices=list(accuracies.keys()),
        required=True,
        help="Which model to use. Supported values are: "
        f"{', '.join(list(accuracies.keys()))}.",
    )

    parser.add_argument(
        "--units_file",
        type=str,
        required=True,
        help="Path to json-file with unit names as generated by sample_units.py.",
    )

    parser.add_argument("--num_gpus", type=int, default=1, help="How many GPUs to use")

    arguments = parser.parse_args()

    # read units from file
    arguments.units = read_units_file(arguments.units_file)

    split_units_list(arguments)
