import numpy as np
import matplotlib.pyplot as plt
import nest
import nest.voltage_trace
import time
import sys

# All times given in milliseconds
dt = 0.5
dt_rec = 1.
t_end = 2000. # Simulation time

NL         = int(sys.argv[1]) # number of lgn neurons
NS         = int(sys.argv[2]) # number of step_current_source created
numThreads = int(sys.argv[3]) # number of threads

nest.set_verbosity("M_WARNING")
nest.ResetKernel()
nest.SetKernelStatus({'local_num_threads': numThreads}) # old syntax is nest.local_num_threads = numThreads
nest.resolution = dt
nest.print_time = True

t0 = 0

# synaptic time constants of LGN connections
tau_ex_L = 1.5  # in ms
tau_in_L = 10.0  # in ms
tau_m_L = 10.0
C_m_L = 0.29*1000 #nF to pF conversion
g_L = C_m_L/tau_m_L

paramsL = {
    'V_m': -70.0,# initial membrane potential
    'E_L': -70.0 , # v_rest, Leak reversal potential
    'C_m': C_m_L, # cm
    't_ref': 2.0, # tau_refrac
    'V_th': -57.0, # v_thresh
    'V_reset': -70.0, # v_reset
    'E_ex': 0.0, # e_rev_E
    'E_in': -75.0, # e_rev_I
    'g_L' : g_L, # cm/tau_m_L,
    'tau_syn_ex': tau_ex_L, # tau_syn_E
    'tau_syn_in': tau_in_L, # tau_syn_I
    'I_e': 0.0, #Constant input current
}
## Create NL lgn neurons
lgn_pop = nest.Create('iaf_cond_exp', NL)
nest.SetStatus(lgn_pop,paramsL)

## create time list for step current
sdt = 7
amplitude_tims = np.arange(sdt,t_end,sdt)
amplitude_vals = [600.]*len(amplitude_tims)

## inject the spike current into lgn neuronsi
nest_stepcurrent = nest.Create('step_current_generator', NS, params={"amplitude_values": amplitude_vals, "amplitude_times": amplitude_tims, "start": 0., "stop": t_end})
con_params_ll = {"rule": "fixed_indegree", "indegree": 25}
nest.Connect(nest_stepcurrent, lgn_pop, con_params_ll)

## run the simulation
nest.rng_seed = 1
time_start = time.time()
nest.Simulate(t_end + dt_rec)
time_evaluate = time.time()

## print the results
print('Simulation time: {:.3f} s'.format(time_evaluate-time_start))
print('Number of lgn neurons: {}'.format(NL))
print('Number of step current sources: {}'.format(NS))