# For evaluation only, a nest installation is not required. For generation of data,
# a nest version with changes by the authors is required.
try:
    import nest
except:
    nest = None    
import readout
import numpy as np
import sys

'''
This function applies a stepwise constant stimulus to the neurons.
'''

def apply_stimulus(rate_amplitude, dt, rate_generator,  neurons, T, multimeter, reset=1., reset_rate=0., input_vector=None, rate_timestamps=None):
    '''
    Apply Stimulus to neurons. Important: Set sigma of neuron as default.

    Parameters
    ----------
    rate_amplitude : np.array
        One- or twodimensional. If onedimensional, the same array is used for all neurons.
        Otherwise, one array for every neuron has to be provided.
        Contains the rate at every time step. Equal step size, unless 
        rate_timestamps is given. If the Simulation time
        is larger than the length of the array times the step size,
        the last value is changed to zero.

    dt : float
        Simulation time resolution and step size, if rate_timestamps are not given.

    rate_generator :
        If one-dimensional rate_amplitude:
        Step_rate_generator created by nest. Amplitude will be reset.
        If two-dimensional rate_amplitude:
        List of N step_rate_generators created by nest. Amplitudes will be reset.

    neurons : 
        N Neurons created by nest, to be given input.

    T : float
        Timepoint when to readout the rate.

    multimeter :
        Multimeter created by nest, connected to (part of) the neurons.

    reset : float
        Time simulated previously to set everything zero.
        
    reset_rate : float
        Rate to reset to for repeated stimulation or beginning with some initial condition.
        
    input_vector : np.array
        Input projection u containing the stimulation weights u_i, i=1...N.
        Default: Unit weights.
        
    rate_timestamps : np.array
        Time stamps when the stimulus changes its constant values.
    
    

    Returns
    -------
    Y_T : np.array
        rates of neurons at timepoints T.

    multi : np.array
        [senders, times, rate] with time in [0, T) for this simulation.
    '''
    #######
    #Reset#
    #######

    KernelTime=nest.GetKernelStatus()['time']
    std=nest.GetStatus(neurons, 'std')
    nest.SetStatus(neurons, {'std': 0., 'rate': 0.})
    
    nest.Simulate(reset-dt)
    nest.ResetNetwork()
    if input_vector is None:
        nest.SetStatus(neurons, {'rate':reset_rate})
    else:
        for neuron in range(len(neurons)):
            nest.SetStatus([neurons[neuron]], {'rate':reset_rate*input_vector[neuron]})
    

    #######
    #Input#
    #######
    #If the same stimuli are used for all neurons
    if rate_amplitude.ndim==1:

        #Set final value
        rate_amplitude[-1]=0.
        
        #Reset
        if rate_timestamps is None:
            rate_timestamps=np.cumsum(dt*np.ones_like(rate_amplitude))+KernelTime+reset
        else:
            rate_timestamps=np.copy(rate_timestamps)+KernelTime+reset+dt
        nest.SetStatus(rate_generator, {'amplitude_times':rate_timestamps, 'amplitude_values': rate_amplitude})

    #If each neuron has its own array
    elif rate_amplitude.ndim==2:
        
        #Set final value
        rate_amplitude.T[-1]=np.zeros(len(neurons))
        #Reset
        if rate_timestamps is None:
            rate_timestamps=np.cumsum(dt*np.ones_like(rate_amplitude[0]))+KernelTime+reset
        else:
            rate_timestamps=np.copy(rate_timestamps)+KernelTime+reset+dt
        for neuron in range(len(neurons)):
            nest.SetStatus(rate_generator[neuron], {'amplitude_times':rate_timestamps, 'amplitude_values': rate_amplitude[neuron]})
    #If Error
    else:
         sys.exit('rate_amplitude must be one- or twodimensional.')

    ##########
    #Simulate#
    ##########

    nest.SetStatus(neurons, {'std': std[0]})
    nest.Simulate(T+dt)

    
    #########
    #Readout#
    #########

    #Find neurons measured by multimeter
    senders=np.unique(nest.GetStatus(multimeter)[0]['events']['senders'])

    #Store rates of recorded neurons
    Y_T=np.zeros(len(senders))

    #Measure rate for these neurons
    for count_idx, neuron_idx in enumerate(senders):
        if count_idx==0:
            Y_T[count_idx]=readout.readout_rate(multimeter, KernelTime, T, neuron_idx)
        else:
            Y_T[count_idx]=readout.readout_rate(multimeter, KernelTime, T, neuron_idx, printselection=False)

    s=nest.GetStatus(multimeter)[0]['events']['senders']
    t=nest.GetStatus(multimeter)[0]['events']['times']-reset-KernelTime#-nest.GetStatus(multimeter)[0]['interval']
    r=nest.GetStatus(multimeter)[0]['events']['rate']

    multi=np.array([s, t, r])
    
    
    ########
    #Return#
    ########
    return Y_T, multi
