using GerryChain
using TickTock
using Plots, Shapefile
using Printf
using JLD
using SparseArrays

### METADATA ###
SHAPEFILE_PATH = "Shapefiles/NC/NC_VTD.shp"
POPULATION_COL = "TOTPOP"
SEED_MAP       = "CD"
ELECTION       = "PR16"
BLUE_VOTES     = "EL16G_PR_D"
RED_VOTES      = "EL16G_PR_R"
CHAIN_LENGTH   = 10
NUM_MAPS       = 10
OUTPUT_FILENAME= "NC_ensemble"
EPSILON        = 0.02

# Initialize graph
graph = BaseGraph(SHAPEFILE_PATH, POPULATION_COL)

# Initialize an array of empty partitions; seed the first one with the real map to get the # of districts
partitions = Array{Union{Nothing, Partition}}(nothing, NUM_MAPS)
partitions[1] = Partition(graph, SEED_MAP)

# Initialize an array of statistical data; we need this for determining the colors of districts in each map
results = Array{Union{Nothing, ChainScoreData}}(nothing, NUM_MAPS)

# Run the chain
println("No. of  chain runs                  = ", NUM_MAPS)
println("Length of each chain                = ", CHAIN_LENGTH)
println("No. of  districts in each partition = ", partitions[1].num_dists)
println("Running the ReCom chains...")
tick()
for i = 1:NUM_MAPS

    # seed the chain with the real-world map
    partitions[i] = Partition(graph, SEED_MAP)

    # Define population constraint for this chain
    pop_constraint = PopulationConstraint(graph, partitions[i], EPSILON)

    # Initialize Election of interest
    election = Election(ELECTION, [BLUE_VOTES, RED_VOTES], partitions[i].num_dists)
    election_metrics = [
        vote_count("count_d", election, BLUE_VOTES)
        vote_count("count_r", election, RED_VOTES)
    ]
    scores = [
            ElectionTracker(election, election_metrics),
    ]

    # Run a chain for CHAIN_LENGTH steps to get the i-th map
    println("Generating map ", i, " ...")
    results[i] = recom_chain(graph, partitions[i], pop_constraint, CHAIN_LENGTH, scores)

end
println("Saving the ReCom chains to file...")
save(string(OUTPUT_FILENAME, ".jld"), "maps", partitions)

tock()
