import numpy as np
import sys
import scipy.stats as stats
import matplotlib.pyplot as plt
import math
import string
import random
from scipy.linalg import eig 
import pickle
from datetime import datetime
sd = (int)(sys.argv[0][-4])
np.random.seed(sd)
now = datetime.now()
tm = now.strftime("%H_%M_%S_%d_%m_%y")
fname_res = './dump/res_'+tm+"_"+str(sd)+'.p'
fname_res_kt = './dump/res_kt_'+tm+"_"+str(sd)+'.p'
print(fname_res)
print(fname_res_kt)

class RC_Byz_flips_deletions:
    
    def __init__(self, n, k, bd, stra):
        
        self.k = k
        self.n = n
        self.bd = bd
        self.stra = stra
        self.get_pairs(20)
        self.get_weights()
        
    def get_pairs(self, factor):
        
        lmt = (factor * math.log(self.n)) / self.n
        ls = []
        for i in range(self.n):
            for j in range(i+1,self.n):
                rn = random.random()
                if(rn < lmt):
                    ls.append([i,j])
        
        self.pairs = np.array(ls)
    
    def get_weights(self):
        
        x = np.random.uniform(low = 1.0, high = 100.0, size = self.n)
        x /= np.sum(x)
        self.weights = x
    
    def getA(self):
        
        A = np.zeros((self.n,self.n))
        B = np.zeros((self.n,self.n))
        deg = np.zeros(self.n)
        it = 0
        
        adj = []
        for i in range(self.n):
            adj.append([])
        
        for x in self.pairs:
            deg[x[0]] += 1
            deg[x[1]] += 1
            
            adj[x[0]].append(x[1])
            adj[x[1]].append(x[0])
        
        
        thr = 12
        
        for i in range(self.n):
            if(i%30 == 0):
                print("i = " + str(i))
            
            
            num_buckets = len(adj[i]) // thr
            
            bucket_entries = []
            for j in range(num_buckets):
                bucket_entries.append([])
            
            it = 0
            for j in adj[i]:    
                bucket_entries[it].append(j)
                it += 1
                it %= num_buckets
                
            for j in range(num_buckets):
                bad_voter_count = np.sum(np.random.binomial(1, self.bd, self.k))
                good_voter_count = self.k - bad_voter_count
                itms = bucket_entries[j]
                voter_entries = np.zeros((len(itms), self.k))
                it = 0
                
                for cntd in itms:
                    x = [i, cntd]
                    p = self.weights[x[1]] / (self.weights[x[0]] + self.weights[x[1]])
                    siml = np.random.binomial(1, p, good_voter_count)
                    
                    p1= self.stra(x[0], x[1], self.weights[x[0]], self.weights[x[1]])
                    rns = np.random.binomial(1, p1, bad_voter_count)
                    siml = np.concatenate((siml, rns), axis = 0)
                    voter_entries[it, :] = siml
                    it += 1
                
                count_out = np.zeros(self.k)
                all_ars = np.zeros((2**len(itms), self.k))
                thr_out = math.sqrt(len(itms))
                binrep = [0]*len(itms)
                
                exc = set()
                div = 20
                thr2 = self.k / div
                ana = []
                
                for ite in range(2**len(itms)):
                    br = np.binary_repr(ite, width = len(itms))
                    binrep = np.array(list(br), dtype = int)
                    U = (np.matmul(np.transpose(voter_entries), binrep))
                    Umed = np.median(U)
                    all_ars[ite] = np.where((np.logical_or(U > Umed + thr_out, U < Umed - thr_out)),1,0)                        
                    sm11 = np.sum(all_ars[ite])
                    for jte in exc:
                        sm11 -= all_ars[ite][jte]
                    if(sm11 > thr2):
                        for jte in range(self.k):
                            if(all_ars[ite][jte] == 1):
                                exc.add(jte)
                ite = 0
                for cntd in itms:
                    smn = 0
                    smd = 0
                    for jj in range(self.k):
                        if(jj not in exc):
                            smn += voter_entries[ite][jj]
                            smd += 1        
                    A[i, cntd] = smn/smd
                    ite += 1
        
        self.A = A
        self.deg = deg
        
    def getP(self):
        
        self.getA()
        mxdeg = np.max(self.deg)
        P = self.A / mxdeg
        Ps = np.sum(P, axis = 1)
        for i in range(self.n):
            P[i,i] = 1 - Ps[i]
        self.P = P
        

    def get_stationary(self, Tr):
        
        evals, evecs = np.linalg.eig(Tr.T)
        evec1 = evecs[:,np.isclose(evals, 1)]
        evec1 = evec1[:,0]
        stationary = evec1 / evec1.sum()
        stationary = stationary.real
        return stationary
        

    def norm(self, x):
    
        y = x*x
        z = np.sum(y)
        return z
    
    def get_error(self):
        
        df = self.predicted_wts - self.weights
        return self.norm(df) / self.norm(self.weights)
        
    def get_predicted_ws(self):
        
        self.getP()
        prws = self.get_stationary(self.P)
        self.predicted_wts = prws
        er = self.get_error()
        kt, _ = stats.kendalltau(self.predicted_wts, self.weights)
        return prws, er, kt

def opposite_vote(i, j, wi, wj):
    if(wi >  wj):
        return 1
    else:
        return 0

def opposite_vote_probabilistic(i, j, wi, wj):
    return wi / (wi + wj)

def fixed_order_vote(i, j, wi, wj):
    if(i > j):
        return 1
    else:
        return 0
    
def random_subset(i,j,wi,wj):
    rn = random.random()
    if(rn > 0.5):
        return opposite_vote(i,j,wi,wj)
    else:
        return wj / (wi + wj)


res = []
res_kt = []
all_items = [50+40*x for x in range(6)]
bf = [0.1, 0.2]
for byz_fr in bf:
    print("Byzantine Fraction = "+str(byz_fr))
    res1 = []
    res2 = []
    for i in range(len(all_items)):
        print("n = "+str(all_items[i]))
        rc = RC_Byz_flips_deletions(all_items[i], all_items[i], byz_fr, fixed_order_vote)
        wts, err, kt = rc.get_predicted_ws()
        print(byz_fr, math.sqrt(err), kt)
        res1.append(math.sqrt(err))
        res2.append(kt)
    res.append(res1)
    res_kt.append(res2)

with open(fname_res, 'wb') as fh:
    pickle.dump(res, fh)
with open(fname_res_kt, 'wb') as fh:
    pickle.dump(res_kt, fh)