"""Wrappers for MNE functions to perform preprocessing.
"""
# Authors: Andrew Quinn <a.quinn@bham.ac.uk>
# Chetan Gohil <chetan.gohil@psych.ox.ac.uk>
# Mats van Es <mats.vanes@psych.ox.ac.uk>
import logging
import mne
import numpy as np
import sails
import yaml
import pickle
from os.path import exists
from scipy import stats
from pathlib import Path
import matplotlib.pyplot as plt
import glmtools
from ..glm import glm_epochs, glm_spectrum, glm_irasa, group_glm_epochs, group_glm_spectrum, MaxStatPermuteGLMSpectrum, ClusterPermuteGLMSpectrum
from ..glm.glm_base import SensorMaxStatPerm, SensorClusterPerm
[docs]logger = logging.getLogger(__name__)
# --------------------------------------------------------------
# osl-ephys preprocessing functions
#
[docs]def gesd(x, alpha=0.05, p_out=.1, outlier_side=0):
"""Detect outliers using Generalized ESD test
Parameters
----------
x : vector
Data set containing outliers
alpha : scalar
Significance level to detect at (default = 0.05)
p_out : int
Maximum number of outliers to detect (default = 10% of data set)
outlier_side : {-1,0,1}
Specify sidedness of the test
- outlier_side = -1 -> outliers are all smaller
- outlier_side = 0 -> outliers could be small/negative or large/positive (default)
- outlier_side = 1 -> outliers are all larger
Returns
-------
idx : boolean vector
Boolean array with TRUE wherever a sample is an outlier
x2 : vector
Input array with outliers removed
References
----------
B. Rosner (1983). Percentage Points for a Generalized ESD Many-Outlier Procedure. Technometrics 25(2), pp. 165-172.
http://www.jstor.org/stable/1268549?seq=1
"""
if outlier_side == 0:
alpha = alpha/2
if not isinstance(x, np.ndarray):
x = np.asarray(x)
n_out = int(np.ceil(len(x)*p_out))
if np.any(np.isnan(x)):
# Need to find outliers only in finite x
y = np.where(np.isnan(x))[0]
idx1, x2 = gesd(x[np.isfinite(x)], alpha, n_out, outlier_side)
# idx1 has the indexes of y which were marked as outliers
# the value of y contains the corresponding indexes of x that are outliers
idx = np.zeros_like(x).astype(bool)
idx[y[idx1]] = True
n = len(x)
temp = x.copy()
R = np.zeros((n_out,))
rm_idx = np.zeros((n_out,), dtype=int)
lam = np.zeros((n_out,))
for j in range(0, int(n_out)):
i = j+1
if outlier_side == -1:
rm_idx[j] = np.nanargmin(temp)
sample = np.nanmin(temp)
R[j] = np.nanmean(temp) - sample
elif outlier_side == 0:
rm_idx[j] = int(np.nanargmax(abs(temp-np.nanmean(temp))))
R[j] = np.nanmax(abs(temp-np.nanmean(temp)))
elif outlier_side == 1:
rm_idx[j] = np.nanargmax(temp)
sample = np.nanmax(temp)
R[j] = sample - np.nanmean(temp)
R[j] = R[j] / np.nanstd(temp)
temp[int(rm_idx[j])] = np.nan
p = 1-alpha/(n-i+1)
t = stats.t.ppf(p, n-i-1)
lam[j] = ((n-i) * t) / (np.sqrt((n-i-1+t**2)*(n-i+1)))
# Create a boolean array of outliers
idx = np.zeros((n,)).astype(bool)
idx[rm_idx[np.where(R > lam)[0]]] = True
x2 = x[~idx]
return idx, x2
[docs]def _find_outliers_in_dims(X, axis=-1, metric_func=np.std, gesd_args=None):
"""Find outliers across specified dimensions of an array"""
if gesd_args is None:
gesd_args = {}
if axis == -1:
axis = np.arange(X.ndim)[axis]
squashed_axes = tuple(np.setdiff1d(np.arange(X.ndim), axis))
metric = metric_func(X, axis=squashed_axes)
rm_ind, _ = gesd(metric, **gesd_args)
return rm_ind
[docs]def _find_outliers_in_segments(X, axis=-1, segment_len=100,
metric_func=np.std, gesd_args=None,
channel_wise = False, channel_axis = 0, threshold = 0.05):
"""
Identify outlier segments within an array.
Parameters:
- X: np.ndarray
Input data array with dimensions (channel, time).
- axis: int
The axis along which to segment the data (default is -1, the last axis).
- channel_axis: int
The axis along which channels are stored (default is 0, the first axis).
- segment_len: iant
Length of each segment along the specified axis.
- metric_func: callable
Function to compute the metric for each segment (default is np.std).
- gesd_args: dict
Additional arguments to pass to the GESD.
- channel_wise: bool
If True, the function will treat each channel seperately when detecting bad segments.
- channel_axis: int
The axis to treat as the channel axis. Only used when ``channel_wise=True``.
- threshold: str or float
Threshold for outlier detection. Only used when ``channel_wise=True``.
If 'any', a segment is marked as an outlier if
any of its channels is an outlier.
If a float, a segment is marked as an outlier if the fraction of outlier
channels exceeds the threshold.
Returns:
- bads: np.ndarray
Boolean array indicating outlier segments.
"""
if gesd_args is None:
gesd_args = {}
if axis == -1:
axis = np.arange(X.ndim)[axis]
# Preallocate some variables and prepare to slice data array
starts = np.arange(0, X.shape[axis], segment_len)
num_segments = len(starts)
bad_inds = np.zeros(X.shape[axis])*np.nan
slc = [slice(None)] * X.ndim
if channel_wise:
if channel_axis == -1:
channel_axis = np.arange(X.ndim)[channel_axis]
if axis == channel_axis:
raise ValueError('The time axis and channel axis cannot be the same.')
num_channels = X.shape[channel_axis]
if threshold != 'any':
if not isinstance(threshold, (int, float)):
raise ValueError("Threshold must be an integer or float or 'any'.")
if not 0 < threshold <= 1:
raise ValueError('Threshold must be between 0 and 1 or "any".')
if num_channels*threshold < 1:
raise ValueError('Threshold*n_channels must be at least 1 channel.')
metric = np.zeros((num_channels, num_segments))
for ii, start in enumerate(starts):
if ii == num_segments - 1:
stop = None
else:
stop = start + segment_len
# Update slice on dim of interest
slc[axis] = slice(start, stop)
for ch in range(num_channels):
# Update the slice object for channels
slc[channel_axis] = ch
# Compute metric for current chunk
metric[ch, ii] = np.nan_to_num(metric_func(X[tuple(slc)]), nan=0)
bad_inds[slc[axis]] = ii
bads = np.zeros_like(X, dtype=bool)
for ch in range(num_channels):
metric_ch = metric[ch]
# Apply the GESD test to identify outlier segments for each channel
rm_ind, _ = gesd(metric_ch, **gesd_args)
# Convert to int indices
rm_ind = np.where(rm_ind)[0]
# Convert to bool in original space of defined axis
bads_ch = np.isin(bad_inds, rm_ind)
# Store the boolean array for each channel
bads[ch] = bads_ch
# Combine the boolean arrays for each channel
if threshold != 'any':
bads = np.sum(bads,axis=0) >= threshold*num_channels
else:
bads = np.any(bads,axis=0)
else:
metric = np.zeros((num_segments, ))
for ii, start in enumerate(starts):
if ii == num_segments - 1:
stop = None
else:
stop = start + segment_len
# Update slice on dim of interest
slc[axis] = slice(start, stop)
# Compute metric for current chunk
metric[ii] = np.nan_to_num(metric_func(X[tuple(slc)]), nan=0)
# Store which chunk we've used
bad_inds[slc[axis]] = ii
rm_ind, _ = gesd(metric, **gesd_args)
rm_ind = np.where(rm_ind)[0]
bads = np.isin(bad_inds, rm_ind)
return bads
[docs]def detect_artefacts(X, axis=None, reject_mode='dim', metric_func=np.std,
segment_len=100, gesd_args=None, ret_mode='bad_inds',
channel_wise = False, channel_axis = 0, channel_threshold = 0.05):
"""Detect bad observations or segments in a dataset
Parameters
----------
X : ndarray
Array to find artefacts in.
axis : int
Index of the axis to detect artefacts in
reject_mode : {'dim' | 'segments'}
Flag indicating whether to detect outliers across a dimension (dim;
default) or whether to split a dim into segments and detect outliers in
the them (segments)
metric_func : function
Function defining metric to detect outliers on. Defaults to np.std but
can be any function taking an array and returning a single number.
segement_len : int > 0
Integer window length of dummy epochs for bad_segment detection
gesd_args : dict
Dictionary of arguments to pass to gesd
ret_mode : {'good_inds','bad_inds','zero_bads','nan_bads'}
Flag indicating whether to return the indices for good observations,
indices for bad observations (default), the input data with outliers
removed (zero_bads) or the input data with outliers replaced with nans
(nan_bads)
channel_wise : bool
If True, the function will treat each channel seperately when detecting bad segments,
only used when ``reject_mode='segments'``.
channel_axis : int
The axis to treat as the channel axis. Only used when ``channel_wise=True``.
channel_threshold : str or float
The treshold to use for channel-wise detection. Only used when ``channel_wise=True``.
Returns
-------
ndarray
If ret_mode is ``'bad_inds'`` or ``'good_inds'``, this returns a boolean vector
of length ``X.shape[axis]`` indicating good or bad samples. If ``ret_mode`` is
``'zero_bads'`` or ``'nan_bads'`` this returns an array copy of the input data
``X`` with bad samples set to zero or ``np.nan`` respectively.
"""
if reject_mode not in ['dim', 'segments']:
raise ValueError("reject_mode: '{0}' not recognised".format(reject_mode))
if ret_mode not in ['bad_inds', 'good_inds', 'zero_bads', 'nan_bads']:
raise ValueError("ret_mode: '{0}' not recognised")
if axis is None or axis > X.ndim:
raise ValueError('bad axis')
if reject_mode == 'dim':
bad_inds = _find_outliers_in_dims(X, axis=axis, metric_func=metric_func, gesd_args=gesd_args)
elif reject_mode == 'segments':
bad_inds = _find_outliers_in_segments(X, axis=axis,
segment_len=segment_len,
metric_func=metric_func,
gesd_args=gesd_args,channel_wise=channel_wise,
channel_axis = channel_axis,
threshold = channel_threshold)
if ret_mode == 'bad_inds':
return bad_inds
elif ret_mode == 'good_inds':
return bad_inds == False # noqa: E712
elif ret_mode in ['zero_bads', 'nan_bads']:
out = X.copy()
slc = []
for ii in range(X.ndim):
if ii == axis:
slc.append(bad_inds)
else:
slc.append(slice(None))
slc = tuple(slc)
if ret_mode == 'zero_bads':
out[slc] = 0
return out
elif ret_mode == 'nan_bads':
out[slc] = np.nan
return out
[docs]def detect_maxfilt_zeros(raw, use_maxfilter_log=True):
"""This function tries to load the maxfilter log files in order
to annotate zeroed out data in the :py:class:`mne.io.Raw <mne.io.Raw>` object. It
assumes that the log file is in the same directory as the
raw file and has the same name, but with the extension ``.log``.
If the log file can't be found, it will look for zeros in the data.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
Returns
-------
bad_inds : np.array of bool (n_times,) or None
Boolean array indicating which time points are zeroed out.
"""
if 'mag' in raw.get_channel_types() or 'grad' in raw.get_channel_types():
logger.info("No MEG data detected - not looking for MaxFilter zeroed-out data.")
return np.zeros(raw.n_times).astype(bool)
if raw.filenames[0] is not None:
log_fname = str(raw.filenames[0]).replace('.fif', '.log')
if 'log_fname' in locals() and exists(log_fname) and use_maxfilter_log:
try:
starttime = raw.first_time
endtime = raw._last_time
with open(log_fname) as f:
lines = f.readlines()
# for determining the start, end and point
phrase_ndataseg = ['(', ' data buffers)']
gotduration = False
# for detecting zeroed out data
zeroed=[]
phrase_zero = ['Time ', ': cont HPI is off, data block is skipped!']
for line in lines:
if gotduration == False and phrase_ndataseg[1] in line:
gotduration = True
n_dataseg = float(line.split(phrase_ndataseg[0])[1].split(phrase_ndataseg[1])[0]) # number of segments
if phrase_zero[1] in line:
zeroed.append(float(line.split(phrase_zero[0])[1].split(phrase_zero[1])[0])) # in seconds
duration = raw.n_times/n_dataseg # duration of each data segment in samples
starts = (np.array(zeroed) - starttime) * raw.info['sfreq'] # in samples
bad_inds = np.zeros(raw.n_times)
for ii in range(len(starts)):
stop = starts[ii] + duration # in samples
bad_inds[int(starts[ii]):int(stop)] = 1
except:
logger.info("detecting zeroed-out data from maxfilter log file failed")
detect_maxfilt_zeros(raw, use_maxfilter_log=False)
else:
logger.info("Detecting zeroed-out data from the data itself")
d = raw.get_data(picks='meg', reject_by_annotation='omit')
bad_inds = np.all(d == 0, axis=0)
# check if most of the data is marked as bad
if np.sum(bad_inds)/len(bad_inds) > 0.5:
raise RuntimeError("More than half of the data is marked as bad. This often happens when maxfilter movement compensation is used but not enough HPI coils are useable. Please check your data and/or maxfilter settings.")
return bad_inds.astype(bool)
[docs]def bad_segments(
raw,
picks,
segment_len=1000,
significance_level=0.05,
metric='std',
ref_meg='auto',
mode=None,
detect_zeros=True,
channel_wise=False,
channel_axis = 0,
channel_threshold = 0.05,
):
"""Set bad segments in an MNE :py:class:`Raw <mne.io.Raw>` object as defined by the Generalized ESD test
in :py:func:`osl_ephys.preprocessing.osl_wrappers.gesd <osl_ephys.preprocessing.osl_wrappers.gesd>`.
This function is typically used by calling :py:func:`run_osl_bad_segments <osl_ephys.preprocessing.osl_wrappers.run_osl_bad_segments>`.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
picks : str
Channel types to pick. See Notes for recommendations.
segment_len : int
Window length to divide the data into (non-overlapping).
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
metric : str
Metric to use. Could be ``'std'``, ``'var'`` or ``'kurtosis'``.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
mode : str
Should be ``None`` ``'diff'`` or ``'maxfilter'``.
When ``mode='diff'`` we calculate a difference time series before
detecting bad segments. When ``mode='maxfilter'`` we only mark the
segments with zeros from MaxFiltering as bad.
detect_zeros : bool
Should we detect segments of zeros based on the maxfilter files?
channel_wise : bool
If True, the function will treat each channel seperately.
channel_axis : int
The axis to treat as the channel axis. Only used when ``channel_wise=True``.
channel_threshold : str or float
The treshold to use for channel-wise detection. Only used when ``channel_wise=True``.
Returns
-------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object with bad segments annotated.
Notes
-----
Note that for Elekta/MEGIN data, we recommend using ``picks: 'mag'`` or ``picks: 'grad'`` separately (in no particular order).
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
Thus, it is recommended to use ``picks:'mag'`` in combination with ``ref_mag: False``, and ``picks:'grad'`` separately (in no particular order).
"""
gesd_args = {'alpha': significance_level}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(raw.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(raw.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(raw.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(raw.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "emg":
chinds = mne.pick_types(raw.info, emg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(raw.info, misc=True, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
if mode is None:
if detect_zeros:
bdinds_maxfilt = detect_maxfilt_zeros(raw)
else:
bdinds_maxfilt = None
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
elif mode == "diff":
bdinds_maxfilt = None
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
XX = np.diff(XX, axis=1)
XX_times = XX_times[1:] # remove the first time point
elif mode == "maxfilter":
bdinds_maxfilt = detect_maxfilt_zeros(raw)
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
allowed_metrics = ["std", "var", "kurtosis"]
if metric not in allowed_metrics:
raise ValueError(f"metric {metric} unknown.")
if metric == "std":
metric_func = np.std
elif metric == "var":
metric_func = np.var
else:
def kurtosis(inputs):
return stats.kurtosis(inputs, axis=None)
metric_func = kurtosis
if mode == "maxfilter":
bad_indices = [bdinds_maxfilt]
else:
bdinds = detect_artefacts(
XX,
axis=1,
reject_mode="segments",
metric_func=metric_func,
segment_len=segment_len,
ret_mode="bad_inds",
gesd_args=gesd_args,
channel_wise = channel_wise,
channel_axis = channel_axis,
channel_threshold = channel_threshold,
)
bad_indices = [bdinds, bdinds_maxfilt]
for count, bdinds in enumerate(bad_indices):
if bdinds is None:
continue
if count==1:
descp1 = count * 'maxfilter_' # when count==0, should be ''
descp2 = ' (maxfilter)'
else:
descp1 = ''
descp2 = ''
onsets = np.where(np.diff(bdinds.astype(float)) == 1)[0]
if bdinds[0]:
onsets = np.r_[0, onsets]
offsets = np.where(np.diff(bdinds.astype(float)) == -1)[0]
if bdinds[-1]:
offsets = np.r_[offsets, len(bdinds) - 1]
assert len(onsets) == len(offsets)
descriptions = np.repeat("{0}bad_segment_{1}".format(descp1, picks), len(onsets))
logger.info("Found {0} bad segments".format(len(onsets)))
onsets_secs = raw.first_samp/raw.info["sfreq"] + XX_times[onsets.astype(int)]
offsets_secs = raw.first_samp/raw.info["sfreq"] + XX_times[offsets.astype(int)]
durations_secs = offsets_secs - onsets_secs
raw.annotations.append(onsets_secs, durations_secs, descriptions)
mod_dur = durations_secs.sum()
full_dur = raw.n_times / raw.info["sfreq"]
pc = (mod_dur / full_dur) * 100
s = "Modality {0}{1} - {2:02f}/{3} seconds rejected ({4:02f}%)"
logger.info(s.format("picks", descp2, mod_dur, full_dur, pc))
return raw
[docs]def bad_channels(raw, picks, ref_meg="auto", significance_level=0.05):
"""Set bad channels in an MNE :py:class:`Raw <mne.io.Raw>` object as defined by the Generalized ESD test in :py:func:`osl_ephys.preprocessing.osl_wrappers.gesd <osl_ephys.preprocessing.osl_wrappers.gesd>`.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
picks : str
Channel types to pick. See Notes for recommendations.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
Returns
-------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE Raw object with bad channels marked.
Notes
-----
Note that for Elekta/MEGIN data, we recommend using ``picks:'mag'`` or ``picks:'grad'`` separately (in no particular order).
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
Thus, it is recommended to use ``picks:'mag'`` in combination with ``ref_mag: False``, and ``picks:'grad'`` separately (in no particular order).
"""
gesd_args = {'alpha': significance_level}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(raw.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(raw.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(raw.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(raw.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(raw.info, misc=True, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
ch_names = np.array(raw.ch_names)[chinds]
bdinds = detect_artefacts(
raw.get_data(picks=chinds),
axis=0,
reject_mode="dim",
ret_mode="bad_inds",
gesd_args=gesd_args,
)
s = "Modality {0} - {1}/{2} channels rejected ({3:02f}%)"
pc = (bdinds.sum() / len(bdinds)) * 100
logger.info(s.format(picks, bdinds.sum(), len(bdinds), pc))
# concatenate newly found bads to existing bads
if np.any(bdinds):
raw.info["bads"].extend(list(ch_names[np.where(bdinds)[0]]))
return raw
[docs]def drop_bad_epochs(
epochs,
picks,
significance_level=0.05,
max_percentage=0.1,
outlier_side=0,
metric='std',
ref_meg='auto',
mode=None,
):
"""Drop bad epochs in an MNE :py:class:`Epochs <mne.Epochs>` object as defined by the Generalized ESD test in :py:func:`osl_ephys.preprocessing.osl_wrappers.gesd <osl_ephys.preprocessing.osl_wrappers.gesd>`.
Parameters
----------
epochs : :py:class:`mne.Epochs <mne.Epochs>`
MNE Epochs object.
picks : str
Channel types to pick.
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
max_percentage : float
Maximum fraction of the epochs to drop. Should be between 0-1.
outlier_side : int
Specify sidedness of the test:
* outlier_side = -1 -> outliers are all smaller
* outlier_side = 0 -> outliers could be small/negative or large/positive (default)
* outlier_side = 1 -> outliers are all larger
metric : str
Metric to use. Could be ``'std'``, ``'var'`` or ``'kurtosis'``.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
mode : str
Should be ``'diff'`` or ``None``. When ``mode='diff'`` we calculate a difference time
series before detecting bad segments.
Returns
-------
epochs : :py:meth:`mne.Epochs <mne.Epochs>`
MNE Epochs object with bad epoches marked.
Notes
-----
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
"""
gesd_args = {
'alpha': significance_level,
'p_out': max_percentage,
'outlier_side': outlier_side,
}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(epochs.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(epochs.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(epochs.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(epochs.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(epochs.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(epochs.info, misc=True, ref_meg=ref_meg, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
if mode is None:
X = epochs.get_data(picks=chinds)
elif mode == "diff":
X = np.diff(epochs.get_data(picks=chinds), axis=-1)
# Get the function used to calculate the evaluation metric
allowed_metrics = ["std", "var", "kurtosis"]
if metric not in allowed_metrics:
raise ValueError(f"metric {metric} unknown.")
if metric == "std":
metric_func = np.std
elif metric == "var":
metric_func = np.var
else:
metric_func = stats.kurtosis
# Calculate the metric used to evaluate whether an epoch is bad
X = metric_func(X, axis=-1)
# Average over channels so we have a metric for each trial
X = np.mean(X, axis=1)
# Use gesd to find outliers
bad_epochs, _ = gesd(X, **gesd_args)
logger.info(
f"Modality {picks} - {np.sum(bad_epochs)}/{X.shape[0]} epochs rejected"
)
# Drop bad epochs
epochs.drop(bad_epochs)
return epochs
[docs]def detect_bad_channels_psd(raw, fmin=2, fmax=80, n_fft=2000, alpha=0.05):
"""
Detect bad channels using PSD and GESD outlier detection.
Parameters
----------
raw : mne.io.Raw
Raw data object.
fmin, fmax : float
Frequency range for PSD computation.
n_fft : int
FFT length for PSD.
alpha : float
Significance level for GESD outlier detection.
Returns
-------
list of str
Detected bad channel names.
"""
# Exclude already-marked bads
good_chans = [ch for ch in raw.ch_names if ch not in raw.info['bads']]
# Compute PSD (bad channels excluded by MNE)
psd = raw.compute_psd(
fmin=fmin, fmax=fmax, n_fft=n_fft,
reject_by_annotation=True, verbose=False
)
pow_data = psd.get_data()
if len(good_chans) != pow_data.shape[0]:
raise RuntimeError(
f"Channel mismatch: {len(good_chans)} chans vs PSD shape {pow_data.shape[0]}"
)
# Check for NaN or zero PSD
bad_forced = [
ch for ch, psd_ch in zip(good_chans, pow_data)
if np.any(np.isnan(psd_ch)) or np.all(psd_ch == 0)
]
if bad_forced:
raise RuntimeError(
f"PSD contains NaNs or all-zero values for channels: {bad_forced}"
)
# Log-transform PSD
pow_log = np.log10(pow_data)
# Detect artefacts with GESD
mask = detect_artefacts(
pow_log, axis=0, reject_mode="dim", gesd_args={"alpha": alpha}
)
return [ch for ch, is_bad in zip(good_chans, mask) if is_bad]
# Wrapper functions
[docs]def run_osl_read_dataset(dataset, userargs):
"""Reads ``fif``/``npy``/``yml`` files associated with a dataset.
Parameters
----------
fif : str
Path to raw fif file (can be preprocessed).
preload : bool
Should we load the raw fif data?
ftype : str
Extension for the fif file (will be replaced for e.g. ``'_events.npy'`` or
``'_ica.fif'``). If ``None``, we assume the fif file is preprocessed with
osl-ephys and has the extension ``'_preproc-raw'``. If this fails, we guess
the extension as whatever comes after the last ``'_'``.
extra_keys : str
Space separated list of extra keys to read in from the same directory as the fif file.
If no suffix is provided, it's assumed to be .pkl. e.g., 'glm' will read in '..._glm.pkl'
'events.npy' will read in '..._events.npy'.
Returns
-------
dataset : dict
Contains keys: ``'raw'``, ``'events'``, ``'event_id'``, ``'epochs'``, ``'ica'``.
"""
logger.info("OSL Stage - {0}".format( "read_dataset"))
logger.info("userargs: {0}".format(str(userargs)))
ftype = userargs.pop("ftype", None)
extra_keys = userargs.pop("extra_keys", [])
fif = dataset['raw'].filenames[0]
# Guess extension
if ftype is None:
logger.info("Guessing the preproc extension")
if "preproc-raw" in fif:
logger.info('Assuming fif file type is "preproc-raw"')
ftype = "preproc-raw"
else:
if len(fif.split("_"))<2:
logger.error("Unable to guess the fif file extension")
else:
logger.info('Assuming fif file type is the last "_" separated string')
ftype = fif.split("_")[-1].split('.')[-2]
# add extension to fif file name
ftype = ftype + ".fif"
events = Path(fif.replace(ftype, "events.npy"))
if events.exists():
print("Reading", events)
events = np.load(events)
else:
events = None
event_id = Path(fif.replace(ftype, "event-id.yml"))
if event_id.exists():
print("Reading", event_id)
with open(event_id, "r") as file:
event_id = yaml.load(file, Loader=yaml.Loader)
else:
event_id = None
epochs = Path(fif.replace(ftype, "epo.fif"))
if epochs.exists():
print("Reading", epochs)
epochs = mne.read_epochs(epochs)
else:
epochs = None
ica = Path(fif.replace(ftype, "ica.fif"))
if ica.exists():
print("Reading", ica)
ica = mne.preprocessing.read_ica(ica)
else:
ica = None
dataset['event_id'] = event_id
dataset['events'] = events
dataset['ica'] = ica
dataset['epochs'] = epochs
if len(extra_keys)>0:
extra_keys = extra_keys.split(" ")
for key in extra_keys:
extra_file = Path(fif.replace(ftype, key))
key = key.split(".")[0]
if '.' not in extra_file.name:
extra_file = extra_file.with_suffix('.pkl')
if extra_file.exists():
print("Reading", extra_file)
if '.pkl' in extra_file.name:
with open(extra_file, 'rb') as outp:
dataset[key] = pickle.load(outp)
elif '.npy' in extra_file.name:
dataset[key] = np.load(extra_file)
elif '.yml' in extra_file.name:
with open(extra_file, 'r') as file:
dataset[key] = yaml.load(file, Loader=yaml.Loader)
return dataset
[docs]def run_osl_bad_segments(dataset, userargs):
"""osl-ephys Batch wrapper for :py:meth:`bad_segments <osl_ephys.preprocessing.osl_wrappers.bad_segments>`.
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``raw``.
userargs: dict
Dictionary of additional arguments to be passed to :py:meth:`bad_segments <osl_ephys.preprocessing.osl_wrappers.bad_segments>`.
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
target = userargs.pop("target", "raw")
logger.info("osl-ephys Stage - {0} : {1}".format(target, "bad_segments"))
logger.info("userargs: {0}".format(str(userargs)))
dataset["raw"] = bad_segments(dataset["raw"], **userargs)
return dataset
[docs]def run_osl_bad_channels(dataset, userargs):
"""osl-ephys Batch wrapper for :py:func:`bad_channels <osl_ephys.preprocessing.osl_wrappers.bad_channels>`.
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``raw``.
userargs: dict
Dictionary of additional arguments to be passed to :py:meth:`bad_channels <osl_ephys.preprocessing.osl_wrappers.bad_channels>`.
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
Notes
-----
Note that using 'picks' with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
"""
target = userargs.pop("target", "raw")
logger.info("osl-ephys Stage - {0} : {1}".format(target, "bad_channels"))
logger.info("userargs: {0}".format(str(userargs)))
dataset["raw"] = bad_channels(dataset["raw"], **userargs)
return dataset
[docs]def run_osl_drop_bad_epochs(dataset, userargs):
"""osl-ephys Batch wrapper for :py:meth:`drop_bad_epochs <osl_ephys.preprocessing.osl_wrappers.drop_bad_epochs>`.
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``raw``.
userargs: dict
Dictionary of additional arguments to be passed to :py:meth:`drop_bad_epochs <osl_ephys.preprocessing.osl_wrappers.drop_bad_epochs>`.
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
target = userargs.pop("target", "raw")
logger.info("osl-ephys Stage - {0} : {1}".format(target, "drop_bad_epochs"))
logger.info("userargs: {0}".format(str(userargs)))
if dataset["epochs"] is None:
logger.info("no epoch object found! skipping..")
dataset["epochs"] = drop_bad_epochs(dataset["epochs"], **userargs)
return dataset
#%% GLM wrappers
[docs]def run_osl_zscore_present_data(dataset, userargs):
"""
z-scoring parametric regressors, without NaNs
Nans will be zeros in the z-scored version
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments containing the keys ``keys``.
"""
keys = userargs.pop("keys", None)
# make sure keys is a single string or list of strings
if keys[0]=='[' and keys[-1]==']':
keys = keys[1:-1].split(' ')
for key in keys:
new = stats.zscore(dataset["covs"][key], nan_policy='omit')
new[np.isnan(dataset["covs"][key])] = 0
dataset["covs"][key] = new
return dataset
[docs]def run_osl_glm_add_regressor(dataset, userargs):
"""osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor <osl_ephys.preprocessing.osl_glm.add_regressor>`.
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments containing the keys ``keys``.
"""
logger.info("osl-ephys Stage - {0}".format("GLM Add Regressor"))
if 'design_config' not in dataset or not isinstance(dataset['design_config'], glmtools.design.DesignConfig):
dataset['design_config'] = glmtools.design.DesignConfig()
rtype = userargs.pop("rtype", None)
name = userargs.pop("name", None)
codes = userargs.pop("codes", None)
preproc = userargs.pop("preproc", None)
key = userargs.pop("key", None)
if rtype == 'Constant':
dataset['design_config'].add_regressor(name, rtype)
elif rtype == 'Categorical':
if codes == 'unique': # add a regressor for each unique value
codes = np.unique(dataset['covs'][key])
for code in codes:
dataset['design_config'].add_regressor(name=name + '_{0}'.format(code), rtype=rtype, codes=code)
else:
codes = [float(codes)
if np.logical_or(type(codes) == int, type(codes) == float)
else np.array(codes[0].split(" ")).astype(float)][0]
dataset['design_config'].add_regressor(name=name, rtype=rtype, codes=codes)
elif rtype == 'Parametric':
dataset['design_config'].add_regressor(name=name, rtype=rtype, datainfo=key, preproc=preproc)
elif rtype == 'MeanEffects':
dataset['design_config'].add_regressor(name=name + '_{0}',rtype=rtype, datainfo=key)
else:
raise ValueError("Unknown regressor type")
return dataset
[docs]def run_osl_glm_add_contrast(dataset, userargs):
"""osl-ephys Batch wrapper for :py:func:`osl_ephys.preprocessing.osl_glm.add_regressor <osl_ephys.preprocessing.osl_glm.add_regressor>`.
Parameters
"""
logger.info("osl-ephys Stage - {0}".format("GLM Add Contrast"))
simple = userargs.pop("simple", False)
name = userargs.pop("name", None)
values = userargs.pop("values", None)
key = userargs.pop("key", None)
if simple:
dataset['design_config'].add_simple_contrasts()
else:
if values == 'unique':
values = np.unique(dataset['covs'][key])
values={f"{key}_{v}": 1/len(values) for v in values}
else:
for key, value in values.items():
if isinstance(values[key], str):
values[key] = float(eval(value))
else:
values[key] = float(value)
dataset['design_config'].add_contrast(name=name, values=values)
return dataset
[docs]def run_osl_glm_fit(dataset, userargs):
""" wrapper for the different glm functions in the glm module
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
run_on_group = userargs.pop("run_on_group", False)
method = userargs.pop("method", None)
if method is None:
raise ValueError("method not specified")
target = userargs.pop("target", None)
if target is None:
if run_on_group:
target = "glm"
else:
if method in ['epochs', 'glm_epochs']:
target = "epochs"
elif method in ['spectrum', 'glm_spectrum']:
target = "raw"
name = userargs.pop("name", None)
if name is None:
if run_on_group:
name = "group_glm"
else:
name = "glm"
metric = userargs.pop("metric", 'copes')
plot_summary = userargs.pop("plot_summary", True)
plot_efficiency = userargs.pop("plot_efficiency", True)
plot_leverage = userargs.pop("plot_leverage", True)
if method == 'epochs' or method == 'glm_epochs':
baseline = userargs.pop("baseline", None)
if baseline is not None:
baseline = np.array(baseline[1:-1].split(" ")).astype(float)
if run_on_group:
dataset[name] = group_glm_epochs(dataset[target], dataset['design_config'], dataset['covs'], metric, baseline)
else:
dataset[name] = glm_epochs(dataset['design_config'], dataset[target])
elif method in ['spectrum', 'glm_spectrum', 'irasa', 'glm_irasa']:
if run_on_group:
dataset[name] = group_glm_spectrum(dataset[target], dataset['design_config'], dataset['covs'], metric, baseline)
else:
reg_categorical = userargs.pop("reg_categorical", None)
if reg_categorical[0]=='[' and reg_categorical[-1]==']':
reg_categorical = userargs["covs"][reg_categorical[1:-1].split(' ')]
else:
reg_categorical = userargs["covs"][reg_categorical]
reg_ztrans = userargs.pop("reg_ztrans", None)
if reg_ztrans[0]=='[' and reg_ztrans[-1]==']':
reg_ztrans = userargs["covs"][reg_ztrans[1:-1].split(' ')]
else:
reg_ztrans = userargs["covs"][reg_ztrans]
reg_unitmax = userargs.pop("reg_unitmax", None)
if reg_unitmax[0]=='[' and reg_unitmax[-1]==']':
reg_unitmax = userargs["covs"][reg_unitmax[1:-1].split(' ')]
else:
reg_unitmax = userargs["covs"][reg_unitmax]
if method in ['spectrum', 'glm_spectrum']:
dataset[name] = glm_spectrum(dataset[target], reg_unitmax=reg_unitmax, reg_ztrans=reg_ztrans, reg_categorical=reg_categorical, **userargs)
else:
dataset[name] = glm_irasa(dataset[target], reg_unitmax=reg_unitmax, reg_ztrans=reg_ztrans, reg_categorical=reg_categorical, **userargs)
if plot_summary:
dataset['fig'][name + 'design_summary'] = dataset[name].design.plot_summary(show=False)
if plot_efficiency:
dataset['fig'][name + 'design_efficiency'] = dataset[name].design.plot_efficiency(show=False)
if plot_leverage:
dataset['fig'][name + 'design_leverage'] = dataset[name].design.plot_leverage(show=False)
return dataset
[docs]def run_osl_glm_permutations(dataset, userargs):
""" wrapper for the different permutation options in the glm module
Parameters
----------
dataset: dict
Dictionary containing at least an MNE object with the key ``covs``.
userargs: dict
Dictionary of additional arguments
Returns
-------
dataset: dict
Input dictionary containing MNE objects that have been modified in place.
"""
run_on_group = userargs.pop("run_on_group", False)
target = userargs.pop("target", "group_glm")
name = userargs.pop("name", "group_glm_perm")
method = userargs.pop("method", None)
if method is None:
raise ValueError("method not specified")
type = userargs.pop("type", None)
if type is None:
raise ValueError("type not specified (e.g. 'max', 'cluster')")
thresh = userargs.pop("threshold", 95)
plot_sig = userargs.pop("plot_sig", True)
contrast = userargs.pop("contrast", None)
contrast = dataset[target].contrast_names.index(contrast)
fl_contrast = userargs.pop("fl_contrast", 0)
if fl_contrast != 0:
fl_contrast = dataset[target].fl_contrast_names.index(fl_contrast)
if type in ['max', 'maxstat']:
if method == 'epochs' or method == 'glm_epochs':
dataset[name] = SensorMaxStatPerm(dataset[target], contrast, fl_contrast, **userargs)
elif method == 'spectrum' or method == 'glm_spectrum':
dataset[name] = MaxStatPermuteGLMSpectrum(dataset[target], contrast, fl_contrast, **userargs)
elif type == 'cluster':
if method == 'epochs' or method == 'glm_epochs':
dataset[name] = SensorClusterPerm(dataset[target], contrast, fl_contrast, **userargs)
elif method == 'spectrum' or method == 'glm_spectrum':
dataset[name] = ClusterPermuteGLMSpectrum(dataset[target], contrast, fl_contrast, **userargs)
if plot_sig:
fig, ax = plt.subplots()
dataset[name].plot_sig_clusters(thresh, ax=ax)
dataset['fig'][name + 'sig' + str(thresh)] = fig
return dataset