from copy import deepcopy
import elephant.spike_train_generation as stg
import elephant.statistics as stat
import neo
import numpy as np
import quantities as pq
from ..tools import annotate_overlapping_spikes, compute_sync_rate
[docs]class SpikeTrainGenerator:
"""
Class for generation of spike trains called by the gen_recordings function.
The list of parameters is in default_params/recordings_params.yaml (spiketrains field).
Parameters
----------
params : dict
Dictionary with parameters to simulate spiketrains. Default values can be retrieved with
mr.get_default_recordings_params()['spiketrains']
spiketrains : list of neo.SpikeTrain
List of neo.SpikeTrain objects to instantiate a SpikeTrainGenerator with existing data
verbose : bool
If True, output is verbose
"""
def __init__(self, params=None, spiketrains=None, seed=None, verbose=False):
self._verbose = verbose
self._has_spiketrains = False
self.params = {}
if params is None:
if self._verbose:
print("Using default parameters")
if spiketrains is None:
self.params = deepcopy(params)
if seed is None:
seed = np.random.randint(1000)
if self._verbose:
print("Spiketrains seed: ", seed)
self.params["seed"] = seed
np.random.seed(self.params["seed"])
if "t_start" not in self.params.keys():
params["t_start"] = 0
self.params["t_start"] = params["t_start"] * pq.s
if "duration" not in self.params.keys():
params["duration"] = 10
self.params["t_stop"] = self.params["t_start"] + params["duration"] * pq.s
if "min_rate" not in self.params.keys():
params["min_rate"] = 0.1
self.params["min_rate"] = params["min_rate"] * pq.Hz
if "ref_per" not in self.params.keys():
params["ref_per"] = 2
self.params["ref_per"] = params["ref_per"] * pq.ms
if "process" not in self.params.keys():
params["process"] = "poisson"
self.params["process"] = params["process"]
if "gamma_shape" not in self.params.keys() and params["process"] == "gamma":
params["gamma_shape"] = 2
self.params["gamma_shape"] = params["gamma_shape"]
if "rates" in self.params.keys(): # all firing rates are provided
self.params["rates"] = self.params["rates"] * pq.Hz
self.n_neurons = len(self.params["rates"])
else:
rates = []
types = []
if "f_exc" not in self.params.keys():
params["f_exc"] = 5
self.params["f_exc"] = params["f_exc"] * pq.Hz
if "f_inh" not in self.params.keys():
params["f_inh"] = 15
self.params["f_inh"] = params["f_inh"] * pq.Hz
if "st_exc" not in self.params.keys():
params["st_exc"] = 1
self.params["st_exc"] = params["st_exc"] * pq.Hz
if "st_inh" not in self.params.keys():
params["st_inh"] = 3
self.params["st_inh"] = params["st_inh"] * pq.Hz
if "n_exc" not in self.params.keys():
params["n_exc"] = 2
self.params["n_exc"] = params["n_exc"]
if "n_inh" not in self.params.keys():
params["n_inh"] = 1
self.params["n_inh"] = params["n_inh"]
for exc in np.arange(self.params["n_exc"]):
rate = self.params["st_exc"] * np.random.randn() + self.params["f_exc"]
if rate < self.params["min_rate"]:
rate = self.params["min_rate"]
rates.append(rate)
types.append("e")
for inh in np.arange(self.params["n_inh"]):
rate = self.params["st_inh"] * np.random.randn() + self.params["f_inh"]
if rate < self.params["min_rate"]:
rate = self.params["min_rate"]
rates.append(rate)
types.append("i")
self.params["rates"] = rates
self.params["types"] = types
self.n_neurons = len(self.params["rates"])
self.info = params
self.spiketrains = False
else:
self.spiketrains = spiketrains
self.info = {}
self._has_spiketrains = True
if params is not None:
self.params = deepcopy(params)
[docs] def set_spiketrain(self, idx, spiketrain):
"""
Sets spike train idx to new spiketrain.
Parameters
----------
idx : int
Index of spike train to set
spiketrain : neo.SpikeTrain
New spike train
"""
self.spiketrains[idx] = spiketrain
[docs] def generate_spikes(self):
"""
Generate spike trains based on default_params of the SpikeTrainGenerator class.
self.spiketrains contains the newly generated spike trains
"""
if not self._has_spiketrains:
self.spiketrains = []
idx = 0
for n in np.arange(self.n_neurons):
rate = self.params["rates"][n]
if self.params["process"] == "poisson":
st = stg.homogeneous_poisson_process(rate, self.params["t_start"], self.params["t_stop"])
elif self.params["process"] == "gamma":
st = stg.homogeneous_gamma_process(
self.params["gamma_shape"], rate, self.params["t_start"], self.params["t_stop"]
)
self.spiketrains.append(st)
self.spiketrains[-1].annotate(fr=rate)
if "n_exc" in self.params.keys() and "n_inh" in self.params.keys():
if idx < self.params["n_exc"]:
self.spiketrains[-1].annotate(cell_type="E")
else:
self.spiketrains[-1].annotate(cell_type="I")
idx += 1
# check consistency and remove spikes below refractory period
for idx, st in enumerate(self.spiketrains):
isi = stat.isi(st)
idx_remove = np.where(isi < self.params["ref_per"])[0]
spikes_to_remove = len(idx_remove)
unit = st.times.units
while spikes_to_remove > 0:
new_times = np.delete(st.times, idx_remove[0]) * unit
st = neo.SpikeTrain(new_times, t_start=self.params["t_start"], t_stop=self.params["t_stop"])
isi = stat.isi(st)
idx_remove = np.where(isi < self.params["ref_per"])[0]
spikes_to_remove = len(idx_remove)
st.annotations = self.spiketrains[idx].annotations
self.set_spiketrain(idx, st)
else:
print("SpikeTrainGenerator initialized with existing spike trains!")
[docs] def add_synchrony(self, idxs, rate=0.05, time_jitt=1 * pq.ms, verbose=False):
"""
Adds synchronous spikes between pairs of spike trains at a certain rate.
Parameters
----------
idxs : list or array
Spike train indexes to add synchrony to
rate : float
Rate of added synchrony spike to spike train idxs[1] for each spike of idxs[0]
time_jitt : quantity
Maximum time jittering between added spikes
verbose : bool
If True output is verbose
Returns
-------
sync_rate : float
New synchrony rate
fr1 : quantity
Firing rate spike train 1
fr2 : quantity
Firing rate spike train 2
"""
idx1 = idxs[0]
idx2 = idxs[1]
st1 = self.spiketrains[idx1]
st2 = self.spiketrains[idx2]
times1 = st1.times
times2 = st2.times
t_start = st2.t_start
t_stop = st2.t_stop
unit = times2.units
sync_rate = compute_sync_rate(times1, times2, time_jitt)
if sync_rate < rate:
added_spikes_t1 = 0
added_spikes_t2 = 0
spiketrains = [st1, st2]
curr_overlaps = np.floor(sync_rate * (len(times1) + len(times2)))
tot_spikes = len(times1) + len(times2)
# this assumes that: target_overlaps = curr_overlaps + add_overlaps
add_overlaps = int(np.round((curr_overlaps - rate * tot_spikes) / (rate - 1)))
# find non-overlappping spikes
annotate_overlapping_spikes(spiketrains)
st1_no_idx = np.where(spiketrains[0].annotations["overlap"] == "NO")[0]
st2_no_idx = np.where(spiketrains[1].annotations["overlap"] == "NO")[0]
st1_no = times1[st1_no_idx]
st2_no = times2[st2_no_idx]
all_times_no_shuffle = np.concatenate((st1_no, st2_no))
all_times_no_shuffle = all_times_no_shuffle[np.random.permutation(len(all_times_no_shuffle))] * unit
for t in all_times_no_shuffle:
if added_spikes_t1 + added_spikes_t2 <= add_overlaps:
# check time difference (since they are NO, they most likely won't violate ref_period)
if t in times1:
t1_jitt = (
time_jitt.rescale(unit).magnitude * np.random.rand(1)
+ t.rescale(unit).magnitude
- (time_jitt.rescale(unit) / 2).magnitude
)
if t1_jitt < t_stop:
times2 = np.concatenate((np.array(times2), np.array(t1_jitt)))
times2 = times2 * unit
added_spikes_t1 += 1
elif t in times2:
t2_jitt = (
time_jitt.rescale(unit).magnitude * np.random.rand(1)
+ t.rescale(unit).magnitude
- (time_jitt.rescale(unit) / 2).magnitude
)
if t2_jitt < t_stop:
times1 = np.concatenate((np.array(times1), np.array(t2_jitt)))
times1 = times1 * unit
added_spikes_t2 += 1
else:
break
times1 = np.sort(times1)
times2 = np.sort(times2)
# remove spike trains violating ref period
ref_violations_idxs1 = np.where(np.diff(times1) < self.params["ref_per"])[0]
ref_violations_idxs2 = np.where(np.diff(times2) < self.params["ref_per"])[0]
if len(ref_violations_idxs1) > 0:
print(f"Remove {len(ref_violations_idxs1)} violations from times1")
times1 = np.delete(times1, ref_violations_idxs1) * unit
if len(ref_violations_idxs2) > 0:
print(f"Remove {len(ref_violations_idxs2)} violations from times2")
times2 = np.delete(times2, ref_violations_idxs2) * unit
sync_rate = compute_sync_rate(times1, times2, time_jitt)
if verbose:
print(
"Added",
added_spikes_t1,
"spikes to spike train",
idxs[0],
"and",
added_spikes_t2,
"spikes to spike train",
idxs[1],
"Sync rate:",
sync_rate,
)
else:
spiketrains = [st1, st2]
annotate_overlapping_spikes(spiketrains)
max_overlaps = np.floor(rate * (len(times1) + len(times2)))
curr_overlaps = np.floor(sync_rate * (len(times1) + len(times2)))
remove_overlaps = int(curr_overlaps - max_overlaps)
if curr_overlaps > max_overlaps:
st1_to_idx = np.where(spiketrains[0].annotations["overlap"] == "TO")[0]
st2_to_idx = np.where(spiketrains[1].annotations["overlap"] == "TO")[0]
perm = np.random.permutation(len(st1_to_idx))[:remove_overlaps]
st1_ovrl_idx = st1_to_idx[perm]
st2_ovrl_idx = st2_to_idx[perm]
idx_rm_1 = st1_ovrl_idx[: remove_overlaps // 2]
idx_rm_2 = st2_ovrl_idx[remove_overlaps // 2 :]
times1 = np.delete(st1.times, idx_rm_1)
times1 = times1 * unit
times2 = np.delete(st2.times, idx_rm_2)
times2 = times2 * unit
sync_rate = compute_sync_rate(times1, times2, time_jitt)
if verbose:
print(
"Removed",
len(idx_rm_1),
"spikes from spike train",
idxs[0],
"and",
len(idx_rm_2),
"spikes from spike train",
idxs[1],
"Sync rate:",
sync_rate,
)
st1 = neo.SpikeTrain(times1, t_start=t_start, t_stop=t_stop)
st2 = neo.SpikeTrain(times2, t_start=t_start, t_stop=t_stop)
st1.annotations = self.spiketrains[idx1].annotations
st2.annotations = self.spiketrains[idx2].annotations
self.set_spiketrain(idx1, st1)
self.set_spiketrain(idx2, st2)
fr1 = len(st1.times) / st1.t_stop
fr2 = len(st2.times) / st2.t_stop
return sync_rate, fr1, fr2