import numpy as np
# For evaluation only, a nest installation is not required. For generation of data,
# a nest version with changes by the authors is required.
try:
    import nest
except:
    nest = None    
import create_network, apply_stimulus, get_properties
import pickle
import multiprocessing as mp
import sys
from copy import deepcopy



'''
train input and readout projection on ECG5000 dataset
Use as
python ECG_train.py seed system n_rand solver
where seed is the network realization
system  is lin or non for linear or non-linear network
n_rand is the number of random u for comparison (optional)
solver is eigenvalue or ridge as corrent possibilities. Default: eigenvalue
For example:
python ECG_train.py 1005 non 50 eigenvalue
'''



#%%

seed = int(sys.argv[1])
np.random.rand(seed)

if len(sys.argv) > 3:
    n_rand = int(sys.argv[3])
else:
    n_rand = 0

if len(sys.argv) > 4:
    solver = sys.argv[5]
else:
    solver='eigenvalue'

file_label = str(seed)+'_'+sys.argv[2] # name output files

#####
# Preprocess the data
#####

# load data
train = np.genfromtxt('ECG5000/ECG5000_TRAIN.txt', dtype='float')
test = np.genfromtxt('ECG5000/ECG5000_TEST.txt', dtype='float')
# shape (nu, n) or (t, n) where nu/t is observation/trial index and n is time index

# Choose classes for binary classification
binary_classes = np.unique(train[:, 0])[:2]

train = train[[i for i, x in enumerate(train[:, 0]) if x in binary_classes]]
test = test[[i for i, x in enumerate(test[:, 0]) if x in binary_classes]]

# modify labels to be +- 1
for idx, label in enumerate(train[:, 0]):
    if np.abs(label - binary_classes[0]) < 0.1:
        train[idx, 0] = 1
    elif np.abs(label - binary_classes[1]) < 0.1:
        train[idx, 0] = -1
    else:
        sys.exit('Error in class selection.')

for idx, label in enumerate(test[:, 0]):
    if np.abs(label - binary_classes[0]) < 0.1:
        test[idx, 0] = 1
    elif np.abs(label - binary_classes[1]) < 0.1:
        test[idx, 0] = -1
    else:
        sys.exit('Error in class selection.')


# In principle, classes of different sizes are possible, but for simplicity,
# additional samples from the larger class are appended to the test set.
pos_indices = np.where(train[:, 0] > 0)[0]
neg_indices = np.where(train[:, 0] < 0)[0]

if len(pos_indices) < len(neg_indices):

    test = np.append(test, train[neg_indices[len(pos_indices):]], axis=0)

    neg_indices = neg_indices[:len(pos_indices)]
    used_indices = np.sort(np.append(pos_indices, neg_indices))

    train_data = train[used_indices, 1:]
    train_labels = train[used_indices, 0]


elif len(neg_indices) < len(pos_indices):

    test = np.append(test, train[pos_indices[len(neg_indices):]], axis=0)

    pos_indices = pos_indices[:len(neg_indices)]
    used_indices = np.sort(np.append(neg_indices, pos_indices))

    train_data = train[used_indices, 1:]
    train_labels = train[used_indices, 0]


else:
    train_data = train[:, 1:]
    train_labels = train[:, 0]

pos_indices = np.where(train_labels > 0)[0]
neg_indices = np.where(train_labels < 0)[0]

test_data = test[:, 1:]
test_labels = test[:, 0]

all_labels = np.append(train_labels, test_labels)



# offset to center data
offset = np.mean(train_data, axis=0)
train_data -= offset
test_data -= offset

# normalize to generate weak inputs for the nonlinearity. Guarantees ||mu||=1.
normalization_factor = np.linalg.norm(np.mean(train_data[pos_indices], axis=0))
train_data /= normalization_factor
test_data /= normalization_factor


# estimated mu. Identical to <zeta_nu x^nu>
mu = np.mean(train_data[pos_indices], axis=0)


#%%

#####
# Set up the network
#####

with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
    network_dictionary = pickle.loads(handle.read())

if sys.argv[2] == 'lin':
    net = network_dictionary['net_lin']

if sys.argv[2] == 'non':
    net = network_dictionary['net_non']

#set network resolution to dataset's resolution
net.stop_points = np.linspace(0., net.readout_time, len(mu)+1)
net.T = len(net.stop_points) - 1
# Before any optimization, the Green's function's product with the given stimuli has to be determined.
# Always use Large_N=True, as the O(alpha**2) correction term is still missing otherwise. Also, less
# memory intensive, although slower.
net.determine_sample_dynamics(train_data, train_labels, large_N=True)

# Calculate connectivity matrix from eigenvectors and eigenvalues
W = np.real(np.einsum('aj, ia, a -> ij', net.Left, net.Right, net.lamb))

opt_steps_non = 30 # steps for alternating optimization of u, v (non-linear reservoir)
opt_steps_lin = 30 # steps for alternating optimization of u, v (linear reservoir)

eta = network_dictionary['eta'] # determines weight of covariance term for the soft margin
n_ini = mp.cpu_count() #number of initial conditions for non-linear system optimization. On our machine: 48
steps_ini = 10 # Number of steops, after which initial conditions are compared.





#%%

#####
# Optimize the system
#####

# optimization results. All quantities as list over optimization steps.
if sys.argv[2] == 'lin':
    # optimization routine in the linear system
    print('Optimizing linear system...')
    soft_margins, input_vectors, readout_vectors = net.find_good_optimization(opt_steps_lin, mu, eta=eta, initial_cond=n_ini, initial_steps=steps_ini, solver=solver)

if sys.argv[2] == 'non':
    # optimization routine in the nonlinear system
    print('Optimizing non-linear system...')
    soft_margins, input_vectors, readout_vectors = net.find_good_optimization(opt_steps_non, mu, eta=eta, initial_cond=n_ini, initial_steps=steps_ini, solver=solver)

#The optimal soft margin is therefore soft_margins[-1], the optimal input projection is input_vectors[-1]
print('Finished Optimization. Reached a soft margin of', soft_margins[-1]) # This is the quantity referred to in the evaluations




#%%

#####
# Composition of the soft margin
#####

# contributions to the soft margin from separation and covariance. Used in fig. 4 c / d

# over optimization steps: fig. 4c
composition_optimization = np.empty((len(input_vectors), 2))
for idx, vector in enumerate(input_vectors):

    dist_lin = net.linear_distance(vector)
    #d1
    dist_non = net.nonlinear_distance(vector)
    #Sigma0
    Sigma_lin = net.linear_covariance(vector)
    #Sigma1
    Sigma_non = net.nonlinear_covariance(vector)
    Sigma_corr = net.nonlinear_covariance_correction(vector)

    composition_optimization[idx] = np.array([
            np.einsum('i, i -> ', dist_lin + dist_non, readout_vectors[idx]),
            eta / 2 * np.einsum('i, ij, j ->', readout_vectors[idx], Sigma_lin + Sigma_non + Sigma_corr, readout_vectors[idx])
            ])


# over readout times: fig. 4d
readout_times = np.linspace(net.readout_time/100, net.readout_time, 100)
responses_over_time = np.empty((len(readout_times), len(train_labels), net.N))

for time_idx, time in enumerate(readout_times):

    # for intermediate time points, the network has to be modified.
    # It is proposed to be used only at readout time.
    net_temp = deepcopy(net)

    net_temp.stop_points = np.append(net.stop_points[net.stop_points < time], time)
    net_temp.T=len(net_temp.stop_points)-1
    net_temp.readout_time = time

    responses_over_time[time_idx] = net_temp.determine_responses(train_data[:, :net_temp.T], input_vectors[-1])

# zeta_nu y^nu. The basis for the cumulants M, Sigma.
generalized_states_temp = np.einsum('nti, t -> nti', responses_over_time, train_labels)

# M(t)
dist_temp = np.mean(generalized_states_temp, axis=1) #(time, neuron)
# Sigma(t)
Sigma_temp = np.array([np.cov(generalized_states_temp[time_idx].T, bias=True) for time_idx in range(len(readout_times))]) #(time, neuron, neuron)

composition_time = np.array([
        np.einsum('ni, i -> n', dist_temp, readout_vectors[-1]),
        eta / 2 * np.einsum('i, nij, j -> n', readout_vectors[-1], Sigma_temp, readout_vectors[-1])
        ]).T

# save the results for evaluation.
composition_dictionary = {
        'composition_optimization': composition_optimization,
        'composition_temporal': composition_time,
        'times': readout_times
        }



#%%

#####
# Simulate all stimuli with the optimized input projection
#####

# Since real datasets are not Gaussian, the result is checked with simulation recordings.
# Later, cumulants of zeta_nu y^nu and the readout vector are calculated anew.
# Tolerable differences in the seolutions are expected from the approximations made in the paper.

input_projection = input_vectors[-1]
poly_coeffs_sim = np.copy(net.poly_coeffs)
poly_coeffs_sim[2] *= 2 # This is alpha, the strength of the non-linearity.
# the prefactor of NEST's polynomial_rate_ipn neuron's quadratic non-linearity
# is 1/2, so we need to multiply it with to to use the same equation of motion.
# also, in NEST the stimuli x(t) pass through the non-linearity. We will take
# measures against that further below.

# simulation resolution
dt=0.001
# readout resolution
dt_multi = 0.1

simulation_stimuli = np.append(train_data, test_data, axis=0)
# as mentioned above, NEST internally, the output of the step rate generator passes through the non-linearity.
# This, we do not want, so we have to revert this. This can fail if the stimuli are too large, so a rescaling can be necessary.
if poly_coeffs_sim[2] != 0:
    simulation_stimuli = - poly_coeffs_sim[1]/poly_coeffs_sim[2] + np.sqrt((poly_coeffs_sim[1]/poly_coeffs_sim[2])**2 \
                        + 2./poly_coeffs_sim[2] * simulation_stimuli)
# Only print when errors occur
nest.set_verbosity('M_ERROR')
# dictionaries for multimeter and neuron setup
neuron_dict={'poly_coeffs':poly_coeffs_sim, 'linear_summation': False, 'mu': 0., 'sigma': 0., 'tau': net.tau}
multi_dict={'withtime': True, 'record_from': ['rate'], 'interval': dt_multi}
# setup multimeter, neurons and step rate generator (self-made by the authors) with the routine described in create_network
multi, n, rate_gen = create_network.create_network(N=net.N, N_recorded=net.N, neuron_dict=neuron_dict, multi_dict=multi_dict, neuron_type='polynomial_rate_ipn', input_type='same', W=W, input_vector=input_projection, dt=dt, setKernel=True)

# apply stimuli to neurons (n_samples per class) using routine described in apply_stimulus and sort multimeter readout
# into times_sim (simulation readout times with resolution dt_multi) and rates #(time, trial, neuron) with here only 1 trial
readout=np.empty((len(all_labels), net.N))
multi_list=np.zeros((len(all_labels), 3, int(net.N*(net.stop_points[-1]+dt_multi)/dt_multi)))
for trial in range(len(all_labels)):
    readout[trial], multi_list[trial] = apply_stimulus.apply_stimulus(rate_amplitude=np.append(simulation_stimuli[trial], 0.), dt=dt, rate_generator=rate_gen,  neurons=n, T=net.stop_points[-1]+dt_multi, multimeter=multi, rate_timestamps=net.stop_points)
rates, times_sim = get_properties.get_rate(multi_list, net.N, len(all_labels)) #(time, trial, neuron)




#%%

#####
# Evaluate the simulation results
#####


#class-generalized network states zeta_nu*y^nu to copte network states's distances and covariances
final_states = rates[-1] # y(T), #(trial, neuron)
final_states -= np.mean(final_states, axis=0) # centered
final_train_states = final_states[:len(train_labels)]
final_test_states = final_states[len(train_labels):]

# zeta_nu y^nu
generalized_train_states = np.einsum('ti, t -> ti', final_train_states, train_labels)
# M(T), based on simulation
dist_sim = np.mean(generalized_train_states, axis=0)
# Sigma(T), based on simulation
Sigma_sim = np.cov(generalized_train_states.T, bias=True)
# determine readout vector anew based on actual network state cumulants
v_sim = get_properties.find_simulation_readout(Sigma_sim, dist_sim, eta, net.N, net, final_train_states, train_labels)


opt_distances = np.einsum('ti, i -> t', final_states, v_sim)


# labels assigned to each network state by the readout vector
assigned_labels = np.sign( np.einsum('ti, i -> t', final_states, v_sim) )
assigned_train_labels = assigned_labels[:len(train_labels)]
assigned_test_labels = assigned_labels[len(train_labels):]

# proportion of correct classification
accs = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_labels, all_labels)) )
accs_train = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_train_labels, train_labels)) )
accs_test = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_test_labels, test_labels)) )

print('Correctly classified samples in the training set: ', accs_train)
print('Correctly classified samples in the test set: ', accs_test) # This is the quantity referred to in the evaluations
print('Overall correctly classified samples: ', accs)


#%%

#####
# Comparison with expected accuracies for random input vectors
#####

# calculate accuracy for random u
rand_responses = np.zeros((n_rand, len(all_labels), net.N)) # y(T) for all random u
rand_readouts = np.zeros((n_rand, net.N)) # optimal readout for all random u

soft_margins_rand = np.zeros(n_rand) # soft margins for all random u
accs_rand = np.zeros(n_rand) # accuracies in all data set for all random u
accs_train_rand = np.zeros(n_rand) # accuracies in train set for all random u
accs_test_rand = np.zeros(n_rand) # accuracies in test set for all random u


for rand_idx in range(n_rand):

    u_rand = np.random.rand(net.N) - 0.5
    u_rand /= np.linalg.norm(u_rand)

    # determines readout and returns soft margin, used u and corresponding, optimal v
    soft_margins_rand[rand_idx], input_vectors_rand, readout_vectors_rand = net.alternating_optimization(1, mu, eta = eta, initial_guesses=1, solver=solver)
    rand_readouts[rand_idx] = readout_vectors_rand[0]

    rand_responses[rand_idx] = net.determine_responses(np.append(train_data, test_data, axis=0), input_vectors_rand[0], center=True)

    assigned_labels_rand = np.sign(np.einsum('ti, i -> t', rand_responses[rand_idx], readout_vectors_rand[-1]))
    assigned_train_labels_rand = assigned_labels_rand[:len(train_labels)]
    assigned_test_labels_rand = assigned_labels_rand[len(train_labels):]

    accs_rand[rand_idx] = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_labels_rand, all_labels)) )
    accs_train_rand[rand_idx] = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_train_labels_rand, train_labels)) )
    accs_test_rand[rand_idx] = np.mean( 0.5 * (1 + np.einsum('t, t -> t', assigned_test_labels_rand, test_labels)) )
    

rand_distances = np.einsum('rti, ri -> rt', rand_responses, rand_readouts)
    

print('Random u: soft margins ranging between', np.min(soft_margins_rand, axis=0), ' and', np.max(soft_margins_rand, axis=0), ' with mean', np.mean(soft_margins_rand, axis=0))
print('Random u: training set accuracies ranging between', np.min(accs_train_rand, axis=0), ' and', np.max(accs_train_rand, axis=0), ' with mean', np.mean(accs_train_rand, axis=0))
print('Random u: testing set accuracies ranging between', np.min(accs_test_rand, axis=0), ' and', np.max(accs_test_rand, axis=0), ' with mean', np.mean(accs_test_rand, axis=0))
print('Random u: overall accuracies ranging between', np.min(accs_rand, axis=0), ' and', np.max(accs_rand, axis=0), ' with mean', np.mean(accs_rand, axis=0))

#%%

#####
# Saving data for evaluation
#####



# dictionaries containing the soft margins, margins and accuracies for the given configurations
opt_dictionary = {
        'soft_margins_train': soft_margins[-1],
        'accs_test': accs_test
        }

rand_dictionary = {
        'soft_margins_train_rand': soft_margins_rand,
        'accs_test_rand': accs_test_rand
        }

# dictionaries containing the network state data and readout vectors for visualization

figure_dictionary = {
        'train_labels': train_labels,
        'test_labels': test_labels,

        'opt_distances': opt_distances,
        'rand_distances': rand_distances
        }



full_dictionary = {
                  'rand_dictionary': rand_dictionary,
                  'opt_dictionary': opt_dictionary,
                  'figure_dictionary': figure_dictionary,
                  'composition_dict': composition_dictionary
                  }

with open('data/ECG/ECG5000_'+file_label+'.txt', 'wb') as handle:
    pickle.dump(full_dictionary, handle)

