import rotor
from rotor.algorithms import persistent, constrained, autocapper, opportunist, pofo, utils, offload
from all_chains import chains
import time
import sys

gbs = 1024*1024*1.024
BW_values = [12 * gbs, 16 * gbs, 24 * gbs, 36*gbs]

class ChainDescr(rotor.algorithms.Chain):
    def __init__(self, data):
        try:
            fw, bw, cw, cbw, ftmp, btmp = map(list, zip(*data))
            if fw[-1] is None:
                fw = fw[:-1]
            if ftmp[-1] is None: 
                ftmp = ftmp[:-1]
            super().__init__(fw, bw, cw, cbw, ftmp, btmp)
        except (TypeError, ValueError) as e:
            raise AttributeError("Error when creating ChainDescr") from e


normal_algorithms = [
                     ("persistent", persistent, {})
                   ]
offload_algorithms = [ ("autocapper", autocapper, {}),
                       ("opportunist", opportunist, {}),
                       ("pofo", pofo, {"nb_slots":50}),
                       ("offload", offload, {"nb_slots":50}),
                   ]


def one_alg(alg, chain, target_memory, discretize, **kwargs):
    if discretize:
        mem_unit = target_memory // 500
        target_memory = 500
        chain = chain.discretize(mem_unit)
        if "BW" in kwargs:
            kwargs["BW"] = kwargs["BW"] / mem_unit
    start = time.perf_counter()
    try:
        seq = alg(chain, target_memory, **kwargs)
    except:
        seq = None
    end = time.perf_counter()
    return (seq, 1000*(end-start))

def mem_usage(seq, chain):
    return utils.simulate_sequence(seq, None, chain, display=False)

memory_ratios = [0.05, 0.1, 0.2, 0.3, 0.5, 0.75]

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser("Simulation code")
    parser.add_argument("--network", "-n", nargs='*', help="Only consider networks that contain one of these strings", default=None)
    parser.add_argument("--batch", "-b", nargs='*', type=int, help="Only consider networks with one of these batch sizes", default=[8])
    parser.add_argument("--depth", "-d", nargs='*', type=int, help="Only consider one of these depths", default=None)
    parser.add_argument("--size", "-s", nargs='*', type=int, help="Only consider networks with one of these sizes", default=None)
    args = parser.parse_args()

    def valid(param):
        net, depth, size, batch = param.split(":")
        depth, size, batch = int(depth), int(size), int(batch)
        if args.batch and batch not in args.batch: return False
        if args.network  and all(n not in net for n in args.network):
            return False
        if depth > 200: return False
        if args.depth and depth not in args.depth:
            return False
        if args.size and size not in args.size:
            return False
        return True

    for (param, chain_descr) in chains.items():
        if not valid(param):
            continue
        chain = ChainDescr(chain_descr)
        sim = utils.SimulateOffload(chain)
        sequential = utils.no_checkpoint(chain.length)
        max_memory = mem_usage(sequential, chain)
        seq_time = sequential.get_makespan(chain)
        print(param, 1.0, max_memory, "NA", "sequential", seq_time, max_memory, seq_time, 0.0, 0.0, flush=True)
        min_memory = mem_usage(utils.recompute_all(chain.length), chain)

        def run_normals(ratio, target_memory):
            for (name, alg, kwargs) in normal_algorithms:
                (seq, compute_time) = one_alg(alg, chain, target_memory, True, **kwargs)
                if seq:
                    duration = seq.get_makespan(chain)
                    mem_used = mem_usage(seq, chain)
                else:
                    duration = "NA"
                    mem_used = "NA"
                print(param, ratio, target_memory, "NA", name, duration, mem_used, duration, 0.0, compute_time, flush=True)

        def run_offloads(ratio, target_memory, BW):
            for (name, alg, kwargs) in offload_algorithms:
                (seq, compute_time) = one_alg(alg, chain, target_memory, alg == opportunist, BW=BW, **kwargs)
                if seq and mem_usage(seq, chain) <= target_memory:
                    seq = sim.synchronize(seq, target_memory)
                    (duration, mem_used) = sim.simulate(seq, BW)
                    compute_cost = seq.get_makespan(chain)
                    transfer_cost = seq.get_offloaded_data(chain) / BW
                else:
                    if seq: print("Warning:", param, target_memory, BW, name, "got usage",
                                  mem_usage(seq, chain), file=sys.stderr)
                    (duration, mem_used) = ("NA", "NA")
                    compute_cost, transfer_cost = ("NA", "NA")
                print(param, ratio, target_memory, BW, name, duration, mem_used, compute_cost, transfer_cost, compute_time, flush=True)

        BW = BW_values[0]
        target_values = { r: int(min_memory + r*(max_memory - min_memory)) for r in memory_ratios }
        for ratio, target_memory in target_values.items():
            run_normals(ratio, target_memory)
            run_offloads(ratio, target_memory, BW)


        ratio = 0.2
        target_memory = target_values[ratio]
        for BW in BW_values:
            run_offloads(ratio, target_memory, BW)

