Source code for bluepyopt.ephys.efeatures
"""eFeature classes"""
"""
Copyright (c) 2016-2020, EPFL/Blue Brain Project
This file is part of BluePyOpt <https://github.com/BlueBrain/BluePyOpt>
This library is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License version 3.0 as published
by the Free Software Foundation.
This library is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
details.
You should have received a copy of the GNU Lesser General Public License
along with this library; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""
# pylint: disable=R0914
import logging
import numpy as np
from bluepyopt.ephys.base import BaseEPhys
from bluepyopt.ephys.serializer import DictMixin
from .extra_features_utils import *
logger = logging.getLogger(__name__)
def masked_cosine_distance(exp, model):
from scipy.spatial import distance
exp_mask = np.isfinite(exp)
model_mask = np.isfinite(model)
valid_mask = exp_mask & model_mask
score = distance.cosine(
exp[valid_mask], model[valid_mask]
)
score *= sum(exp_mask) / len(valid_mask)
return score
[docs]
class eFELFeature(EFeature, DictMixin):
"""eFEL feature"""
SERIALIZED_FIELDS = ('name', 'efel_feature_name', 'recording_names',
'stim_start', 'stim_end', 'exp_mean',
'exp_std', 'threshold', 'comment')
def __init__(
self,
name,
efel_feature_name=None,
recording_names=None,
stim_start=None,
stim_end=None,
exp_mean=None,
exp_std=None,
threshold=None,
stimulus_current=None,
comment='',
interp_step=None,
double_settings=None,
int_settings=None,
string_settings=None,
force_max_score=False,
max_score=250
):
"""Constructor
Args:
name (str): name of the eFELFeature object
efel_feature_name (str): name of the eFeature in the eFEL library
(ex: 'AP1_peak')
recording_names (dict): eFEL features can accept several recordings
as input
stim_start (float): stimulation start time (ms)
stim_end (float): stimulation end time (ms)
exp_mean (float): experimental mean of this eFeature
exp_std(float): experimental standard deviation of this eFeature
threshold(float): spike detection threshold (mV)
comment (str): comment
interp_step(float): interpolation step (ms)
double_settings(dict): dictionary with efel double settings that
should be set before extracting the features
int_settings(dict): dictionary with efel int settings that
should be set before extracting the features
string_settings(dict): dictionary with efel string settings that
should be set before extracting the features
"""
super(eFELFeature, self).__init__(name, comment)
self.recording_names = recording_names
self.efel_feature_name = efel_feature_name
self.exp_mean = exp_mean
self.exp_std = exp_std
self.stim_start = stim_start
self.stim_end = stim_end
self.threshold = threshold
self.interp_step = interp_step
self.stimulus_current = stimulus_current
self.double_settings = double_settings
self.int_settings = int_settings
self.string_settings = string_settings
self.force_max_score = force_max_score
self.max_score = max_score
def _construct_efel_trace(self, responses):
"""Construct trace that can be passed to eFEL"""
trace = {}
if '' not in self.recording_names:
raise Exception(
'eFELFeature: \'\' needs to be in recording_names')
for location_name, recording_name in self.recording_names.items():
if location_name == '':
postfix = ''
else:
postfix = ';%s' % location_name
if recording_name not in responses:
logger.debug(
"Recording named %s not found in responses %s",
recording_name,
str(responses))
return None
if responses[self.recording_names['']] is None or \
responses[recording_name] is None:
return None
trace['T%s' % postfix] = \
responses[self.recording_names['']]['time']
trace['V%s' % postfix] = responses[recording_name]['voltage']
trace['stim_start%s' % postfix] = [self.stim_start]
trace['stim_end%s' % postfix] = [self.stim_end]
return trace
def _setup_efel(self):
"""Set up efel before extracting the feature"""
import efel
efel.reset()
if self.threshold is not None:
efel.setThreshold(self.threshold)
if self.stimulus_current is not None:
efel.setDoubleSetting('stimulus_current', self.stimulus_current)
if self.interp_step is not None:
efel.setDoubleSetting('interp_step', self.interp_step)
if self.double_settings is not None:
for setting_name, setting_value in self.double_settings.items():
efel.setDoubleSetting(setting_name, setting_value)
if self.int_settings is not None:
for setting_name, setting_value in self.int_settings.items():
efel.setIntSetting(setting_name, setting_value)
if self.string_settings is not None:
for setting_name, setting_value in self.string_settings.items():
efel.setStrSetting(setting_name, setting_value)
[docs]
def calculate_feature(self, responses, raise_warnings=False):
"""Calculate feature value"""
efel_trace = self._construct_efel_trace(responses)
if efel_trace is None:
feature_value = None
else:
self._setup_efel()
import efel
values = efel.getMeanFeatureValues(
[efel_trace],
[self.efel_feature_name],
raise_warnings=raise_warnings)
feature_value = values[0][self.efel_feature_name]
efel.reset()
logger.debug(
'Calculated value for %s: %s',
self.name,
str(feature_value))
return feature_value
[docs]
def calculate_score(self, responses, trace_check=False):
"""Calculate the score"""
efel_trace = self._construct_efel_trace(responses)
if efel_trace is None:
score = self.max_score
else:
self._setup_efel()
import efel
score = efel.getDistance(
efel_trace,
self.efel_feature_name,
self.exp_mean,
self.exp_std,
trace_check=trace_check,
error_dist=self.max_score
)
if self.force_max_score:
score = min(score, self.max_score)
efel.reset()
logger.debug('Calculated score for %s: %f', self.name, score)
return score
def __str__(self):
"""String representation"""
return "%s for %s with stim start %s and end %s, " \
"exp mean %s and std %s and AP threshold override %s" % \
(self.efel_feature_name,
self.recording_names,
self.stim_start,
self.stim_end,
self.exp_mean,
self.exp_std,
self.threshold)
[docs]
class extraFELFeature(EFeature, DictMixin):
"""extraFEL feature"""
SERIALIZED_FIELDS = ('name', 'extrafel_feature_name', 'recording_names',
'somatic_recording_name', 'fcut', 'fs',
'channel_ids', 'stim_start', 'stim_end',
'exp_mean', 'exp_std', 'threshold', 'comment')
def __init__(
self,
name,
extrafel_feature_name=None,
recording_names=None,
somatic_recording_name=None,
fcut=None,
fs=None,
filt_type=None,
ms_cut=None,
upsample=None,
skip_first_spike=True,
skip_last_spike=True,
channel_ids=None,
stim_start=None,
stim_end=None,
exp_mean=None,
exp_std=None,
threshold=None,
comment='',
interp_step=None,
double_settings=None,
int_settings=None,
force_max_score=False,
max_score=250,
):
"""Constructor
Args:
name (str): name of the extraFELFeature object
extrafel_feature_name (str): name of the eFeature in the
spikefeatures library (ex: 'halfwidth')
recording_names (dict): eFEL features can accept several
recordings as input
somatic_recording_name (str): intracellualar recording from soma,
used to detect spikes. If None, spikes are detected from
extracellular trace
fcut (float, array, or None): cutoff frequency(ies) for filter. If
float, a high-pass filter is used. If array-like a bandpass
filter is used. If None, traces are note filtered
fs (float): sampling frequency to resample extracellular traces
(in kHz)
filt_type (str): type of the bandpass filter used
(default 'filtfilt')
ms_cut (float, list, or None): cut in ms before and after the
intra peak. If scalar, the cut is symmetrical
upsample (int, or None): upsample factor for average waveform
before computing features
skip_first_spike (bool): if True, the first spike is skipped
before computing the average waveform
(to avoid artifacts)
skip_last_spike (bool): if True, the last spike is skipped
before computing the average waveform
(to avoid artifacts)
channel_ids (int, np.array, or None): if None, all channels are
used to compute the feature and calculate the score
(using the cosine_dist). If int, a single channel is used and
the score is the normalised deviation form the exp value.
If list/array, the cosine distance is computed over a subset
of channels
stim_start (float): stimulation start time (ms)
stim_end (float): stimulation end time (ms)
exp_mean (list of floats): experimental mean of this eFeature
exp_std (list of floats): experimental standard deviation
of this eFeature
threshold (float): spike detection threshold (mV)
comment (str): comment
interp_step (float): interpolation step (ms)
double_settings (dict): dictionary with efel double settings that
should be set before extracting the features
int_settings (dict): dictionary with efel int settings that
should be set before extracting the features
"""
super(extraFELFeature, self).__init__(name, comment)
self.recording_names = recording_names
self.somatic_recording_name = somatic_recording_name
self.extrafel_feature_name = extrafel_feature_name
self.fcut = fcut
self.fs = fs
self.filt_type = filt_type
self.ms_cut = ms_cut
self.upsample = upsample
self.skip_first_spike = skip_first_spike
self.skip_last_spike = skip_last_spike
self.channel_ids = channel_ids
self.exp_mean = exp_mean
self.exp_std = exp_std
self.stim_start = stim_start
self.stim_end = stim_end
self.threshold = threshold
self.interp_step = interp_step
self.double_settings = double_settings
self.int_settings = int_settings
self.force_max_score = force_max_score
self.max_score = max_score
def _construct_somatic_efel_trace(self, responses):
"""Construct trace that can be passed to eFEL"""
trace = {}
if self.somatic_recording_name not in responses:
logger.debug(
"Recording named %s not found in responses %s",
self.somatic_recording_name,
str(responses),
)
return None
if responses[self.somatic_recording_name] is None:
return None
response = responses[self.somatic_recording_name]
trace["T"] = response["time"]
trace["V"] = response["voltage"]
trace["stim_start"] = [self.stim_start]
trace["stim_end"] = [self.stim_end]
return trace
def _setup_efel(self):
"""Set up efel before extracting the feature"""
import efel
efel.reset()
if self.threshold is not None:
efel.setThreshold(self.threshold)
if self.interp_step is not None:
efel.setDoubleSetting("interp_step", self.interp_step)
if self.double_settings is not None:
for setting_name, setting_value in self.double_settings.items():
efel.setDoubleSetting(setting_name, setting_value)
if self.int_settings is not None:
for setting_name, setting_value in self.int_settings.items():
efel.setIntSetting(setting_name, setting_value)
def _get_peak_times(self, responses, raise_warnings=False):
efel_trace = self._construct_somatic_efel_trace(responses)
if efel_trace is None:
peak_times = None
else:
self._setup_efel()
import efel
peaks = efel.getFeatureValues(
[efel_trace], ["peak_time"], raise_warnings=raise_warnings
)
peak_times = peaks[0]["peak_time"]
efel.reset()
return peak_times
def calculate_feature(
self,
responses,
raise_warnings=False,
return_waveforms=False,
):
from .extra_features_utils import calculate_features
"""Calculate feature value"""
peak_times = self._get_peak_times(
responses, raise_warnings=raise_warnings
)
if peak_times is None:
if return_waveforms:
return None, None
else:
return None
if len(peak_times) > 1 and self.skip_first_spike:
peak_times = peak_times[1:]
if len(peak_times) > 1 and self.skip_last_spike:
peak_times = peak_times[:-1]
if responses[self.recording_names[""]] is not None:
response = responses[self.recording_names[""]]
else:
if return_waveforms:
return None, None
else:
return None
if np.std(np.diff(response["time"])) > 0.001 * np.mean(
np.diff(response["time"])
):
assert self.fs is not None
logger.info("extraFELFeature.calculate_feature: interpolate")
response_interp = _interpolate_response(response, fs=self.fs)
else:
response_interp = response
if self.fcut is not None:
logger.info("extraFELFeature.calculate_feature: enabled")
response_filter = _filter_response(response_interp,
fcut=self.fcut,
filt_type=self.filt_type)
else:
logger.info("extraFELFeature.calculate_feature: filter disabled")
response_filter = response_interp
ewf = _get_waveforms(response_filter, peak_times, self.ms_cut)
mean_wf = np.mean(ewf, axis=0)
values = calculate_features(
mean_wf,
self.fs * 1000,
upsample=self.upsample,
feature_names=[self.extrafel_feature_name]
)
feature_value = values[self.extrafel_feature_name]
if self.channel_ids is not None:
feature_value = feature_value[self.channel_ids]
logger.debug(
"Calculated value for %s: %s", self.name, str(feature_value)
)
if return_waveforms:
return feature_value, mean_wf
else:
return feature_value
[docs]
def calculate_score(self, responses, trace_check=False):
"""Calculate the score"""
if (
responses[self.recording_names[""].replace("soma.v",
"MEA.LFP")]
is None
or responses[self.recording_names[""]] is None
):
return self.max_score
feature_value = self.calculate_feature(responses)
if np.isscalar(feature_value):
# scalar feature
if np.isfinite(feature_value):
score = np.abs((feature_value - self.exp_mean)) / self.exp_std
else:
score = self.max_score
if not np.isfinite(score):
logger.debug(
f"Found score nan value {self.extrafel_feature_name} "
f"- std: {self.exp_std} - channel: {self.channel_ids}"
)
score = self.max_score
else:
score = masked_cosine_distance(
np.asarray(self.exp_mean),
np.asarray(feature_value)
)
if np.isnan(score):
score = self.max_score
if self.force_max_score:
score = min(score, self.max_score)
logger.debug("Calculated score for %s: %f", self.name, score)
return score
def __str__(self):
"""String representation"""
return ("%s for %s with stim start %s and end %s, "
"exp mean %s and std %s and AP threshold override %s"
% (self.extrafel_feature_name,
self.recording_names,
self.stim_start,
self.stim_end,
self.exp_mean,
self.exp_std,
self.threshold)
)
def _interpolate_response(response, fs=20.0):
from scipy.interpolate import interp1d
x = response["time"]
y = response["voltage"]
f = interp1d(x, y, axis=1)
xnew = np.arange(np.min(x), np.max(x), 1.0 / fs)
ynew = f(xnew) # use interpolation function returned by `interp1d`
response_new = {}
response_new["time"] = xnew
response_new["voltage"] = ynew
return response_new
def _filter_response(response, fcut=[0.5, 6000], order=2, filt_type="lfilter"):
import scipy.signal as ss
fs = 1 / np.mean(np.diff(response["time"])) * 1000
fn = fs / 2.0
trace = response["voltage"]
if isinstance(fcut, (float, int, np.floating, np.integer)):
btype = "highpass"
band = fcut / fn
else:
assert isinstance(fcut, (list, np.ndarray)) and len(fcut) == 2
btype = "bandpass"
band = np.array(fcut) / fn
b, a = ss.butter(order, band, btype=btype)
if len(trace.shape) == 2:
if filt_type == "filtfilt":
filtered = ss.filtfilt(b, a, trace, axis=1)
else:
filtered = ss.lfilter(b, a, trace, axis=1)
else:
if filt_type == "filtfilt":
filtered = ss.filtfilt(b, a, trace)
else:
filtered = ss.lfilter(b, a, trace)
response_new = {}
response_new["time"] = response["time"]
response_new["voltage"] = filtered
return response_new
def _get_waveforms(response, peak_times, snippet_len_ms):
times = response["time"]
traces = response["voltage"]
assert np.std(np.diff(times)) < 0.001 * np.mean(
np.diff(times)
), "Sampling frequency must be constant"
fs = 1.0 / np.mean(np.diff(times)) # kHz
reference_frames = (peak_times * fs).astype(int)
if isinstance(snippet_len_ms, (tuple, list, np.ndarray)):
snippet_len_before = int(snippet_len_ms[0] * fs)
snippet_len_after = int(snippet_len_ms[1] * fs)
else:
snippet_len_before = int((snippet_len_ms + 1) / 2 * fs)
snippet_len_after = int((snippet_len_ms - snippet_len_before) * fs)
num_snippets = len(peak_times)
if len(traces.shape) == 2:
num_channels = traces.shape[0]
else:
num_channels = 1
traces = traces[np.newaxis, :]
num_frames = len(times)
snippet_len_total = int(snippet_len_before + snippet_len_after)
waveforms = np.zeros(
(num_snippets, num_channels, snippet_len_total), dtype=traces.dtype
)
for i in range(num_snippets):
snippet_chunk = np.zeros(
(num_channels, snippet_len_total), dtype=traces.dtype
)
if 0 <= reference_frames[i] < num_frames:
snippet_range = np.array(
[
int(reference_frames[i]) - snippet_len_before,
int(reference_frames[i]) + snippet_len_after,
]
)
snippet_buffer = np.array([0, snippet_len_total], dtype="int")
# The following handles the out-of-bounds cases
if snippet_range[0] < 0:
snippet_buffer[0] -= snippet_range[0]
snippet_range[0] -= snippet_range[0]
if snippet_range[1] >= num_frames:
snippet_buffer[1] -= snippet_range[1] - num_frames
snippet_range[1] -= snippet_range[1] - num_frames
snippet_chunk[:, snippet_buffer[0]:snippet_buffer[1]] = \
traces[:, snippet_range[0]:snippet_range[1]]
waveforms[i] = snippet_chunk
return waveforms