import numpy as np
import random as rd

def ME(N,On,alpha,eps,delta):
    t = 0
    el = eps/4
    dl = delta/2
    while np.sum(On)>1:
        k = int(np.ceil((2/(el*el))*np.log(3/dl)))
        S =  [On[i]*np.random.binomial(k,alpha[i]) for i in range(N)] 
        m = np.median([S[i] for i in range(N) if On[i]])
        On = [1.*(S[j]>=m)  for j in range(N)]
        el = 3*el/4
        dl = dl/2
        t = t + k*np.sum(On)
    return [np.argmax(On),t]
    



def NAIVE(N,On,alpha,eps,delta):
    k = int(np.ceil((2/(eps*eps))*np.log(sum(On)/delta)))
    #print('k=',k,sum(On))
    #rawS = np.random.rand(N,k)
    #S =  [np.mean([rawS[i,j]*On[i]*alpha[i]+1-alpha[i] for j in range(k)]) for i in range(N)]
    #S =  [On[i]*(np.mean(np.random.rand(k))*alpha[i]+1-alpha[i]) for i in range(N)]
    
    S =  [np.random.binomial(k,alpha[i]) if On[i] else 0 for i in range(N)]      
    #print("!")
    #saple one of the maxes
    best = np.argwhere(S==np.amax(S))#[i  for i in range(N) if S[i]==np.max(S)]
    #print(best)
    return [rd.sample(list(best),1)[0][0],k*sum(On)]


def AGG(N,On,alpha,eps,delta):
    Nsamples=0
    phi = np.sqrt(6*np.log(N)/np.power(N,0.75))
    t = int(np.ceil(np.log(N)/(4*np.log(1/(delta+phi)))))
    #print('t=',t,np.log(N)/(4*np.log(1/(delta+phi))))
    for i in range(t):
        l = int(np.round((i+1)*(np.log(1/delta)*2/(eps*eps))))
        S =  [np.random.binomial(l,alpha[j]) if On[j] else 0  for j in range(N)] 
        #print(S)
        idx = np.array(S).argsort()[-int(np.ceil(sum(On)*(delta+phi))):][::-1]
        On = [1 if (j in idx) else 0 for j in range(N)]
        Nsamples = Nsamples + l*sum(On)
        #print(Nsamples)
    return [On,Nsamples] 

def SABA(N,alpha,eps,delta):
    On = [1]*N
    [On,NsamplesA] = AGG(N,On,alpha,eps,delta/2)
    [idx,NsamplesB] = NAIVE(N,On,alpha,eps,delta/2)
    return [idx,NsamplesA+NsamplesB]


def ABA(N,alpha,eps,delta):
    On = [1]*N
    if N<max(10000,1/delta**4):
        return NAIVE(N,On,alpha,eps,delta)
    R = rd.sample(range(N),int(np.ceil(np.power(N,0.875)/2)))
    [On,NsamplesA] = AGG(N,On,alpha,(1-1/np.e)*eps,delta/2)
    
    for r in R:
        On[r] = 1
    
    [idx,NsamplesB] = NAIVE(N,On,alpha,eps/np.e,delta/np.e)
    return [idx,NsamplesA+NsamplesB]
   
def ABALE(N,alpha,eps,delta,lamb):
    
    
    a = np.sqrt(1-lamb/8)
    R = rd.sample(range(N),int(np.ceil(np.power(N,0.75))))

    
    l = int(np.ceil((lamb/2+1)*(np.log(1/delta)*1/(2*eps*eps))))
    S =  [np.random.binomial(l,alpha[j]) for j in range(N)] 

    On = [0]*N
    idx = np.array(S).argsort()[-int(np.ceil(N*lamb/50)):][::-1]
    On = [1 if (j in idx) else 0 for j in range(N)]    
    [On,NsamplesA] = AGG(N,On,alpha,a*eps,delta/4)
    for r in R:
        On[r] = 1
 
    [idx,NsamplesB] = NAIVE(N,On,alpha,(1-a)*eps,delta/4)
    return [idx,l*N+NsamplesA+NsamplesB]
    
'''
def di(x,y):
    return (x-y)**2


def OptimalWeights(mu, delta=1e-11):
    
  # returns T*(mu) and w*(mu)
  K=len(mu)
  IndMax=[i for i in range(K) if mu[i]==np.max(mu)]
  L=len(IndMax)
  if (L>1):
     # multiple optimal arms
     vOpt=[1/L if mu[i]==np.max(mu) else 0 for i in range(K) ]
     return 0,vOpt
  else:
      mu2 = mu.copy()
      mu2.sort()
      mu2 = mu[::-1]

     unsorted=vec(collect(1:K))
     invindex=zeros(Int,K)
     invindex[index]=unsorted 
     # one-step optim
     vOpt,NuOpt=oneStepOpt(mu,delta)
     # back to good ordering
     nuOpt=NuOpt[invindex]
     NuOpt=zeros(1,K)
     NuOpt[1,:]=nuOpt
     return vOpt,NuOpt
     
     
def eT(N,alpha,eps,delta):
    condition = True
    
    times = [1]*N
    S =  [np.random.binomial(1,alpha[i]) for i in range(N)]
    t=N
    best=1
    while (condition):
        Mu=[S[i]/times[i] for i in range(N)]
        # Empirical best arm
        maxlist = [i  for i in range(N) if S[i]==np.max(S)]
        best = rd.sample(maxlist,1)[0]
        I=1
        
        if (len(maxlist)>1):
            # if multiple maxima, draw one them at random
            I = rd.sample(maxlist,1)[0]
        else:
            best = maxlist[0]
            # compute the stopping statistic
            NB=times[best]
            SB=S[best]
            muB=SB/NB
            MuMid=[(SB+S[i])/(NB+times[i]) for i in range(N)]
            
            Score=np.min([NB*di(muB,MuMid[i])+N[i]*di(Mu[i],MuMid[i]) for i in range(N) if i not in maxlist])
            
            if (Score > rate(t,0,delta)):
                # stop 
                condition=False
            elif (t >10000000):
                # stop and outputs (0,0) 
                condition=False
                best=N-1
                print('f')
            else:
                if (np.min(times) <= np.max(np.sqrt(t) - N/2,0)):
                    # forced exploration
                    I=np.argmin(times)
                else:
                    # continue and sample an arm
                    val,Dist=OptimalWeights(Mu,1e-11)
                    # choice of the arm
                    I=np.argmin([Dist-times[i]/t for i in range(N)])
        # draw the arm 
        t+=1
        S[I]+=np.random.binomial(1,alpha[I])
        times[I]+=1
    return [best,t]
    
'''    
   
  

    
def simul(N,eps,delta):
    print('N:',N,'eps:',eps,'delta:',delta)
    alpha = [1-(1+i)/N for i in range(N)]
    print('naive: ',NAIVE(N,[1]*N,alpha,eps,delta))
    print('saba: ',SABA(N,alpha,eps,delta))
    print('aba: ',ABA(N,alpha,eps,delta))
    print('abale: ',ABALE(N,alpha,eps,delta))


             
np.random.seed(1589)
eps = 0.2
N= 300000
delta = 0.05

alpha = [0.5]*N
alpha[0] = 0.5+eps+1e-13
#alpha[1] = alpha[0]
#alpha[2] = alpha[0]
#alpha = [1-(1+i)/N for i in range(N)]




#alpha = np.random.rand(N)
#alpha.sort()
#alpha = alpha[::-1]
#alpha = [0.6,0.2,0.3,0.45,0.55,0.6]
best = np.max(alpha)

lamb = 7.48877 #for delta = 1/20
#lamb = 8.31448 #for delta = 1/10 which is bad...
#lamb = 5.85489 #for delta = 0.005   x/100 = (1/delta)^(-x^2/64)
    
TsampN = 0
failN = 0
bestN = 0
approxN = 0


TsampS = 0
failS = 0
bestS = 0
approxS = 0


TsampA = 0
failA = 0
bestA = 0
approxA = 0


TsampB = 0
failB = 0
bestB = 0
approxB = 0

TsampM = 0
failM = 0
bestM = 0
approxM = 0

rounds = 500
print('*******************************')  
print('N,eps,del,rounds',N,eps,delta,rounds)  
print('*******************************')  
for i in range(rounds):
    [rN,sampN] = NAIVE(N,[1]*N,alpha,eps,delta)
    #print(rN,sampN)

    [rS,sampS] = SABA(N,alpha,eps,delta)
    #print(rN,sampS)
    [rA,sampA] = ABA(N,alpha,eps,delta)
    #print(rN,sampA)
    [rB,sampB] = ABALE(N,alpha,eps,delta,lamb)
    #print(rN,sampB)
    [rM,sampM] = ME(N,[1]*N,alpha,eps,delta)
    #print(rN,sampM)
    
    TsampN = TsampN+sampN
    failN = failN + 1.*(best-alpha[rN]>=eps)
    bestN = bestN + 1.*(best-alpha[rN]==0)    
    approxN = approxN + 1.*(best-alpha[rN]<eps)
    
    TsampM = TsampM+sampM
    failM = failM + 1.*(best-alpha[rM]>=eps)
    bestM = bestM + 1.*(best-alpha[rM]==0)    
    approxM = approxM + 1.*(best-alpha[rM]<eps)
    
    
    TsampS = TsampS+sampS
    failS = failS + 1.*(best-alpha[rS]>=eps)
    bestS = bestS + 1.*(best-alpha[rS]==0)    
    approxS = approxS + 1.*(best-alpha[rS]<eps)
    
    TsampA = TsampA+sampA
    failA = failA + 1.*(best-alpha[rA]>=eps)
    bestA = bestA + 1.*(best-alpha[rA]==0)   
    approxA = approxA + 1.*(best-alpha[rA]<eps)
    
    TsampB = TsampB+sampB
    failB = failB + 1.*(best-alpha[rB]>=eps)
    bestB = bestB + 1.*(best-alpha[rB]==0)    
    approxB = approxB + 1.*(best-alpha[rB]<eps)
    
    
    if(np.mod(i,rounds/20)==0):
        print(i,' *******************************')    
        print('naive: number of queries',TsampN/(i+1),' found best:',bestN,' found approx:',approxN,' fail:',failN)    
        print('MEDIAN: number of queries',TsampM/(i+1),' found best:',bestM,' found approx:',approxM,' fail:',failM)    
        print('saba: number of queries',TsampS/(i+1),' found best:',bestS,' found approx:',approxS,' fail:',failS)    
        print('aba: number of queries',TsampA/(i+1),' found best:',bestA,' found approx:',approxA,' fail:',failA)    
        print('abale: number of queries',TsampB/(i+1),' found best:',bestB,' found approx:',approxB,' fail:',failB)    
    
print('*******************************')  
print('N,eps,del',N,eps,delta)  
print('*******************************')  
  
print('median elimination: average number of queries per exp.:',TsampM/rounds,' found best:',bestM,' found approx:',approxM,' fail:',failM)    
print('NAIVE: average number of queries per exp.:',TsampN/rounds,' found best:',bestN,' found approx:',approxN,' fail:',failN)    
print('SABA: average number of queries per exp.:',TsampS/rounds,' found best:',bestS,' found approx:',approxS,' fail:',failS)    
print('ABA: average number of queries per exp.:',TsampA/rounds,' found best:',bestA,' found approx:',approxA,' fail:',failA)    
print('ABALE elimination: average number of queries per exp.:',TsampB/rounds,' found best:',bestB,' found approx:',approxB,' fail:',failB)    
    
    
    
    