using ArgParse
using Arpack
using CSV
using DataFrames
using LinearAlgebra
using LightGraphs
using Logging
using Printf
using PyPlot
using Random
using Statistics

include("Utils.jl")
include("IOUtils.jl")

"""
    oracleEvecs(A, r)

Compute the trailing `n - r` eigenvectors and eigenvalues of the matrix A, and
compute the smallest constant `C` for which Assumption 2 in the main paper is
satisfied.
"""
function oracleEvecs(A, r, T)
    n, _ = size(A)
    evals, V, _ = eigs(A, nev=(r + 1), which=:LM, ritzvec=true)
    @info("Finished finding evecs")
    V = V[:, 1:r]
    rhsVal = inf1(V)
    λ₊ = evals[end]; evals = (evals[1:r] ./ λ₊)  # keep first r evals and scale
    Acumul = copy(A) / λ₊
    C = 0
    for t = 1:T
        lhsVal = infPerRow(Acumul, V, evals.^t)
        # lhsVal = opnorm(Acumul - V * Diagonal(evals.^t) * V', Inf)
        C = max(C, lhsVal / rhsVal)
        @info("Iteration $(t) - C = $(C)")
        Acumul = Acumul * (A / λ₊)
    end
    return C
end

# norm(1.0I - V * V', Inf)
function inf1(V)
    n, r = size(V)
    Cmax = 0.0
    @inbounds for i = 1:n
        colI = -V * V[i, :]; colI[i] += 1.0
        Cmax = max(Cmax, norm(colI, 1))
    end
    return Cmax
end

function infPerRow(A, V, Λ)
    bSize = 1000  # block size
    n, r = size(V)
    nblk = trunc(Int, ceil(n / bSize))
    Vmul = V .* Λ'
    Cmax = 0.0
    @inbounds for i = 0:(nblk-1)
        iStart = i * bSize + 1
        iStop  = min((i + 1) * bSize, n)
        Cmax = max(Cmax, opnorm(A[:, iStart:iStop] - Vmul * V[iStart:iStop, :]', 1))
    end
    return Cmax
end

# pick loading function from IOUtils
pickFun(data) = begin
    if (data == :DBLP)
        IOUtils.genDBLPGraph()
    elseif (data == :YOUTUBE)
        IOUtils.genYoutubeGraph()
    elseif (data == :LIVEJOURNAL)
        IOUtils.genLJGraph()
    elseif (data == :GEMSEC)
        IOUtils.genGemsecGraph()
    elseif (data == :HEP)
        IOUtils.genHepGraph()
    elseif (data == :ASTRO)
        IOUtils.genAstroGraph()
    else
        throw(ErrorException("Option $(data) not recognized!"))
    end
end

function main(dataset, r, T, τ, no_reg)
    g = pickFun(dataset); A, _ = IOUtils.adjMatrix(g)
    C = oracleEvecs(A, r, T)
    @info("[$(dataset)] - Maximum C found: $(@sprintf("%.2f", C))")
end


s = ArgParseSettings(
    description="Verify Assumption 2 on real-world datasets.")
datasets = ["dblp", "youtube", "livejournal", "gemsec", "hep", "astro"]
@add_arg_table s begin
    "--dataset"
        arg_type     = String
        range_tester = (x -> lowercase(x) in datasets)
        help         = "dataset; must be one of $(join(datasets, ", "))"
    "--r"
        arg_type     = Int
        help         = "The rank of the spectral embedding"
        default      = 10
    "--no_reg"
        help         = "Set to use just the plain adjacency matrix"
        action       = :store_true
    "--tau"
        help         = "The regularization parameter for the normalized adj. matrix"
        arg_type     = Float64
        default      = 5.0
    "--T"
        help         = "The highest power of diagonal to verify for"
        arg_type     = Int
        default      = 100
    "--seed"
        help         = "The random seed to set"
        arg_type     = Int
        default      = 999
end
parsed  = parse_args(s); Random.seed!(parsed["seed"])
data    = Symbol(uppercase(parsed["dataset"]))
τ, T, r = parsed["tau"], parsed["T"], parsed["r"]
no_reg  = parsed["no_reg"]
main(data, r, T, τ, no_reg)
