import nest
import nest.raster_plot
import random
import numpy as np
import pylab
import matplotlib.pyplot as plt
import pickle, yaml
import time

#This script connects two populations of AdEx neurons with one-to-one electrical synapses. Chemical synapses are added when running test number 4.

nest.Install("nestml_gap_module")

#Import parameters for network
file = open(r'test_configuration_gap.yaml')
args = yaml.load(file, Loader=yaml.FullLoader)
print(f'\nLoading parameters from configuration file:\n')

rng_seed = np.random.randint(10**7) if args['seed'] is None else args['seed']

sim_time = args['simulation_time']
dt = args['delta_clock']
test_num = args['test_number']
gj_on = args['gap_junction']
gap_weight = args['gap_weight']
n_neuron = args['neurons_per_pop'] #neurons per population

nest.ResetKernel()
nest.set_verbosity('M_WARNING')
plt.rcParams.update({'font.size': 20})
nest.SetKernelStatus({'resolution': dt})
nest.SetKernelStatus({'print_time': True})
nest.SetKernelStatus({'rng_seed': rng_seed})

#Set neuron parameters
C_m_initial_tonic = nest.random.normal(mean=200.0, std=40.0) #pF 500/80
I_e_tonic = nest.random.normal(mean=320.0, std=80.0) #pA
t_ref_initial = nest.random.normal(mean=1.0, std=0.2) #ms
V_m_initial = nest.random.normal(mean=-60.0, std=10.0) #mV

tonic_neuronparams = {'C_m':C_m_initial_tonic, 'g_L':10.,'E_L':-70.,'V_th':-50.,'Delta_T':2.,'tau_w':30., 'a':3., 'b':0., 'V_reset':-58., 'I_e':I_e_tonic,'t_ref':t_ref_initial,'V_m':V_m_initial} #tonic firing, Naud et al. 2008, Fig 4a
tonic_neuronparams_static1 = {'C_m':200., 'g_L':10.,'E_L':-70.,'V_th':-50.,'Delta_T':2.,'tau_w':30., 'a':3., 'b':0., 'V_reset':-58., 'I_e':320.,'t_ref':1.,'V_m':-60.} #tonic firing, Naud et al. 2008, Fig 4a
tonic_neuronparams_static2 = {'C_m':200., 'g_L':10.,'E_L':-70.,'V_th':-50.,'Delta_T':2.,'tau_w':30., 'a':3., 'b':0., 'V_reset':-58., 'I_e':320.,'t_ref':1.,'V_m':-50.} #tonic firing, Naud et al. 2008, Fig 4a

#Set noise parameters
noise_params = {"dt": 1., "std":320}

#Set synapse parameters
inh_syn_params = {"synapse_model":"static_synapse",
            "weight" : -1., #nS            
            "delay" : 1.}	#ms
conn_dict_custom = {'rule': 'pairwise_bernoulli', 'p': 0.03}

if test_num==1:
    adex1 = nest.Create('aeif_cond_alpha_gap_nestml',1,tonic_neuronparams_static1) #aeif_cond_alpha_gap_nestml
    adex2 = nest.Create('aeif_cond_alpha_gap_nestml',1,tonic_neuronparams_static1)
elif test_num==2 or test_num==3:
    adex1 = nest.Create('aeif_cond_alpha_gap_nestml',1,tonic_neuronparams_static1) #aeif_cond_alpha
    adex2 = nest.Create('aeif_cond_alpha_gap_nestml',1,tonic_neuronparams_static2)
elif test_num==4:
    adex1 = nest.Create('aeif_cond_alpha_gap_nestml',n_neuron,tonic_neuronparams_static1)
    adex2 = nest.Create('aeif_cond_alpha_gap_nestml',n_neuron,tonic_neuronparams_static2)
    
#Set recording device parameters
mm_params = {'interval': 1., 'record_from': ['V_m']}
sd_params = {"withtime" : True, "withgid" : True, 'to_file' : False, 'flush_after_simulate' : False, 'flush_records' : True}

#Create multimeters
vm1 = nest.Create('multimeter', mm_params)
vm2 = nest.Create('multimeter', mm_params)

#Create spike recorders
spike_detector1 = nest.Create("spike_recorder",n_neuron)
spike_detector2 = nest.Create("spike_recorder",n_neuron)

#Connect multimeters
nest.Connect(vm1, adex1)
nest.Connect(vm2, adex2)

#Connect spike detectors to neuron populations
nest.Connect(adex1,spike_detector1,"one_to_one")
spike_detector1.n_events = 0	#ensure no spikes left from previous simulations
nest.Connect(adex2,spike_detector2,"one_to_one")
spike_detector2.n_events = 0	#ensure no spikes left from previous simulations

if gj_on==1:
   nest.Connect(adex1,
                 adex2,
                 conn_spec={"rule": "one_to_one", "make_symmetric": True},
                 #conn_spec={"rule": "symmetric_pairwise_bernoulli", "p": .5,"allow_autapses": False, "make_symmetric": True},
                 syn_spec={"synapse_model": "gap_junction","weight": gap_weight})

#Connect populations with chemical synapses
if test_num==4:
    nest.Connect(adex1,adex2,conn_dict_custom,inh_syn_params)
    nest.Connect(adex2,adex1,conn_dict_custom,inh_syn_params)

adex1_adex2_gap = nest.GetConnections(adex1,adex2,synapse_model="gap_junction")
print('Original weights: ',adex1_adex2_gap)
nest.SetStatus(adex1_adex2_gap,{'weight':2})
print('New weights: ',adex1_adex2_gap)

nest.Simulate(sim_time)

neuron_to_sample = random.randint(1, n_neuron)

senders1 = nest.GetStatus(vm1, 'events')[0]['senders']
times1 = nest.GetStatus(vm1, 'events')[0]['times']
V1 = nest.GetStatus(vm1, 'events')[0]['V_m']
V1 = V1[neuron_to_sample::n_neuron]

senders2 = nest.GetStatus(vm2, 'events')[0]['senders']
times2 = nest.GetStatus(vm2, 'events')[0]['times']
V2 = nest.GetStatus(vm2, 'events')[0]['V_m']
V2 = V2[neuron_to_sample::n_neuron]

print('Average Vm difference between neurons: ',np.mean(V1-V2))

pylab.figure()
pylab.plot(V1,label='Adex1')
pylab.plot(V2,label='Adex2')
pylab.xlabel('time (ms)')
pylab.ylabel('membrane potential (mV)')
pylab.legend()

pylab.show()
