import numpy as np
from ctypes import *
from numpy.ctypeslib import ndpointer
import deepzono_nodes as dn
from elina_scalar import *
from elina_dimension import *
from elina_linexpr0 import *
from elina_abstract0 import *
from fppoly import *
from config import config

_doublepp = ndpointer(dtype=np.uintp, ndim=1, flags='C')

krelu_api = CDLL("../libkrelu.so") 


class DoubleMatrix(Structure):
    """ DoubleMatrix ctype compatible with double_matrix_t from double_matrix.h """

    _fields_ = [('p', _doublepp), ('nbrows', c_size_t),('nbcolumns', c_size_t),('_maxrows', c_size_t), ('nbeq',c_size_t)]


relu_with_2_var_c = krelu_api.relu_with_2_var
relu_with_2_var_c.argtypes = [ndpointer(c_double), ndpointer(c_double), ndpointer(c_double), ndpointer(c_double), ndpointer(c_double)] 
relu_with_2_var_c.restype = POINTER(DoubleMatrix)

relu_with_3_var_c = krelu_api.relu_with_3_var
relu_with_3_var_c.argtypes = [ndpointer(c_double), ndpointer(c_double), ndpointer(c_double)] 
relu_with_3_var_c.restype = POINTER(DoubleMatrix)

get_num_rows_c = krelu_api.get_num_rows
get_num_rows_c.argtypes = [POINTER(DoubleMatrix)] 
get_num_rows_c.restype = c_size_t



get_matrix_element_c = krelu_api.get_matrix_element
get_matrix_element_c.argtypes = [POINTER(DoubleMatrix), c_size_t, c_size_t] 
get_matrix_element_c.restype = c_double


double_matrix_free_c = krelu_api.double_matrix_free
double_matrix_free_c.argtypes = [POINTER(DoubleMatrix)] 
double_matrix_free_c.restype = None


def generate_tvpi_linexpr0(x,y,coeff_x, coeff_y):  
        linexpr0 = elina_linexpr0_alloc(ElinaLinexprDiscr.ELINA_LINEXPR_SPARSE, 2)
        cst = pointer(linexpr0.contents.cst)
        elina_scalar_set_double(cst.contents.val.scalar, 0)

        linterm = pointer(linexpr0.contents.p.linterm[0])
        linterm.contents.dim = ElinaDim(x)
        coeff = pointer(linterm.contents.coeff)
        elina_scalar_set_double(coeff.contents.val.scalar, coeff_x)
        	
        	
        linterm = pointer(linexpr0.contents.p.linterm[1])
        linterm.contents.dim = ElinaDim(y)
        coeff = pointer(linterm.contents.coeff)
        elina_scalar_set_double(coeff.contents.val.scalar, coeff_y)
        return linexpr0


def generate_3vpi_linexpr0(x,y,z,coeff_x, coeff_y,coeff_z):  
        linexpr0 = elina_linexpr0_alloc(ElinaLinexprDiscr.ELINA_LINEXPR_SPARSE, 3)
        cst = pointer(linexpr0.contents.cst)
        elina_scalar_set_double(cst.contents.val.scalar, 0)

        linterm = pointer(linexpr0.contents.p.linterm[0])
        linterm.contents.dim = ElinaDim(x)
        coeff = pointer(linterm.contents.coeff)
        elina_scalar_set_double(coeff.contents.val.scalar, coeff_x)
        	
        	
        linterm = pointer(linexpr0.contents.p.linterm[1])
        linterm.contents.dim = ElinaDim(y)
        coeff = pointer(linterm.contents.coeff)
        elina_scalar_set_double(coeff.contents.val.scalar, coeff_y)


        linterm = pointer(linexpr0.contents.p.linterm[2])
        linterm.contents.dim = ElinaDim(z)
        coeff = pointer(linterm.contents.coeff)
        elina_scalar_set_double(coeff.contents.val.scalar, coeff_z)

        return linexpr0


class Relu2Vars:
    def __init__(self, x, y, coeff_x, coeff_y, lb, ub, xy):
         self.x = x
         self.y = y
         self.lb = lb
         self.ub = ub
         mat = relu_with_2_var_c(coeff_x, coeff_y, lb,ub,xy)
         nbrows = get_num_rows_c(mat)
         #print("nbrows ",nbrows)
         self.cons = np.zeros([nbrows,6],dtype=np.double)
         for i in range(nbrows):
             for j in range(6):
                 self.cons[i][j] = get_matrix_element_c(mat,i,j)
         double_matrix_free_c(mat)     
         #print(self.cons)         


class Relu3Vars:
    def __init__(self, man, element, tdim, layerno, offset, x, y,z, lb, ub, domain):
         self.x = x
         self.y = y
         self.z = z
         self.lb = lb
         self.ub = ub
         coeff1 = np.ones(4,dtype=np.double)
         coeff2 = np.ones(4,dtype=np.double)
         lb2 = np.zeros(2,dtype=np.double)
         ub2 = np.zeros(2,dtype=np.double) 
         xy = np.zeros(4,dtype=np.double)
         lb2[0] = lb[0]
         lb2[1] = lb[1]
         ub2[0] = ub[0]
         ub2[1] = ub[1]
         compute_2relu_input(man, element, tdim, layerno, offset, x, y, coeff1, coeff2, xy, domain)
         mat1 = relu_with_2_var_c(coeff1,coeff2,lb2,ub2,xy)
         nbrows1 = get_num_rows_c(mat1)
         #self.cons = np.zeros([nbrows,8],dtype=np.double)
         count = 0
         #for i in range(nbrows):
         #    for j in range(8):
         #        self.cons[count][j] = get_matrix_element_c(mat1,i,j)
         #    count = count + 1    
         #double_matrix_free_c(mat1)     

         lb2[1] = lb[2]
         ub2[1] = ub[2]
         compute_2relu_input(man, element, tdim, layerno, offset, x, z, coeff1, coeff2, xy, domain)
         mat2 = relu_with_2_var_c(coeff1,coeff2,lb2,ub2,xy)
         nbrows2 = get_num_rows_c(mat2)
         #self.cons = np.zeros([nbrows,8],dtype=np.double)
         #for i in range(nbrows):
         #    for j in range(8):
         #        self.cons[count][j] = get_matrix_element_c(mat2,i,j)
         #    count = count + 1    
         #double_matrix_free_c(mat2)  

         lb2[0] = lb[1]
         ub2[0] = ub[1]
         compute_2relu_input(man, element, tdim, layerno, offset, y, z, coeff1, coeff2, xy, domain)
         mat3 = relu_with_2_var_c(coeff1,coeff2,lb2,ub2,xy)
         nbrows3 = get_num_rows_c(mat3)
         print("nbrows ",nbrows1,nbrows2,nbrows3)
         self.cons = np.zeros([nbrows1+nbrows2+nbrows3,8],dtype=np.double)
         for i in range(nbrows1):
             #for j in range(8):
             self.cons[count][0] = get_matrix_element_c(mat1,i,0)
             self.cons[count][1] = get_matrix_element_c(mat1,i,1)
             self.cons[count][2] = get_matrix_element_c(mat1,i,2)
             self.cons[count][3] = get_matrix_element_c(mat1,i,3)
             self.cons[count][5] = get_matrix_element_c(mat1,i,4)
             self.cons[count][6] = get_matrix_element_c(mat1,i,5)
             count = count + 1
         for i in range(nbrows2):
             #for j in range(8):
             self.cons[count][0] = get_matrix_element_c(mat2,i,0)
             self.cons[count][1] = get_matrix_element_c(mat2,i,1)
             self.cons[count][2] = get_matrix_element_c(mat2,i,2)
             self.cons[count][4] = get_matrix_element_c(mat2,i,3)
             self.cons[count][5] = get_matrix_element_c(mat2,i,4)
             self.cons[count][7] = get_matrix_element_c(mat2,i,5)
             count = count + 1
         for i in range(nbrows3):
             #for j in range(8):
             self.cons[count][0] = get_matrix_element_c(mat3,i,0)
             self.cons[count][1] = get_matrix_element_c(mat3,i,1)
             self.cons[count][3] = get_matrix_element_c(mat3,i,2)
             self.cons[count][4] = get_matrix_element_c(mat3,i,3)
             self.cons[count][6] = get_matrix_element_c(mat3,i,4)
             self.cons[count][7] = get_matrix_element_c(mat3,i,5)

             count = count + 1
         double_matrix_free_c(mat1)    
         double_matrix_free_c(mat2)    
         double_matrix_free_c(mat3)                   

def compute_coeff(nn,layerno,x,y):    
    if 0:#(layerno+1) < nn.numlayer:
        weights = nn.weights[layerno+1]
        coeff_x = np.full(4,-np.inf,dtype=np.double)
        coeff_y = np.full(4,-np.inf,dtype=np.double)
        pp_count = 0
        pn_count = 0
        np_count = 0
        nn_count = 0
        for j in range(weights.shape[0]):
            if(weights[j][x]>=0):
                if(weights[j][y]>=0):
                    coeff_x[0]= max(coeff_x[0],weights[j][x])
                    coeff_y[0]= max(coeff_y[0],weights[j][y])
                    pp_count+=1
                else:
                    coeff_x[1] = max(coeff_x[1],weights[j][x])
                    coeff_y[1] = max(coeff_y[1],np.abs(weights[j][y])) 
                    pn_count+=1
            else:
                if(weights[j][y]<0):
                    coeff_x[2]= max(coeff_x[2],np.abs(weights[j][x]))
                    coeff_y[2]= max(coeff_y[2],np.abs(weights[j][y]))
                    nn_count+=1
                else:
                    coeff_x[3]= max(coeff_x[3],np.abs(weights[j][x]))
                    coeff_y[3]= max(coeff_y[3],weights[j][y])
                    np_count+=1
            #if(coeff_x[0]>coeff_y[0]):
            #    coeff_x[0] = coeff_x[0]/coeff_y[0]
            #    coeff_y[0] = 1
            #else:
            #    coeff_y[0] = coeff_y[0]/coeff_x[0]
            #    coeff_x[0] = 1
            coeff_x[0] = 1#coeff_x[0]/coeff_y[0]
            coeff_y[0] = 1

            coeff_x[1] = 1#coeff_x[1]/coeff_y[1]
            coeff_y[1] = 1

            coeff_x[2] = 1#coeff_x[2]/coeff_y[2]
            coeff_y[2] = 1

            coeff_x[3] = 1#coeff_x[3]/coeff_y[3]
            coeff_y[3] = 1
            #coeff_x[0] = coeff_x[0]/pp_count
            #coeff_y[0] = coeff_y[0]/pp_count

            #coeff_x[1] = coeff_x[1]/pn_count
            #coeff_y[1] = coeff_y[1]/pn_count            
           
            #coeff_x[2] = coeff_x[2]/nn_count
            #coeff_y[2] = coeff_y[2]/nn_count 
        
            #coeff_x[3] = coeff_x[3]/np_count
            #coeff_y[3] = coeff_y[3]/np_count     
            
        return coeff_x,coeff_y
    else:
        coeff_x = np.ones(4,dtype=np.double)
        coeff_y = np.ones(4,dtype=np.double)
        return coeff_x,coeff_y    

def compute_2relu_input(man, element, tdim, layerno, offset, x, y,coeff_x, coeff_y, xy, domain):
    linexpr0 = generate_tvpi_linexpr0(offset+x,offset+y,coeff_x[0], coeff_y[0])
    if(domain=='refinezono'):
        element = elina_abstract0_assign_linexpr_array(man,True,element,tdim,linexpr0,1,None)
        bound_linexpr = elina_abstract0_bound_dimension(man,element,offset+length)
    else:
        bound_linexpr = get_bounds_for_linexpr0(man,element,linexpr0,layerno)
    xy[0] = bound_linexpr.contents.sup.contents.val.dbl
         
            

    linexpr0 = generate_tvpi_linexpr0(offset+x,offset+y,coeff_x[1],-coeff_y[1])
    if(domain=='refinezono'):
        element = elina_abstract0_assign_linexpr_array(man,True,element,tdim,linexpr0,1,None)
        bound_linexpr = elina_abstract0_bound_dimension(man,element,offset+length)
    else:
        bound_linexpr = get_bounds_for_linexpr0(man,element,linexpr0,layerno)
    xy[1] = bound_linexpr.contents.sup.contents.val.dbl
         

    linexpr0 = generate_tvpi_linexpr0(offset+x,offset+y,-coeff_x[2], -coeff_y[2])
    if(domain=='refinezono'):
        element = elina_abstract0_assign_linexpr_array(man,True,element,tdim,linexpr0,1,None)
        bound_linexpr = elina_abstract0_bound_dimension(man,element,offset+length)
    else:
        bound_linexpr = get_bounds_for_linexpr0(man,element,linexpr0,layerno)
    xy[2] = bound_linexpr.contents.sup.contents.val.dbl

    linexpr0 = generate_tvpi_linexpr0(offset+x,offset+y,-coeff_x[3], coeff_y[3])
    if(domain=='refinezono'):
        element = elina_abstract0_assign_linexpr_array(man,True,element,tdim,linexpr0,1,None)
        bound_linexpr = elina_abstract0_bound_dimension(man,element,offset+length)
    else:
        bound_linexpr = get_bounds_for_linexpr0(man,element,linexpr0,layerno)
    xy[3] = bound_linexpr.contents.sup.contents.val.dbl

def encode_2reLu_cons(nn, man, element, offset, layerno, length, lbi, ubi, relu3vars_list, relu2vars_list, relu1var_list, need_pop, domain):
    if(need_pop):
        relu3vars_list.pop()
        relu2vars_list.pop()
        relu1var_list.pop()   

    candidate_vars = []
    widths = []
    for i in range(length):
        if(lbi[i]<0 and ubi[i]>0):
            candidate_vars.append(i)
            widths.append(ubi[i]-lbi[i])
    widths = np.asarray(widths)
    num_candidates = len(candidate_vars)
    #print("num candidates here ", num_candidates)
    sorted_width_indices = np.argsort(widths)
    
    relu3varsi = []    
    relu2varsi = []
    relu1vari = []
    tdim= ElinaDim(offset+length)
    if(domain=='refinezono'):
        element = dn.add_dimensions(man,element,offset+length,1)
        
    count = 0

    if config.use_3relu:
        for i in range(num_candidates//3):
        #for i in range(min(0,int(num_candidates/3))):
            x = candidate_vars[sorted_width_indices[3*i]]
            y = candidate_vars[sorted_width_indices[3*i+1]]
            z = candidate_vars[sorted_width_indices[3*i+2]]
            lb = np.zeros(3,dtype=np.double)
            ub = np.zeros(3,dtype=np.double)

            lb[0] = lbi[x]
            lb[1] = lbi[y]
            lb[2] = lbi[z]

            ub[0] = ubi[x]
            ub[1] = ubi[y]
            ub[2] = ubi[z]

            #xyz = np.zeros(8,dtype=np.double)
            relu3varsi.append(Relu3Vars(man,element, tdim, layerno, offset, x,y,z,lb,ub, domain))
            count+=3


    #for i in range(min(0,int(num_candidates/3))):
    #for i in range(count,num_candidates-1,2):
    if config.use_2relu:
        for i in range(count,num_candidates-1,2):   
             x = candidate_vars[sorted_width_indices[i]]
             y = candidate_vars[sorted_width_indices[i + 1]]
             coeff_x,coeff_y = compute_coeff(nn,layerno, x,y)
             #coeff_x = coeff[x]
             #coeff_y = coeff[y]
             
             lb = np.zeros(2,dtype=np.double)
             ub = np.zeros(2,dtype=np.double)
             
             lb[0] = lbi[x]
             lb[1] = lbi[y]
             ub[0] = ubi[x]
             ub[1] = ubi[y]
             #print("bounds here2 ", x, y, lbi[x],lbi[y], ubi[x],ubi[y])
             xy = np.zeros(4, dtype=np.double)
             compute_2relu_input(man, element, tdim, layerno, offset, x, y,coeff_x, coeff_y, xy, domain)
             relu2varsi.append(Relu2Vars(x,y,coeff_x, coeff_y, lb,ub,xy))
             count+=2
  
    #print("count ",count, "num_candidates ",num_candidates, " layerno ",layerno)
    for i in range(count,num_candidates):
        relu1vari.append(candidate_vars[sorted_width_indices[i]])       
    if(domain=='refinezono'):
        element = dn.remove_dimensions(man,element,offset+length,1)
    relu3vars_list.append(relu3varsi)
    relu2vars_list.append(relu2varsi) 
    relu1var_list.append(relu1vari)
