import matlab.engine
import numpy as np
import torch
import sys
sys.path.append('..')
from experiment import Experiment, MethodNest, Job
from hyperbox import Hyperbox
from relu_nets import ReLUNet
from neural_nets import data_loaders as dl
from neural_nets import train
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from other_methods import LOCAL_METHODS, GLOBAL_METHODS, OTHER_METHODS
from utilities import Factory, DoEvery
import utilities as utils
import os
import time
import pickle 

def build_network(layers, train_params, data_seed=None):
    network = ReLUNet(layers)
    train.training_loop(network, train_params)
    return network

def make_eij(i,j, num_classes):
    output = np.zeros(num_classes)
    output[i] = 1.0
    output[j] = -1.0
    return output


def naive_multiclass(network, box):
    timer = utils.Timer()
    lip_vals = {}
    num_classes = network.layer_sizes[-1]
    i = network.classify_np(box.get_center())
    for j in range(num_classes):
        if j == i:
            continue
        cvec = make_eij(i,j, num_classes)
        lip_prob = LipProblem(network, box, c_vector=cvec, primal_norm='linf', verbose=False, num_threads=2)
        ij_result = lip_prob.compute_max_lipschitz()
        print((i,j), ij_result.compute_time)
        lip_vals[(i,j)] = ij_result.value
    return (timer.stop(), min(lip_vals.values()))

def do_data_box(network, datum, radius):
    style_outputs = {}
    box = Hyperbox.build_linf_ball(datum, radius)

    for style in ['targetCrossLipschitz', 'trueTargetCrossLipschitz']:
        xlip = LipProblem(network, box, style, primal_norm='linf', verbose=True, num_threads=2)
        xlip_result = xlip.compute_max_lipschitz()
        xlip_tv = xlip_result.compute_time, xlip_result.value
        style_outputs[style] = xlip_tv
    style_outputs['naive'] = naive_multiclass(network, box)
    print('-' * 100)
    print(style_outputs)

    return style_outputs



if __name__ == '__main__':
	DIMENSION = 8
	NUM_CLASSES = 100
	LAYERS = [DIMENSION, 40, 40, 40, NUM_CLASSES]
	NUM_TRIALS = 10
	NUM_SAMPLES = 20
	RANDOM_SEED = 420
	data_params = dl.RandomKParameters(2000, 200, radius=0.01, dimension=DIMENSION,
	                                   num_classes=NUM_CLASSES)
	dataset = dl.RandomDataset(data_params, random_seed=420)
	trainset, _ = dataset.split_train_val(1.0)
	train_batch = trainset[0][0]
	xentropy = train.XEntropyReg()
	l1_reg = train.LpWeightReg(scalar=5e-4)
	loss = train.LossFunctional(regularizers=[xentropy, l1_reg])
	train_params = train.TrainParameters(trainset, trainset, 2000, loss_functional=loss, 
	                                     test_after_epoch=20)

	for rad, radname in [(0.1, 'crosslip_exp_r01.pkl')]:
		results = []
		for trial in range(NUM_TRIALS):
			network = build_network(LAYERS, train_params)
			for i in range(NUM_SAMPLES):
				data = train_batch[i]
				result = do_data_box(network, data, rad)
				results.append(result)
		with open(radname, 'wb') as f:
			pickle.dump(results, f)