import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from numpy.random import random_sample
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import pickle as pkl
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import sys

class SphereNDSampler():

    def __init__(self, n):
        self.x = 0
        self.n = n
        if n == 2:
            self.v1 = [1, 0]
            self.v2 = [0, 1]
        else:
            v1 = np.random.rand(n)
            v1 = v1/np.linalg.norm(v1)
            v2_found = False
            while not v2_found:
                v = np.random.rand(n)
                v = v/np.linalg.norm(v)
                if np.abs(np.dot(v, v1)) < 0.75:
                    v2 = v - np.dot(v, v1)*v1
                    v2 = v2/np.linalg.norm(v2)
                    v2_found = True
            self.v1 = v1
            self.v2 = v2

    def get_from_t(self, t):
        vec_t = np.sin(t)*self.v1 + np.cos(t)*self.v2
        return vec_t

    def get_n_samples(self, num_samples=1000, sample_range=(-np.pi, np.pi)):
        pts = []
        t_samples = (sample_range[1] - sample_range[0]) * random_sample(size=num_samples) + sample_range[0]

        for t_sample in t_samples:
            x = self.get_from_t(t_sample)
            pts.append(x)
        pts = np.array(pts)
        return t_samples, pts

    def get_distance(self, t1, t2):
        if t1*t2 > 0:
            return np.abs(t1 - t2)
        else:
            return np.minimum(np.abs(t1) + np.abs(t2), np.abs(t1 + t2))

    def get_interval_length(self, t1, t2):
        return 2*np.pi

class Net(nn.Module):

    def __init__(self, sample_mean, sample_std, n, v1, v2, l1=10, l2=16):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.layer_neurons = {0: l1, 1: l2, 2: 1}
        self.sample_mean = sample_mean
        self.sample_std = sample_std
        self.n = n
        self.v1 = v1
        self.v2 = v2
        self.fc1 = nn.Linear(n, l1, bias=True)
        self.fc2 = nn.Linear(l1, l2, bias=True)
        self.fcout = nn.Linear(l2, 1, bias=True)
        self.layers = [self.fc1, self.fc2, self.fcout]

    def standardize_data(self, val_arr):
        val_arr = (val_arr - self.sample_mean) / self.sample_std
        return val_arr

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fcout(x)
        return x

    def get_layerwise_activation(self, layer_num, x):
        v_x = self.v1*np.sin(x) + self.v2*np.cos(x)
        v_x = self.standardize_data(v_x)
        x = torch.tensor(v_x)
        x = x.float()
        for i in range(layer_num + 1):
            if i == layer_num:
                x = torch.abs(self.layers[i](x))
            else:
                x = F.relu(self.layers[i](x))
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


def duplicate_exists(list_to_check, value_to_check, boundary_thresh):
    found_duplicate = False
    for existing_x in list_to_check:
        if np.abs(existing_x - value_to_check) < boundary_thresh:
            return True
    return found_duplicate


def safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, breakpoint_x):
    if layer_number not in linear_boundaries:
        linear_boundaries[layer_number] = {}
    if neuron_number not in linear_boundaries[layer_number]:
        linear_boundaries[layer_number][neuron_number] = [breakpoint_x]
    else:
        if not duplicate_exists(linear_boundaries[layer_number][neuron_number], breakpoint_x, 0.01):
            linear_boundaries[layer_number][neuron_number] .append(breakpoint_x)


def count_linear_regions(model_path, net=None, sample_range = (-3.15, 3.15)):
    if not net:
        net = torch.load(model_path)

    zero_threshold = 0.001
    linear_boundaries = {}

    for layer_number in range(len(net.layers) - 1):
        for neuron_number in range(net.layer_neurons[layer_number]):
            num_init_points = 10
            for i in range(1, num_init_points):

                start_point = sample_range[1] * i / num_init_points

                def fun_to_optimize(t):
                    output = net.get_layerwise_activation(layer_number, t).detach().numpy()
                    return output[neuron_number]

                t_breakpoint = minimize(fun_to_optimize, np.array([start_point]), method='SLSQP', tol=1e-6, bounds=[(0, sample_range[1])],
                                        options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint.x)
                t_breakpoint_2 = minimize(fun_to_optimize, np.array([-1 * start_point]), method='SLSQP', tol=1e-6,
                                          bounds=[(sample_range[0], 0)],
                                          options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint_2.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint_2.x)
    boundary_values = []
    for layer_number, layer_boundaries in linear_boundaries.items():
        for neuron_boundaries in layer_boundaries.values():

            for neuron_boundary in neuron_boundaries:
                for boundary_value in neuron_boundary:
                    if not duplicate_exists(boundary_values, boundary_value, 0.0001):
                        boundary_values.append(boundary_value)

    boundary_values.sort()
    return len(boundary_values), boundary_values



def generate_training_data(sphere_sampler, num_samples=2000, periodic_freq=1.0, noise_scale = 0.25):
    ts, pts = sphere_sampler.get_n_samples(num_samples)
    f_vals = []
    noise_vals = np.random.normal(scale=noise_scale, size=num_samples)
    for t_val, noise_val in zip(ts, noise_vals):
        f_vals.append(np.sin(t_val*np.pi/periodic_freq)*(1 + noise_val))

    f_vals = np.array([f_vals])

    return ts, pts, f_vals


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_avg_distance(linear_boundaries, sphere_sampler, sample_range= (-5, 5), num_samples=300):
    sample_ts, _ = sphere_sampler.get_n_samples(num_samples, sample_range=sample_range)
    min_dists = []
    for sample_t in sample_ts:
        min_dist = np.inf
        for linear_boundary in linear_boundaries:
            sample_dist = sphere_sampler.get_distance(sample_t, linear_boundary)
            if sample_dist < min_dist:
                min_dist = sample_dist
        min_dists.append(min_dist)
    return np.mean(np.array(min_dists))


def get_mean_and_variance(xs):
    return np.mean(xs), np.std(xs)

def train_model(xs, ys, n, seed_val, sphere_sampler, num_epochs=150, sample_range=(-np.pi, np.pi), run_number=0, l1=10,
                l2=16):
    xs_mean, xs_std = get_mean_and_variance(xs)
    net = Net(xs_mean, xs_std, n, sphere_sampler.v1, sphere_sampler.v2, l1=l1, l2=l2)
    net = net.float()

    fun_vals = np.reshape(ys, (ys.shape[1], 1))
    stacked_input = xs
    frac_train = int(0.8*stacked_input.shape[0])
    frac_test = stacked_input.shape[0] - frac_train
    print(frac_train, frac_test)
    train_inputs = stacked_input[:frac_train]
    test_inputs = stacked_input[-frac_test:]
    train_outputs = fun_vals[:frac_train]
    test_outputs = fun_vals[-frac_test:]
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    trainset = torch.utils.data.TensorDataset(torch.Tensor(train_inputs), torch.Tensor(train_outputs))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
    testset = torch.utils.data.TensorDataset(torch.Tensor(test_inputs), torch.Tensor(test_outputs))
    testloader = torch.utils.data.DataLoader(testset, batch_size=8,
                                             shuffle=False)
    num_linear_regions = []
    linear_region_epochs = []
    average_distances = []
    total_losses = []
    max_distance = sphere_sampler.get_interval_length(sample_range[0], sample_range[1])
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_loss = 0.0

        for i, data in enumerate(trainloader):
            inputs, real_fun_vals = data

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, real_fun_vals)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_loss += loss.item()
        print('[%d] train loss: %.3f' %
              (epoch + 1, epoch_loss / i))

        with torch.no_grad():
            total_loss = 0.0
            num_passes = 0
            for data in testloader:
                inputs, real_fun_vals = data
                outputs = net(inputs)
                loss = criterion(real_fun_vals, outputs)
                total_loss += loss.item()
                num_passes += 1
            print('[%d, %d] test loss: %.3f' % (run_number, epoch + 1, total_loss/num_passes))

        if epoch % 5 == 0:
            num_regions, linear_boundaries = count_linear_regions('', net, sample_range=sample_range)
            average_distance = get_avg_distance(linear_boundaries, sphere_sampler, sample_range=sample_range)
            average_distance = average_distance/max_distance
            print("Run Number: %s, Epoch: %s, Average Distance: %s, Num Linear Regions: %s, Dim: %s" %
                  (run_number, epoch, average_distance, num_regions, n))
            average_distances.append(average_distance)
            num_linear_regions.append(num_regions)
            total_losses.append(total_loss)
            linear_region_epochs.append(epoch)


    data_arr = [linear_region_epochs, num_linear_regions, average_distances, total_losses]
    with open("data_nd/sphere_run_%s_seed_%s_dim_%s_neurons_%s.pkl"%(run_number, seed_val, n, l1 + l2), "wb") as f_out:
        pkl.dump(data_arr, f_out)

    torch.save(net.state_dict(), "models_nd/sphere_run_%s_seed_%s_dim_%s.torch"%(run_number, seed_val, n))

if __name__ == '__main__':
    seed_val = int(sys.argv[1])
    n = int(sys.argv[2])
    if len(sys.argv) > 3:
        l1 = int(sys.argv[3])
        l2 = int(sys.argv[4])
    else:
        l1 = 10
        l2 = 16
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)

    sample_range = (-np.pi, np.pi)
    num_runs = 20
    for i in range(num_runs):
        sphere_sampler = SphereNDSampler(n)
        ts, xs, ys = generate_training_data(sphere_sampler)
        print("Starting train for run number: %s"%(i))
        train_model(xs, ys, n, seed_val, sphere_sampler, num_epochs=150, sample_range=sample_range, run_number=i)



