# -*- coding: utf-8 -*-
#
# test_stdp_triplet_synapse.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
# This script tests the stdp_triplet_synapse in NEST.
import nest
import unittest
from math import exp
import numpy as np
[docs]@nest.check_stack
class STDPTripletConnectionTestCase(unittest.TestCase):
"""Check stdp_triplet_connection model properties."""
[docs] def setUp(self):
nest.set_verbosity('M_WARNING')
nest.ResetKernel()
# settings
self.dendritic_delay = 1.0
self.decay_duration = 5.0
self.synapse_model = "stdp_triplet_synapse"
self.syn_spec = {
"model": self.synapse_model,
"delay": self.dendritic_delay,
# set receptor 1 post-synaptically, to not generate extra spikes
"receptor_type": 1,
"weight": 5.0,
"tau_plus": 16.8,
"tau_plus_triplet": 101.0,
"Aplus": 0.1,
"Aminus": 0.1,
"Aplus_triplet": 0.1,
"Aminus_triplet": 0.1,
"Kplus": 0.0,
"Kplus_triplet": 0.0,
"Wmax": 100.0,
}
self.post_neuron_params = {
"tau_minus": 33.7,
"tau_minus_triplet": 125.0,
}
# setup basic circuit
self.pre_neuron = nest.Create("parrot_neuron")
self.post_neuron = nest.Create(
"parrot_neuron", 1, params=self.post_neuron_params)
nest.Connect(self.pre_neuron, self.post_neuron, syn_spec=self.syn_spec)
[docs] def generateSpikes(self, neuron, times):
"""Trigger spike to given neuron at specified times."""
delay = 1.
gen = nest.Create("spike_generator", 1, {
"spike_times": [t - delay for t in times]})
nest.Connect(gen, neuron, syn_spec={"delay": delay})
[docs] def status(self, which):
"""Get synapse parameter status."""
stats = nest.GetConnections(
self.pre_neuron, synapse_model=self.synapse_model)
return nest.GetStatus(stats, [which])[0][0]
[docs] def decay(self, time, Kplus, Kplus_triplet, Kminus, Kminus_triplet):
"""Decay variables."""
Kplus *= exp(-time / self.syn_spec["tau_plus"])
Kplus_triplet *= exp(-time / self.syn_spec["tau_plus_triplet"])
Kminus *= exp(-time / self.post_neuron_params["tau_minus"])
Kminus_triplet *= exp(-time /
self.post_neuron_params["tau_minus_triplet"])
return (Kplus, Kplus_triplet, Kminus, Kminus_triplet)
[docs] def facilitate(self, w, Kplus, Kminus_triplet):
"""Facilitate weight."""
Wmax = self.status("Wmax")
return np.sign(Wmax) * (abs(w) + Kplus * (
self.syn_spec["Aplus"] +
self.syn_spec["Aplus_triplet"] * Kminus_triplet)
)
[docs] def depress(self, w, Kminus, Kplus_triplet):
"""Depress weight."""
Wmax = self.status("Wmax")
return np.sign(Wmax) * (abs(w) - Kminus * (
self.syn_spec["Aminus"] +
self.syn_spec["Aminus_triplet"] * Kplus_triplet)
)
[docs] def assertAlmostEqualDetailed(self, expected, given, message):
"""Improve assetAlmostEqual with detailed message."""
messageWithValues = "%s (expected: `%s` was: `%s`" % (
message, str(expected), str(given))
self.assertAlmostEqual(given, expected, msg=messageWithValues)
[docs] def test_badPropertiesSetupsThrowExceptions(self):
"""Check that exceptions are thrown when setting bad parameters."""
def setupProperty(property):
bad_syn_spec = self.syn_spec.copy()
bad_syn_spec.update(property)
nest.Connect(self.pre_neuron, self.post_neuron,
syn_spec=bad_syn_spec)
def badPropertyWith(content, parameters):
self.assertRaisesRegexp(
nest.NESTError, "BadProperty(.+)" + content,
setupProperty, parameters
)
badPropertyWith("Kplus", {"Kplus": -1.0})
badPropertyWith("Kplus_triplet", {"Kplus_triplet": -1.0})
[docs] def test_varsZeroAtStart(self):
"""Check that pre and post-synaptic variables are zero at start."""
self.assertAlmostEqualDetailed(
0.0, self.status("Kplus"), "Kplus should be zero")
self.assertAlmostEqualDetailed(0.0, self.status(
"Kplus_triplet"), "Kplus_triplet should be zero")
[docs] def test_preVarsIncreaseWithPreSpike(self):
"""Check that pre-synaptic variables (Kplus, Kplus_triplet) increase
after each pre-synaptic spike."""
self.generateSpikes(self.pre_neuron, [2.0])
Kplus = self.status("Kplus")
Kplus_triplet = self.status("Kplus_triplet")
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(
Kplus + 1.0,
self.status("Kplus"),
"Kplus should have increased by 1")
self.assertAlmostEqualDetailed(
Kplus_triplet + 1.0,
self.status("Kplus_triplet"),
"Kplus_triplet should have increased by 1")
[docs] def test_preVarsDecayAfterPreSpike(self):
"""Check that pre-synaptic variables (Kplus, Kplus_triplet) decay
after each pre-synaptic spike."""
self.generateSpikes(self.pre_neuron, [2.0])
# trigger computation
self.generateSpikes(self.pre_neuron, [2.0 + self.decay_duration])
(Kplus, Kplus_triplet, _, _) = self.decay(
self.decay_duration, 1.0, 1.0, 0.0, 0.0)
Kplus += 1.0
Kplus_triplet += 1.0
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(
Kplus, self.status("Kplus"), "Kplus should have decay")
self.assertAlmostEqualDetailed(Kplus_triplet, self.status(
"Kplus_triplet"), "Kplus_triplet should have decay")
[docs] def test_preVarsDecayAfterPostSpike(self):
"""Check that pre-synaptic variables (Kplus, Kplus_triplet) decay
after each post-synaptic spike."""
self.generateSpikes(self.pre_neuron, [2.0])
self.generateSpikes(self.post_neuron, [3.0, 4.0])
# trigger computation
self.generateSpikes(self.pre_neuron, [2.0 + self.decay_duration])
(Kplus, Kplus_triplet, _, _) = self.decay(
self.decay_duration, 1.0, 1.0, 0.0, 0.0)
Kplus += 1.0
Kplus_triplet += 1.0
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(
Kplus, self.status("Kplus"), "Kplus should have decay")
self.assertAlmostEqualDetailed(Kplus_triplet, self.status(
"Kplus_triplet"), "Kplus_triplet should have decay")
[docs] def test_weightChangeWhenPrePostSpikes(self):
"""Check that weight changes whenever a pre-post spike pair happen."""
self.generateSpikes(self.pre_neuron, [2.0])
self.generateSpikes(self.post_neuron, [4.0])
self.generateSpikes(self.pre_neuron, [6.0]) # trigger computation
Kplus = self.status("Kplus")
Kplus_triplet = self.status("Kplus_triplet")
Kminus = 0.0
Kminus_triplet = 0.0
weight = self.status("weight")
Wmax = self.status("Wmax")
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0, Kplus, Kplus_triplet, Kminus, Kminus_triplet)
weight = self.depress(weight, Kminus, Kplus_triplet)
Kplus += 1.0
Kplus_triplet += 1.0
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0 + self.dendritic_delay, Kplus, Kplus_triplet,
Kminus, Kminus_triplet
)
weight = self.facilitate(weight, Kplus, Kminus_triplet)
Kminus += 1.0
Kminus_triplet += 1.0
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0 - self.dendritic_delay, Kplus, Kplus_triplet,
Kminus, Kminus_triplet
)
weight = self.depress(weight, Kminus, Kplus_triplet)
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(weight, self.status(
"weight"), "weight should have decreased")
[docs] def test_weightChangeWhenPrePostPreSpikes(self):
"""Check that weight changes whenever a pre-post-pre spike triplet
happen."""
self.generateSpikes(self.pre_neuron, [2.0, 6.0])
self.generateSpikes(self.post_neuron, [4.0])
self.generateSpikes(self.pre_neuron, [8.0]) # trigger computation
Kplus = self.status("Kplus")
Kplus_triplet = self.status("Kplus_triplet")
Kminus = 0.0
Kminus_triplet = 0.0
weight = self.status("weight")
Wmax = self.status("Wmax")
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0, Kplus, Kplus_triplet, Kminus, Kminus_triplet)
weight = self.depress(weight, Kminus, Kplus_triplet)
Kplus += 1.0
Kplus_triplet += 1.0
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0 + self.dendritic_delay, Kplus, Kplus_triplet,
Kminus, Kminus_triplet
)
weight = self.facilitate(weight, Kplus, Kminus_triplet)
Kminus += 1.0
Kminus_triplet += 1.0
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0 - self.dendritic_delay, Kplus, Kplus_triplet,
Kminus, Kminus_triplet
)
weight = self.depress(weight, Kminus, Kplus_triplet)
Kplus += 1.0
Kplus_triplet += 1.0
(Kplus, Kplus_triplet, Kminus, Kminus_triplet) = self.decay(
2.0, Kplus, Kplus_triplet, Kminus, Kminus_triplet)
weight = self.depress(weight, Kminus, Kplus_triplet)
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(weight, self.status(
"weight"), "weight should have decreased")
[docs] def test_maxWeightStaturatesWeight(self):
"""Check that setting maximum weight property keep weight limited."""
limited_weight = self.status("weight") + 1e-10
limited_syn_spec = self.syn_spec.copy()
limited_syn_spec.update({"Wmax": limited_weight})
nest.Connect(self.pre_neuron, self.post_neuron,
syn_spec=limited_syn_spec)
self.generateSpikes(self.pre_neuron, [2.0])
self.generateSpikes(self.pre_neuron, [3.0]) # trigger computation
nest.Simulate(20.0)
self.assertAlmostEqualDetailed(limited_weight, self.status(
"weight"), "weight should have been limited")
[docs]@nest.check_stack
class STDPTripletInhTestCase(STDPTripletConnectionTestCase):
[docs] def setUp(self):
nest.set_verbosity('M_WARNING')
nest.ResetKernel()
# settings
self.dendritic_delay = 1.0
self.decay_duration = 5.0
self.synapse_model = "stdp_triplet_synapse"
self.syn_spec = {
"model": self.synapse_model,
"delay": self.dendritic_delay,
# set receptor 1 post-synaptically, to not generate extra spikes
"receptor_type": 1,
"weight": -5.0,
"tau_plus": 16.8,
"tau_plus_triplet": 101.0,
"Aplus": 0.1,
"Aminus": 0.1,
"Aplus_triplet": 0.1,
"Aminus_triplet": 0.1,
"Kplus": 0.0,
"Kplus_triplet": 0.0,
"Wmax": -100.0,
}
self.post_neuron_params = {
"tau_minus": 33.7,
"tau_minus_triplet": 125.0,
}
# setup basic circuit
self.pre_neuron = nest.Create("parrot_neuron")
self.post_neuron = nest.Create("parrot_neuron", 1,
params=self.post_neuron_params)
nest.Connect(self.pre_neuron, self.post_neuron, syn_spec=self.syn_spec)
[docs]def suite_inh():
return unittest.makeSuite(STDPTripletInhTestCase, "test")
[docs]def suite():
return unittest.makeSuite(STDPTripletConnectionTestCase, "test")
[docs]def run():
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())
runner.run(suite_inh())
if __name__ == "__main__":
run()