using GerryChain
using TickTock
using Plots, Shapefile
using Printf
using JLD
using SparseArrays
using Random
using DataFrames
using CSV
using Graphs
using ArgParse

### METADATA ###
SHAPEFILE_PATH = "Shapefiles/NC/NC_VTD.shp"
POPULATION_COL = "TOTPOP"
NUM_MAPS       = 10
NUM_TREES      = 100 
ENSEMBLE_FILENAME   = "NC_ensemble" 
TREE_FILENAME  = "random_spanning_trees_NC" 
NUM_DISTRICTS  = 13                         
BLUE_VOTES     = "EL16G_PR_D"
RED_VOTES      = "EL16G_PR_R"
EPSILON        = 0.02
DG_ROBUST_THRESHOLD = 0.5                   
VERIFICATION_ON     = true                  

# Converting a given BitSet of tree edges to a directed graph (tree) rooted at 1
function convert_to_graph(edges::BitSet)::SimpleDiGraph
    cur_tree = SimpleGraph(graph.num_nodes) 
    for edge in edges 
        add_edge!(cur_tree, graph.edge_src[edge], graph.edge_dst[edge])
    end
    cur_tree = dfs_tree(cur_tree, 1) 
    return cur_tree
end

function audit_with_tree_ensemble(cur_map_index::Int, isblue::Bool)::Int
    total_num_dgs = 0
    # Loop thru each tree
    println("------------------------------------------")
    for cur_tree_index = START_TREE:END_TREE 
        println("Start auditing for ", (isblue ? "blue" : "red"), " DGs for map ", cur_map_index,
                " using tree ", cur_tree_index)
        # with verification, we will count # of dev groups found thruout LP; otherwise will find at most one
        cur_num_dgs = audit_with_tree(cur_map_index, cur_tree_index, isblue)
        if (VERIFICATION_ON)
            total_num_dgs += cur_num_dgs
        else
            if (cur_num_dgs > 0)
                return 1
            end
        end 
    end
    println("Found ", total_num_dgs, " ", (isblue ? "blue" : "red"), " DGs for map ", cur_map_index,
                " using trees ", START_TREE, " to ", END_TREE)
    return total_num_dgs
end

# A wrapper for the DP algorithm, where isblue = whether we are looking for blue deviating groups
function audit_with_tree(cur_map_index::Int, cur_tree_index::Int, isblue::Bool)::Int

    dg_count = 0
    compute_scaled_pop_at_random() 
    reset_dp_tables() 
    cur_tree = trees[cur_tree_index]

    for i = 1:graph.num_nodes
        cur_precinct_index = reversed_orders[cur_tree_index][i] # look at the i-th node w.r.t. the reversed traversal order
        for pop = 0:maxpop
            vertical_dp(cur_map_index, cur_tree, cur_precinct_index, pop, isblue) # evaluate dp equation for this node for all pop level
        end
        if (is_a_dev_group(cur_map_index, cur_tree_index, cur_precinct_index, isblue)) # there is a deviating group of correct size rooted at this precinct node
            dg_count += 1
            if (!VERIFICATION_ON) # if verification is turned off, then terminate this search 
                break 
            end 
        end
    end
    return dg_count 
end

function is_a_dev_group(
    cur_map_index::Int,
    cur_tree_index::Int,
    cur_precinct::Int,
    isblue::Bool)::Bool
    for pop = 0:maxpop
        if (global_dp_table[cur_precinct, pop+1] > DG_ROBUST_THRESHOLD * minpop * PRECISION) 
            this_dev_group = BitSet()
            if (VERIFICATION_ON) # backtrack and verify; if valid, return true
                backtrack(cur_precinct, pop, this_dev_group)
                if (verify(cur_map_index, cur_tree_index, isblue, this_dev_group)) # is really a deviating group
                    return true
                end
            else # directly return without verification
                return true
            end
        end
    end
    return false
end

function verify(
    cur_map_index::Int,
    cur_tree_index::Int,
    isblue::Bool,
    this_dev_group::BitSet)::Bool

    total_exact_population = 0
    total_exact_unhappy_pts = 0
    for precinct in this_dev_group
        total_exact_population += graph.attributes[precinct][POPULATION_COL]
        total_exact_unhappy_pts += (isblue ? unhappy_blues[cur_map_index][precinct] : unhappy_reds[cur_map_index][precinct])
    end
    if (total_exact_population > ((1 + EPSILON) * graph.total_pop / NUM_DISTRICTS))
        return false
    end
    if (total_exact_population < ((1 - EPSILON) * graph.total_pop / NUM_DISTRICTS))
        return false
    end
    if (total_exact_unhappy_pts > DG_ROBUST_THRESHOLD * total_exact_population)
        println("Found Deviating Group with total population ", total_exact_population, " and total unhappy points ", total_exact_unhappy_pts)
        println("This deviating group has the following precincts: ", this_dev_group)
        println("There are ", length(this_dev_group), " precincts in this deviating group;")
        println("Saving this deviating group to dataframe...")
            push!(dev_group_table[cur_map_index], (
                cur_tree_index, 
                this_dev_group, 
                (isblue ? "Blue" : "Red"), 
                total_exact_unhappy_pts, 
                total_exact_population, 
                (total_exact_unhappy_pts / total_exact_population) 
                ))
        return true
    end
    return false
end

        "Group" => BitSet[],
        "Type" => String[],
        "Unhappy_pop" => Int[],
        "Total_Pop" => Float64[],
        "Unhappy_pct" => Float16[]

function backtrack(
    cur_precinct,
    pop::Int,
    this_dev_group::BitSet)
    
    for (child, pop_child) in global_backtrack_table[cur_precinct, pop+1] 
        if (child == cur_precinct) 
            union!(this_dev_group, cur_precinct)
        else
            backtrack(child, pop_child, this_dev_group)
        end
    end
end

function vertical_dp(
    cur_map_index::Int,
    cur_tree::SimpleDiGraph,
    cur_precinct_index::Int,
    pop::Int,
    isblue::Bool)

    children = outneighbors(cur_tree, cur_precinct_index) # this is a vector of child nodes
    uhp_cur_precinct = isblue ? unhappy_blues[cur_map_index][cur_precinct_index] : unhappy_reds[cur_map_index][cur_precinct_index]

    # Base case (leaf)
    if (length(children) == 0)
        if (pop >= scaled_pop[cur_precinct_index]) # pop level is enough for including the whole leaf node (this includes the case where both are 0)
            global_dp_table[cur_precinct_index, pop+1] = uhp_cur_precinct
            push!(global_backtrack_table[cur_precinct_index, pop+1], (cur_precinct_index, scaled_pop[cur_precinct_index])) # self
        else 
            global_dp_table[cur_precinct_index, pop+1] = 0
            # we do not push anything Into the global backtrack table in this case
        end
    else # General case
        budget = pop - scaled_pop[cur_precinct_index] # Need to include the current precinct
        if (budget < 0) # the desired pop is less than that for this node, the only feasible subgraph is an empty subgraph
            global_dp_table[cur_precinct_index, pop+1] = 0 
            # we do not push anything Into the global backtrack table in this case
        else
            lb = (pop == 0) ? 0 : global_dp_table[cur_precinct_index, pop] # if pop is nonzero, lower bound is A[cur_precinct_index, pop-1]
            ub = uhp_cur_precinct
            for i = 1:length(children)
                ub += global_dp_table[children[i], budget+1] # we compute upper bound as if every child gets the entire population budget
            end
            if (lb >= ub) # lucky case: no need to do horizontal dp. In this case, the solution provided by lb is optimal
                global_dp_table[cur_precinct_index, pop+1] = lb
                if (pop > 0) # this means exactly same configuration for pop-1 is the answer here
                    global_backtrack_table[cur_precinct_index, pop+1] = deepcopy(global_backtrack_table[cur_precinct_index, pop])
                end
            else # do the horizontal dp; here budget can be 0
                (local_dp_value, local_dp_info) = horizontal_dp(cur_map_index, cur_precinct_index, children, budget)
                global_dp_table[cur_precinct_index, pop+1] = uhp_cur_precinct + local_dp_value
                # horizontal dp also updates the backtrack info
                global_backtrack_table[cur_precinct_index, pop+1] = push!(local_dp_info, (cur_precinct_index, scaled_pop[cur_precinct_index])) # add self Into backtracking set

            end
        end
    end
end

function horizontal_dp( 
    cur_map_index::Int,
    cur_precinct_index::Int,
    children::Vector{Int},
    budget::Int)::Tuple{Int, Set{Tuple{Int, Int}}} 

    nchildren = length(children)

    # initiates local dp tables for the horizontal dp
    local_dp_table = zeros(Int, nchildren, budget+1)
    local_set_table = Array{Union{Nothing, Set{Tuple{Int, Int}}}}(nothing, nchildren, budget+1) 
    for j = 1:nchildren
        for pop = 0:budget
            local_set_table[j,pop+1] = Set{Tuple{Int, Int}}()
        end
    end

    for pop = 0:budget
        local_dp_table[1,pop+1] = global_dp_table[children[1], pop+1] # B[1,pop] = A[w_1, pop]
        if (pop >= scaled_pop[children[1]])
            push!(local_set_table[1,pop+1], (children[1], pop))
        end 
    end
    
    for j = 2:nchildren
        for pop = 0:budget
            ub = local_dp_table[j-1,pop+1] + global_dp_table[children[j],pop+1] # upper bound: B[j,k] <= B[j-1,k] + A[w_j, k]
            lb = (pop == 0) ? 0 : local_dp_table[j, pop] # lower bound: B[j,k] >= B[j,k-1] for all k > 1, and B[j,0] >= 0

            bestx = 0

            for x = 0:pop # x is the population level being distributed to w_j
                if (lb >= ub) # anytime, if lower and upper bounds match, directly set this value as answer
                    local_dp_table[j,pop+1] = lb
                    break # terminate the for loop
                end 

                # otherwise, we update lb = B[j-1, pop-x] + A[w_j, x]
                if (local_dp_table[j-1, pop-x+1] + global_dp_table[children[j], x+1] > lb)
                    lb = local_dp_table[j-1, pop-x+1] + global_dp_table[children[j], x+1]
                    bestx = x
                end

                # we also update ub = B[j-1, pop-x] + A[w_j, pop] >= B[j-1, pop-y] + A[w_j, y] for all y > x
                ub = min(ub, local_dp_table[j-1, pop-x+1] + global_dp_table[children[j], pop+1])
            end

            local_dp_table[j, pop+1] = lb # lb and ub never match, then return the lb (best solution found)
            if (bestx >= scaled_pop[children[j]])
                push!(local_set_table[j,pop+1], (children[j], bestx))
            end
            union!(local_set_table[j, pop+1], local_set_table[j-1, pop-bestx+1])
        end
    end
    return (local_dp_table[nchildren, budget+1], deepcopy(local_set_table[nchildren, budget+1]))
end 

function roundandscale(x, threshold)::Int # rounds the population randomly to ceiling/floor preserving expectation
    pop = (x / PRECISION)
    return (pop - floor(pop) > threshold) ? ceil(pop) : floor(pop)
end

function compute_unhappy_pop(cur_map_index::Int)
    unhappy_blues[cur_map_index] = zeros(Int, graph.num_nodes)
    unhappy_reds[cur_map_index] = zeros(Int, graph.num_nodes)

    for district_index = 1:NUM_DISTRICTS
        current_district = partitions[cur_map_index].dist_nodes[district_index] # This is a set of precinct nodes

        # Count the votes, and decide whether blue or red wins for that district
        democrat_votes = 0
        republic_votes = 0
        for precinct in current_district
            democrat_votes += graph.attributes[precinct][BLUE_VOTES]
            republic_votes += graph.attributes[precinct][RED_VOTES]
        end
        for precinct in current_district            
            b = graph.attributes[precinct][BLUE_VOTES]
            r = graph.attributes[precinct][RED_VOTES]
            p = graph.attributes[precinct][POPULATION_COL]
            if (democrat_votes > republic_votes) # Blue wins
                unhappy_blues[cur_map_index][precinct] = 0
                unhappy_reds[cur_map_index][precinct] = floor((r * p) / (b + r))
            else # Red wins
                unhappy_blues[cur_map_index][precinct] = floor((b * p) / (b + r))
                unhappy_reds[cur_map_index][precinct] = 0
            end
        end
    end
end

function compute_scaled_pop() # always round up (not used now)
    for i = 1:graph.num_nodes
        scaled_pop[i] = roundandscale(graph.attributes[i][POPULATION_COL], 0.0)
    end
end

function compute_scaled_pop_at_random() # now hardcoded to always round down
    Random.seed!()
    randarray = rand(Float16, graph.num_nodes)
    for i = 1:graph.num_nodes
        # scaled_pop[i] = roundandscale(graph.attributes[i][POPULATION_COL], randarray[i]) # randomized rounding
        scaled_pop[i] = roundandscale(graph.attributes[i][POPULATION_COL], 1.0)
    end
end

function reset_dp_tables()
    global_dp_table = zeros(Int, graph.num_nodes, maxpop+1)

    # (re)-initialize the backtrack table
    for i = 1:graph.num_nodes
        for j = 1:maxpop+1
            global_backtrack_table[i,j] = Set{Tuple{Int, Int}}()
        end
    end
end 

function parse_commandline()
    s = ArgParseSettings()

    @add_arg_table s begin
        "--sm"
            help = "start map"
            arg_type = Int
            default = 1
        "--em"
            help = "end map"
            arg_type = Int
            default = 1
            required = true
        "--st"
            help = "start tree"
            arg_type = Int
            default = 1
        "--et"
            help = "end map"
            arg_type = Int
            default = 1
            required = true
        "--precision", "-p"
            help = "precision"
            arg_type = Int
            default = 1000
            required = true
        "--output", "-o"
            help = "output name file"
            arg_type = String
            default = "Output"
    end

    return parse_args(s)
end

parsed_args = parse_commandline()

# Parse Command Line Arguments
START_MAP             = parsed_args["sm"]
END_MAP               = parsed_args["em"]
START_TREE            = parsed_args["st"]
END_TREE              = parsed_args["et"]
PRECISION             = parsed_args["precision"]
OUTPUT_FILENAME       = parsed_args["output"]

# Initialize graph, scale the pop and store in the array
graph = BaseGraph(SHAPEFILE_PATH, POPULATION_COL)

# These are computed only once in compute_unhappy_pop()
# The first matricies hold the (prorated) unhappy pop of each color
# unhappy_blues[i][j] is the number of unhappy blue points in precinct j w.r.t. map i
# rounded_pop[j] is the total pop of map j rounded up to multiples of PRECISION
unhappy_blues = Array{Union{Nothing, Array{Int}}}(nothing, NUM_MAPS)
unhappy_reds = Array{Union{Nothing, Array{Int}}}(nothing, NUM_MAPS)
scaled_pop = zeros(Int, graph.num_nodes)

# We always round up/down when calculating max/min pop size of a district
maxpop = roundandscale((1 + EPSILON) * graph.total_pop / NUM_DISTRICTS, 0.0)
minpop = roundandscale((1 - EPSILON) * graph.total_pop / NUM_DISTRICTS, 1.0)

# Declare the tables
global_dp_table = zeros(Int, graph.num_nodes, maxpop+1)
# The global backtrack table is a 2D matrix with dimensions (i) i = node ID and (ii) j = pop level
# Each entry is a set of {(w_i, p_i)} where w_i's are childrens and p_i's are the associated pop level
global_backtrack_table = Array{Union{Nothing, Set{Tuple{Int, Int}}}}(nothing, graph.num_nodes, maxpop+1)
reset_dp_tables()

# Initialize the scaled population; will re-randomize before each auditing as well
compute_scaled_pop_at_random()

# Load the map ensemble from jld dataframe, 
# compute the (unscaled) unhappy pop for each color and store in the arrays
println("Loading the map ensemble from file...")
tick()
partitions = load(string(ENSEMBLE_FILENAME, ".jld"), "maps")
for i = 1:NUM_MAPS
    compute_unhappy_pop(i)
end
tock()
println("Successfully loaded the map ensemble from file.")

# Load the tree ensemble from jld dataframe; they are bitsets of edges of the graph
println("Loading the tree ensemble from file...")
tick()
sparsetrees = load(string(TREE_FILENAME, ".jld"), "trees")
tock()
println("Successfully loaded the tree ensemble from file.")

println("Parsing the tree ensemble...")
tick()
trees = Array{Union{Nothing, SimpleDiGraph}}(nothing, NUM_TREES)
reversed_orders = Array{Union{Nothing, Vector{Int}}}(nothing, NUM_TREES)
for i = 1:NUM_TREES
    trees[i] = convert_to_graph(sparsetrees[i])
    reversed_orders[i] = reverse(topological_sort_by_dfs(trees[i])) 
end
tock()
println("Successfully parsed the tree ensemble.")

# Create DataFrame for map-level statistics
map_data_table = DataFrame(
    "Map" => Int[],
    "Blue DG" => Int[],
    "Red DG" => Int[],
    )

# Create DataFrame to hold deviating groups
dev_group_table = Array{Union{Nothing, DataFrame}}(nothing, NUM_MAPS)
for i = START_MAP:END_MAP
    dev_group_table[i] = DataFrame(
        "Tree" => Int[],
        "Group" => BitSet[],
        "Type" => String[],
        "Unhappy_pop" => Int[],
        "Total_Pop" => Float64[],
        "Unhappy_pct" => Float16[]
        )
end

tick()
# Loop thru each map
for cur_map_index = START_MAP:END_MAP

    blue_dg = audit_with_tree_ensemble(cur_map_index, true)
    red_dg = audit_with_tree_ensemble(cur_map_index, false)

    push!(map_data_table,
            (
                cur_map_index,
                blue_dg,
                red_dg
            )
        )
end
tock()

show(map_data_table)
println("")
println("Starting writing dataframe to files...")
save(string(OUTPUT_FILENAME, "_stats.jld"), "data", map_data_table) 
for i = START_MAP:END_MAP
    save(string(OUTPUT_FILENAME, "_dgs_map_", i, ".jld"), "data", dev_group_table[i]) 
end
println("Finished writing dataframe to files.")

println("Starting outputing dataframes to CSV files...")
CSV.write(string(OUTPUT_FILENAME, "_stats.csv"), map_data_table)
for i = START_MAP:END_MAP
    CSV.write(string(OUTPUT_FILENAME, "_dgs_map_", i, ".csv"), dev_group_table[i]) 
end
println("Finished outputing dataframe to CSV files.")
