import os
import shutil
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
import h5py
import MEAutility as mu
import neo
import numpy as np
import quantities as pq
import scipy.signal as ss
import yaml
from joblib import Parallel, delayed
from lazy_ops import DatasetView
from packaging.version import parse
from quantities import Quantity
from . import __version__ as mearec_version
if parse(yaml.__version__) >= parse("5.0.0"):
use_loader = True
else:
use_loader = False
def safe_yaml_load(yaml_file):
with open(yaml_file, "r") as f:
if use_loader:
data = yaml.load(f, Loader=yaml.FullLoader)
else:
data = yaml.load(f)
return data
### GET DEFAULT SETTINGS ###
[docs]def get_default_config(print_version=False):
"""
Returns default_info and mearec_home path.
Returns
-------
default_info : dict
Default_info from config file
mearec_path : str
Mearec home path
"""
this_dir, this_filename = os.path.split(__file__)
this_dir = Path(this_dir)
home = Path(os.path.expanduser("~"))
mearec_home = home / ".config" / "mearec"
version_folder = mearec_home / mearec_version
if print_version:
print(f"MEArec version: {mearec_version}\n")
if not mearec_home.is_dir():
mearec_home.mkdir(exist_ok=True, parents=True)
if not (version_folder / "mearec.conf").is_file():
version_folder.mkdir(exist_ok=True, parents=True)
shutil.copytree(str(this_dir / "default_params"), str(version_folder / "default_params"))
shutil.copytree(str(this_dir / "cell_models"), str(version_folder / "cell_models"))
default_info = {
"templates_params": str(version_folder / "default_params" / "templates_params.yaml"),
"recordings_params": str(version_folder / "default_params" / "recordings_params.yaml"),
"templates_folder": str(version_folder / "templates"),
"recordings_folder": str(version_folder / "recordings"),
"cell_models_folder": str(version_folder / "cell_models" / "bbp"),
}
with (version_folder / "mearec.conf").open("w") as f:
yaml.dump(default_info, f)
else:
default_info = safe_yaml_load(version_folder / "mearec.conf")
return default_info, str(mearec_home)
[docs]def get_default_cell_models_folder():
"""
Returns default cell models folder.
Returns
-------
cell_models_folder : str
Path to default cell models folder
"""
default_info, mearec_home = get_default_config()
cell_models_folder = default_info["cell_models_folder"]
return cell_models_folder
[docs]def get_default_templates_params():
"""
Returns default templates parameters.
Returns
-------
templates_params : dict
Dictionary with default teplates parameters
"""
default_info, mearec_home = get_default_config()
templates_params_file = default_info["templates_params"]
# load template parameters
templates_params = safe_yaml_load(templates_params_file)
return templates_params
[docs]def get_default_recordings_params():
"""
Returns default recordings parameters.
Returns
-------
recordings_params : dict
Dictionary with default recording parameters
"""
default_info, mearec_home = get_default_config()
recordings_params_file = default_info["recordings_params"]
# load template parameters
recordings_params = safe_yaml_load(recordings_params_file)
return recordings_params
def get_default_drift_dict():
return {
"drift_mode_speed": "slow",
"drift_mode_probe": "rigid",
"drift_fs": 100,
"non_rigid_gradient_mode": "linear",
"non_rigid_linear_direction": 1,
"non_rigid_step_depth_boundary": None,
"non_rigid_step_factors": None,
"slow_drift_velocity": 5,
"slow_drift_amplitude": None,
"slow_drift_waveform": "triangluar",
"fast_drift_period": 10,
"fast_drift_max_jump": 20,
"fast_drift_min_jump": 5,
"t_start_drift": None,
"t_end_drift": None,
"external_drift_vector_um": None,
"external_drift_times": None,
"external_drift_factors": None,
}
def available_probes():
"""
Returns list of available probes
Returns
-------
probe_list : list
List of available probes in MEAutility
"""
return mu.return_mea_list()
### LOAD FUNCTIONS ###
def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose=False):
"""
Loads EAP from temporary folder.
Parameters
----------
templates_folder : str
Path to temporary folder
celltypes : list (optional)
List of celltypes to be loaded
samples_per_cat : int (optional)
The number of eap to load per category
Returns
-------
templates : np.array
Templates (n_eap, n_elec, n_sample)
locations : np.array
Locations (n_eap, 3)
rotations : np.array
Rotations (n_eap, 3)
celltypes : np.array
Cell types (n_eap)
"""
if verbose:
print("Loading eap data ...")
templates_folder = Path(templates_folder)
eaplist = [f for f in templates_folder.iterdir() if f.name.startswith("eap")]
loclist = [f for f in templates_folder.iterdir() if f.name.startswith("pos")]
rotlist = [f for f in templates_folder.iterdir() if f.name.startswith("rot")]
eap_list = []
loc_list = []
rot_list = []
cat_list = []
eaplist = sorted(eaplist)
loclist = sorted(loclist)
rotlist = sorted(rotlist)
loaded_categories = set()
ignored_categories = set()
for idx, f in enumerate(eaplist):
celltype = f.parts[-1].split("-")[1][:-4]
if verbose:
print("loading cell type: ", f)
if celltypes is not None:
if celltype in celltypes:
eaps = np.load(str(eaplist[idx]), mmap_mode="r")
locs = np.load(str(loclist[idx]))
rots = np.load(str(rotlist[idx]))
if samples_per_cat is None or samples_per_cat > len(eaps):
samples_to_read = len(eaps)
else:
samples_to_read = samples_per_cat
eap_list.extend(eaps[:samples_to_read])
rot_list.extend(rots[:samples_to_read])
loc_list.extend(locs[:samples_to_read])
cat_list.extend([celltype] * samples_to_read)
loaded_categories.add(celltype)
else:
ignored_categories.add(celltype)
else:
eaps = np.load(str(eaplist[idx]), mmap_mode="r")
locs = np.load(str(loclist[idx]))
rots = np.load(str(rotlist[idx]))
if samples_per_cat is None or samples_per_cat > len(eaps):
samples_to_read = len(eaps)
else:
samples_to_read = samples_per_cat
eap_list.extend(eaps[:samples_to_read])
rot_list.extend(rots[:samples_to_read])
loc_list.extend(locs[:samples_to_read])
cat_list.extend([celltype] * samples_to_read)
loaded_categories.add(celltype)
if len(eap_list) > 0:
all_eaps = np.lib.format.open_memmap(templates_folder / "all_eaps.npy", mode="w+", dtype=eaps[0].dtype, shape=(len(eap_list), *eap_list[0].shape))
for i in range(len(eap_list)):
all_eaps[i, ...] = eap_list[i]
else:
all_eaps = np.array([])
if verbose:
print("Done loading spike data ...")
return all_eaps, np.array(loc_list), np.array(rot_list), np.array(cat_list, dtype=str)
[docs]def load_templates(templates, return_h5_objects=True, verbose=False, check_suffix=True):
"""
Load generated eap templates.
Parameters
----------
templates : str or Path object
templates file
return_h5_objects : bool
If True output objects are h5 objects
verbose : bool
If True output is verbose
check_suffix : bool
If True, hdf5 suffix is checked
Returns
-------
tempgen : TemplateGenerator
TemplateGenerator object
"""
from MEArec import TemplateGenerator
if verbose:
print("Loading templates...")
temp_dict = {}
templates = Path(templates)
if (templates.suffix in [".h5", ".hdf5"]) or (not check_suffix):
f = h5py.File(str(templates), "r")
info = load_dict_from_hdf5(f, "info/")
celltypes = np.array(f.get("celltypes"))
temp_dict["celltypes"] = np.array([c.decode("utf-8") for c in celltypes])
if return_h5_objects:
temp_dict["locations"] = f.get("locations")
else:
temp_dict["locations"] = np.array(f.get("locations"))
if return_h5_objects:
temp_dict["rotations"] = f.get("rotations")
else:
temp_dict["rotations"] = np.array(f.get("rotations"))
if return_h5_objects:
temp_dict["templates"] = f.get("templates")
else:
temp_dict["templates"] = np.array(f.get("templates"))
else:
raise Exception("Recordings must be an hdf5 file (.h5 or .hdf5)")
if verbose:
print("Done loading templates...")
if not return_h5_objects:
f.close()
tempgen = TemplateGenerator(temp_dict=temp_dict, info=info)
return tempgen
[docs]def load_recordings(
recordings, return_h5_objects=True, load=None, load_waveforms=True, check_suffix=True, verbose=False
):
"""
Load generated recordings.
Parameters
----------
recordings : str or Path object
Recordings file
return_h5_objects : bool
If True output objects are h5 objects
load : list
List of fields to be loaded (('recordings', 'channel_positions', 'voltage_peaks', 'spiketrains',
'timestamps', 'spike_traces', 'templates'))
load_waveforms : bool
If True waveforms are loaded to spiketrains
verbose : bool
If True output is verbose
check_suffix : bool
If True, hdf5 suffix is checked
Returns
-------
recgen : RecordingGenerator
RecordingGenerator object
"""
from MEArec import RecordingGenerator
if verbose:
print("Loading recordings...")
rec_dict = {}
recordings = Path(recordings)
if (recordings.suffix in [".h5", ".hdf5"]) or (not check_suffix):
f = h5py.File(str(recordings), "r")
mearec_version = f.attrs.get("mearec_version", "1.4.0")
if parse(mearec_version) >= parse("1.5.0"):
# version after 1.5.0 is (n_samples, n_channel) inside the h5 file
need_transpose = False
else:
# version 1.4.0 and before is (n_channel, n_samples) inside the h5 file
print(
"Warning: MEArec file created with version <1.5. This could result in lower efficiency. To upgrade"
"your file to the new format use: mr.convert_recording_to_new_version(filename)"
)
need_transpose = True
rec_dict, info = load_recordings_from_file(
f,
return_h5_objects=return_h5_objects,
load=load,
need_transpose=need_transpose,
load_waveforms=load_waveforms,
)
else:
raise Exception("Recordings must be an hdf5 file (.h5 or .hdf5)")
if verbose:
print("Done loading recordings...")
if not return_h5_objects:
f.close()
recgen = RecordingGenerator(rec_dict=rec_dict, info=info)
if "gain_to_uV" in rec_dict:
recgen.gain_to_uV = rec_dict["gain_to_uV"]
return recgen
def load_recordings_from_file(f, path="", return_h5_objects=True, load=None, need_transpose=False, load_waveforms=True):
"""
Load generated recordings from file.
Parameters
----------
filename : _io.TextIOWrapper
File handler
path: str
Path inside the h5 database
return_h5_objects : bool
If True output objects are h5 objects
load : list
List of fields to be loaded (('recordings', 'channel_positions', 'voltage_peaks', 'spiketrains',
'timestamps', 'spike_traces', 'templates'))
load_waveforms : bool
If True waveforms are loaded to spiketrains
verbose : bool
If True output is verbose
check_suffix : bool
If True, hdf5 suffix is checked
Returns
-------
recgen : RecordingGenerator
RecordingGenerator object
"""
if load is None:
load = [
"recordings",
"channel_positions",
"voltage_peaks",
"spiketrains",
"timestamps",
"spike_traces",
"templates",
"template_ids",
"drift_dict",
]
else:
assert isinstance(load, list), (
"'load' should be a list with strings of what to be loaded "
"('recordings', 'channel_positions', 'voltage_peaks', 'spiketrains', "
"'timestamps', 'spike_traces', 'templates')"
)
rec_dict = {}
info = load_dict_from_hdf5(f, path + "info/")
if f.get(path + "voltage_peaks") is not None and "voltage_peaks" in load:
if return_h5_objects:
rec_dict["voltage_peaks"] = f.get(path + "voltage_peaks")
else:
rec_dict["voltage_peaks"] = np.array(f.get(path + "voltage_peaks"))
if f.get(path + "channel_positions") is not None and "channel_positions" in load:
if return_h5_objects:
rec_dict["channel_positions"] = f.get(path + "channel_positions")
else:
rec_dict["channel_positions"] = np.array(f.get(path + "channel_positions"))
if f.get(path + "recordings") is not None and "recordings" in load:
if return_h5_objects:
if need_transpose:
rec_dict["recordings"] = DatasetView(f.get(path + "recordings")).lazy_transpose()
else:
rec_dict["recordings"] = f.get(path + "recordings")
else:
arr = np.array(f.get(path + "recordings"))
if need_transpose:
arr = arr.T
rec_dict["recordings"] = arr
if "gain_to_uV" in f.get(path + "recordings").attrs:
rec_dict["gain_to_uV"] = f.get(path + "recordings").attrs["gain_to_uV"]
if f.get(path + "spike_traces") is not None and "spike_traces" in load:
if return_h5_objects:
if need_transpose:
rec_dict["spike_traces"] = DatasetView(f.get(path + "spike_traces")).lazy_transpose()
else:
rec_dict["spike_traces"] = f.get(path + "spike_traces")
else:
arr = np.array(f.get(path + "spike_traces"))
if need_transpose:
arr = arr.T
rec_dict["spike_traces"] = arr
if f.get(path + "templates") is not None and "templates" in load:
if return_h5_objects:
rec_dict["templates"] = f.get(path + "templates")
else:
rec_dict["templates"] = np.array(f.get(path + "templates"))
if f.get(path + "original_templates") is not None:
if return_h5_objects:
rec_dict["original_templates"] = f.get(path + "original_templates")
else:
rec_dict["original_templates"] = np.array(f.get(path + "original_templates"))
if f.get(path + "template_locations") is not None:
if return_h5_objects:
rec_dict["template_locations"] = f.get(path + "template_locations")
else:
rec_dict["template_locations"] = np.array(f.get(path + "template_locations"))
if f.get(path + "template_rotations") is not None:
if return_h5_objects:
rec_dict["template_rotations"] = f.get(path + "template_rotations")
else:
rec_dict["template_rotations"] = np.array(f.get(path + "template_rotations"))
if f.get(path + "template_celltypes") is not None:
celltypes = np.array(([n.decode() for n in f.get(path + "template_celltypes")]))
rec_dict["template_celltypes"] = np.array(celltypes)
if f.get(path + "timestamps") is not None and "timestamps" in load:
if return_h5_objects:
rec_dict["timestamps"] = f.get(path + "timestamps")
else:
rec_dict["timestamps"] = np.array(f.get(path + "timestamps")) * pq.s
if f.get(path + "template_ids") is not None and "template_ids" in load:
rec_dict["template_ids"] = f.get(path + "template_ids")
if f.get(path + "spiketrains") is not None and "spiketrains" in load:
spiketrains = []
sorted_units = sorted([int(u) for u in f.get(path + "spiketrains/")])
for unit in sorted_units:
unit = str(unit)
times = np.array(f.get(path + "spiketrains/" + unit + "/times"))
t_stop = np.array(f.get(path + "spiketrains/" + unit + "/t_stop"))
if f.get(path + "spiketrains/" + unit + "/waveforms") is not None and load_waveforms:
waveforms = np.array(f.get(path + "spiketrains/" + unit + "/waveforms"))
else:
waveforms = None
annotations = load_dict_from_hdf5(f, path + "spiketrains/" + unit + "/annotations/")
st = neo.core.SpikeTrain(times, t_stop=t_stop, waveforms=waveforms, units=pq.s)
st.annotations = annotations
spiketrains.append(st)
rec_dict["spiketrains"] = spiketrains
if f.get(path + "drift_list") is not None:
drift_list = []
for i in f.get(path + "drift_list").keys():
drift_dict = load_dict_from_hdf5(f, path + "drift_list/" + str(i) + "/")
drift_list.append(drift_dict)
rec_dict["drift_list"] = drift_list
else:
rec_dict["drift_list"] = None
return rec_dict, info
[docs]def save_template_generator(tempgen, filename=None, verbose=True):
"""
Save templates to disk.
Parameters
----------
tempgen : TemplateGenerator
TemplateGenerator object to be saved
filename : str
Path to .h5 file
verbose : bool
If True output is verbose
"""
filename = Path(filename)
if not filename.parent.is_dir():
os.makedirs(str(filename.parent))
assert filename.suffix in [".h5", ".hdf5"], "Provide an .h5 or .hdf5 file name"
with h5py.File(filename, "w") as f:
save_dict_to_hdf5(tempgen.info, f, "info/")
f.attrs["date"] = datetime.now().strftime("%y-%m-%d %H:%M:%S")
if len(tempgen.celltypes) > 0:
celltypes = [str(x).encode("utf-8") for x in tempgen.celltypes]
f.create_dataset("celltypes", data=celltypes)
if len(tempgen.locations) > 0:
f.create_dataset("locations", data=tempgen.locations)
if len(tempgen.rotations) > 0:
f.create_dataset("rotations", data=tempgen.rotations)
if len(tempgen.templates) > 0:
f.create_dataset("templates", data=tempgen.templates)
if verbose:
print("\nSaved templates in", filename, "\n")
[docs]def save_recording_generator(recgen, filename=None, verbose=False, include_spike_traces: bool = True):
"""
Save recordings to disk.
Parameters
----------
recgen : RecordingGenerator
RecordingGenerator object to be saved
filename : str
Path to .h5 file
verbose : bool
If True output is verbose
include_spike_traces: bool, default=True
If True, will include the spike traces (which can be large for many units)
"""
filename = Path(filename)
if not filename.parent.is_dir():
os.makedirs(str(filename.parent))
assert filename.suffix in [".h5", ".hdf5"], "Provide an .h5 or .hdf5 file name"
with h5py.File(filename, "w") as f:
f.attrs["mearec_version"] = mearec_version
f.attrs["date"] = datetime.now().strftime("%y-%m-%d %H:%M:%S")
save_recording_to_file(recgen, f, include_spike_traces=include_spike_traces)
if verbose:
print("\nSaved recordings in", filename, "\n")
def save_recording_to_file(recgen, f, path="", include_spike_traces: bool = True):
"""
Save recordings to file handler.
Parameters
----------
recgen : RecordingGenerator
RecordingGenerator object to be saved
filename : _io.TextIOWrapper
File handler
include_spike_traces: bool, default=True
If True, will include the spike traces (can be heavy)
"""
save_dict_to_hdf5(recgen.info, f, path + "info/")
if len(recgen.voltage_peaks) > 0:
f.create_dataset(path + "voltage_peaks", data=recgen.voltage_peaks)
if len(recgen.channel_positions) > 0:
f.create_dataset(path + "channel_positions", data=recgen.channel_positions)
if len(recgen.recordings) > 0:
f.create_dataset(path + "recordings", data=recgen.recordings)
if recgen.gain_to_uV is not None:
f["recordings"].attrs["gain_to_uV"] = recgen.gain_to_uV
if len(recgen.spike_traces) > 0 and include_spike_traces:
f.create_dataset(path + "spike_traces", data=recgen.spike_traces)
if len(recgen.spiketrains) > 0:
for ii in range(len(recgen.spiketrains)):
st = recgen.spiketrains[ii]
f.create_dataset(path + "spiketrains/{}/times".format(ii), data=st.times.rescale("s").magnitude)
f.create_dataset(path + "spiketrains/{}/t_stop".format(ii), data=st.t_stop)
if st.waveforms is not None:
f.create_dataset(path + "spiketrains/{}/waveforms".format(ii), data=st.waveforms)
save_dict_to_hdf5(st.annotations, f, path + "spiketrains/{}/annotations/".format(ii))
if len(recgen.templates) > 0:
f.create_dataset(path + "templates", data=recgen.templates)
if len(recgen.original_templates) > 0:
f.create_dataset(path + "original_templates", data=recgen.original_templates)
if len(recgen.template_locations) > 0:
f.create_dataset(path + "template_locations", data=recgen.template_locations)
if len(recgen.template_rotations) > 0:
f.create_dataset(path + "template_rotations", data=recgen.template_rotations)
if len(recgen.template_celltypes) > 0:
celltypes = [n.encode("ascii", "ignore") for n in recgen.template_celltypes]
f.create_dataset(path + "template_celltypes", data=celltypes)
if len(recgen.timestamps) > 0:
f.create_dataset(path + "timestamps", data=recgen.timestamps)
if hasattr(recgen, "template_ids"):
if recgen.template_ids is not None:
f.create_dataset(path + "template_ids", data=recgen.template_ids)
if recgen.drift_list is not None:
for i, drift_dict in enumerate(recgen.drift_list):
save_dict_to_hdf5(drift_dict, f, path + "drift_list/" + str(i) + "/")
def save_dict_to_hdf5(dic, h5file, path):
"""
Save dictionary to h5 file.
Parameters
----------
dic : dict
Dictionary to be saved
h5file : file
Hdf5 file object
path : str
Path to the h5 field
"""
recursively_save_dict_contents_to_group(h5file, path, dic)
def recursively_save_dict_contents_to_group(h5file, path, dic):
"""
Save dictionary recursively to h5 file (helper function).
Parameters
----------
dic : dict
Dictionary to be saved
h5file : file
Hdf5 file object
path : str
Path to the h5 field
"""
for key, item in dic.items():
if isinstance(item, (int, float, np.integer, str, bytes, np.bool_)):
if isinstance(item, np.str_):
item = str(item)
h5file[path + key] = item
elif isinstance(item, pq.Quantity):
h5file[path + key] = float(item.magnitude)
elif isinstance(item, (list, np.ndarray)):
if len(item) > 0:
if isinstance(item[0], (str, bytes)):
item = [n.encode("ascii", "ignore") for n in item]
h5file[path + key] = np.array(item)
else:
h5file[path + key] = np.array(item)
else:
item = "[]"
h5file[path + key] = item
elif isinstance(item, tuple):
h5file[path + key] = np.array(item)
elif item is None:
h5file[path + key] = "null"
elif isinstance(item, dict):
recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
else:
print(key, item)
raise ValueError("Cannot save %s type" % type(item))
def load_dict_from_hdf5(h5file, path):
"""
Load h5 object as dict.
Parameters
----------
h5file :file
Hdf5 file object
path : str
Path to the h5 field
Returns
-------
dictionary : dict
Loaded dictionary
"""
return recursively_load_dict_contents_from_group(h5file, path)
def recursively_load_dict_contents_from_group(h5file, path):
"""
Load h5 object as dict recursively (helper function).
Parameters
----------
h5file :file
Hdf5 file object
path : str
Path to the h5 field
Returns
-------
dictionary : dict
Loaded dictionary
"""
ans = {}
for key, item in h5file[path].items():
if isinstance(item, h5py._hl.dataset.Dataset):
# handle bytes strings
if isinstance(item[()], bytes):
ans[key] = item[()].decode()
else:
ans[key] = item[()]
elif isinstance(item, h5py._hl.group.Group):
ans[key] = recursively_load_dict_contents_from_group(h5file, path + key + "/")
return clean_dict(ans)
def clean_dict(d):
"""
Clean dictionary loaded from h5 file.
Parameters
----------
d : dict
Dictionary to be cleaned.
Returns
-------
d : dict
Cleaned dictionary
"""
for key, item in d.items():
if isinstance(item, dict):
clean_dict(item)
elif isinstance(item, str):
if item == "null":
d[key] = None
elif item == "[]":
d[key] = np.array([])
elif isinstance(item, np.ndarray):
if len(item) > 0:
if isinstance(item[0], np.bytes_):
d[key] = list([n.decode() for n in item])
else:
d[key] = list(item)
return d
def _clean_numpy_scalar(v):
if isinstance(v, np.bool_):
v = bool(v)
if isinstance(v, np.float_):
v = float(v)
if isinstance(v, np.int_):
v = int(v)
return v
def clean_dict_for_yaml(d):
"""
Clean dictionary before saving to yaml
Parameters
----------
d : dict
Dictionary to be cleaned.
Returns
-------
d : dict
Cleaned dictionary
"""
d2 = d.copy()
for k, v in d2.items():
d2[k] = _clean_numpy_scalar(v)
if isinstance(v, list):
d2[k] = [_clean_numpy_scalar(e) for e in v]
return d2
def convert_recording_to_new_version(filename, new_filename=None):
"""
Converts MEArec h5 file from a version <1.5 to the new format >=1.5.
Parameters
----------
filename: str
Path to original .h5 file
new_filename: str (optional)
Path to new .h5 file. If None (default), the original file is overwritten
"""
filename = Path(filename)
assert filename.suffix in [".h5", ".hdf5"], "Provide an .h5 or .hdf5 file name"
if new_filename is not None:
new_filename = Path(new_filename)
assert new_filename.suffix in [".h5", ".hdf5"]
shutil.copy(filename, new_filename)
with h5py.File(filename, "r+") as f:
mearec_version_in_file = f.attrs.get("mearec_version", "1.4.0")
if parse(mearec_version_in_file) >= parse("1.5.0"):
print("The provided mearec file is already up to date")
else:
# version 1.4.0 and before is (n_channel, n_samples) inside the h5 file
print("Updating file")
recordings = f.get("recordings")[:]
spike_traces = f.get("spike_traces")[:]
if new_filename is not None:
with h5py.File(new_filename, "r+") as fnew:
del fnew["recordings"]
del fnew["spike_traces"]
fnew.create_dataset("recordings", data=recordings.T)
fnew.create_dataset("spike_traces", data=spike_traces.T)
fnew.attrs["mearec_version"] = mearec_version
else:
del f["recordings"]
del f["spike_traces"]
f.create_dataset("recordings", data=recordings.T)
f.create_dataset("spike_traces", data=spike_traces.T)
f.attrs["mearec_version"] = mearec_version
### TEMPLATES INFO ###
def get_binary_cat(celltypes, excit, inhib):
"""
Returns binary category depending on cell type.
Parameters
----------
celltypes : np.array
String array with cell types
excit : list
List of substrings for excitatory cell types (e.g. ['PC', 'UTPC'])
inhib : list
List of substrings for inhibitory celltypes (e.g. ['BP', 'MC'])
Returns
-------
binary_cat : np.array
Array with binary cell type (E-I)
"""
binary_cat = []
sample_type = celltypes[0]
# Find if bbp or custom models
if sample_type.startswith("L") and len(sample_type.split("_")) == 4:
models = "bbp"
else:
models = "custom"
for i, cat in enumerate(celltypes):
if models == "bbp":
cell_str = str(cat).split("_")[1]
else:
cell_str = str(cat)
if np.any([ex in cell_str for ex in excit]):
binary_cat.append("E")
elif np.any([inh in str(cat) for inh in inhib]):
binary_cat.append("I")
else:
binary_cat.append("U")
return np.array(binary_cat, dtype=str)
def get_templates_features(templates, feat_list, dt=None, templates_times=None, threshold_detect=0):
"""
Computes several templates features.
Parameters
----------
templates : np.array
EAP templates
feat_list : list
List of features to be computed (amp, width, fwhm, ratio, speed, neg, pos)
dt : float
Sampling period
threshold_detect : float
Threshold to zero out features
Returns
-------
feature_dict : dict
Dictionary with features (keys: amp, width, fwhm, ratio, speed, neg, pos)
"""
if dt is not None:
templates_times = np.arange(templates.shape[-1]) * dt
else:
if "width" in feat_list or "fwhm" in feat_list or "speed" in feat_list:
raise NotImplementedError("Please, specify either dt or templates_times.")
if len(templates.shape) == 1:
templates = np.reshape(templates, [1, 1, -1])
elif len(templates.shape) == 2:
templates = np.reshape(templates, [1, templates.shape[0], templates.shape[1]])
if len(templates.shape) != 3:
raise ValueError("Cannot handle templatess with shape", templates.shape)
features = {}
amps = np.zeros((templates.shape[0], templates.shape[1]))
na_peak = np.zeros((templates.shape[0], templates.shape[1]))
rep_peak = np.zeros((templates.shape[0], templates.shape[1]))
if "width" in feat_list:
features["width"] = np.zeros((templates.shape[0], templates.shape[1]))
if "fwhm" in feat_list:
features["fwhm"] = np.zeros((templates.shape[0], templates.shape[1]))
if "ratio" in feat_list:
features["ratio"] = np.zeros((templates.shape[0], templates.shape[1]))
if "speed" in feat_list:
features["speed"] = np.zeros((templates.shape[0], templates.shape[1]))
if "neg" in feat_list:
features["neg"] = np.zeros((templates.shape[0], templates.shape[1]))
if "pos" in feat_list:
features["pos"] = np.zeros((templates.shape[0], templates.shape[1]))
for i in range(templates.shape[0]):
# For AMP feature
min_idx = np.array(
[np.unravel_index(templates[i, e].argmin(), templates[i, e].shape)[0] for e in range(templates.shape[1])]
)
max_idx = np.array(
[
np.unravel_index(templates[i, e, min_idx[e] :].argmax(), templates[i, e, min_idx[e] :].shape)[0]
+ min_idx[e]
for e in range(templates.shape[1])
]
)
# for na and rep
min_elid, min_idx_na = np.unravel_index(templates[i].argmin(), templates[i].shape)
max_idx_rep = templates[i, min_elid, min_idx_na:].argmax() + min_idx_na
na_peak[i, :] = templates[i, :, min_idx_na]
rep_peak[i, :] = templates[i, :, max_idx_rep]
amps[i, :] = np.array(
[templates[i, e, max_idx[e]] - templates[i, e, min_idx[e]] for e in range(templates.shape[1])]
)
too_low = np.where(amps[i, :] < threshold_detect)
amps[i, too_low] = 0
if "ratio" in feat_list:
min_id_ratio = np.array(
[
np.unravel_index(templates[i, e, min_idx_na:].argmin(), templates[i, e, min_idx_na:].shape)[0]
+ min_idx_na
for e in range(templates.shape[1])
]
)
max_id_ratio = np.array(
[
np.unravel_index(templates[i, e, min_idx_na:].argmax(), templates[i, e, min_idx_na:].shape)[0]
+ min_idx_na
for e in range(templates.shape[1])
]
)
features["ratio"][i, :] = np.array(
[
np.abs(templates[i, e, max_id_ratio[e]]) / np.abs(templates[i, e, min_id_ratio[e]])
for e in range(templates.shape[1])
]
)
# If below 'detectable threshold, set amp and width to 0
too_low = np.where(amps[i, :] < threshold_detect)
features["ratio"][i, too_low] = 1
if "speed" in feat_list:
features["speed"][i, :] = np.array((min_idx - min_idx_na) * dt)
features["speed"][i, too_low] = min_idx_na * dt
if "width" in feat_list:
features["width"][i, :] = np.abs(templates_times[max_idx] - templates_times[min_idx])
features["width"][i, too_low] = templates.shape[2] * dt # templates_times[-1]-templates_times[0]
if "fwhm" in feat_list:
import scipy.signal as ss
min_peak = np.min(templates[i], axis=1)
fwhm_ref = np.array([templates[i, e, 0] for e in range(templates.shape[1])])
fwhm_V = (fwhm_ref + min_peak) / 2.0
id_trough = [np.where(templates[i, e] < fwhm_V[e])[0] for e in range(templates.shape[1])]
# no linear interpolation
features["fwhm"][i, :] = [
(id_trough[e][-1] - id_trough[e][0]) * dt if len(id_trough[e]) > 1 else templates.shape[2] * dt
for e in range(templates.shape[1])
]
features["fwhm"][i, too_low] = templates.shape[2] * dt # EAP_times[-1]-EAP_times[0]
if "amp" in feat_list:
features.update({"amp": amps})
if "neg" in feat_list:
features.update({"neg": na_peak})
if "pos" in feat_list:
features.update({"pos": rep_peak})
return features
### TEMPLATES OPERATIONS ###
def is_position_within_boundaries(position, x_lim, y_lim, z_lim):
"""
Check if position is within given boundaries.
Parameters
----------
position : np.array
3D position
x_lim : list
Boundaries in x dimension (low, high)
y_lim : list
Boundaries in y dimension (low, high)
z_lim : list
Boundaries in z dimension (low, high)
Returns
-------
valid_position : bool
If True the position is within boundaries
"""
valid_position = True
if x_lim is not None:
if position[0] < x_lim[0] or position[0] > x_lim[1]:
valid_position = False
if y_lim is not None:
if position[1] < y_lim[0] or position[1] > y_lim[1]:
valid_position = False
if z_lim is not None:
if position[2] < z_lim[0] or position[2] > z_lim[1]:
valid_position = False
return valid_position
def select_templates(
loc,
templates,
bin_cat,
n_exc,
n_inh,
min_dist=25,
x_lim=None,
y_lim=None,
z_lim=None,
min_amp=None,
max_amp=None,
drifting=False,
drift_dir=None,
preferred_dir=None,
angle_tol=15,
n_overlap_pairs=None,
overlap_threshold=0.8,
verbose=False,
):
"""
Select templates given specified rules.
Parameters
----------
loc : np.array
Array with 3D soma locations
templates : np.array
Array with eap templates (n_eap, n_channels, n_samples)
bin_cat : np.array
Array with binary category (E-I)
n_exc : int
Number of excitatory cells to be selected
n_inh : int
Number of inhibitory cells to be selected
min_dist : float
Minimum allowed distance between somata (in um)
x_lim : list
Boundaries in x dimension (low, high)
y_lim : list
Boundaries in y dimension (low, high)
z_lim : list
Boundaries in z dimension (low, high)
min_amp : float
Minimum amplitude in uV
max_amp : float
Maximum amplitude in uV
drifting : bool
If True drifting templates are selected
drift_dir : np.array
3D array with drift direction for each template
preferred_dir : np.array
3D array with preferred
angle_tol : float
Tollerance in degrees for selecting final drift position
n_overlap_pairs: int
Number of spatially overlapping templates to select
overlap_threshold: float
Threshold for considering spatially overlapping pairs ([0-1])
verbose : bool
If True the output is verbose
Returns
-------
selected_idxs : np.array
Selected template indexes
selected_cat : list
Selected templates binary type
"""
pos_sel = []
selected_idxs = []
categories = np.unique(bin_cat)
if bin_cat is not None and "E" in categories and "I" in categories:
if verbose:
print("Selecting Excitatory and Inhibitory cells")
excinh = True
selected_cat = []
else:
if verbose:
print("Selecting random templates (cell types not specified)")
excinh = False
selected_cat = []
permuted_idxs = np.random.permutation(len(loc))
if bin_cat is not None:
permuted_bin_cats = bin_cat[permuted_idxs]
else:
permuted_bin_cats = ["U"] * len(loc)
if verbose:
print("Min dist: ", min_dist, "Min amp: ", min_amp)
if min_amp is None:
min_amp = 0
if max_amp is None:
max_amp = np.inf
if drifting:
if drift_dir is None or preferred_dir is None:
raise Exception("For drift selection provide drifting angles and preferred drift direction")
n_sel = 0
n_sel_exc = 0
n_sel_inh = 0
iter = 0
current_overlapping_pairs = 0
for i, (id_cell, bcat) in enumerate(zip(permuted_idxs, permuted_bin_cats)):
placed = False
iter += 1
if n_sel == n_exc + n_inh:
break
# Excitatory and inhibitory cells
if excinh:
# excitatory cell
if bcat == "E":
if n_sel_exc < n_exc:
dist = np.array([np.linalg.norm(loc[id_cell] - p) for p in pos_sel])
if np.any(dist < min_dist):
if verbose:
print("Distance violation", np.min(dist), iter)
pass
else:
amp = np.max(np.abs(np.min(templates[id_cell])))
if not drifting:
if (
is_position_within_boundaries(loc[id_cell], x_lim, y_lim, z_lim)
and min_amp < amp < max_amp
):
# save cell
if n_overlap_pairs is None:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if len(selected_idxs) == 0:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
possible_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(possible_selected)], overlap_threshold
)
)
current_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(selected_idxs)], overlap_threshold
)
)
if (
current_overlapping_pairs < n_overlap_pairs
and possible_overlapping_pairs <= n_overlap_pairs
):
if possible_overlapping_pairs == current_overlapping_pairs:
continue
else:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
if verbose:
print("Number of overlapping pairs:", possible_overlapping_pairs)
else:
if possible_overlapping_pairs == current_overlapping_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if verbose:
print("Overlapping violation:", possible_overlapping_pairs)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell], iter)
else:
# drifting
if (
is_position_within_boundaries(loc[id_cell, 0], x_lim, y_lim, z_lim)
and min_amp < amp < max_amp
):
# save cell
drift_angle = np.rad2deg(np.arccos(np.dot(drift_dir[id_cell], preferred_dir)))
if drift_angle - angle_tol <= 0:
if n_overlap_pairs is None:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if len(selected_idxs) == 0:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
possible_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(possible_selected), 0], overlap_threshold
)
)
current_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(selected_idxs), 0], overlap_threshold
)
)
if (
current_overlapping_pairs < n_overlap_pairs
and possible_overlapping_pairs <= n_overlap_pairs
):
if possible_overlapping_pairs == current_overlapping_pairs:
continue
else:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
if verbose:
print(
"Number of overlapping pairs:", possible_overlapping_pairs
)
else:
if possible_overlapping_pairs == current_overlapping_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if verbose:
print("Overlapping violation:", possible_overlapping_pairs)
else:
if verbose:
print("Drift violation", loc[id_cell, 0], iter)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell, 0], iter)
if placed:
n_sel_exc += 1
selected_cat.append("E")
# inhibitory cell
elif bcat == "I":
if n_sel_inh < n_inh:
dist = np.array([np.linalg.norm(loc[id_cell] - p) for p in pos_sel])
if np.any(dist < min_dist):
if verbose:
print("Distance violation", np.min(dist), iter)
pass
else:
amp = np.max(np.abs(np.min(templates[id_cell])))
if not drifting:
if (
is_position_within_boundaries(loc[id_cell], x_lim, y_lim, z_lim)
and min_amp < amp < max_amp
):
# save cell
if n_overlap_pairs is None:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if len(selected_idxs) == 0:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
possible_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(possible_selected)], overlap_threshold
)
)
current_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(selected_idxs)], overlap_threshold
)
)
if (
current_overlapping_pairs < n_overlap_pairs
and possible_overlapping_pairs <= n_overlap_pairs
):
if possible_overlapping_pairs == current_overlapping_pairs:
continue
else:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
if verbose:
print("Number of overlapping pairs:", possible_overlapping_pairs)
else:
if possible_overlapping_pairs == current_overlapping_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if verbose:
print("Overlapping violation:", possible_overlapping_pairs)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell], iter)
else:
# drifting
if (
is_position_within_boundaries(loc[id_cell, 0], x_lim, y_lim, z_lim)
and min_amp < amp < max_amp
):
# save cell
drift_angle = np.rad2deg(np.arccos(np.dot(drift_dir[id_cell], preferred_dir)))
if drift_angle - angle_tol <= 0:
if n_overlap_pairs is None:
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if len(selected_idxs) == 0:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
possible_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(possible_selected), 0], overlap_threshold
)
)
current_overlapping_pairs = len(
find_overlapping_templates(
templates[sorted(selected_idxs), 0], overlap_threshold
)
)
if (
current_overlapping_pairs < n_overlap_pairs
and possible_overlapping_pairs <= n_overlap_pairs
):
if possible_overlapping_pairs == current_overlapping_pairs:
continue
else:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
if verbose:
print(
"Number of overlapping pairs:", possible_overlapping_pairs
)
else:
if possible_overlapping_pairs == current_overlapping_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if verbose:
print("Overlapping violation:", possible_overlapping_pairs)
else:
if verbose:
print("Drift violation", loc[id_cell], iter)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell, 0], iter)
if placed:
n_sel_inh += 1
selected_cat.append("I")
# unknown cell type
else:
dist = np.array([np.linalg.norm(loc[id_cell] - p) for p in pos_sel])
if np.any(dist < min_dist):
if verbose:
print("Distance violation", np.min(dist), iter)
pass
else:
amp = np.max(np.abs(np.min(templates[id_cell])))
if not drifting:
if is_position_within_boundaries(loc[id_cell], x_lim, y_lim, z_lim) and min_amp < amp < max_amp:
if n_overlap_pairs is None:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if len(selected_idxs) == 0:
# save cell
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
possible_overlapping_pairs = len(
find_overlapping_templates(
templates[np.array(possible_selected)], overlap_threshold
)
)
current_overlapping_pairs = len(
find_overlapping_templates(templates[np.array(selected_idxs)], overlap_threshold)
)
if (
current_overlapping_pairs < n_overlap_pairs
and possible_overlapping_pairs <= n_overlap_pairs
):
if possible_overlapping_pairs == current_overlapping_pairs:
continue
else:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
if verbose:
print("Number of overlapping pairs:", possible_overlapping_pairs)
else:
if possible_overlapping_pairs == current_overlapping_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
if verbose:
print("Overlapping violation:", possible_overlapping_pairs)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell], iter)
else:
# drifting
if is_position_within_boundaries(loc[id_cell, 0], x_lim, y_lim, z_lim) and min_amp < amp < max_amp:
# save cell
drift_angle = np.rad2deg(np.arccos(np.dot(drift_dir[id_cell], preferred_dir)))
if drift_angle - angle_tol <= 0:
if n_overlap_pairs is None:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
else:
possible_selected = deepcopy(selected_idxs)
possible_selected.append(id_cell)
overlapping = find_overlapping_templates(
templates[np.array(possible_selected), 0], overlap_threshold
)
possible_overlapping_pairs = len(overlapping)
if possible_overlapping_pairs <= n_overlap_pairs:
pos_sel.append(loc[id_cell])
selected_idxs.append(id_cell)
n_sel += 1
placed = True
current_overlapping_pairs = len(overlapping)
if verbose:
print("Number of overlapping pairs:", current_overlapping_pairs)
else:
if verbose:
print("Overlapping violation:", current_overlapping_pairs)
else:
if verbose:
print("Drift violation", loc[id_cell, 0], iter)
else:
if verbose:
print("Amplitude or boundary violation", amp, loc[id_cell, 0], iter)
if placed:
selected_cat.append("U")
if i == len(permuted_idxs) - 1 and n_sel < n_exc + n_inh:
raise RuntimeError(
"Templates could not be selected. \n"
"Decrease number of spiketrains, decrease 'min_dist', or use more templates."
)
return selected_idxs, selected_cat
def resample_templates(
templates, n_resample, up, down, drifting, dtype, verbose, n_jobs=None, tmp_file=None, parallel=False
):
"""
Resamples the templates to a specified sampling frequency.
Parameters
----------
templates : np.array
Array with templates (n_neurons, n_channels, n_samples)
or (n_neurons, n_drift, n_channels, n_samples) if drifting
n_resample : int
Samples for resampled templates
up : float
The original sampling frequency in Hz
down : float
The new sampling frequency in Hz
drifting : bool
If True templates are assumed to be drifting
verbose : bool
If True output is verbose
n_jobs : int
Number of jobs for parallel processing. If None half cpus are used
tmp_file : str
Path to tmp file to memmap. If None, processing is in memory
parallel : bool
If True each template is resampled in parellel
Returns
-------
template_rs : np.array
Array with resampled templates (n_neurons, n_channels, n_resample)
or (n_neurons, n_drift, n_channels, n_resample) if drifting
"""
# create resampled templates
if not drifting:
if tmp_file is not None:
templates_rs = np.memmap(
tmp_file, shape=(templates.shape[0], templates.shape[1], n_resample), dtype=dtype, mode="w+"
)
else:
templates_rs = np.zeros((templates.shape[0], templates.shape[1], n_resample))
else:
if tmp_file is not None:
templates_rs = np.memmap(
tmp_file,
shape=(templates.shape[0], templates.shape[1], templates.shape[2], n_resample),
dtype=dtype,
mode="w+",
)
else:
templates_rs = np.zeros((templates.shape[0], templates.shape[1], templates.shape[2], n_resample))
if parallel:
if n_jobs is None:
n_jobs = os.cpu_count() // 2
if verbose:
print("Resampling with", n_jobs, "jobs")
output_list = Parallel(n_jobs=n_jobs)(
delayed(_resample_parallel)(
i,
tem,
up,
down,
drifting,
)
for i, tem in enumerate(templates)
)
for i, tem in enumerate(templates):
template_rs = output_list[i]
if template_rs.shape[-1] < templates_rs.shape[-1]:
if not drifting:
templates_rs[i, :, : len(template_rs)] = template_rs
else:
templates_rs[i, :, :, : len(template_rs)] = template_rs
elif template_rs.shape[-1] > templates_rs.shape[-1]:
if not drifting:
templates_rs[i] = template_rs[:, : templates_rs.shape[-1]]
else:
templates_rs[i] = template_rs[:, :, : templates_rs.shape[-1]]
else:
templates_rs[i] = template_rs
else:
for i, tem in enumerate(templates):
template_rs = _resample_parallel(i, tem, up, down, drifting)
if template_rs.shape[-1] < templates_rs.shape[-1]:
if not drifting:
templates_rs[i, :, : len(template_rs)] = template_rs
else:
templates_rs[i, :, :, : len(template_rs)] = template_rs
elif template_rs.shape[-1] > templates_rs.shape[-1]:
if not drifting:
templates_rs[i] = template_rs[:, : templates_rs.shape[-1]]
else:
templates_rs[i] = template_rs[:, :, : templates_rs.shape[-1]]
else:
templates_rs[i] = template_rs
return templates_rs
def pad_templates(templates, pad_samples, drifting, dtype, verbose, n_jobs=None, tmp_file=None, parallel=False):
"""
Pads the templates on both ends.
Parameters
----------
templates : np.array
Array with templates (n_neurons, n_channels, n_samples)
or (n_neurons, n_drift, n_channels, n_samples) if drifting
pad_samples : list
List of 2 ints with number of samples for padding before and after
drifting : bool
If True templates are assumed to be drifting
verbose : bool
If True output is verbose
n_jobs : int
Number of jobs for parallel processing. If None half cpus are used
tmp_file : str
Path to tmp file to memmap. If None, processing is in memory
parallel : bool
If True each template is padded in parellel
Returns
-------
template_pad : np.array
Array with padded templates (n_neurons, n_channels, n_padded_sample)
or (n_neurons, n_drift, n_channels, n_padded_sample) if drifting
"""
padded_template_samples = templates.shape[-1] + np.sum(pad_samples)
# create padded templates
if not drifting:
if tmp_file is not None:
templates_pad = np.memmap(
tmp_file,
shape=(templates.shape[0], templates.shape[1], padded_template_samples),
dtype=dtype,
mode="w+",
)
else:
templates_pad = np.zeros((templates.shape[0], templates.shape[1], padded_template_samples))
else:
if tmp_file is not None:
templates_pad = np.memmap(
tmp_file,
shape=(templates.shape[0], templates.shape[1], templates.shape[2], padded_template_samples),
dtype=dtype,
mode="w+",
)
else:
templates_pad = np.zeros(
(templates.shape[0], templates.shape[1], templates.shape[2], padded_template_samples)
)
if parallel:
assert tmp_file is not None
if n_jobs is None:
n_jobs = os.cpu_count() // 2
if verbose:
print("Padding with", n_jobs, "jobs")
output_list = Parallel(n_jobs=n_jobs)(
delayed(_pad_parallel)(i, tem, pad_samples, drifting, verbose, templates_pad)
for i, tem in enumerate(templates)
)
else:
for i, tem in enumerate(templates):
templates_pad[i] = _pad_parallel(i, tem, pad_samples, drifting, verbose, None)
return templates_pad
def jitter_templates(
templates, upsample, fs, n_jitters, jitter, drifting, dtype, verbose, n_jobs=None, tmp_file=None, parallel=False
):
"""
Adds jittered replicas to the templates.
Parameters
----------
templates : np.array
Array with templates (n_neurons, n_channels, n_samples)
or (n_neurons, n_drift, n_channels, n_samples) if drifting
upsample : int
Factor for upsampling the templates
n_jitters : int
Number of jittered copies for each template
jitter : quantity
Jitter in time for shifting the template
drifting : bool
If True templates are assumed to be drifting
verbose : bool
If True output is verbose
n_jobs : int
Number of jobs for parallel processing. If None half cpus are used
tmp_file : str
Path to tmp file to memmap. If None, processing is in memory
parallel : bool
If True each template is jittered in parellel
Returns
-------
template_jitt : np.array
Array with jittered templates (n_neurons, n_jitters, n_channels, n_samples)
or (n_neurons, n_drift, n_jitters, n_channels, n_samples) if drifting
"""
# create padded templates
if not drifting:
if tmp_file is not None:
templates_jitter = np.memmap(
tmp_file,
shape=(templates.shape[0], n_jitters, templates.shape[1], templates.shape[2]),
dtype=dtype,
mode="w+",
)
else:
templates_jitter = np.zeros((templates.shape[0], n_jitters, templates.shape[1], templates.shape[2]))
else:
if tmp_file is not None:
templates_jitter = np.memmap(
tmp_file,
shape=(templates.shape[0], templates.shape[1], n_jitters, templates.shape[2], templates.shape[3]),
dtype=dtype,
mode="w+",
)
else:
templates_jitter = np.zeros(
(templates.shape[0], templates.shape[1], n_jitters, templates.shape[2], templates.shape[3])
)
if parallel:
assert tmp_file is not None
if n_jobs is None:
n_jobs = os.cpu_count() // 2
if verbose:
print("Jittering with", n_jobs, "jobs")
output_list = Parallel(n_jobs=n_jobs)(
delayed(_jitter_parallel)(i, tem, upsample, fs, n_jitters, jitter, drifting, verbose, templates_jitter)
for i, tem in enumerate(templates)
)
else:
for i, tem in enumerate(templates):
templates_jitter[i] = _jitter_parallel(i, tem, upsample, fs, n_jitters, jitter, drifting, verbose, None)
return templates_jitter
def cubic_padding(template, pad_samples):
"""
Cubic spline padding on left and right side to 0. The initial offset of the templates is also removed.
Parameters
----------
template : np.array
Templates to be padded (n_elec, n_samples)
pad_samples : list
Padding samples before and after the template
Returns
-------
padded_template : np.array
Padded template
"""
import scipy.interpolate as interp
assert len(pad_samples) == 2, "'pad_samples' should be a list/tuple/array of length 2!"
n_pre, n_post = pad_samples
padded_template = np.zeros((template.shape[0], int(n_pre) + template.shape[1] + n_post))
splines = np.zeros((template.shape[0], int(n_pre) + template.shape[1] + n_post))
for i, sp in enumerate(template):
# Remove inital offset
sp_copy = deepcopy(sp)
padded_sp = np.zeros(n_pre + len(sp) + n_post)
padded_t = np.arange(len(padded_sp))
initial_offset = sp[0]
sp_copy -= initial_offset
x_pre = float(n_pre)
x_pre_pad = np.arange(n_pre)
x_post = float(n_post)
x_post_pad = np.arange(n_post)[::-1]
# fill pre and post intervals with linear values from sp[0] - sp[-1] to 0 for better fit
m_pre = sp_copy[0] / x_pre
m_post = sp_copy[-1] / x_post
padded_sp[:n_pre] = m_pre * x_pre_pad
padded_sp[n_pre:-n_post] = sp_copy
padded_sp[-n_post:] = m_post * x_post_pad
f = interp.interp1d(padded_t, padded_sp, kind="cubic")
splines[i] = f(np.arange(len(padded_sp)))
padded_template[i, :n_pre] = f(x_pre_pad)
padded_template[i, n_pre:-n_post] = sp_copy
padded_template[i, -n_post:] = f(np.arange(n_pre + len(sp_copy), n_pre + len(sp_copy) + n_post))
return padded_template
def find_overlapping_templates(templates, thresh=0.8):
"""
Find spatially overlapping templates.
Parameters
----------
templates : np.array
Array with templates (n_templates, n_elec, n_samples)
thresh : float
Percent threshold to consider two templates to be overlapping.
Returns
-------
overlapping_pairs : np.array
Array with overlapping pairs (n_overlapping, 2)
"""
overlapping_pairs = []
for i, temp_1 in enumerate(templates):
if len(templates.shape) == 4: # jitter
temp_1 = temp_1[0]
peak_electrode_idx = np.unravel_index(temp_1.argmin(), temp_1.shape)
for j, temp_2 in enumerate(templates):
if len(templates.shape) == 4: # jitter
temp_2 = temp_2[0]
if i != j:
if are_templates_overlapping([temp_1, temp_2], thresh):
if [i, j] not in overlapping_pairs and [j, i] not in overlapping_pairs:
overlapping_pairs.append(sorted([i, j]))
return np.array(overlapping_pairs)
def are_templates_overlapping(templates, thresh):
"""
Returns true if templates are spatially overlapping
Parameters
----------
templates : np.array
Array with 2 templates (2, n_elec, n_samples)
thresh : float
Overlapping threshold ([0 - 1])
Returns
-------
overlab : bool
Whether the templates are spatially overlapping or not
"""
assert len(templates) == 2
temp_1 = templates[0]
temp_2 = templates[1]
peak_electrode_idx = np.unravel_index(temp_1.argmin(), temp_1.shape)
peak_2_on_max = np.abs(np.min(temp_2[peak_electrode_idx]))
peak_2 = np.abs(np.min(temp_2))
if peak_2_on_max > thresh * peak_2:
return True
else:
return False
### SPIKETRAIN OPERATIONS ###
def annotate_overlapping_spikes(spiketrains, t_jitt=1 * pq.ms, overlapping_pairs=None, parallel=True, verbose=True):
"""
Annotate spike trains with temporal and spatio-temporal overlapping labels.
NO - Non overlap
TO - Temporal overlap
SO - Spatio-temporal overlap
Parameters
----------
spiketrains : list
List of neo spike trains to be annotated
t_jitt : Quantity
Time jitter to consider overlapping spikes in time (default 1 ms)
overlapping_pairs : np.array
Array with overlapping information between spike trains (n_spiketrains, 2)
parallel : bool
If True spike trains are processed in parallel with multiprocessing
verbose : bool
If True output is verbose
"""
if parallel:
import multiprocessing
threads = []
manager = multiprocessing.Manager()
return_spiketrains = manager.dict()
for i, st_i in enumerate(spiketrains):
p = multiprocessing.Process(
target=annotate_parallel,
args=(
i,
st_i,
spiketrains,
t_jitt,
overlapping_pairs,
return_spiketrains,
verbose,
),
)
p.start()
threads.append(p)
for p in threads:
p.join()
# retrieve annotated spiketrains
for i, st in enumerate(spiketrains):
spiketrains[i] = return_spiketrains[i]
else:
# find overlapping spikes
for i, st_i in enumerate(spiketrains):
if verbose:
print("Annotating overlapping spike train ", i)
over = np.array(["NONE"] * len(st_i))
for i_sp, t_i in enumerate(st_i):
for j, st_j in enumerate(spiketrains):
if i != j:
# find overlapping
id_over = np.where((st_j > t_i - t_jitt) & (st_j < t_i + t_jitt))[0]
if not np.any(overlapping_pairs):
if len(id_over) != 0:
over[i_sp] = "TO"
else:
pair = [i, j]
pair_i = [j, i]
if np.any([np.all(pair == p) for p in overlapping_pairs]) or np.any(
[np.all(pair_i == p) for p in overlapping_pairs]
):
if len(id_over) != 0:
over[i_sp] = "STO"
else:
if len(id_over) != 0:
over[i_sp] = "TO"
over[over == "NONE"] = "NO"
st_i.annotate(overlap=over)
def annotate_parallel(i, st_i, spiketrains, t_jitt, overlapping_pairs, return_spiketrains, verbose):
"""
Helper function to annotate spike trains in parallel.
Parameters
----------
i : int
Index of spike train
st_i : neo.SpikeTrain
Spike train to be processed
spiketrains : list
List of neo spiketrains
t_jitt : Quantity
Time jitter to consider overlapping spikes in time (default 1 ms)
overlapping_pairs : np.array
Array with overlapping information between spike trains (n_spiketrains, 2)
verbose : bool
If True output is verbose
"""
if verbose:
print("Annotating overlapping spike train ", i)
over = np.array(["NONE"] * len(st_i))
for i_sp, t_i in enumerate(st_i):
for j, st_j in enumerate(spiketrains):
if i != j:
# find overlapping
id_over = np.where((st_j > t_i - t_jitt) & (st_j < t_i + t_jitt))[0]
if not np.any(overlapping_pairs):
if len(id_over) != 0:
over[i_sp] = "TO"
else:
pair = [i, j]
pair_i = [j, i]
if np.any([np.all(pair == p) for p in overlapping_pairs]) or np.any(
[np.all(pair_i == p) for p in overlapping_pairs]
):
if len(id_over) != 0:
over[i_sp] = "STO"
else:
if len(id_over) != 0:
over[i_sp] = "TO"
over[over == "NONE"] = "NO"
st_i.annotate(overlap=over)
return_spiketrains[i] = st_i
def resample_spiketrains(spiketrains, fs=None):
"""
Resamples spike trains. Provide either fs or T parameters
Parameters
----------
spiketrains : list
List of neo spiketrains to be resampled
fs : Quantity
New sampling frequency
Returns
-------
resampled_mat : np.array
Matrix with resampled binned spike trains
"""
import elephant.conversion as conv
resampled_mat = []
if not fs:
raise Exception("Provide either sampling frequency fs or time period T")
elif fs:
if not isinstance(fs, Quantity):
raise ValueError("fs must be of type pq.Quantity")
binsize = 1.0 / fs
binsize.rescale("ms")
resampled_mat = []
for sts in spiketrains:
spikes = conv.BinnedSpikeTrain(sts, binsize=binsize).to_array()
resampled_mat.append(np.squeeze(spikes))
return np.array(resampled_mat)
def compute_sync_rate(times1, times2, time_jitt):
"""
Compute synchrony rate between two wpike trains.
Parameters
----------
times1 : quantity array
Spike times 1
times2 : quantity array
Spike times 2
time_jitt : quantity
Maximum time jittering between added spikes
Returns
-------
rate : float
Synchrony rate (0-1)
"""
# TODO make this faster
count = 0
for t1 in times1:
if len(np.where(np.abs(times2 - t1) <= time_jitt)[0]) >= 1:
if len(np.where(np.abs(times2 - t1) <= time_jitt)[0]) > 1:
print("Len: ", len(np.where(np.abs(times2 - t1) <= time_jitt)[0]))
count += 1
rate = count / (len(times1) + len(times2))
return rate
### CONVOLUTION OPERATIONS ###
def compute_modulation(st, n_el=1, mrand=1, sdrand=0.05, n_spikes=1, exp=0.2, max_burst_duration=100 * pq.ms):
"""
Computes modulation value for an input spike train.
Parameters
----------
st : neo.SpikeTrain
Input spike train
n_el : int
Number of electrodes to compute modulation.
If 1, modulation is computed at the template level.
If n_elec, modulation is computed at the electrode level.
mrand : float
Mean for Gaussian modulation (should be 1)
sdrand : float
Standard deviation for Gaussian modulation
n_spikes : int
Number of spikes for bursting behavior.
If 1, no bursting behavior.
If > 1, up to n_spikes consecutive spike are modulated with an exponentially decaying function.
exp : float
Exponent for exponential modulation (default 0.2)
max_burst_duration : Quantity
Maximum duration of a bursting event. After this duration, bursting modulation is reset.
Returns
-------
mod : np.array
Modulation value for each spike in the spike train
cons : np.array
Number of consecutive spikes computed for each spike
"""
import elephant.statistics as stat
if n_el == 1:
ISI = stat.isi(st).rescale("ms")
# max_burst_duration = 2*mean_ISI
mod = np.zeros(len(st))
mod[0] = sdrand * np.random.randn() + mrand
cons = np.zeros(len(st))
last_burst_event = 0 * pq.s
for i, isi in enumerate(ISI):
if n_spikes == 0:
# no isi-dependent modulation
mod[i + 1] = sdrand * np.random.randn() + mrand
elif n_spikes == 1:
if isi > max_burst_duration:
mod[i + 1] = sdrand * np.random.randn() + mrand
else:
mod[i + 1] = (
isi.magnitude**exp * (1.0 / max_burst_duration.magnitude**exp) + sdrand * np.random.randn()
)
else:
if last_burst_event.magnitude == 0:
consecutive_idx = np.where((st > st[i] - max_burst_duration) & (st <= st[i]))[0]
consecutive = len(consecutive_idx)
else:
consecutive_idx = np.where((st > last_burst_event) & (st <= st[i]))[0]
consecutive = len(consecutive_idx)
if consecutive == n_spikes - 1:
last_burst_event = st[i + 1]
if consecutive >= 1:
if st[i + 1] - st[consecutive_idx[0]] >= max_burst_duration:
last_burst_event = st[i + 1] - 0.001 * pq.ms
consecutive = 0
if consecutive == 0:
mod[i + 1] = sdrand * np.random.randn() + mrand
elif consecutive == 1:
amp = (isi / float(consecutive)) ** exp * (1.0 / max_burst_duration.magnitude**exp)
# scale std by amp
mod[i + 1] = amp + amp * sdrand * np.random.randn()
else:
if i != len(ISI):
isi_mean = np.mean(ISI[i - consecutive + 1 : i + 1])
else:
isi_mean = np.mean(ISI[i - consecutive + 1 :])
amp = (isi_mean / float(consecutive)) ** exp * (1.0 / max_burst_duration.magnitude**exp)
# scale std by amp
mod[i + 1] = amp + amp * sdrand * np.random.randn()
cons[i + 1] = consecutive
else:
if n_spikes == 0:
mod = sdrand * np.random.randn(len(st), n_el) + mrand
cons = []
else:
ISI = stat.isi(st).rescale("ms")
mod = np.zeros((len(st), n_el))
mod[0] = sdrand * np.random.randn(n_el) + mrand
cons = np.zeros(len(st))
last_burst_event = 0 * pq.s
for i, isi in enumerate(ISI):
if n_spikes == 1:
if isi > max_burst_duration:
mod[i + 1] = sdrand * np.random.randn(n_el) + mrand
else:
mod[i + 1] = isi.magnitude**exp * (
1.0 / max_burst_duration.magnitude**exp
) + sdrand * np.random.randn(n_el)
else:
if isi > max_burst_duration:
mod[i + 1] = sdrand * np.random.randn(n_el) + mrand
consecutive = 0
elif last_burst_event.magnitude == 0:
consecutive_idx = np.where((st > st[i] - max_burst_duration) & (st <= st[i]))[0]
consecutive = len(consecutive_idx)
else:
consecutive_idx = np.where((st > last_burst_event) & (st <= st[i]))[0]
consecutive = len(consecutive_idx)
if consecutive == n_spikes:
last_burst_event = st[i + 1]
if consecutive >= 1:
if st[i + 1] - st[consecutive_idx[0]] >= max_burst_duration:
last_burst_event = st[i + 1] - 0.001 * pq.ms
consecutive = 0
if consecutive == 0:
mod[i + 1] = sdrand * np.random.randn(n_el) + mrand
elif consecutive == 1:
amp = (isi.magnitude / float(consecutive)) ** exp * (1.0 / max_burst_duration.magnitude**exp)
# scale std by amp
if amp > 1:
raise Exception
mod[i + 1] = amp + amp * sdrand * np.random.randn(n_el)
else:
if i != len(ISI):
isi_mean = np.mean(ISI[i - consecutive + 1 : i + 1])
else:
isi_mean = np.mean(ISI[i - consecutive + 1 :])
amp = (isi_mean / float(consecutive)) ** exp * (1.0 / max_burst_duration.magnitude**exp)
# scale std by amp
mod[i + 1] = amp + amp * sdrand * np.random.randn(n_el)
cons[i + 1] = consecutive
return np.array(mod), cons
def compute_bursting_template(template, mod, wc_mod, filtfilt=False):
"""
Compute modulation in shape for a template with low-pass filter.
Parameters
----------
template : np.array
Template to be modulated (num_chan, n_samples) or (n_samples)
mod : int or np.array
Amplitude modulation for template or single electrodes
wc_mod : float
Normalized frequency of low-pass filter
filtfilt: bool
If True forward-backward filter is used
Returns
-------
temp_filt : np.array
Modulated template
"""
import scipy.signal as ss
b, a = ss.butter(3, wc_mod)
if len(template.shape) == 2:
if filtfilt:
temp_filt = ss.filtfilt(b, a, template, axis=1)
else:
temp_filt = ss.lfilter(b, a, template, axis=1)
if mod.size > 1:
temp_filt = np.array(
[m * np.min(temp) / np.min(temp_f) * temp_f for (m, temp, temp_f) in zip(mod, template, temp_filt)]
)
else:
temp_filt = (mod * np.min(template) / np.min(temp_filt)) * temp_filt
else:
if filtfilt:
temp_filt = ss.filtfilt(b, a, template)
else:
temp_filt = ss.lfilter(b, a, template)
temp_filt = (mod * np.min(template) / np.min(temp_filt)) * temp_filt
return temp_filt
def sigmoid(x, b=1):
"""
Compute sigmoid function
Parameters
----------
x: np.array
Array to compute sigmoid
b: float
Sigmoid slope
Returns
-------
x_sig: np.array
Output sigmoid array
"""
return 1 / (1 + np.exp(-b * x)) - 0.5
def compute_stretched_template(template, mod, shape_stretch=30.0):
"""
Compute modulation in shape for a template with low-pass filter.
Parameters
----------
template : np.array
Template to be modulated (num_chan, n_samples) or (n_samples)
mod : int or np.array
Amplitude modulation for template or single electrodes
shape_stretch : float
Sigmoid range to stretch the template
Returns
-------
temp_filt : np.array
Modulated template
"""
import scipy.interpolate as interp
if isinstance(mod, (int, np.integer)):
mod = np.array(mod)
if mod.size > 1:
stretch_factor = np.mean(mod)
mod_value = np.mean(mod)
else:
stretch_factor = mod
mod_value = mod
if len(template.shape) == 2:
min_idx = np.unravel_index(np.argmin(template), template.shape)[1]
x_centered = np.arange(-min_idx, template.shape[1] - min_idx)
x_centered = x_centered / float(np.ptp(x_centered))
x_centered = x_centered * shape_stretch
if stretch_factor >= 1:
x_stretch = x_centered
else:
x_stretch = sigmoid(x_centered, 1 - stretch_factor)
x_stretch = x_stretch / float(np.ptp(x_stretch))
x_stretch *= shape_stretch + (np.min(x_centered) - np.min(x_stretch))
x_recovered = np.max(x_stretch) / np.max(x_centered) * x_centered
x_stretch = np.round(x_stretch, 6)
x_recovered = np.round(x_recovered, 6)
temp_filt = np.zeros(template.shape)
for i, t in enumerate(template):
try:
f = interp.interp1d(x_stretch, t, kind="cubic")
temp_filt[i] = f(x_recovered)
except Exception as e:
raise Exception("'shape_stretch' is too large. Try reducing it (default = 30)")
if mod.size > 1:
temp_filt = np.array(
[m * np.min(temp) / np.min(temp_f) * temp_f for (m, temp, temp_f) in zip(mod, template, temp_filt)]
)
else:
temp_filt = (mod * np.min(template) / np.min(temp_filt)) * temp_filt
else:
min_idx = np.argmin(template)
x_centered = np.arange(-min_idx, len(template) - min_idx)
x_centered = x_centered / float(np.ptp(x_centered))
x_centered = x_centered * shape_stretch
if stretch_factor >= 1:
x_stretch = x_centered
else:
x_stretch = sigmoid(x_centered, 1 - stretch_factor)
x_stretch = x_stretch / float(np.ptp(x_stretch))
x_stretch *= shape_stretch + (np.min(x_centered) - np.min(x_stretch))
x_recovered = np.max(x_stretch) / np.max(x_centered) * x_centered
x_stretch = np.round(x_stretch, 6)
x_recovered = np.round(x_recovered, 6)
try:
f = interp.interp1d(x_stretch, template, kind="cubic")
temp_filt = f(x_recovered)
except Exception as e:
raise Exception("'shape_stretch' is too large. Try reducing it (default = 30)")
temp_filt = (mod_value * np.min(template) / np.min(temp_filt)) * temp_filt
return temp_filt
def convolve_single_template(
spike_id,
st_idx,
template,
n_samples,
cut_out=None,
modulation=False,
mod_array=None,
bursting=False,
shape_stretch=None,
):
"""Convolve single template with spike train. Used to compute 'spike_traces'.
Parameters
----------
spike_id : int
Index of spike trains - template.
st_idx : np.array
Spike times
template : np.array
Array with single template
n_samples : int
Number of samples in chunk
cut_out : list
Number of samples before and after the peak
modulation : bool
If True modulation is applied
mod_array : np.array
Array with modulation value for each spike
bursting : bool
If True templates are modulated in shape
shape_stretch : float
Range of sigmoid transform for bursting shape stretch
Returns
-------
spike_trace : np.array
Trace with convolved signal (n_samples)
"""
if len(template.shape) == 2:
njitt = template.shape[0]
len_spike = template.shape[1]
if cut_out is None:
cut_out = [len_spike // 2, len_spike // 2]
spike_trace = np.zeros(n_samples)
if len(template.shape) == 2:
rand_idx = np.random.randint(njitt)
temp_jitt = template[rand_idx]
for pos, spos in enumerate(st_idx):
if not modulation:
if spos - cut_out[0] >= 0 and spos - cut_out[0] + len_spike <= n_samples:
spike_trace[spos - cut_out[0] : spos - cut_out[0] + len_spike] += temp_jitt
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
spike_trace[: spos - cut_out[0] + len_spike] += temp_jitt[diff:]
else:
diff = n_samples - (spos - cut_out[0])
spike_trace[spos - cut_out[0] :] += temp_jitt[:diff]
else:
if bursting:
if spos - cut_out[0] >= 0 and spos - cut_out[0] + len_spike <= n_samples:
spike_trace[spos - cut_out[0] : spos - cut_out[0] + len_spike] += compute_stretched_template(
temp_jitt, mod_array[pos], shape_stretch
)
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
temp_filt = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)
spike_trace[: spos - cut_out[0] + len_spike] += temp_filt[diff:]
else:
diff = n_samples - (spos - cut_out[0])
temp_filt = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)
spike_trace[spos - cut_out[0] :] += temp_filt[:diff]
else:
if mod_array[pos].size > 1:
mod_value = np.mean(mod_array[pos])
else:
mod_value = mod_array[pos]
if spos - cut_out[0] >= 0 and spos - cut_out[0] + len_spike <= n_samples:
spike_trace[spos - cut_out[0] : spos - cut_out[0] + len_spike] += mod_value * temp_jitt
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
spike_trace[: spos - cut_out[0] + len_spike] += mod_value * temp_jitt[diff:]
else:
diff = n_samples - (spos - cut_out[0])
spike_trace[spos - cut_out[0] :] += mod_value * temp_jitt[:diff]
else:
raise Exception("For drifting len(template.shape) should be 2")
return spike_trace
def convolve_templates_spiketrains(
spike_id,
st_idx,
template,
n_samples,
cut_out=None,
modulation=False,
mod_array=None,
verbose=False,
bursting=False,
shape_stretch=None,
max_channels_per_template=None,
recordings=None,
drift_idxs=None,
):
# , drift_vector=None, drift_fs=None):
"""
Convolve template with spike train on all electrodes. Used to compute 'recordings'.
Parameters
----------
spike_id : int
Index of spike trains - template.
st_idx : np.array
Spike times
template : np.array
Array with template
n_samples : int
Number of samples in chunk
fs : float
Sampling frequency in Hz
cut_out : list
Number of samples before and after the peak
modulation : bool
If True modulation is applied
mod_array : np.array
Array with modulation value for each spike
verbose : bool
If True output is verbose
bursting : bool
If True templates are modulated in shape
shape_stretch : float
Range of sigmoid transform for bursting shape stretch
max_channels_per_template : np.array
Maximum number of channels to be convolved
recordings : np.arrays
Array to use for recordings. If None it is created
drift_vector: None or np.array 1d
Optionally the drift vector related to the chunk!!!
drift_fs: None or float
Sampling frequency of the drift signal
Returns
-------
recordings: np.array
Trace with convolved signals (n_elec, n_samples)
"""
# drifting = drift_vector is not None
drifting = drift_idxs is not None
if drifting:
assert template.ndim == 4
else:
assert template.ndim == 3
if verbose:
print("Convolution with spike:", spike_id)
if drifting:
drift_steps = template.shape[0]
n_jitt = template.shape[1]
n_elec = template.shape[2]
len_spike = template.shape[3]
drift_idxs = drift_idxs.clip(0, drift_steps - 1)
else:
n_jitt = template.shape[0]
n_elec = template.shape[1]
len_spike = template.shape[2]
if recordings is None:
recordings = np.zeros((n_samples, n_elec))
else:
assert recordings.shape == (n_samples, n_elec), "'recordings' has the wrong shape"
dtype = recordings.dtype
if cut_out is None:
cut_out = [len_spike // 2, len_spike // 2]
if not modulation:
# No modulation
mod_array = np.ones_like(st_idx)
else:
assert mod_array is not None, " For 'electrode' and 'template' modulations provide 'mod_array'"
for pos, spos in enumerate(st_idx):
rand_idx = np.random.randint(n_jitt)
if drifting:
drift_ind = drift_idxs[pos] # int(spos / fs * drift_fs)
# drift_ind = drift_vector[spos_drift]
temp_jitt = template[drift_ind, rand_idx]
else:
temp_jitt = template[rand_idx]
if max_channels_per_template is None:
elec_idxs = np.arange(n_elec)
else:
# find max channels
peak_idx = np.unravel_index(np.argmax(np.abs(temp_jitt)), temp_jitt.shape)[1]
elec_idxs = np.argsort(np.abs(temp_jitt[:, peak_idx]))[::-1][:max_channels_per_template]
temp_jitt = temp_jitt[elec_idxs]
if bursting:
if not isinstance(mod_array[0], (list, tuple, np.ndarray)):
# template
if spos - cut_out[0] >= 0 and spos - cut_out[0] + len_spike <= n_samples:
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch).T
recordings[spos - cut_out[0] : spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)[:, diff:].T
recordings[: spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
else:
diff = n_samples - (spos - cut_out[0])
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)[:, :diff].T
recordings[spos - cut_out[0] :, elec_idxs] += snippet.astype(dtype)
else:
# electrode
if spos - cut_out[0] >= 0 and spos - cut_out[0] + len_spike <= n_samples:
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch).T
recordings[spos - cut_out[0] : spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)[:, diff:].T
recordings[: spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
else:
diff = n_samples - (spos - cut_out[0])
snippet = compute_stretched_template(temp_jitt, mod_array[pos], shape_stretch)[:, :diff].T
recordings[spos - cut_out[0] :, elec_idxs] += snippet.astype(dtype)
else:
if not isinstance(mod_array[0], (list, tuple, np.ndarray)):
# template + none
if spos - cut_out[0] >= 0 and spos + cut_out[1] <= n_samples:
snippet = mod_array[pos] * temp_jitt.T
recordings[spos - cut_out[0] : spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
snippet = mod_array[pos] * temp_jitt[:, diff:].T
recordings[: spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
else:
diff = n_samples - (spos - cut_out[0])
snippet = mod_array[pos] * temp_jitt[:, :diff].T
recordings[spos - cut_out[0] :, elec_idxs] += snippet.astype(dtype)
else:
# electrode
if spos - cut_out[0] >= 0 and spos + cut_out[1] <= n_samples:
snippet = np.array([a * t for (a, t) in zip(mod_array[pos], temp_jitt)]).T
recordings[spos - cut_out[0] : spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
elif spos - cut_out[0] < 0:
diff = -(spos - cut_out[0])
snippet = np.array([a * t for (a, t) in zip(mod_array[pos], temp_jitt[:, diff:])]).T
recordings[: spos + cut_out[1], elec_idxs] += snippet.astype(dtype)
else:
diff = n_samples - (spos - cut_out[0])
snippet = np.array([a * t for (a, t) in zip(mod_array[pos], temp_jitt[:, :diff])]).T
recordings[spos - cut_out[0] :, elec_idxs] += snippet.astype(dtype)
return recordings
def compute_drift_idxs_from_drift_list(spike_index, spike_train_frames, drift_list, fs):
# pre-compute drift idxs
drift_idxs_displacements = np.zeros(len(spike_train_frames), dtype="int16")
spike_times = spike_train_frames / fs
mid_point_idx = drift_list[0]["drift_steps"] // 2
for drift_dict in drift_list:
drift_vector_idxs = np.array(drift_dict["drift_vector_idxs"]) # 0 - num_steps
drift_fs = drift_dict["drift_fs"]
drift_factors = drift_dict["drift_factors"]
drift_times = drift_dict["drift_times"]
if drift_times is not None:
# drift is only in a period
spike_mask = np.logical_and(spike_times >= drift_times[0], spike_times <= drift_times[-1])
drift_spike_idxs = np.searchsorted(drift_times, spike_times[spike_mask])
else:
# drift vector covers entire recording
spike_mask = np.ones(len(spike_times), dtype=bool)
drift_spike_idxs = (spike_times * drift_fs).astype("int")
drift_idxs_disp_i = drift_vector_idxs[drift_spike_idxs]
drift_idxs_displacements[spike_mask] += (drift_idxs_disp_i * drift_factors[spike_index]).astype("int16")
drift_idxs = (drift_idxs_displacements + mid_point_idx).astype("uint16")
return drift_idxs
def extract_units_drift_vector(mearec_file=None, recgen=None, time_vector=None):
"""
Retrieve drift vector per units.
Internally vector drift vector per units is constructed with a linear sum of
of drift index multiplied by a factor per cell.
Then this drift index is converted to micrometer given the cell locations.
Here `time_vector` is assumed to be the center of the bins (even if internally evreything is floored to the left of the bin)
Parameters
----------
mearec_file: str or None
The MEArec filename
recgen: RecordingGenerator or None
The RecordingGenerator
time_vector: array or None
An external time vector to interpolate dirft.
If None the internal drift vector with hihest fs is used.
Returns
-------
units_drift_vectors: array
the drift vector in micro meters
shape (n_time_bin, n_units)
time_vector: array
time vector in second
shape (n_time_bin, n_units)
"""
import scipy.interpolate
if mearec_file is not None:
recgen = load_recordings(mearec_file)
drift_list = recgen.drift_list
locations = np.array(recgen.template_locations)
if time_vector is None:
# the main times constructed from the first drift
best = np.argmax([d["drift_fs"] for d in drift_list])
drift_dict = drift_list[best]
main_fs = drift_dict["drift_fs"]
length = len(drift_list[best]["drift_vector_idxs"])
time_vector = np.arange(length) / main_fs
else:
main_fs = np.median(np.diff(time_vector))
# interpolate drift_vector_idxs on the same clock
for drift_dict in drift_list:
drift_vector_idxs = np.array(drift_dict["drift_vector_idxs"])
drift_fs = drift_dict["drift_fs"]
if drift_fs == main_fs and drift_vector_idxs.shape[0] == time_vector.shape[0]:
# no interpolation needed
interpolated_drift_vector_idxs = drift_vector_idxs
else:
# linear interpolation on the timevector
# note that we use the center of the bins here
local_times = np.arange(drift_vector_idxs.shape[0]) / drift_fs + 0.5 / drift_fs
f = scipy.interpolate.interp1d(local_times, drift_vector_idxs)
interpolated_drift_vector_idxs = f(time_vector)
drift_dict["interpolated_drift_vector_idxs"] = interpolated_drift_vector_idxs
n_units = len(recgen.spiketrains)
units_drift_vectors = np.zeros((time_vector.size, n_units), dtype="float32")
mid_point_idx = drift_list[0]["drift_steps"] // 2
for unit_index in range(n_units):
summed_drift_idxs = np.zeros(time_vector.size, dtype="int16")
for drift_dict in drift_list:
interpolated_drift_vector_idxs = drift_dict["interpolated_drift_vector_idxs"]
drift_factors = drift_dict["drift_factors"]
summed_drift_idxs += ((interpolated_drift_vector_idxs) * drift_factors[unit_index]).astype("int16")
summed_drift_idxs = (summed_drift_idxs + mid_point_idx).astype("uint16")
locs = locations[unit_index, :, 2]
units_drift_vectors[:, unit_index] = locs[summed_drift_idxs]
return units_drift_vectors, time_vector
### RECORDING OPERATION ###
def extract_wf(spiketrains, recordings, fs, cut_out=2, timestamps=None):
"""
Extract waveforms from recordings and load it in waveform field of neo spike trains.
Parameters
----------
spiketrains : list
List of neo spike trains
recordings : np.array
Array with recordings (n_samples, n_elec)
fs : Quantity
Sampling frequency
cut_out : float or list
Length in ms to cut before and after spike peak. If a single value the cut is symmetrical
timestamps : Quantity array (optional)
Array with recordings timestamps
"""
if cut_out is None:
cut_out = 2
if not isinstance(cut_out, list):
n_pad = int(cut_out * pq.ms * fs.rescale("kHz"))
n_pad = [n_pad, n_pad]
else:
n_pad = [int(p * pq.ms * fs.rescale("kHz")) for p in cut_out]
n_samples, n_elec = recordings.shape
if timestamps is None:
timestamps = np.arange(n_samples) / fs.rescale("Hz")
unit = timestamps[0].rescale("ms").units
for st in spiketrains:
sp_rec_wf = []
sp_amp = []
for t in st:
idx = np.where(timestamps >= t)[0]
if len(idx) > 0:
idx = idx[0]
else:
idx = len(timestamps) - 1
# find single waveforms crossing thresholds
if idx - n_pad[0] > 0 and idx + n_pad[1] < n_samples:
spike_rec = recordings[idx - n_pad[0] : idx + n_pad[1]]
elif idx - n_pad[0] < 0:
spike_rec = recordings[: idx + n_pad[1]]
spike_rec = np.pad(spike_rec, ((np.abs(idx - n_pad[0]), 0), (0, 0)), "constant")
elif idx + n_pad[1] > n_samples:
spike_rec = recordings[idx - n_pad[0] :]
spike_rec = np.pad(spike_rec, ((0, idx + n_pad[1] - n_samples), (0, 0)), "constant")
sp_rec_wf.append(spike_rec.T)
st.waveforms = np.array(sp_rec_wf)
def filter_analog_signals(signals, freq, fs, filter_type="bandpass", mode="filtfilt", order=3):
"""
Filter analog signals with zero-phase Butterworth filter.
The function raises an Exception if the required filter is not stable.
Parameters
----------
signals : np.array
Array of analog signals (n_samples, n_elec)
freq : list or float
Cutoff frequency-ies in Hz
fs : Quantity
Sampling frequency
filter_type : str
Filter type ('lowpass', 'highpass', 'bandpass', 'bandstop')
mode : str
Filtering mode ('filtfilt', 'lfilter')
order : int
Filter order
Returns
-------
signals_filt : np.array
Filtered signals
"""
from scipy.signal import butter, filtfilt, lfilter
fn = fs / 2.0
freq = freq.rescale(pq.Hz)
band = freq / fn
assert mode in ["filtfilt", "lfilter"], "Filtering mode not recognized"
if mode == "filtfilt":
filter_func = filtfilt
elif mode == "lfilter":
filter_func = lfilter
b, a = butter(order, band, btype=filter_type)
if np.all(np.abs(np.roots(a)) < 1) and np.all(np.abs(np.roots(a)) < 1):
# print('Filtering signals with ', filter_type, ' filter at ', freq, '...')
if len(signals.shape) == 2:
signals_filt = filter_func(b, a, signals, axis=0)
elif len(signals.shape) == 1:
signals_filt = filter_func(b, a, signals)
return signals_filt
else:
raise ValueError("Filter is not stable")
### PLOTTING ###
[docs]def plot_rasters(
spiketrains, cell_type=False, ax=None, overlap=False, color=None, fs=10, marker="|", mew=2, markersize=5
):
"""
Plot raster for spike trains.
Parameters
----------
spiketrains : list
List of neo spike trains
cell_type : bool
If True and 'bintype' in spike train annotation spike trains are plotted based on their type
ax : axes
Plot on the given axes
overlap : bool
Plot spike colors based on overlap
labels : bool
Plot spike colors based on labels
color : matplotlib color (single or list)
Color or color list
fs : int
Font size
marker : matplotlib arg
Marker type
mew : matplotlib arg
Width of marker
markersize : int
Marker size
Returns
-------
ax : axis
Matplotlib axis
"""
import matplotlib.pylab as plt
if not ax:
fig = plt.figure()
ax = fig.add_subplot(111)
if overlap:
if "overlap" not in spiketrains[0].annotations.keys():
raise Exception()
for i, spiketrain in enumerate(spiketrains):
t = spiketrain.rescale(pq.s)
if cell_type:
if "cell_type" in spiketrain.annotations.keys():
if spiketrain.annotations["cell_type"] == "E":
ax.plot(t, i * np.ones_like(t), "b", marker=marker, mew=mew, markersize=markersize, ls="")
elif spiketrain.annotations["cell_type"] == "I":
ax.plot(t, i * np.ones_like(t), "r", marker=marker, mew=mew, markersize=markersize, ls="")
else:
if color is not None:
if isinstance(color, list) or isinstance(color, np.ndarray):
ax.plot(
t, i * np.ones_like(t), color=color[i], marker=marker, mew=mew, markersize=markersize, ls=""
)
else:
ax.plot(
t, i * np.ones_like(t), color=color, marker=marker, mew=mew, markersize=markersize, ls=""
)
else:
ax.plot(t, i * np.ones_like(t), "k", marker=marker, mew=mew, markersize=markersize, ls="")
else:
if not overlap:
if color is not None:
if isinstance(color, list) or isinstance(color, np.ndarray):
ax.plot(
t, i * np.ones_like(t), color=color[i], marker=marker, mew=mew, markersize=markersize, ls=""
)
else:
ax.plot(
t, i * np.ones_like(t), color=color, marker=marker, mew=mew, markersize=markersize, ls=""
)
else:
ax.plot(t, i * np.ones_like(t), "k", marker=marker, mew=mew, markersize=markersize, ls="")
elif overlap:
for j, t_sp in enumerate(spiketrain):
if spiketrain.annotations["overlap"][j] == "STO":
ax.plot(t_sp, i, "r", marker=marker, mew=mew, markersize=markersize, ls="")
elif spiketrain.annotations["overlap"][j] == "TO":
ax.plot(t_sp, i, "g", marker=marker, mew=mew, markersize=markersize, ls="")
elif spiketrain.annotations["overlap"][j] == "NO":
ax.plot(t_sp, i, "k", marker=marker, mew=mew, markersize=markersize, ls="")
ax.axis("tight")
ax.set_xlim([spiketrains[0].t_start.rescale(pq.s), spiketrains[0].t_stop.rescale(pq.s)])
ax.set_xlabel("Time (s)", fontsize=fs)
ax.set_ylabel("Spike Train Index", fontsize=fs)
ax.set_yticks(np.arange(len(spiketrains)))
ax.set_yticklabels(np.arange(len(spiketrains)))
return ax
[docs]def plot_templates(
gen,
template_ids=None,
single_jitter=True,
ax=None,
single_axes=False,
max_templates=None,
drifting=False,
cmap=None,
ncols=6,
**kwargs,
):
"""
Plot templates.
Parameters
----------
gen : TemplateGenerator or RecordingGenerator
Generator object containing templates
template_ids : int or list
The template(s) to plot
single_axes : bool
If True all templates are plotted on the same axis
ax : axis
Matplotlib axis
single_jitter: bool
If True and jittered templates are present, a single jittered template is plotted
max_templates: int
Maximum number of templates to be plotted
drifting: bool
If True and templates are drifting, drifting templates are displayed
cmap : matplotlib colormap
Colormap to be used
ncols : int
Number of columns for subplots
Returns
-------
ax : ax
Matplotlib axes
"""
import matplotlib.pylab as plt
from matplotlib import gridspec
templates = gen.templates
mea = mu.return_mea(info=gen.info["electrodes"])
if "params" in gen.info.keys():
if gen.info["params"]["drifting"]:
if not drifting:
templates = templates[:, 0]
if "recordings" in gen.info.keys():
if gen.info["recordings"]["drifting"]:
if single_jitter:
if not drifting:
if len(templates.shape) == 5:
templates = templates[:, 0, 0]
else:
templates = templates[:, 0]
else:
if len(templates.shape) == 5:
templates = templates[:, :, 0]
else:
if not drifting:
if len(templates.shape) == 5:
templates = templates[:, 0]
else:
if single_jitter:
if len(templates.shape) == 4:
templates = templates[:, 0]
if drifting:
assert isinstance(template_ids, (int, np.integer)), (
"When plotting drifting templates, 'template_ids' should " "be a single index (int)"
)
single_axes = True
if template_ids is not None:
if isinstance(template_ids, (int, np.integer)):
template_ids = list(np.array([template_ids]))
elif isinstance(template_ids, list):
template_ids = list(np.array(template_ids))
else:
template_ids = list(np.arange(templates.shape[0]))
if max_templates is not None:
if max_templates < len(templates):
random_idxs = np.random.permutation(len(templates))
template_ids = np.arange(templates.shape[0])[random_idxs][:max_templates]
# templates = templates[random_idxs][:max_templates]
n_sources = len(template_ids)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
if "vscale" not in kwargs.keys():
kwargs["vscale"] = 1.5 * np.max(np.abs(templates[template_ids]))
if single_axes:
if cmap is not None:
cm = plt.get_cmap(cmap)
colors = [cm(i / len(template_ids)) for i in np.arange(len(template_ids))]
else:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
for n, t in enumerate(templates):
if n in template_ids:
if len(t.shape) == 3:
if not drifting:
mu.plot_mea_recording(
t.mean(axis=0), mea, colors=colors[np.mod(n, len(colors))], ax=ax, **kwargs
)
else:
if cmap is None:
cmap = "Reds"
cm = plt.get_cmap(cmap)
colors = [cm(i / t.shape[0]) for i in np.arange(t.shape[0])]
mu.plot_mea_recording(t, mea, colors=colors, ax=ax, **kwargs)
else:
mu.plot_mea_recording(t, mea, colors=colors[np.mod(n, len(colors))], ax=ax, **kwargs)
else:
if n_sources > ncols:
nrows = int(np.ceil(len(template_ids) / ncols))
else:
nrows = 1
ncols = n_sources
if cmap is not None:
cm = plt.get_cmap(cmap)
colors = [cm(i / len(template_ids)) for i in np.arange(len(template_ids))]
else:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax)
for i_n, n in enumerate(template_ids):
r = i_n // ncols
c = np.mod(i_n, ncols)
gs_sel = gs[r, c]
ax_t = fig.add_subplot(gs_sel)
if cmap is not None:
mu.plot_mea_recording(templates[n], mea, ax=ax_t, colors=colors[i_n], **kwargs)
else:
mu.plot_mea_recording(templates[n], mea, ax=ax_t, colors=colors[np.mod(i_n, len(colors))], **kwargs)
ax.axis("off")
return ax
[docs]def plot_recordings(
recgen,
ax=None,
start_time=None,
end_time=None,
overlay_templates=False,
n_templates=None,
max_channels_per_template=16,
cmap=None,
templates_lw=1,
**kwargs,
):
"""
Plot recordings.
Parameters
----------
recgen : RecordingGenerator
Recording generator object to plot
ax : axis
Matplotlib axis
start_time : float
Start time to plot recordings in s
end_time : float
End time to plot recordings in s
overlay_templates : bool
If True, templates are overlaid on the recordings
n_templates : int
Number of templates to overlay (if overlay_templates is True)
max_channels_per_template : int
Number of maximum channels in which the template is overlaid
cmap : matplotlib colormap
Colormap to be used
Returns
-------
ax : axis
Matplotlib axis
"""
import matplotlib.pylab as plt
recordings = recgen.recordings
mea = mu.return_mea(info=recgen.info["electrodes"])
fs = recgen.info["recordings"]["fs"]
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
if start_time is None:
start_frame = 0
else:
start_frame = int(start_time * fs)
if end_time is None:
end_frame = recordings.shape[0]
else:
end_frame = int(end_time * fs)
if max_channels_per_template is None:
max_channels_per_template = len(recordings)
if "vscale" not in kwargs.keys():
kwargs["vscale"] = 1.5 * np.max(np.abs(recordings))
mu.plot_mea_recording(recordings[start_frame:end_frame, :].T, mea, ax=ax, **kwargs)
if overlay_templates:
if "lw" in kwargs.keys():
kwargs["lw"] = templates_lw
fs = recgen.info["recordings"]["fs"] * pq.Hz
if n_templates is None:
template_ids = np.arange(len(recgen.templates))
else:
template_ids = np.random.permutation(len(recgen.templates))[:n_templates]
cut_out_samples = [
int((c + p) * fs.rescale("kHz").magnitude)
for (c, p) in zip(recgen.info["templates"]["cut_out"], recgen.info["templates"]["pad_len"])
]
spike_idxs = []
for st in recgen.spiketrains:
spike_idxs.append((st.times * fs).magnitude.astype("int"))
n_samples = recordings.shape[0]
if cmap is not None:
cm = plt.get_cmap(cmap)
colors_t = [cm(i / len(template_ids)) for i in np.arange(len(template_ids))]
else:
colors_t = plt.rcParams["axes.prop_cycle"].by_key()["color"]
i_col = 0
if "lw" not in kwargs.keys():
kwargs["lw"] = 1
if "colors" in kwargs.keys():
del kwargs["colors"]
# for i, (sp, t) in enumerate(zip(spike_idxs, recgen.templates)):
for i, sp in enumerate(spike_idxs):
if i in template_ids:
template = recgen.templates[i]
sp_frames = sp * fs
if recgen.drift_list is None:
drift_idxs = None
else:
drift_idxs = compute_drift_idxs_from_drift_list(i, sp_frames, recgen.drift_list, fs)
rec_t = convolve_templates_spiketrains(
i,
sp,
template,
n_samples,
max_channels_per_template=max_channels_per_template,
cut_out=cut_out_samples,
drift_idxs=drift_idxs,
).T
rec_t[np.abs(rec_t) < 1e-4] = np.nan
mu.plot_mea_recording(
rec_t[:, start_frame:end_frame], mea, ax=ax, colors=colors_t[np.mod(i_col, len(colors_t))], **kwargs
)
i_col += 1
del rec_t
return ax
[docs]def plot_amplitudes(
recgen,
spiketrain_id=None,
electrode=None,
ax=None,
color=None,
cmap=None,
single_axes=True,
marker="*",
ms=5,
ncols=6,
):
"""
Plot waveform amplitudes over time.
Parameters
----------
recgen : RecordingGenerator
Recording generator object to plot spike train waveform from
spiketrain_id : int or list
Indexes of spike trains
electrode : int or 'max'
Electrode id or 'max'
ax : axis
Matplotlib axis
color : matplotlib color
Color of the waveform amplitudes
cmap : matplotlib colormap
Colormap to be used
single_axes : bool
If True all templates are plotted on the same axis
marker : str
Matplotlib marker (default '*')
ms : int
Markersize (default 5)
ncols : int
Number of columns for subplots
Returns
-------
ax : axis
Matplotlib axis
"""
import matplotlib.gridspec as gridspec
import matplotlib.pylab as plt
if spiketrain_id is None:
spiketrain_id = np.arange(len(recgen.spiketrains))
elif isinstance(spiketrain_id, (int, np.integer)):
spiketrain_id = [spiketrain_id]
n_units = len(spiketrain_id)
if electrode is None:
electrode = "max"
waveforms = []
for sp in spiketrain_id:
wf = recgen.spiketrains[sp].waveforms
if wf is None:
fs = recgen.info["recordings"]["fs"] * pq.Hz
extract_wf([recgen.spiketrains[sp]], recgen.recordings, fs)
wf = recgen.spiketrains[sp].waveforms
waveforms.append(wf)
if n_units > 1:
if color is None:
if cmap is not None:
cm = plt.get_cmap(cmap)
colors = [cm(i / n_units) for i in np.arange(n_units)]
else:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
else:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
else:
if color is None:
colors = "k"
else:
colors = color
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
amps = []
for i, id in enumerate(spiketrain_id):
st = recgen.spiketrains[id]
wf = st.waveforms
mwf = np.mean(wf, axis=0)
if electrode == "max":
max_elec = np.unravel_index(np.argmin(mwf), mwf.shape)[0]
print("Max electrode", max_elec)
else:
assert isinstance(electrode, (int, np.integer)), "'electrode' can be 'max' or type int"
max_elec = electrode
amps.append(np.array([np.min(w[max_elec]) for w in wf]))
if single_axes:
for i_n, n in enumerate(spiketrain_id):
amp = amps[i_n]
st = recgen.spiketrains[n]
ax.plot(st, amp, marker=marker, ms=ms, color=colors[i_n], ls="")
else:
if n_units > ncols:
nrows = int(np.ceil(len(spiketrain_id) / ncols))
else:
nrows = 1
ncols = n_units
gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax)
for i_n, n in enumerate(spiketrain_id):
r = i_n // ncols
c = np.mod(i_n, ncols)
gs_sel = gs[r, c]
ax_amp = fig.add_subplot(gs_sel)
amp = amps[i_n]
st = recgen.spiketrains[n]
ax_amp.plot(st, amp, marker=marker, ms=ms, color=colors[i_n], ls="")
ax.axis("off")
return ax
[docs]def plot_pca_map(
recgen, n_pc=2, max_elec=None, cmap="rainbow", cut_out=2, n_units=None, ax=None, whiten=False, pc_comp=None
):
"""
Plots a PCA map of the waveforms.
Parameters
----------
recgen : RecordingGenerator
Recording generator object to plot PCA scores of
ax : axis
Matplotlib axis
n_pc : int
Number of principal components (default 2)
max_elec : int
Max number of electrodes to plot
cmap : matplotlib colormap
Colormap to be used
cut_out : float or list
Cut outs in ms for waveforms (if not computed). If float the cut out is symmetrical. n_units
whiten : bool
If True, PCA scores are whitened
pc_comp : np.array
PC component matrix to be used.
Returns
-------
ax : axis
Matplotlib axis
pca_scores : list
List of np.arrays with pca scores for the different units
pca_component : np.array
PCA components matrix (n_pc, n_waveform_timepoints)
"""
try:
from sklearn.decomposition import PCA
except:
raise Exception("'plot_pca_map' requires scikit-learn package")
import matplotlib.gridspec as gridspec
import matplotlib.pylab as plt
waveforms = []
n_spikes = []
if n_units is None:
n_units = len(recgen.spiketrains)
if recgen.spiketrains[0].waveforms is None:
print("Computing waveforms")
recgen.extract_waveforms(cut_out=cut_out)
for st in recgen.spiketrains:
wf = st.waveforms
waveforms.append(wf)
n_elec = waveforms[0].shape[1]
if n_pc == 1:
pc_dims = [0]
elif n_pc > 1:
pc_dims = np.arange(n_pc)
else:
pc_dims = [0]
if max_elec is not None and max_elec < n_elec:
if max_elec == 1:
elec_dims = [np.random.randint(n_elec)]
elif max_elec > 1:
elec_dims = np.random.permutation(np.arange(n_elec))[:max_elec]
else:
elec_dims = [np.random.randint(n_elec)]
else:
elec_dims = np.arange(n_elec)
for i_w, wf in enumerate(waveforms):
# wf_reshaped = wf.reshape((wf.shape[0] * wf.shape[1], wf.shape[2]))
wf_reshaped = wf.reshape((wf.shape[0] * wf.shape[1], wf.shape[2]))
n_spikes.append(len(wf) * n_elec)
if i_w == 0:
all_waveforms = wf_reshaped
else:
all_waveforms = np.vstack((all_waveforms, wf_reshaped))
if pc_comp is None:
compute_pca = True
elif pc_comp.shape == (n_pc, all_waveforms.shape[1]):
compute_pca = False
else:
print("'pc_comp' has wrong dimensions. Recomputing PCA")
compute_pca = True
if compute_pca:
print("Fitting PCA of %d dimensions on %d waveforms" % (n_pc, len(all_waveforms)))
pca = PCA(n_components=n_pc, whiten=whiten)
# pca.fit_transform(all_waveforms)
pca.fit(all_waveforms)
pc_comp = pca.components_
pca_scores = []
for st in recgen.spiketrains:
pct = np.dot(st.waveforms, pc_comp.T)
if whiten:
pct /= np.sqrt(pca.explained_variance_)
pca_scores.append(pct)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
ax.axis("off")
if cmap is not None:
cm = plt.get_cmap(cmap)
colors = [cm(i / n_units) for i in np.arange(n_units)]
else:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
nrows = len(pc_dims) * len(elec_dims)
ncols = nrows
gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax)
for p1 in pc_dims:
for i1, ch1 in enumerate(elec_dims):
for p2 in pc_dims:
for i2, ch2 in enumerate(elec_dims):
r = n_pc * i1 + p1
c = n_pc * i2 + p2
gs_sel = gs[r, c]
ax_sel = fig.add_subplot(gs_sel)
if c < r:
ax_sel.axis("off")
else:
if r == 0:
ax_sel.set_xlabel("Ch." + str(ch2 + 1) + ":PC" + str(p2 + 1))
ax_sel.xaxis.set_label_position("top")
ax_sel.set_xticks([])
ax_sel.set_yticks([])
ax_sel.spines["right"].set_visible(False)
ax_sel.spines["top"].set_visible(False)
for i, pc in enumerate(pca_scores):
if i1 == i2 and p1 == p2:
h, b, _ = ax_sel.hist(pc[:, i1, p1], bins=50, alpha=0.6, color=colors[i], density=True)
ax_sel.set_ylabel("Ch." + str(ch1 + 1) + ":PC" + str(p1 + 1))
else:
ax_sel.plot(
pc[:, i2, p2], pc[:, i1, p1], marker="o", ms=1, ls="", alpha=0.5, color=colors[i]
)
fig.subplots_adjust(wspace=0.02, hspace=0.02)
return ax, pca_scores, pc_comp
def plot_cell_drifts(recgen, ax=None):
"""
Plot drifting positions for all cells
Parameters
----------
recgen : RecordingGenerator
Recording generator object
ax : axis
Matplotlib axis
Returns
-------
ax : axis
Matplotlib axis
"""
import matplotlib.pylab as plt
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
assert recgen.drift_list is not None, "No drift info is available"
drift_list = recgen.drift_list
fs = recgen.info["recordings"]["fs"]
locations = recgen.template_locations
drift_steps = locations.shape[1]
for i, st in enumerate(recgen.spiketrains):
spike_indices = (st.magnitude * fs).astype("int")
drift_idxs = compute_drift_idxs_from_drift_list(i, spike_indices, drift_list, fs)
drift_idxs = drift_idxs.clip(0, drift_steps - 1)
loc = locations[i]
drifting = False
if "drifting" in st.annotations:
if st.annotations["drifting"]:
drifting = True
if drifting:
loc_drift = loc[drift_idxs, 2]
else:
n_steps = loc.shape[0]
loc_drift = [loc[n_steps // 2, 2]] * len(st.magnitude)
ax.plot(st.magnitude, loc_drift, label=f"Unit {i}")
ax.legend()
return ax
######### HELPER FUNCTIONS #########
def _resample_parallel(i, template, up, down, drifting):
"""
Resamples a template to a specified sampling frequency.
Parameters
----------
template : np.array
Array with one template (n_channels, n_samples) or (n_drift, n_channels, n_samples) if drifting
up : float
The original sampling frequency in Hz
down : float
The new sampling frequency in Hz
drifting : bool
If True templates are assumed to be drifting
Returns
-------
template_rs : np.array
Array with resampled template (n_channels, n_resample)
or (n_drift, n_channels, n_resample)
"""
if not drifting:
tem_poly = ss.resample_poly(template, up, down, axis=1)
else:
tem_poly = ss.resample_poly(template, up, down, axis=2)
return tem_poly
def _jitter_parallel(i, template, upsample, fs, n_jitters, jitter, drifting, verbose, templates_jitter):
"""
Adds jittered replicas to one template.
Parameters
----------
template : np.array
Array with templates (n_channels, n_samples) or (n_drift, n_channels, n_samples) if drifting
upsample : int
Factor for upsampling the templates
n_jitters : int
Number of jittered copies for each template
jitter : quantity
Jitter in time for shifting the template
drifting : bool
If True templates are assumed to be drifting
verbose : bool
If True output is verbose
Returns
-------
template_jitt : np.array
Array with one jittered template (n_jitters, n_channels, n_samples)
or (n_drift, n_jitters, n_channels, n_samples) if drifting
"""
rng = np.random.RandomState(i)
if not drifting:
template_jitter = np.zeros((n_jitters, template.shape[0], template.shape[1]))
temp_up = ss.resample_poly(template, upsample, 1, axis=1)
nsamples_up = temp_up.shape[1]
for n in np.arange(n_jitters):
# align waveform
shift = int((jitter * (rng.rand() - 0.5) * upsample * fs).magnitude)
if shift > 0:
t_jitt = np.pad(temp_up, [(0, 0), (np.abs(shift), 0)], "constant")[:, :nsamples_up]
elif shift < 0:
t_jitt = np.pad(temp_up, [(0, 0), (0, np.abs(shift))], "constant")[:, -nsamples_up:]
else:
t_jitt = temp_up
temp_down = t_jitt[:, ::upsample]
template_jitter[n] = temp_down
else:
if verbose:
print("Jittering: neuron ", i)
template_jitter = np.zeros((template.shape[0], n_jitters, template.shape[1], template.shape[2]))
for tp, tem_p in enumerate(template):
temp_up = ss.resample_poly(tem_p, upsample, 1, axis=1)
nsamples_up = temp_up.shape[1]
for n in np.arange(n_jitters):
# align waveform
shift = int((jitter * rng.rand() * upsample * fs).magnitude)
if shift > 0:
t_jitt = np.pad(temp_up, [(0, 0), (np.abs(shift), 0)], "constant")[:, :nsamples_up]
elif shift < 0:
t_jitt = np.pad(temp_up, [(0, 0), (0, np.abs(shift))], "constant")[:, -nsamples_up:]
else:
t_jitt = temp_up
temp_down = t_jitt[:, ::upsample]
template_jitter[tp, n] = temp_down
if templates_jitter is None:
return template_jitter
else:
templates_jitter[i] = template_jitter
def _pad_parallel(i, template, pad_samples, drifting, verbose, templates_pad):
"""
Pads one template on both ends.
Parameters
----------
template : np.array
Array with templates (n_channels, n_samples) or (n_drift n_channels, n_samples) if drifting
pad_samples : list
List of 2 ints with number of samples for padding before and after
drifting : bool
If True templates are assumed to be drifting
verbose : bool
If True output is verbose
Returns
-------
template_pad : np.array
Array with padded template (n_channels, n_padded_sample)
or (n_drift, n_channels, n_padded_sample) if drifting
"""
if not drifting:
tem_pad = cubic_padding(template, pad_samples)
else:
if verbose:
print("Padding edges: neuron ", i)
padded_template_samples = template.shape[-1] + np.sum(pad_samples)
tem_pad = np.zeros((template.shape[0], template.shape[1], padded_template_samples))
for tp, tem_p in enumerate(template):
tem_pad[tp] = cubic_padding(tem_p, pad_samples)
if templates_pad is None:
return tem_pad
else:
templates_pad[i] = tem_pad
def _annotate_parallel(i, st_i, spiketrains, t_jitt, overlapping_pairs, verbose):
"""
Helper function to annotate spike trains in parallel.
Parameters
----------
i : int
Index of spike train
st_i : neo.SpikeTrain
Spike train to be processed
spiketrains : list
List of neo spiketrains
t_jitt : Quantity
Time jitter to consider overlapping spikes in time (default 1 ms)
overlapping_pairs : np.array
Array with overlapping information between spike trains (n_spiketrains, 2)
verbose : bool
If True output is verbose
"""
if verbose:
print("Annotating overlapping spike train ", i)
over = np.array(["NONE"] * len(st_i))
for i_sp, t_i in enumerate(st_i):
for j, st_j in enumerate(spiketrains):
if i != j:
# find overlapping
id_over = np.where((st_j > t_i - t_jitt) & (st_j < t_i + t_jitt))[0]
if not np.any(overlapping_pairs):
if len(id_over) != 0:
over[i_sp] = "TO"
else:
pair = [i, j]
pair_i = [j, i]
if np.any([np.all(pair == p) for p in overlapping_pairs]) or np.any(
[np.all(pair_i == p) for p in overlapping_pairs]
):
if len(id_over) != 0:
over[i_sp] = "STO"
else:
if len(id_over) != 0:
over[i_sp] = "TO"
over[over == "NONE"] = "NO"
st_i.annotate(overlap=over)
return st_i