Source code for osl_ephys.preprocessing.plot_ica

"""Plotting functions for ICA.

"""

# Authors: Mats van Es <mats.vanes@psych.ox.ac.uk>
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from mne import channel_type

import matplotlib

# Force matplotlib to use an interactive backend:
[docs]backends_to_try = ['Qt5Agg', 'QtAgg', 'GTK3Agg', 'GTK4Agg', 'macosx', 'TkAgg', 'GTK3Cairo', 'GTK4Cairo', 'wxAgg' ]
for backend in backends_to_try: try: matplotlib.use(backend, force=True) break except ImportError: continue # Configure logging
[docs]logger = logging.getLogger(__name__)
[docs]def plot_ica( ica, inst, picks=None, start=None, stop=None, title=None, show=True, block=False, show_first_samp=False, show_scrollbars=True, time_format="float", n_channels=10, bad_labels_list=["eog", "ecg", "emg", "hardware", "other"], ): """osl-ephys' adaptation of MNE's :py:meth:`mne.preprocessing.ICA.plot_sources <mne.preprocessing.ICA.plot_sources>` function to plot estimated latent sources given the unmixing matrix. Typical usecases: 1. plot evolution of latent sources over time based on (Raw input) 2. plot latent source around event related time windows (Epochs input) 3. plot time-locking in ICA space (Evoked input) Parameters ---------- ica : :py:class:`mne.preprocessing.ICA <mne.preprocessing.ICA>`. The ICA solution. inst : :py:class:`mne.io.Raw <mne.io.Raw>`, :py:class:`mne.Epochs <mne.Epochs>`, or :py:class:`mne.Evoked <mne.Evoked>`. The object to plot the sources from. picks : str Channel types to pick. start, stop : float | int | None If ``inst`` is a :py:class:`mne.io.Raw <mne.io.Raw>` or an :py:class:`mne.Evoked <mne.Evoked>` object, the first and last time point (in seconds) of the data to plot. If ``inst`` is a :py:class:`mne.io.Raw <mne.io.Raw>` object, ``start=None`` and ``stop=None`` will be translated into ``start=0.`` and ``stop=3.``, respectively. For :py:class:`mne.Evoked <mne.Evoked>`, ``None`` refers to the beginning and end of the evoked signal. If ``inst`` is an :py:class:`mne.Epochs <mne.Epochs>` object, specifies the index of the first and last epoch to show. title : str | None The window title. If None a default is provided. show : bool Show figure if True. block : bool Whether to halt program execution until the figure is closed. Useful for interactive selection of components in raw and epoch plotter. For evoked, this parameter has no effect. Defaults to False. show_first_samp : bool If True, show time axis relative to the ``raw.first_samp``. n_channels : int Number of channels to show at the same time (default: 10) bad_labels_list : list of str list of bad labels to show in the bad labels list that can be used to mark the type of bad component. Defaults to ``["eog", "ecg", "emg", "hardware", "other"]``. Returns ------- fig : instance of Figure The figure. Notes ----- For raw and epoch instances, it is possible to select components for exclusion by clicking on the line. The selected components are added to ``ica.exclude`` on close. .. versionadded:: 0.10.0 """ from mne.io.base import BaseRaw from mne.io.pick import _picks_to_idx # OSL ADDITION from mne.evoked import Evoked from mne.epochs import BaseEpochs # silence warnings warnings.filterwarnings("ignore", message=".*more than 20 mm from head frame origin.*") warnings.filterwarnings("ignore", message=".*There are no gridspecs with layoutgrids. Possibly did not call parent GridSpec with the.*") warnings.filterwarnings("ignore", message=".*This figure was using a layout engine that is incompatible with subplots_adjust.*") exclude = ica.exclude picks = _picks_to_idx(ica.n_components_, picks, "all") if isinstance(inst, (BaseRaw, BaseEpochs)): fig = _plot_sources( ica, inst, picks, exclude, start=start, stop=stop, # OSL VERSION show=show, title=title, block=block, show_first_samp=show_first_samp, show_scrollbars=show_scrollbars, time_format=time_format, n_channels=n_channels, bad_labels_list=bad_labels_list, ) elif isinstance(inst, Evoked): if start is not None or stop is not None: inst = inst.copy().crop(start, stop) sources = ica.get_sources(inst) fig = _plot_ica_sources_evoked( evoked=sources, picks=picks, exclude=exclude, title=title, labels=getattr(ica, "labels_", None), show=show, ica=ica, n_channels=n_channels, bad_labels_list=bad_labels_list, ) else: raise ValueError("Data input must be of Raw or Epochs type") return fig
[docs]def _plot_sources( ica, inst, picks, exclude, start, stop, show, title, block, show_scrollbars, show_first_samp, time_format, n_channels, bad_labels_list, ): """Adaptation of MNE's `mne.preprocessing.ica._plot_sources` function to allow for OSL additions. """ """Plot the ICA components as a RawArray or EpochsArray.""" # from mne.viz._figure import _get_browser from mne.viz.utils import _compute_scalings, _make_event_color_dict, plt_show from mne import EpochsArray, BaseEpochs from mne.io import RawArray, BaseRaw from mne import create_info from mne import pick_types from mne.defaults import _handle_default # handle defaults / check arg validity is_raw = isinstance(inst, BaseRaw) is_epo = isinstance(inst, BaseEpochs) sfreq = inst.info["sfreq"] color = _handle_default("color", (0.0, 0.0, 0.0)) units = _handle_default("units", None) scalings = ( _compute_scalings(None, inst) if is_raw else _handle_default("scalings_plot_raw") ) scalings["misc"] = 5.0 scalings["whitened"] = 1.0 unit_scalings = _handle_default("scalings", None) # data if is_raw: data = ica._transform_raw(inst, 0, len(inst.times))[picks] else: data = ica._transform_epochs(inst, concatenate=True)[picks] # events if is_epo: event_id_rev = {v: k for k, v in inst.event_id.items()} event_nums = inst.events[:, 2] event_color_dict = _make_event_color_dict(None, inst.events, inst.event_id) # channel properties / trace order / picks ch_names = list(ica._ica_names) # copy ch_types = ["misc" for _ in picks] # add EOG/ECG channels if present eog_chs = pick_types(inst.info, meg=False, eog=True, ref_meg=False) extra_picks = pick_types(inst.info, meg=False, ecg=True, eog=True, ref_meg=False) for idx in extra_picks[::-1]: ch_names.insert(0, inst.ch_names[idx]) ch_types.insert(0, "eog" if idx in eog_chs else "ecg") if len(extra_picks): if is_raw: eog_ecg_data, _ = inst[extra_picks, :] else: eog_ecg_data = np.concatenate(inst.get_data(extra_picks), axis=1) data = np.append(eog_ecg_data, data, axis=0) picks = np.concatenate((picks, ica.n_components_ + np.arange(len(extra_picks)))) ch_order = np.arange(len(picks)) n_channels = min([n_channels, len(picks)]) ch_names_picked = [ch_names[x] for x in picks] # because we added channels to the beginning of the data, we need to adjust exclude: exclude = [x + len(extra_picks) for x in exclude if x in picks] # create info info = create_info(ch_names_picked, sfreq, ch_types=ch_types) with info._unlock(): info["meas_date"] = inst.info["meas_date"] info["bads"] = [ch_names[x] for x in exclude if x in picks] if is_raw: inst_array = RawArray(data, info, inst.first_samp) inst_array.set_annotations(inst.annotations) else: data = data.reshape(-1, len(inst), len(inst.times)).swapaxes(0, 1) inst_array = EpochsArray(data, info) # handle time dimension start = 0 if start is None else start _last = inst.times[-1] if is_raw else len(inst.events) stop = min(start + 20, _last) if stop is None else stop first_time = inst._first_time if show_first_samp else 0 if is_raw: duration = stop - start start += first_time else: n_epochs = stop - start total_epochs = len(inst) epoch_n_times = len(inst.times) n_epochs = min(n_epochs, total_epochs) n_times = total_epochs * epoch_n_times duration = n_epochs * epoch_n_times / sfreq event_times = ( np.arange(total_epochs) * epoch_n_times + inst.time_as_index(0) ) / sfreq # NB: this includes start and end of data: boundary_times = np.arange(total_epochs + 1) * epoch_n_times / sfreq if duration <= 0: raise RuntimeError("Stop must be larger than start.") # misc bad_color = "lightgray" title = "ICA components" if title is None else title # OSL ADDITION # define some colors for bad component labels import matplotlib.colors as mcolors c = list(mcolors.TABLEAU_COLORS.keys()) idx = [c.index(i) for i in c if "red" in i] for i in idx: del c[i] c = c[: len(bad_labels_list) + 1] # keep as many as required. params = dict( inst=inst_array, ica=ica, ica_inst=inst, info=info, # channels and channel order ch_names=np.array(ch_names_picked), ch_types=np.array(ch_types), ch_order=ch_order, picks=picks, n_channels=n_channels, picks_data=list(), bad_labels_list=bad_labels_list, # OSL ADDITION # time t_start=start if is_raw else boundary_times[start], duration=duration, n_times=inst.n_times if is_raw else n_times, first_time=first_time, time_format=time_format, decim=1, # events event_times=None if is_raw else event_times, # preprocessing projs=list(), projs_on=np.array([], dtype=bool), apply_proj=False, remove_dc=True, # for EOG/ECG filter_coefs=None, filter_bounds=None, noise_cov=None, # scalings scalings=scalings, units=units, unit_scalings=unit_scalings, # colors ch_color_bad=bad_color, ch_color_dict=color, bad_label_colors=c, # display butterfly=False, clipping=None, scrollbars_visible=show_scrollbars, scalebars_visible=False, window_title=title, ) if is_epo: params.update( n_epochs=n_epochs, boundary_times=boundary_times, event_id_rev=event_id_rev, event_color_dict=event_color_dict, event_nums=event_nums, epoch_color_bad=(1, 0, 0), epoch_colors=None, xlabel="Epoch number", ) fig = _get_browser(**params) fig.mne.ch_start = len(extra_picks) # this is necessary to make sure to plot the EOG/ECG only once fig._update_picks() # update data, and plot fig._update_trace_offsets() fig._update_data() fig._draw_traces() # OSL VERSION # plot annotations (if any) if is_raw: fig._setup_annotation_colors() fig._update_annotation_segments() fig._draw_annotations() plt_show(show, block=block) return fig
from mne.viz._mpl_figure import MNEBrowseFigure
[docs]backend = None
[docs]def _get_browser(**kwargs): """OSL Adaptation of MNE's `mne.viz._figure._get_browser` function that instantiate a new MNE browse-style figure. """ from mne.viz.utils import _get_figsize_from_config from mne.viz._figure import _init_browser_backend figsize = kwargs.setdefault("figsize", _get_figsize_from_config()) if figsize is None or np.any(np.array(figsize) < 8): kwargs["figsize"] = (8, 8) # Initialize browser backend _init_browser_backend() # Initialize Browser browser = _init_browser( backend, **kwargs ) # OSL ADDITION IN ORDER TO USE OSL'S FIGURE CLASS FROM _INIT_BROWSER return browser
[docs]def _init_browser(backend, **kwargs): # OSL ADDITION IN ORDER TO USE OSL'S FIGURE CLASS from mne.viz._mpl_figure import _figure """OSL's adaptation of MNE's `mne.viz._mpl_figure._init_browser` that instantiate a new MNE browse-style figure. """ fig = _figure(toolbar=False, FigureClass=osl_MNEBrowseFigure, **kwargs) # initialize zen mode # (can't do in __init__ due to get_position() calls) fig.canvas.draw() fig._update_zen_mode_offsets() fig._resize(None) # needed for MPL >=3.4 # if scrollbars are supposed to start hidden, # set to True and then toggle if not fig.mne.scrollbars_visible: fig.mne.scrollbars_visible = True fig._toggle_scrollbars() return fig
[docs]class osl_MNEBrowseFigure(MNEBrowseFigure): """OSL's adaptatation of MNE's `mne.viz._mpl_figure.MNEBrowseFigure` that creates an interactive figure with scrollbars, for data browsing.""" def __init__(self, inst, figsize, ica=None, xlabel='Time (s)', **kwargs): from matplotlib.colors import to_rgba_array from matplotlib.patches import Rectangle from matplotlib.ticker import (FixedFormatter, FixedLocator, FuncFormatter, NullFormatter) from matplotlib.transforms import blended_transform_factory from matplotlib.widgets import Button from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from mpl_toolkits.axes_grid1.axes_size import Fixed # # OSL IMPORTS from mne import pick_types from mne import BaseEpochs from mne.io import BaseRaw from mne.preprocessing import ICA from mne.viz._figure import BrowserBase from mne.viz._mpl_figure import MNEFigure, _patched_canvas import mne from functools import partial
[docs] self.backend_name = "matplotlib"
kwargs.update({"inst": inst, "figsize": figsize, "ica": ica, "xlabel": xlabel}) BrowserBase.__init__(self, **kwargs) MNEFigure.__init__(self, **kwargs) # hook up a mouse press event self.canvas.mpl_connect("button_press_event", self._on_mouse_press) # MAIN AXES: default sizes (inches) # XXX simpler with constrained_layout? (when it's no longer "beta") l_margin = 0.8#1.0 r_margin = 1.0#0.1 b_margin = 0.45 t_margin = 0.35 scroll_width = 0.25 hscroll_dist = 0.25 vscroll_dist = 0.1 help_width = scroll_width * 2 # MVE: ADD SIZES FOR TOPOS extra_chans = pick_types(inst.info, meg=False, eeg=False, ref_meg=False, eog=True, ecg=True, exclude=[]) exist_meg = any(ct in np.unique(ica.get_channel_types()) for ct in ['mag', 'grad']) exist_eeg = 'eeg' in np.unique(ica.get_channel_types()) n_topos = len( np.unique( [ channel_type(ica.info, ch) for ch in mne.pick_types(ica.info, meg=exist_meg, eeg=exist_eeg) ] ) ) topo_width_ratio = 8 + n_topos # 1 topo_dist = self._inch_to_rel(0.05) # 0.25 # MAIN AXES: default margins (figure-relative coordinates) # self.canvas.figure.clear() # clear axes (inherited from MNE) # TODO: Do we need this? left = self._inch_to_rel(l_margin - vscroll_dist - help_width) right = 1 - self._inch_to_rel(r_margin) bottom = self._inch_to_rel(b_margin, horiz=False) top = 1 - self._inch_to_rel(t_margin, horiz=False) height = top - bottom # OSL ADDITION: ADAPT SIZES OF TIME COURSE SUBPLOT AND ADD TOPO PLOT SIZE fullwidth = right - left width = (topo_width_ratio - n_topos) * ( fullwidth - n_topos * topo_dist ) / topo_width_ratio - ( self._inch_to_rel(hscroll_dist) + self._inch_to_rel(scroll_width) ) # width = right - left topo_width = (fullwidth - topo_dist) / topo_width_ratio topo_height = ( height - self._inch_to_rel(hscroll_dist + b_margin) ) / self.mne.n_channels - topo_dist position = [ left + n_topos * (topo_width + topo_dist), bottom, width, height, ] # position = [left, bottom, width, height] # Main axes must be a subplot for subplots_adjust to work (so user can # adjust margins). That's why we don't use the Divider class directly. ax_main = self.add_axes( position ) # OSL ADDITION USE ADD_AXES INSTEAD OF ADD_SUBPLOT # OSL ADDITION: CREATE TOPO AXES ax_topo = np.empty((n_topos, self.mne.n_channels), dtype=object) for i in np.arange(n_topos): for j in np.arange(self.mne.n_channels): topo_position = [ left + i * (topo_width + topo_dist), bottom + ((self.mne.n_channels) - j) * (topo_height + topo_dist)*1.03 - self._inch_to_rel(0.13), topo_width, topo_height, ] ax_topo[i, j] = self.add_axes(topo_position) ax_topo[i, j].set_axis_off() self.subplotpars.update(left=left, bottom=bottom, top=top, right=right) div = make_axes_locatable(ax_main) # this only gets shown in zen mode self.mne.zen_xlabel = ax_main.set_xlabel(xlabel) self.mne.zen_xlabel.set_visible(not self.mne.scrollbars_visible) # make sure background color of the axis is set if 'bgcolor' in kwargs: ax_main.set_facecolor(kwargs['bgcolor']) # OSL ADDITION: GET POSITIONS FOR BAD LABELS LIST self.mne.bad_labels_xpos = 1 - self._inch_to_rel(r_margin + 0.35) self.mne.bad_labels_ypos = [] for i in range(len(self.mne.bad_labels_list)+2): self.mne.bad_labels_ypos.append(1 - self._inch_to_rel(t_margin + 0.5 + 0.3*(i+1), horiz=False)) # SCROLLBARS ax_hscroll = div.append_axes( position="bottom", size=Fixed(scroll_width), pad=Fixed(hscroll_dist) ) ax_vscroll = div.append_axes( position="right", size=Fixed(scroll_width), pad=Fixed(vscroll_dist) ) ax_hscroll.get_yaxis().set_visible(False) ax_hscroll.set_xlabel(xlabel) ax_vscroll.set_axis_off() # HORIZONTAL SCROLLBAR PATCHES (FOR MARKING BAD EPOCHS) if self.mne.is_epochs: epoch_nums = self.mne.inst.selection for ix, _ in enumerate(epoch_nums): start = self.mne.boundary_times[ix] width = np.diff(self.mne.boundary_times[:2])[0] ax_hscroll.add_patch( Rectangle( (start, 0), width, 1, color="none", zorder=self.mne.zorder["patch"])) # both axes, major ticks: gridlines for _ax in (ax_main, ax_hscroll): _ax.xaxis.set_major_locator(FixedLocator(self.mne.boundary_times[1:-1])) _ax.xaxis.set_major_formatter(NullFormatter()) grid_kwargs = dict( color=self.mne.fgcolor, axis="x", zorder=self.mne.zorder["grid"] ) ax_main.grid(linewidth=2, linestyle="dashed", **grid_kwargs) ax_hscroll.grid(alpha=0.5, linewidth=0.5, linestyle="solid", **grid_kwargs) # main axes, minor ticks: ticklabel (epoch number) for every epoch ax_main.xaxis.set_minor_locator(FixedLocator(self.mne.midpoints)) ax_main.xaxis.set_minor_formatter(FixedFormatter(epoch_nums)) # hscroll axes, minor ticks: up to 20 ticklabels (epoch numbers) ax_hscroll.xaxis.set_minor_locator( FixedLocator(self.mne.midpoints, nbins=20) ) ax_hscroll.xaxis.set_minor_formatter( FuncFormatter(lambda x, pos: self._get_epoch_num_from_time(x)) ) # hide some ticks ax_main.tick_params(axis="x", which="major", bottom=False) ax_hscroll.tick_params(axis="x", which="both", bottom=False) else: # RAW / ICA X-AXIS TICK & LABEL FORMATTING # TODO: OSL NOT SURE IF THIS BREAKS WITH PLOTTING FUNCTING ax_main.xaxis.set_major_formatter( FuncFormatter(partial(self._xtick_formatter, ax_type="main")) ) ax_hscroll.xaxis.set_major_formatter( FuncFormatter(partial(self._xtick_formatter, ax_type="hscroll")) ) if self.mne.time_format != "float": for _ax in (ax_main, ax_hscroll): _ax.set_xlabel("Time (HH:MM:SS)") # VERTICAL SCROLLBAR PATCHES (COLORED BY CHANNEL TYPE) ch_order = self.mne.ch_order for ix, pick in enumerate(ch_order[len(extra_chans):]): this_color = ( self.mne.ch_color_bad if self.mne.ch_names[pick] in self.mne.info["bads"] else self.mne.ch_color_dict ) if isinstance(this_color, dict): this_color = this_color[self.mne.ch_types[pick]] ax_vscroll.add_patch( Rectangle( (0, ix), 1, 1, color=this_color, zorder=self.mne.zorder["patch"] ) ) ax_vscroll.set_ylim(len(ch_order) - len(extra_chans), 0) ax_vscroll.set_visible(not self.mne.butterfly) # SCROLLBAR VISIBLE SELECTION PATCHES sel_kwargs = dict( alpha=0.3, linewidth=4, clip_on=False, edgecolor=self.mne.fgcolor ) vsel_patch = Rectangle( (0, 0), 1, self.mne.n_channels - len(extra_chans), facecolor=self.mne.bgcolor, **sel_kwargs ) ax_vscroll.add_patch(vsel_patch) hsel_facecolor = np.average( np.vstack( (to_rgba_array(self.mne.fgcolor), to_rgba_array(self.mne.bgcolor)) ), axis=0, weights=(3, 1), ) # 75% foreground, 25% background hsel_patch = Rectangle( (self.mne.t_start, 0), self.mne.duration, 1, facecolor=hsel_facecolor, **sel_kwargs, ) ax_hscroll.add_patch(hsel_patch) ax_hscroll.set_xlim( self.mne.first_time, self.mne.first_time + self.mne.n_times / self.mne.info["sfreq"], ) # VLINE vline_color = (0.0, 0.75, 0.0) vline_kwargs = dict( visible=False, animated=True, zorder=self.mne.zorder["vline"] ) if self.mne.is_epochs: x = np.arange(self.mne.n_epochs) vline = ax_main.vlines(x, 0, 1, colors=vline_color, **vline_kwargs) vline.set_transform( blended_transform_factory(ax_main.transData, ax_main.transAxes) ) vline_hscroll = None else: vline = ax_main.axvline(0, color=vline_color, **vline_kwargs) vline_hscroll = ax_hscroll.axvline(0, color=vline_color, **vline_kwargs) vline_text = ax_hscroll.text( self.mne.first_time, 1.2, "", fontsize=10, ha="right", va="bottom", color=vline_color, **vline_kwargs, ) # HELP BUTTON: initialize in the wrong spot... ax_help = div.append_axes( position="left", size=Fixed(help_width), pad=Fixed(vscroll_dist) ) # HELP BUTTON: ...move it down by changing its locator loc = div.new_locator(nx=0, ny=0) ax_help.set_axes_locator(loc) # HELP BUTTON: make it a proper button with _patched_canvas(ax_help.figure): self.mne.button_help = Button(ax_help, "Help") # PROJ BUTTON ax_proj = None if len(self.mne.projs) and not inst.proj: proj_button_pos = [ 1 - self._inch_to_rel(r_margin + scroll_width), # left self._inch_to_rel(b_margin, horiz=False), # bottom self._inch_to_rel(scroll_width), # width self._inch_to_rel(scroll_width, horiz=False), # height ] loc = div.new_locator(nx=4, ny=0) ax_proj = self.add_axes(proj_button_pos) ax_proj.set_axes_locator(loc) with _patched_canvas(ax_help.figure): self.mne.button_proj = Button(ax_proj, "Prj") # INIT TRACES self.mne.trace_kwargs = dict(antialiased=True, linewidth=0.5) self.mne.traces = ax_main.plot( np.full((1, self.mne.n_channels), np.nan), **self.mne.trace_kwargs ) # MVE: INITIALLY THIS IS WHERE I INITIALIZED THE TOPOS. TURNS OUT ITS REDUNDANT BECAUSE IT IS TAKEN CARE OF IN # THE INTERACTIVE PART OF THE FIGURE. IT ALSO SOLVES THE EXTRA BONUS FIGURE # INIT TOPOS # NOTE: Commenting the next line out seems to not break the code, but to solve the bonus figure that is created # upon running the code. # self.plot_topos(ica, ax_topo, self.mne.picks[:self.mne.n_channels]) # SAVE UI ELEMENT HANDLES vars(self.mne).update( ax_main=ax_main, ax_help=ax_help, ax_proj=ax_proj, ax_hscroll=ax_hscroll, ax_vscroll=ax_vscroll, vsel_patch=vsel_patch, hsel_patch=hsel_patch, vline=vline, vline_hscroll=vline_hscroll, vline_text=vline_text, )
[docs] def _update_picks(self): """Compute which channel indices to show.""" n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg'])) if self.mne.butterfly and self.mne.ch_selections is not None: selections_dict = self._make_butterfly_selections_dict() self.mne.picks = np.concatenate(tuple(selections_dict.values())) elif self.mne.butterfly: self.mne.picks = self.mne.ch_order else: # this is replaced: # _slice = slice(self.mne.picks[n_extra_chans], # self.mne.picks[n_extra_chans] + self.mne.n_channels) # self.mne.picks = self.mne.ch_order[_slice] _slice = slice(self.mne.ch_start, self.mne.ch_start + self.mne.n_channels - n_extra_chans ) self.mne.picks = np.concatenate([np.arange(n_extra_chans), self.mne.ch_order[_slice]]) self.mne.n_channels = len(self.mne.picks) assert isinstance(self.mne.picks, np.ndarray) assert self.mne.picks.dtype.kind == 'i'
[docs] def _draw_traces(self): """Draw (or redraw) the channel data.""" from matplotlib.colors import to_rgba_array from matplotlib.patches import Rectangle # OSL ADDITION from mne import pick_types # clear scalebars if self.mne.scalebars_visible: self._hide_scalebars() # get info about currently visible channels picks = self.mne.picks ch_names = self.mne.ch_names[picks] ch_types = self.mne.ch_types[picks] offset_ixs = (picks if self.mne.butterfly and self.mne.ch_selections is None else slice(None)) offsets = self.mne.trace_offsets[offset_ixs] bad_bool = np.isin(ch_names, self.mne.info["bads"]) # OSL ADDITION bad_int = list(np.ones(len(picks))*-1) extra_chans = [picks[k] for k, ch_type in enumerate(ch_types) if ch_type == 'eog' or ch_type=='ecg'] for cnt, ch in enumerate([self.mne.ch_names[ii] for ii in picks]): if cnt < len(extra_chans): continue i = self.mne.ica._ica_names.index(ch) if ch in self.mne.info["bads"]: if len(list(self.mne.ica.labels_.values())) > 0 and i in np.concatenate(list(self.mne.ica.labels_.values())): i = int(i) ix = np.where([i in self.mne.ica.labels_[k] for k in self.mne.ica.labels_.keys()])[0][0] lbl = list(self.mne.ica.labels_.keys())[ix].split('/')[0] if lbl == 'unknown': bad_int[cnt] = int(0) else: bad_int[cnt] = int(self.mne.bad_labels_list.index(lbl) + 1) else: bad_int[cnt] = int(0) else: if len(list(self.mne.ica.labels_.values())) > 0 and i in np.concatenate(list(self.mne.ica.labels_.values())): # remove entry i = int(i) whichkeys = [list(self.mne.ica.labels_.keys())[k] for k in np.where([i in self.mne.ica.labels_[k] for k in self.mne.ica.labels_.keys()])[0]] for k in whichkeys: self.mne.ica.labels_[k] = list(np.setdiff1d(self.mne.ica.labels_[k], i)) bad_int[cnt] = -1 # colors good_ch_colors = [self.mne.ch_color_dict[_type] for _type in ch_types] c = [ self.mne.ch_color_bad ] + self.mne.bad_label_colors # OSL ADDITION: match colors to specific artifact labels ch_colors = to_rgba_array( [c[_bad] if _bad >= 0 else _color for _bad, _color in zip(bad_int, good_ch_colors)]) self.mne.ch_colors = np.array(good_ch_colors) # use for unmarking bads labels = self.mne.ax_main.yaxis.get_ticklabels() if self.mne.butterfly: for label in labels: label.set_color(self.mne.fgcolor) else: for label, color in zip(labels, ch_colors): label.set_color(color) # decim decim = np.ones_like(picks) data_picks_mask = np.in1d(picks, self.mne.picks_data) decim[data_picks_mask] = self.mne.decim # decim can vary by channel type, so compute different `times` vectors decim_times = { decim_value: self.mne.times[::decim_value] + self.mne.first_time for decim_value in set(decim) } # add more traces if needed n_picks = len(picks) if n_picks > len(self.mne.traces): n_new_chs = n_picks - len(self.mne.traces) new_traces = self.mne.ax_main.plot( np.full((1, n_new_chs), np.nan), **self.mne.trace_kwargs ) self.mne.traces.extend(new_traces) # remove extra traces if needed extra_traces = self.mne.traces[n_picks:] for trace in extra_traces: self.mne.ax_main.lines.remove(trace) self.mne.traces = self.mne.traces[:n_picks] # check for bad epochs time_range = (self.mne.times + self.mne.first_time)[[0, -1]] if self.mne.instance_type == "epochs": epoch_ix = np.searchsorted(self.mne.boundary_times, time_range) epoch_ix = np.arange(epoch_ix[0], epoch_ix[1]) epoch_nums = self.mne.inst.selection[epoch_ix[0] : epoch_ix[-1] + 1] visible_bad_epochs = epoch_nums[ np.in1d(epoch_nums, self.mne.bad_epochs).nonzero() ] while len(self.mne.epoch_traces): _trace = self.mne.epoch_traces.pop(-1) self.mne.ax_main.lines.remove(_trace) # handle custom epoch colors (for autoreject integration) if self.mne.epoch_colors is None: # shape: n_traces × RGBA → n_traces × n_epochs × RGBA custom_colors = np.tile( ch_colors[:, None, :], (1, self.mne.n_epochs, 1) ) else: custom_colors = np.empty((len(self.mne.picks), self.mne.n_epochs, 4)) for ii, _epoch_ix in enumerate(epoch_ix): this_colors = self.mne.epoch_colors[_epoch_ix] custom_colors[:, ii] = to_rgba_array( [this_colors[_ch] for _ch in picks] ) # override custom color on bad epochs for _bad in visible_bad_epochs: _ix = epoch_nums.tolist().index(_bad) _cols = np.array([self.mne.epoch_color_bad, self.mne.ch_color_bad])[ bad_bool.astype(int) ] custom_colors[:, _ix] = to_rgba_array(_cols) # update traces ylim = self.mne.ax_main.get_ylim() for ii, line in enumerate(self.mne.traces): this_name = ch_names[ii] this_type = ch_types[ii] this_offset = self.mne.trace_offsets[ii] this_times = decim_times[decim[ii]] this_data = this_offset - self.mne.data[ii] * self.mne.scale_factor this_data = this_data[..., :: decim[ii]] # clip if self.mne.clipping == "clamp": this_data = np.clip(this_data, -0.5, 0.5) elif self.mne.clipping is not None: clip = self.mne.clipping * (0.2 if self.mne.butterfly else 1) bottom = max(this_offset - clip, ylim[1]) height = min(2 * clip, ylim[0] - bottom) rect = Rectangle( xy=np.array([time_range[0], bottom]), width=time_range[1] - time_range[0], height=height, transform=self.mne.ax_main.transData, ) line.set_clip_path(rect) # prep z order is_bad_ch = this_name in self.mne.info["bads"] this_z = self.mne.zorder["bads" if is_bad_ch else "data"] if self.mne.butterfly and not is_bad_ch: this_z = self.mne.zorder.get(this_type, this_z) # plot each trace multiple times to get the desired epoch coloring. # use masked arrays to plot discontinuous epochs that have the same # color in a single plot() call. if self.mne.instance_type == "epochs": this_colors = custom_colors[ii] for cix, color in enumerate(np.unique(this_colors, axis=0)): bool_ixs = (this_colors == color).all(axis=1) mask = np.zeros_like(this_times, dtype=bool) _starts = self.mne.boundary_times[epoch_ix][bool_ixs] _stops = self.mne.boundary_times[epoch_ix + 1][bool_ixs] for _start, _stop in zip(_starts, _stops): _mask = np.logical_and(_start < this_times, this_times <= _stop) mask = mask | _mask _times = np.ma.masked_array(this_times, mask=~mask) # always use the existing traces first if cix == 0: line.set_xdata(_times) line.set_ydata(this_data) line.set_color(color) line.set_zorder(this_z) else: # make new traces as needed _trace = self.mne.ax_main.plot( _times, this_data, color=color, zorder=this_z, **self.mne.trace_kwargs, ) self.mne.epoch_traces.extend(_trace) else: line.set_xdata(this_times) line.set_ydata(this_data) line.set_color(ch_colors[ii]) line.set_zorder(this_z) # update xlim self.mne.ax_main.set_xlim(*time_range) # draw scalebars maybe if self.mne.scalebars_visible: self._show_scalebars() # redraw event lines if self.mne.event_times is not None: self._draw_event_lines() # OSL ADDITION: ADD TOPOS: exist_meg = any(ct in np.unique(self.mne.ica.get_channel_types()) for ct in ['mag', 'grad']) exist_eeg = 'eeg' in np.unique(self.mne.ica.get_channel_types()) n_topos = len(picks) n_chtype = len( np.unique( [ channel_type(self.mne.ica.info, ch) for ch in pick_types(self.mne.ica.info, meg=exist_meg, eeg=exist_eeg) ] ) ) ax_topo = np.reshape( self.get_axes()[1 : n_topos * n_chtype + 1], (n_chtype, n_topos) ) self.plot_topos(self.mne.ica, ax_topo, self.mne.picks) # OSL ADDITION: ADD BAD LABELS for i in range(len(self.mne.bad_labels_list)+2): if i == 0: plt.figtext(self.mne.bad_labels_xpos, self.mne.bad_labels_ypos[i], "bad component \ntype:", fontweight='bold') elif i == 1: plt.figtext(self.mne.bad_labels_xpos, self.mne.bad_labels_ypos[i], "unknown", color=self.mne.ch_color_bad, fontweight='semibold') else: plt.figtext(self.mne.bad_labels_xpos, self.mne.bad_labels_ypos[i], f'{i-1}: ' + self.mne.bad_labels_list[i - 2], color=self.mne.bad_label_colors[i - 2], fontweight='semibold') self._update_vscroll() # takes care of the vsel_patch, because it's too big when there's extra chans
# def plot_topos(self, ica, ax_topo, picks): # OSL ADDITION FOR TOPOS # import mne # from mne.viz.topomap import _plot_ica_topomap # extra_chans = [k for k, ch_type in enumerate(self.mne.ch_types[picks]) if ch_type == 'eog' or ch_type == 'ecg'] # exist_meg = any(ct in np.unique(ica.get_channel_types()) for ct in ['mag', 'grad']) # exist_eeg = 'eeg' in np.unique(ica.get_channel_types()) # n_topos = len(picks) # ica_tmp = ica.copy() # ica_tmp._ica_names = ["" for i in ica_tmp._ica_names] # nchans, ncomps = ica_tmp.get_components().shape # chtype = np.unique( # [ # channel_type(ica.info, ch) # for ch in mne.pick_types(ica.info, meg=exist_meg, eeg=exist_eeg) # ] # ) # n_chtype = len(chtype) # for i in range(n_chtype): # for j in range(n_topos): # if picks[j]<len(extra_chans): # ax_topo[i, j].clear() # ax_topo[i, j].set_axis_off() # else: # _plot_ica_topomap( # ica_tmp, # idx=picks[j]-len(extra_chans), # ch_type=chtype[i], # axes=ax_topo[i, j], # vmin=None, # vmax=None, # cmap="RdBu_r", # colorbar=False, # title=None, # show=True, # outlines="head", # contours=0, # image_interp="cubic", # res=64, # sensors=False, # allow_ref_meg=False, # sphere=None, # ) # if j==0: # ax_topo[i, j].set_title(f"{chtype[i]}") # else: # ax_topo[i, j].set_title('')
[docs] def plot_topos(self, ica, ax_topo, picks): # OSL: robust topo plotting import mne from mne.viz.topomap import _plot_ica_topomap # which channel types are topo-able? (skip EOG/ECG header rows) def _is_extra(idx): return self.mne.ch_types[idx] in ("eog", "ecg") exist_meg = any(ct in np.unique(ica.get_channel_types()) for ct in ("mag", "grad")) exist_eeg = "eeg" in np.unique(ica.get_channel_types()) chtypes = np.unique([ channel_type(ica.info, ch) for ch in mne.pick_types(ica.info, meg=exist_meg, eeg=exist_eeg) ]) n_chtype = len(chtypes) n_topos = len(picks) for i, ch_type in enumerate(chtypes): for j in range(n_topos): ax = ax_topo[i, j] ax.clear() ax.set_axis_off() row_idx = picks[j] # skip the EOG/ECG rows reliably if _is_extra(row_idx): if j == 0: ax.set_title(f"{ch_type}") continue # map displayed row -> component index via name ch_name = self.mne.ch_names[row_idx] try: comp_idx = self.mne.ica._ica_names.index(ch_name) except ValueError: # not an ICA component (defensive) if j == 0: ax.set_title(f"{ch_type}") continue _plot_ica_topomap( ica, idx=comp_idx, ch_type=ch_type, axes=ax, vmin=None, vmax=None, cmap="RdBu_r", colorbar=False, title=None, show=True, outlines="head", contours=0, image_interp="cubic", res=64, sensors=False, allow_ref_meg=False, sphere=None, ) if j == 0: ax.set_title(f"{ch_type}") else: ax.set_title("")
[docs] def _keypress(self, event): from mne.viz.utils import _events_off """Handle keypress events.""" key = event.key n_channels = self.mne.n_channels n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg'])) if self.mne.is_epochs: last_time = self.mne.n_times / self.mne.info["sfreq"] else: last_time = self.mne.inst.times[-1] # scroll up/down if key in ('down', 'up', 'shift+down', 'shift+up'): key = key.split('+')[-1] direction = -1 if key == 'up' else 1 # butterfly case if self.mne.butterfly: return # group_by case elif self.mne.fig_selection is not None: buttons = self.mne.fig_selection.mne.radio_ax.buttons labels = [label.get_text() for label in buttons.labels] current_label = buttons.value_selected current_idx = labels.index(current_label) selections_dict = self.mne.ch_selections penult = current_idx < (len(labels) - 1) pre_penult = current_idx < (len(labels) - 2) has_custom = selections_dict.get('Custom', None) is not None def_custom = len(selections_dict.get('Custom', list())) up_ok = key == 'up' and current_idx > 0 down_ok = key == 'down' and ( pre_penult or (penult and not has_custom) or (penult and has_custom and def_custom)) if up_ok or down_ok: buttons.set_active(current_idx + direction) # normal case else: ceiling = len(self.mne.ch_order) - (n_channels - n_extra_chans) ch_start = self.mne.ch_start + direction * (n_channels - n_extra_chans) self.mne.ch_start = np.clip(ch_start, n_extra_chans, ceiling) self._update_picks() self._update_vscroll() self._redraw() # scroll left/right elif key in ("right", "left", "shift+right", "shift+left"): old_t_start = self.mne.t_start direction = 1 if key.endswith("right") else -1 if self.mne.is_epochs: denom = 1 if key.startswith("shift") else self.mne.n_epochs else: denom = 1 if key.startswith("shift") else 4 t_max = last_time - self.mne.duration t_start = self.mne.t_start + direction * self.mne.duration / denom self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max) if self.mne.t_start != old_t_start: self._update_hscroll() self._redraw(annotations=True) # scale traces elif key in ("=", "+", "-"): scaler = 1 / 1.1 if key == "-" else 1.1 self.mne.scale_factor *= scaler self._redraw(update_data=False) # change number of visible channels elif ( key in ("pageup", "pagedown") and self.mne.fig_selection is None and not self.mne.butterfly ): new_n_ch = n_channels + (1 if key == "pageup" else -1) self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order)) # add new chs from above if we're at the bottom of the scrollbar ch_end = self.mne.ch_start + self.mne.n_channels if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0: self.mne.ch_start -= 1 self._update_vscroll() # redraw only if changed if self.mne.n_channels != n_channels: self._update_picks() self._update_trace_offsets() self._redraw(annotations=True) # change duration elif key in ("home", "end"): old_dur = self.mne.duration dur_delta = 1 if key == "end" else -1 if self.mne.is_epochs: # prevent from showing zero epochs, or more epochs than we have self.mne.n_epochs = np.clip( self.mne.n_epochs + dur_delta, 1, len(self.mne.inst) ) # use the length of one epoch as duration change min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"] new_dur = self.mne.duration + dur_delta * min_dur else: # never show fewer than 3 samples min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] # use multiplicative dur_delta dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5 new_dur = self.mne.duration * dur_delta self.mne.duration = np.clip(new_dur, min_dur, last_time) if self.mne.duration != old_dur: if self.mne.t_start + self.mne.duration > last_time: self.mne.t_start = last_time - self.mne.duration self._update_hscroll() self._redraw(annotations=True) elif key == "?": # help window self._toggle_help_fig(event) elif key == "a": # annotation mode self._toggle_annotation_fig() elif key == "b" and self.mne.instance_type != "ica": # butterfly mode self._toggle_butterfly() elif key == "d": # DC shift self.mne.remove_dc = not self.mne.remove_dc self._redraw() elif key == "h" and self.mne.instance_type == "epochs": # histogram self._toggle_epoch_histogram() elif key == "j" and len(self.mne.projs): # SSP window self._toggle_proj_fig() elif key == 'J' and len(self.mne.projs): self._toggle_proj_checkbox(event, toggle_all=True) elif key == "p": # toggle draggable annotations self._toggle_draggable_annotations(event) if self.mne.fig_annotation is not None: checkbox = self.mne.fig_annotation.mne.drag_checkbox with _events_off(checkbox): checkbox.set_active(0) elif key == "s": # scalebars self._toggle_scalebars(event) elif key == "w": # toggle noise cov whitening self._toggle_whitening() elif key == "z": # zen mode: hide scrollbars and buttons self._toggle_scrollbars() self._redraw(update_data=False) elif key == "t": self._toggle_time_format() # OSL ADDITION: labeling artifact type of bad components elif str(key).isnumeric() and ( int(key) in range(len(self.mne.bad_labels_list) + 1) ): if len(self.mne.info["bads"]) > 0 and self.mne.info["bads"][-1] in self.mne.ica._ica_names: last_bad_component = self.mne.ica._ica_names.index(self.mne.info["bads"][-1]) all_labels = list(self.mne.ica.labels_.keys()) # first remove from a the key it was in before, if applicable: if len(list(self.mne.ica.labels_.values())) > 0 and last_bad_component in list(self.mne.ica.labels_.values())[0]: ix = \ np.where([last_bad_component in self.mne.ica.labels_[k] for k in self.mne.ica.labels_.keys()])[0] for ixx in ix: lbl = all_labels[ix] self.mne.ica.labels_[lbl] = np.setdiff1d(self.mne.ica.labels_[lbl], last_bad_component) # create label based on label list and put it into MNE style tmp_label = self.mne.bad_labels_list[ int(key) - 1] if tmp_label == 'eog': tmp_label = tmp_label + '/3/manual' else: tmp_label = tmp_label + '/manual' # save bad component label in corresponding dict. if tmp_label in self.mne.ica.labels_ and len(self.mne.ica.labels_[tmp_label]) > 0: self.mne.ica.labels_[tmp_label].append(last_bad_component) else: self.mne.ica.labels_[tmp_label] = [last_bad_component] self._draw_traces() # This makes sure the traces are given the corresponding color right away else: # check for close key / fullscreen toggle super()._keypress(event)
[docs] def _on_mouse_press(self, event): """Handle mouse clicks for jumping the vertical scrollbar selection.""" # left click only if event.button != 1: return # only react to clicks inside the vertical scrollbar if event.inaxes is not self.mne.ax_vscroll: return # don't change anything in butterfly mode or when grouped selections UI is open if self.mne.butterfly or self.mne.fig_selection is not None: return # self._vscroll_go_to(event.ydata) event.key='down' self._keypress(event) # use same logic as keypress for consistency
# def _vscroll_go_to(self, y): # """Center the visible window on click y (data coords) with correct clamping.""" # import numpy as np # if y is None: # clicked outside axes # return # n_extra = int(np.sum([1 for _, t in enumerate(self.mne.ch_types) if t in ("eog", "ecg")])) # total_rows = len(self.mne.ch_order) # rows that can scroll # n_vis = self.mne.n_channels - n_extra # visible scrollable rows # if n_vis < 1 or total_rows <= n_extra: # return # # y is already in "row units" because ax_vscroll.set_ylim(total_rows-n_extra, 0). # # First clip the click to the drawable range: # y = float(np.clip(y, 0, total_rows - n_extra)) # # Work in centers, not bottoms. The selection patch center should match y. # # Clamp the desired center so the selection can still fully fit in range. # half = n_vis / 2.0 # desired_center = np.clip(y, half, (total_rows - n_extra) - half) # # Convert center -> start row (anchor = bottom), then re-add the extra header rows. # proposed_start = int(round(desired_center - half)) + n_extra # # Final guardrail identical to your keyboard logic. # ceiling = len(self.mne.ch_order) - n_vis # self.mne.ch_start = int(np.clip(proposed_start, n_extra, ceiling)) # # Apply # self._update_picks() # self._update_vscroll() # self._redraw()
[docs] def _update_vscroll(self): """Update the vertical scrollbar (channel) selection indicator.""" n_extra_chans = int(np.sum([1 for k, ch_type in enumerate(self.mne.ch_types) if ch_type == 'eog' or ch_type == 'ecg'])) ceiling = len(self.mne.ch_order) - (self.mne.n_channels - n_extra_chans) self.mne.ch_start = np.clip(self.mne.ch_start, n_extra_chans, ceiling) self.mne.vsel_patch.set_xy((0, self.mne.ch_start - n_extra_chans)) self.mne.vsel_patch.set_height(self.mne.n_channels - n_extra_chans) self._update_yaxis_labels()
[docs] def _close(self, event): # OSL VERSION - SIMILAR TO OLD MNE VERSION TODO: Check if we need to adopt this """Handle close events (via keypress or window [x]).""" from matplotlib.pyplot import close from mne.utils import set_config # write out bad epochs (after converting epoch numbers to indices) if self.mne.instance_type == "epochs": bad_ixs = np.in1d(self.mne.inst.selection, self.mne.bad_epochs).nonzero()[0] self.mne.inst.drop(bad_ixs) # write bad channels back to instance (don't do this for proj; # proj checkboxes are for viz only and shouldn't modify the instance) if self.mne.instance_type in ("raw", "epochs"): self.mne.inst.info["bads"] = self.mne.info["bads"] # OSL ADDITION # ICA excludes elif self.mne.instance_type == "ica": # remove artefact channels from exclude (if present) rm = [] for cnt, ch in enumerate(self.mne.info['bads']): if ch not in self.mne.ica._ica_names: rm.append(cnt) [self.mne.info['bads'].pop(i) for i in np.sort(rm)[::-1]] self.mne.ica.exclude = [ self.mne.ica._ica_names.index(ch) for ch in self.mne.info["bads"] ] # OSL ADDITION: remove bad component labels that were reversed to good component tmp = list(self.mne.ica.labels_.values())[:] try: tmp = np.unique(np.concatenate(tmp)) except: tmp = [] for ch in tmp: ch = int(ch) if ch not in self.mne.ica.exclude: # find in which label it has allix = np.where([ch in self.mne.ica.labels_[key] for key in self.mne.ica.labels_.keys()])[0] for ix in allix: self.mne.ica.labels_[list(self.mne.ica.labels_.keys())[ix]] = \ np.setdiff1d(self.mne.ica.labels_[list(self.mne.ica.labels_.keys())[ix]], ch) # label bad components without a manual label as "unknown" for ch in self.mne.ica.exclude: ch = int(ch) tmp = list(self.mne.ica.labels_.values()) if len(tmp)==0: tmp = [] else: tmp = np.concatenate(tmp) if ch not in tmp: if "unknown" not in self.mne.ica.labels_.keys(): self.mne.ica.labels_["unknown"] = [] self.mne.ica.labels_["unknown"] = list(self.mne.ica.labels_["unknown"]) self.mne.ica.labels_["unknown"].append(ch) # Add to labels_ a generic eog/ecg field if len(list(self.mne.ica.labels_.keys())) > 0: if "ecg" not in self.mne.ica.labels_: self.mne.ica.labels_["ecg"] = [] if "eog" not in self.mne.ica.labels_: self.mne.ica.labels_["eog"] = [] for key in self.mne.ica.labels_.keys(): self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key]) for key in self.mne.ica.labels_.keys(): self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key]) for k in list(self.mne.ica.labels_.keys()): if "ecg" in k.lower() and k.lower() != "ecg": tmp = self.mne.ica.labels_[k] if type(tmp) is list and tmp: tmp = tmp[0] self.mne.ica.labels_["ecg"].append(tmp) elif "eog" in k.lower() and k.lower() != "eog": tmp = self.mne.ica.labels_[k] if type(tmp) is list and tmp: tmp = tmp[0] self.mne.ica.labels_["eog"].append(tmp) # make sure that the labels are unique and not empty self.mne.ica.labels_["ecg"] = [int(v) for v in self.mne.ica.labels_["ecg"] if not isinstance(v, list)] self.mne.ica.labels_["eog"] = [int(v) for v in self.mne.ica.labels_["eog"] if not isinstance(v, list)] self.mne.ica.labels_["ecg"] = np.unique(self.mne.ica.labels_["ecg"]).tolist() self.mne.ica.labels_["eog"] = np.unique(self.mne.ica.labels_["eog"]).tolist() for key in self.mne.ica.labels_.keys(): self.mne.ica.labels_[key] = list(self.mne.ica.labels_[key]) # write logs logger.info(f"Components marked as bad: {sorted(self.mne.ica.exclude) or 'none'}") for lb in self.mne.ica.labels_.keys(): if 'manual' in lb or lb=='unknown': logger.info(f"Components manually labeled as '{lb.split('/')[0]}': {sorted(self.mne.ica.labels_[lb])}") # write window size to config size = ",".join(self.get_size_inches().astype(str)) set_config("MNE_BROWSE_RAW_SIZE", size, set_env=False) # Clean up child figures (don't pop(), child figs remove themselves) while len(self.mne.child_figs): fig = self.mne.child_figs[-1] close(fig)
[docs]def flatten_recursive(lst): """Flatten a list using recursion.""" for item in lst: if isinstance(item, list): yield from flatten_recursive(item) else: yield item
# TODO: OSL IMPLEMENT PLOT_ICA FOR EVOKED DATA
[docs]def _plot_ica_sources_evoked( evoked, picks, exclude, title, show, ica, labels=None, n_channels=10, bad_labels_list=None, ): """Plot average over epochs in ICA space. Parameters ---------- evoked : instance of mne.Evoked The Evoked to be used. %(picks_base)s all sources in the order as fitted. exclude : array-like of int The components marked for exclusion. If None (default), ICA.exclude will be used. title : str The figure title. show : bool Show figure if True. labels : None | dict The ICA labels attribute. """ raise ValueError("plot_ica is not yet supported for Evoked data") import matplotlib.pyplot as plt from matplotlib import patheffects if title is None: title = "Reconstructed latent sources, time-locked" fig, axes = plt.subplots(1) ax = axes axes = [axes] times = evoked.times * 1e3 # plot unclassified sources and label excluded ones lines = list() texts = list() picks = np.sort(picks) idxs = [picks] if labels is not None: labels_used = [k for k in labels if "/" not in k] exclude_labels = list() for ii in picks: if ii in exclude: line_label = ica._ica_names[ii] if labels is not None: annot = list() for this_label in labels_used: indices = labels[this_label] if ii in indices: annot.append(this_label) if annot: line_label += " – " + ", ".join(annot) # Unicode en-dash exclude_labels.append(line_label) else: exclude_labels.append(None) label_props = [("k", "-") if lb is None else ("r", "-") for lb in exclude_labels] styles = ["-", "--", ":", "-."] if labels is not None: # differentiate categories by linestyle and components by color col_lbs = [it for it in exclude_labels if it is not None] cmap = plt.get_cmap("tab10", len(col_lbs)) unique_labels = set() for label in exclude_labels: if label is None: continue elif " – " in label: unique_labels.add(label.split(" – ")[1]) else: unique_labels.add("") # Determine up to 4 different styles for n categories cat_styles = dict( zip( unique_labels, map( lambda ux: styles[int(ux % len(styles))], range(len(unique_labels)) ), ) ) for label_idx, label in enumerate(exclude_labels): if label is not None: color = cmap(col_lbs.index(label)) if " – " in label: label_name = label.split(" – ")[1] else: label_name = "" style = cat_styles[label_name] label_props[label_idx] = (color, style) for exc_label, ii in zip(exclude_labels, picks): color, style = label_props[ii] # ensure traces of excluded components are plotted on top zorder = 2 if exc_label is None else 10 lines.extend( ax.plot( times, evoked.data[ii].T, picker=True, zorder=zorder, color=color, linestyle=style, label=exc_label, ) ) lines[-1].set_pickradius(3.0) ax.set(title=title, xlim=times[[0, -1]], xlabel="Time (ms)", ylabel="(NA)") if len(exclude) > 0: plt.legend(loc="best") tight_layout(fig=fig) texts.append( ax.text( 0, 0, "", zorder=3, verticalalignment="baseline", horizontalalignment="left", fontweight="bold", alpha=0, ) ) # this is done to give the structure of a list of lists of a group of lines # in each subplot lines = [lines] ch_names = evoked.ch_names path_effects = [patheffects.withStroke(linewidth=2, foreground="w", alpha=0.75)] params = dict( axes=axes, texts=texts, lines=lines, idxs=idxs, ch_names=ch_names, need_draw=False, path_effects=path_effects, ) fig.canvas.mpl_connect("pick_event", partial(_butterfly_onpick, params=params)) fig.canvas.mpl_connect( "button_press_event", partial(_butterfly_on_button_press, params=params) ) plt_show(show) return fig