Source code for osl_ephys.report.preproc_report

"""Reporting tool for preprocessing.

"""

# Authors: Andrew Quinn <a.quinn@bham.ac.uk>
#          Chetan Gohil <chetan.gohil@psych.ox.ac.uk>
#          Mats van Es <mats.vanes@psych.ox.ac.uk>

import os
import mne
import sys
import yaml
import sails
import argparse
import tempfile
import pickle
import pathlib
import re
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from jinja2 import Template
from tabulate import tabulate
from mne.channels.channels import channel_type
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from pathlib import Path
from copy import deepcopy
from glob import glob

try:
    from scipy.ndimage import uniform_filter1d
except ImportError:
    from scipy.ndimage.filters import uniform_filter1d

from ..utils import process_file_inputs, validate_outdir
from ..utils.logger import log_or_print
from ..source_recon.parcellation import plot_source_topo


# ----------------------------------------------------------------------------------
# Report generation
mne.viz.set_browser_backend("matplotlib")

[docs]def gen_report_from_fif(infiles, outdir, ftype=None, logsdir=None, run_id=None): """Generate web-report for a set of MNE data objects. Parameters ---------- infiles : list of str List of paths to fif files. outdir : str Directory to save HTML report and figures to. ftype : str Type of fif file, e.g., ``'raw'`` or ``'preproc-raw'``. logsdir : str Directory the log files were saved into. If None, log files are assumed to be in outdir.replace('report', 'logs') run_id : str Run ID. """ from ..preprocessing import read_dataset # Validate input files and directory to save html file and plots to infiles, outnames, good_files = process_file_inputs(infiles) outdir = validate_outdir(outdir) # Generate HTML data for infile in infiles: print("Generating report for", infile) dataset = read_dataset(infile, ftype=ftype) run_id = get_header_id(dataset['raw']) htmldatadir = validate_outdir(outdir / run_id) gen_html_data( dataset['raw'], htmldatadir, ica=dataset['ica'], logsdir=logsdir, run_id=run_id, ) # Create report gen_html_page(outdir) print("************" + "*" * len(str(outdir))) print(f"* REPORT: {outdir} *") print("************" + "*" * len(str(outdir)))
[docs]def get_header_id(raw): """Extract scan name from MNE data object. Parameters ---------- raw : mne.io.:py:class:`mne.io.Raw <mne.io.Raw>`. MNE Raw object. Returns ------- id : str Scan name. """ return str(raw.filenames[0]).split('/')[-1].strip('.fif')
[docs]def gen_html_data( raw, outdir, ica=None, events=None, event_id=None, preproc_fif_filename=None, logsdir=None, run_id=None, custom_figures=None ): """Generate HTML web-report for an MNE data object. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. outdir : string Directory to write HTML data and plots to. ica : :py:class:`mne.preprocessing.ICA <mne.preprocessing.ICA>` ICA object. events : array, shape (n_events, 3) The events. event_id : dict The event IDs. preproc_fif_filename : str Filename of file output by preprocessing logsdir : str Directory the log files were saved into. If None, log files are assumed to be in reportdir.replace('report', 'logs') run_id : str Run ID. custom_figures : dict Dictionary of custom figures to include in the report. Keys should be the figure names, and values should be the matplotlib figure objects. """ # Silence all matplotlib show() warnings warnings.filterwarnings("ignore", message=".*is non-interactive, and thus cannot be shown.*") warnings.filterwarnings("ignore", message="nperseg.*is greater than input length.*") warnings.filterwarnings("ignore", message=".*more than 20 mm from head frame origin.*") data = {} data['filename'] = raw.filenames[0] data['preproc_fif_filename'] = preproc_fif_filename data['fif_id'] = run_id or get_header_id(raw) # Scan info data['projname'] = raw.info['proj_name'] data['edataperimenter'] = raw.info['experimenter'] data['meas_date'] = raw.info['meas_date'].__str__() data['acq_samples'] = raw.n_times data['acq_sfreq'] = raw.info['sfreq'] data['acq_duration'] = raw.n_times/raw.info['sfreq'] # Channels/coils data['nchans'] = raw.info['nchan'] try: data['nhpi'] = len(raw.info['hpi_meas'][0]['hpi_coils']) except: log_or_print("No HPI info in fif file") chtype = [channel_type(raw.info, c) for c in range(data['nchans'])] chs, chcounts = np.unique(chtype, return_counts=True) data['chantable'] = tabulate(np.c_[chs, chcounts], tablefmt='html', headers=['Channel Type', 'Number Acquired']) # Head digitisation dig_codes = ('Cardinal', 'HPI', 'EEG', 'Extra') dig_counts = np.zeros((4,)) # Only run this if digitisation is available if raw.info['dig']: for ii in range(1, 5): dig_counts[ii-1] = np.sum([d['kind'] == ii for d in raw.info['dig']]) #d, dcounts = np.unique(digs, return_counts=True) data['digitable'] = tabulate(np.c_[dig_codes, dig_counts], tablefmt='html', headers=['Digitisation Stage', 'Points Acquired']) # Events stim_channel = mne.pick_types(raw.info, meg=False, ref_meg=False, stim=True) if len(stim_channel) > 0: ev = mne.find_events(raw, min_duration=5/raw.info['sfreq'], verbose=False) ev, evcounts = np.unique(ev[:, 2], return_counts=True) data['eventtable'] = tabulate(np.c_[ev, evcounts], tablefmt='html', headers=['Event Code', 'Value']) # Bad segments durs = np.array([r['duration'] for r in raw.annotations]) full_dur = raw.n_times/raw.info['sfreq'] types = [r['description'] for r in raw.annotations] data['bad_seg'] = [] for modality in ['grad', 'mag', 'eeg']: # set bad segs to mean bad_tc = np.zeros(raw.n_times) for aa in raw.annotations: if "bad_segment" in aa["description"] and aa["description"].find(modality) > 0: time_inds = np.where((raw.times >= aa["onset"]-raw.first_time) & (raw.times <= (aa["onset"] + aa["duration"] - raw.first_time)))[0] bad_tc[time_inds] = 1 mod_dur = np.sum(bad_tc)/raw.info['sfreq'] pc = np.sum(bad_tc) / len(bad_tc) * 100 s = 'Modality {0} - {1:.2f}/{2} seconds rejected ({3:.2f}%)' if full_dur > 0: data['bad_seg'].append(s.format(modality, mod_dur, full_dur, pc)) # For summary report: data['bad_seg_pc_' + modality] = pc # Bad channels bad_chans = raw.info['bads'] if len(bad_chans) == 0: try: # check if the bad channels were saved elsewhere because of interpolation data['bad_chans'] = 'Bad channels: {}'.format(', '.join(raw.info['temp']['bads'])) bad_chans = raw.info['temp']['bads'] raw.info['bads'] = bad_chans except: data['bad_chans'] = 'No bad channels.' else: data['bad_chans'] = 'Bad channels: {}'.format(', '.join(bad_chans)) # For summary report: data['bad_chan_num'] = len(bad_chans) # Path to save plots savebase = str(outdir / '{0}.png') # Generate plots for the report data["plt_config"] = plot_flowchart(raw, savebase) data["plt_rawdata"] = plot_rawdata(raw, savebase) data['plt_temporalsumsq'] = plot_channel_time_series(raw, savebase, exclude_bads=False) data['plt_temporalsumsq_exclude_bads'] = plot_channel_time_series(raw, savebase, exclude_bads=True) data['plt_badchans'] = plot_sensors(raw, savebase) data['plt_channeldev'] = plot_channel_dists(raw, savebase, exclude_bads=False) data['plt_channeldev_exclude_bads'] = plot_channel_dists(raw, savebase, exclude_bads=True) data['plt_spectra'], data['plt_zoomspectra'] = plot_spectra(raw, savebase) data['plt_freqbands'] = plot_freqbands(raw, savebase) data['plt_digitisation'] = plot_digitisation_2d(raw, savebase) data['plt_artefacts_eog'] = plot_eog_summary(raw, savebase) data['plt_artefacts_ecg'] = plot_ecg_summary(raw, savebase) data['plt_events'] = plot_events(events, event_id, savebase) # add custom figures data['plt_custom_figures'] = plot_custom_figures(custom_figures, savebase) #filenames = plot_artefact_scan(raw, savebase) #data.update(filenames) # add extra funcs if they exist extra_funcs = re.findall( "%% extra_funcs start %%(.*?)%% extra_funcs end %%", raw.info["description"], flags=re.DOTALL, ) if len(extra_funcs) > 0: data["extra_funcs"] = "" for func in extra_funcs: data["extra_funcs"] = data["extra_funcs"] + func # Add ICA if it's been passed if ica is not None: data['ica_ncomps_rej'] = len(ica.exclude) data['ica_ncomps_rej_ecg'] = [len(ica.labels_['ecg']) if 'ecg' in ica.labels_ else 'N/A'][0] data['ica_ncomps_rej_eog'] = [len(ica.labels_['eog']) if 'eog' in ica.labels_ else 'N/A'][0] if len(ica.exclude) > 0: data['plt_ica'] = plot_bad_ica(raw, ica, savebase) # add maxfilter info if possible # we have to guess the exact filename g=[] for ext in ['tsss', 'trans', 'transdef', 'autobad', 'autobad_sss']: g.append(glob(str(data['filename']).replace(f"{ext}.fif",'') + '*.log')) g = list(np.concatenate(g)) for ig in g: with open(ig, 'r') as log_file: if 'autobad' in ig: data['maxlog_autobad'] = log_file.read() elif 'sss' in ig: data['maxlog_sss'] = log_file.read() elif 'trans' in ig: data['maxlog_trans'] = log_file.read() # add log files if possible if logsdir is None: logsdir = outdir._str.replace('preproc_report', 'logs') elif type(logsdir)==pathlib.PosixPath: logsdir = os.path.join(logsdir._str, outdir._str.split('/')[-1]) if os.path.exists(logsdir + '.log'): with open(logsdir + '.log', 'r') as log_file: data['log'] = log_file.read() if os.path.exists(logsdir + '.error.log'): with open(logsdir + '.error.log', 'r') as log_file: data['errlog'] = log_file.read() # Save data that will be used to create html page with open(outdir / 'data.pkl', 'wb') as outfile: pickle.dump(data, outfile)
[docs]def gen_html_page(outdir, logsdir=None, custom_figures=None): """Generate an HTML page from a report directory. Parameters ---------- outdir : str Directory to generate HTML report with. Returns ------- success : bool Whether the report was successfully generated. """ outdir = pathlib.Path(outdir) # Subdirectories which contains plots for each fif file subdirs = sorted( [d.stem for d in pathlib.Path(outdir).iterdir() if d.is_dir()] ) # Load data for each fif file data = [] for subdir in subdirs: subdir = pathlib.Path(subdir) # Just generate the html page with the successful runs try: data.append(pickle.load(open(outdir / subdir / "data.pkl", "rb"))) except: pass total = len(data) if total == 0: return False # Add info to data indicating the total number of files # and an id for each file for i in range(total): data[i]["num"] = i + 1 data[i]["total"] = total # Create panels panels = [] panel_template = load_template('raw_subject_panel') for i in range(total): panels.append(panel_template.render(data=data[i])) # Hyperlink to each panel on the page filenames = "" for i in range(total): filename = pathlib.Path(data[i]["preproc_fif_filename"]).name.replace('_preproc-raw.fif', '') filenames += "{0}. {1}<br>".format(i + 1, filename) # Render the full page page_template = load_template('subject_report') page = page_template.render(panels=panels, filenames=filenames) # Write the output file outpath = pathlib.Path(outdir) / 'subject_report.html' with open(outpath, 'w') as f: f.write(page) return True
[docs]def gen_html_summary(reportdir, logsdir=None, custom_figures=None): """Generate an HTML summary from a report directory. Parameters ---------- reportdir : str Directory to generate HTML summary report with. logsdir : str Directory the log files were saved into. If None, log files are assumed to be in reportdir.replace('report', 'logs') """ reportdir = Path(reportdir) # Subdirectories which contains plots for each fif file subdirs = sorted( [d.stem for d in Path(reportdir).iterdir() if d.is_dir()] ) # Load HTML data subject_data = [] for subdir in subdirs: subdir = Path(subdir) # Just generate the html page with the successful runs try: subject_data.append( pickle.load(open(reportdir / subdir / "data.pkl", "rb")) ) except: pass total = len(subject_data) if total == 0: return False # Data used in the summary report data = {} data["total"] = total # Create plots os.makedirs(f"{reportdir}/summary", exist_ok=True) data["plt_config"] = subject_data[0]["plt_config"] existing_figs = {ig.split('/')[-1].split('custom_')[1].split('.png')[0]: ig.replace(str(reportdir) + '/', '') for ig in glob(f"{reportdir}/summary/*.png")} data['plt_custom_figures'] = plot_custom_figures(custom_figures, f"{reportdir}/summary/{{}}.png") if len(existing_figs)>0: if data['plt_custom_figures'] is None: data['plt_custom_figures'] = existing_figs else: data['plt_custom_figures'] |= existing_figs if "extra_funcs" in subject_data[0]: data["extra_funcs"] = subject_data[0]["extra_funcs"] data['tbl'] = gen_summary_data(subject_data) # log files if logsdir is None: logsdir = reportdir.__str__().replace('preproc_report', 'logs') elif type(logsdir)==pathlib.PosixPath: logsdir = logsdir.__str__() if os.path.exists(os.path.join(logsdir, 'batch_preproc.log')): with open(os.path.join(logsdir, 'batch_preproc.log'), 'r') as log_file: data['batchlog'] = log_file.read() g = glob(os.path.join(logsdir, '*.error.log')) if len(g)>0: data['errlog'] = {} for ig in sorted(g): with open(ig, 'r') as log_file: data['errlog'][ig.split('/')[-1].split('.error.log')[0]] = log_file.read() # Create panel panel_template = load_template('raw_summary_panel') panel = panel_template.render(data=data) # List of filenames filenames = "" for i in range(total): filename = Path(subject_data[i]["filename"]).name filenames += "{0}. {1}<br>".format(i + 1, filename) # Render the full page page_template = load_template('summary_report') page = page_template.render(panel=panel, filenames=filenames) # Write the output file outpath = Path(reportdir) / 'summary_report.html' with open(outpath, 'w') as f: f.write(page) return True
[docs]def gen_summary_data(subject_data): df = pd.DataFrame() column_headers = ['Session ID', 'Measurement month', 'Duration (s)', 'Bad segments (%) - mag', 'Bad segments (%) - grad)', 'Bad segments (%) - eeg)', 'Bad channels (#)', 'Bad channels (id)', 'Bad ICA (total)', 'Bad ICA (ECG)', 'Bad ICA (EOG)'] fields = ['fif_id', 'meas_date', 'acq_duration', 'bad_seg_pc_mag', 'bad_seg_pc_grad', 'bad_seg_pc_eeg', 'bad_chan_num', 'bad_chans', 'ica_ncomps_rej', 'ica_ncomps_rej_ecg', 'ica_ncomps_rej_eog'] for field, hdr in zip(fields, column_headers): if field == 'meas_date': # anonymise date (only year/month) df[hdr] = ["-".join(subject_data[i][field].split(' ')[0].split('-')[:-1]) if field in subject_data[i].keys() else None for i in range(len(subject_data))] else: df[hdr] = [subject_data[i][field] if field in subject_data[i].keys() else None for i in range(len(subject_data))] # remove columns that only contains None df = df.dropna(axis=1, how='all') df.index += 1 # start index at 1 tbl = df.to_html(classes="display", table_id="preproc_tbl") return tbl
[docs]def load_template(tname): """Load an HTML template from the templates directory. Parameters ---------- tname : str Name of the template to load. Returns ------- template : jinja2.Template The loaded template. """ basedir = os.path.dirname(os.path.realpath(__file__)) fname = os.path.join(basedir, 'templates', '{0}.html'.format(tname)) with open(fname, 'r') as file: template = Template(file.read()) return template
# ---------------------------------------------------------------------------------- # Scan stats and figures
[docs]def plot_flowchart(raw, savebase=None): """Plots preprocessing flowchart(s) Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ from ..preprocessing import get_config_from_fif, plot_preproc_flowchart # Get config info from raw.info['description'] config_list = get_config_from_fif(raw) # Number of subplots, i.e. the number of times osl-ephys preprocessing was applied nrows = len(config_list) if nrows == 0: # No preprocessing was applied return None # Plot flowchart in subplots fig, ax = plt.subplots(nrows=nrows, ncols=1, figsize=(12, 6*(nrows+1))) cnt=0 for config in config_list: if type(ax)==np.ndarray: axes=ax[cnt] title=f"osl-ephys Preprocessing Stage {cnt+1}" else: axes=ax title=None fig, axes = plot_preproc_flowchart(config, fig=fig, ax=axes, title=title) cnt=cnt+1 if savebase is not None: # Save figure figname = savebase.format('flowchart') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('flowchart') else: fpath = None return fpath
[docs]def plot_rawdata(raw, savebase): """Plots raw data. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ fig = raw.pick(['meg', 'eeg']).plot(n_channels=np.inf, duration=raw.times[-1], use_opengl=False, show=False) if savebase is not None: figname = savebase.format('rawdata') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('rawdata') else: fpath = None return fpath
[docs]def plot_channel_time_series(raw, savebase=None, exclude_bads=False): """Plots sum-square time courses. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. exclude_bads : bool Whether to exclude bad channels and bad segments. Returns ------- fpath : str Path to saved figure. """ if exclude_bads: # excludes bad channels and bad segments exclude = 'bads' else: # includes bad channels and bad segments exclude = [] is_ctf = raw.info["dev_ctf_t"] is not None if is_ctf: # Note that with CTF mne.pick_types will return: # ~274 axial grads (as magnetometers) if {picks: 'mag', ref_meg: False} # ~28 reference axial grads if {picks: 'grad'} channel_types = { 'Axial Grads (chtype=mag)': mne.pick_types( raw.info, meg='mag', ref_meg=False, exclude=exclude ), 'Ref Axial Grad (chtype=ref_meg)': mne.pick_types( raw.info, meg='grad', exclude=exclude ), 'EEG': mne.pick_types(raw.info, eeg=True), 'CSD': mne.pick_types(raw.info, csd=True), } else: channel_types = { 'Magnetometers': mne.pick_types(raw.info, meg='mag', exclude=exclude), 'Gradiometers': mne.pick_types(raw.info, meg='grad', exclude=exclude), 'EEG': mne.pick_types(raw.info, eeg=True), 'CSD': mne.pick_types(raw.info, csd=True), } t = raw.times x = raw.get_data() # Number of subplots, i.e. the number of different channel types in the fif file nrows = 0 for _, c in channel_types.items(): if len(c) > 0: nrows += 1 if nrows == 0: return None # Make sum-square plots fig, ax = plt.subplots(nrows=nrows, ncols=1, figsize=(16, 4)) if nrows == 1: ax = [ax] row = 0 for name, chan_inds in channel_types.items(): if len(chan_inds) == 0: continue ss = np.sum(x[chan_inds] ** 2, axis=0) # calculate ss value to give to bad segments for plotting purposes good_data = raw.get_data(picks=chan_inds, reject_by_annotation='NaN') # get indices of good data good_inds = np.where(~np.isnan(good_data[0,:]))[0] ss_bad_value = np.mean(ss[good_inds]) if exclude_bads: # set bad segs to mean for aa in raw.annotations: if "bad_segment" in aa["description"]: time_inds = np.where((raw.times >= aa["onset"]-raw.first_time) & (raw.times <= (aa["onset"] + aa["duration"] - raw.first_time)))[0] ss[time_inds] = ss_bad_value ss = uniform_filter1d(ss, int(raw.info['sfreq'])) ax[row].plot(t, ss) ax[row].legend([name], frameon=False, fontsize=16) ax[row].set_xlim(t[0], t[-1]) ylim = ax[row].get_ylim() for a in raw.annotations: if "bad_segment" in a["description"]: ax[row].axvspan( a["onset"] - raw.first_time, a["onset"] + a["duration"] - raw.first_time, color="red", alpha=0.8, ) row += 1 ax[0].set_title('Sum-Square Across Channels') ax[-1].set_xlabel('Time (seconds)') # Save if savebase is not None: plt.tight_layout() if exclude_bads: plot_name = 'temporal_sumsq_exclude_bads' else: plot_name = 'temporal_sumsq' figname = savebase.format(plot_name) fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format(plot_name) else: fpath = None return fpath
[docs]def plot_sensors(raw, savebase=None): """Plots sensors with bad channels highlighted. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ # skip if there is no channel position info if raw.get_montage() is None: return None # plot channel types seperately for neuromag306 (3 coils in same location) if 3012 in np.unique([i['coil_type'] for i in raw.info['chs']]): with open(str(Path(__file__).parent.parent) + "/utils/neuromag306_info.yml", 'r') as f: channels = yaml.safe_load(f) if 3024 in np.unique([i['coil_type'] for i in raw.info['chs']]): coil_types = ['mag', 'grad_longitude', 'grad_lattitude'] else: coil_types = ['grad_longitude', 'grad_lattitude'] fig, ax = plt.subplots(1,len(coil_types), figsize=(16,4)) for k in range(len(coil_types)): raw.copy().pick_channels(channels[coil_types[k]]).plot_sensors(axes=ax[k], show=False, verbose='ERROR') ax[k].set_title(f"{coil_types[k].replace('_', ' ')}") plt.tight_layout() else: fig, ax = plt.subplots(1, 3, figsize=(16, 4)) ax[0].axis('off') ax[2].axis('off') raw.plot_sensors(show=False, axes=ax[1]) if savebase is not None: figname = savebase.format('bad_chans') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('bad_chans') else: fpath = None return fpath
[docs]def plot_channel_dists(raw, savebase=None, exclude_bads=True): """Plot distributions of temporal standard deviation. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. exclude_bads : bool Whether to exclude bad channels and bad segments. Returns ------- fpath : str Path to saved figure. """ if exclude_bads: # excludes bad channels and bad segments exclude = 'bads' reject_by_annotation = 'omit' else: # includes bad channels and bad segments exclude = [] reject_by_annotation = None is_ctf = raw.info["dev_ctf_t"] is not None if is_ctf: # Note that with CTF mne.pick_types will return: # ~274 axial grads (as magnetometers) if {picks: 'mag', ref_meg: False} # ~28 reference axial grads if {picks: 'grad'} channel_types = { 'Axial Grads (chtype=''mag'')': mne.pick_types(raw.info, meg='mag', ref_meg=False, exclude=exclude), 'Ref Axial Grad (chtype=''ref_meg'')': mne.pick_types(raw.info, meg='grad', exclude=exclude), 'EEG': mne.pick_types(raw.info, eeg=True, exclude=exclude), 'CSD': mne.pick_types(raw.info, csd=True, exclude=exclude), } else: channel_types = { 'Magnetometers': mne.pick_types(raw.info, meg='mag', exclude=exclude), 'Gradiometers': mne.pick_types(raw.info, meg='grad', exclude=exclude), 'EEG': mne.pick_types(raw.info, eeg=True, exclude=exclude), 'CSD': mne.pick_types(raw.info, csd=True, exclude=exclude), } x = raw.get_data(reject_by_annotation=reject_by_annotation) # Number of subplots, i.e. the number of different channel types in the fif file ncols = 0 for _, c in channel_types.items(): if len(c) > 0: ncols += 1 if ncols == 0: return None # Make plots fig, ax = plt.subplots(nrows=1, ncols=ncols, figsize=(9, 3.5)) if ncols == 1: ax = [ax] row = 0 for name, chan_inds in channel_types.items(): if len(chan_inds) == 0: continue ax[row].hist(x[chan_inds, :].std(axis=1), bins=24, histtype='step') ax[row].legend(['Temporal Std-Dev'], frameon=False) ax[row].set_xlabel('Std-Dev') ax[row].set_ylabel('Channel Count') ax[row].set_title(name) row += 1 if exclude_bads: # excludes bad channels and bad segments plot_name = 'channel_dev_exclude_bads' else: # includes bad channels and bad segments plot_name = 'channel_dev' # Save if savebase is not None: plt.tight_layout() figname = savebase.format(plot_name) fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format(plot_name) else: fpath = None return fpath
[docs]def plot_spectra(raw, savebase=None): """Plot power spectra for each sensor modality. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath1 : str Path to saved figure (full spectra). fpath2 : str Path to saved figure (zoomed in spectra). """ is_ctf = raw.info["dev_ctf_t"] is not None if is_ctf: # Note that with CTF mne.pick_types will return: # ~274 axial grads (as magnetometers) if {picks: 'mag', ref_meg: False} # ~28 reference axial grads if {picks: 'grad'} channel_types = { 'Axial Grad (chtype=''mag'')': mne.pick_types(raw.info, meg='mag', ref_meg=False, exclude='bads'), 'EEG': mne.pick_types(raw.info, eeg=True, exclude='bads'), } else: channel_types = { 'Magnetometers': mne.pick_types(raw.info, meg='mag', exclude='bads'), 'Gradiometers': mne.pick_types(raw.info, meg='grad', exclude='bads'), 'EEG': mne.pick_types(raw.info, eeg=True, exclude='bads'), } # Number of subplots, i.e. the number of different channel types in the fif file nrows = 0 for _, c in sorted(channel_types.items()): if len(c) > 0: nrows += 1 if nrows == 0: return None fig, ax = plt.subplots(nrows=nrows, ncols=1, figsize=(8, 7)) if nrows == 1: ax = [ax] row = 0 for name, chan_inds in sorted(channel_types.items()): if len(chan_inds) == 0: continue # Plot spectra raw.compute_psd( picks=chan_inds, n_fft=int(raw.info['sfreq']*2), verbose=0).plot( axes=ax[row], show=False,) ax[row].set_title(name, fontsize=12) row += 1 # Save full spectra if savebase is not None: figname = savebase.format('spectra_full') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Make plots fig, ax = plt.subplots(nrows=nrows, ncols=1, figsize=(8, 7)) if nrows == 1: ax = [ax] row = 0 for name, chan_inds in sorted(channel_types.items()): if len(chan_inds) == 0: continue # Plot zoomed in spectra raw.compute_psd( picks=chan_inds, fmin=1, fmax=48, n_fft=int(raw.info['sfreq']*2), verbose=0).plot( axes=ax[row], show=False, ) ax[row].set_title(name, fontsize=12) row += 1 # Save zoomed in spectra if savebase is not None: figname = savebase.format('spectra_zoom') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return filenames savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath1 = filebase.format('spectra_full') fpath2 = filebase.format('spectra_zoom') else: fpath1 = None fpath2 = None return fpath1, fpath2
[docs]def plot_freqbands(raw, savebase=None): """Plot power spectra for each sensor modality. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- figname : str Path to saved figure. """ if 'dev_ctf_t' in raw.info.keys() and raw.info["dev_ctf_t"] is not None: is_ctf = True else: is_ctf = False is_parc = np.any(['parcel' in ch for ch in raw.info['ch_names']]) if is_ctf: # Note that with CTF mne.pick_types will return: # ~274 axial grads (as magnetometers) if {picks: 'mag', ref_meg: False} # ~28 reference axial grads if {picks: 'grad'} channel_types = { 'Axial Grad (chtype=''mag'')': mne.pick_types(raw.info, meg='mag', ref_meg=False, exclude='bads'), 'EEG': mne.pick_types(raw.info, eeg=True, exclude='bads'), } elif is_parc: channel_types = { 'Parcel': mne.pick_types(raw.info, misc=True, ref_meg=False, exclude='bads'), } else: channel_types = { 'Magnetometers': mne.pick_types(raw.info, meg='mag', exclude='bads'), 'Gradiometers': mne.pick_types(raw.info, meg='grad', exclude='bads'), 'EEG': mne.pick_types(raw.info, eeg=True, exclude='bads'), } # only keep those channel types that have sensor position information. for key, picks in list(channel_types.items()): try: if raw.copy().pick(picks).get_montage() is None: del channel_types[key] except Exception: del channel_types[key] # Number of subplots, i.e. the number of different channel types in the fif file nrows = 0 for _, c in sorted(channel_types.items()): if len(c) > 0: nrows += 1 if nrows == 0: return None freq_bands = [(1,4), (4,8), (8,13), (13,30), (30,80), (80,150)] ix = [i for i in range(len(freq_bands)) if np.logical_and(freq_bands[i][1] > raw.info['highpass'], freq_bands[i][0] < raw.info['lowpass'])] freq_bands = [freq_bands[i] for i in ix] freq_band_names = [['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma', 'High Gamma'][i] for i in ix] ncols = len(freq_bands) fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 0.5 + 1.5*nrows)) row = 0 for name, chan_inds in sorted(channel_types.items()): if len(chan_inds) == 0: continue # Plot spectra raw_zscore = deepcopy(raw).apply_function(lambda x: ((x - np.mean(x)) / np.std(x)), picks=chan_inds) psd = raw_zscore.compute_psd( picks=chan_inds, n_fft=int(raw.info['sfreq']*2), verbose=0) if nrows==1: iax = ax else: iax = ax[row] for ifrq, (f1, f2) in enumerate(freq_bands): inds = np.logical_and(psd.freqs>f1, psd.freqs<=(f2)) p = psd._data[:, inds].mean(axis=1) if is_parc: # normalize because nilearwn doesn't plot small values plot_source_topo(p/p.std(), axis=iax[ifrq], cmap='hot') else: mne.viz.plot_topomap(p, psd.info, axes=iax[ifrq], cmap='hot') if row==0: iax[ifrq].set_title(f"{freq_band_names[ifrq]}\n({f1}-{f2} Hz)") if ifrq==0: iax[ifrq].text(-0.1, 0.5, name, transform=iax[ifrq].transAxes, rotation=90, va='center', ha='center', fontsize=10) row += 1 # Save power bands spectra if savebase is not None: figname = savebase.format('freqbands') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('freqbands') return fpath else: return None
[docs]def plot_digitisation_2d(raw, savebase=None): """Plots the digitisation and headshape. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ # Return if no digitisation available if not raw.info['dig']: return None # Make plot fig = plt.figure(figsize=(16, 4)) plt.subplot(141) plt.gca().set_aspect('equal') for dp in raw.info['dig']: if dp['kind'] == 1: plt.plot(dp['r'][0], dp['r'][1], 'r^') if dp['kind'] == 2: plt.plot(dp['r'][0], dp['r'][1], 'mo') if dp['kind'] == 3: plt.plot(dp['r'][0], dp['r'][1], 'g+') if dp['kind'] == 4: plt.plot(dp['r'][0], dp['r'][1], 'b.') plt.title('Top View') plt.subplot(142) plt.gca().set_aspect('equal') for dp in raw.info['dig']: if dp['kind'] == 1: plt.plot(dp['r'][0], dp['r'][2], 'r^') if dp['kind'] == 2: plt.plot(dp['r'][0], dp['r'][2], 'mo') if dp['kind'] == 3: plt.plot(dp['r'][0], dp['r'][2], 'g+') if dp['kind'] == 4: plt.plot(dp['r'][0], dp['r'][2], 'b.') plt.title('Front View') plt.subplot(143) plt.gca().set_aspect('equal') for dp in raw.info['dig']: if dp['kind'] == 1: plt.plot(dp['r'][1], dp['r'][2], 'r^') if dp['kind'] == 2: plt.plot(dp['r'][1], dp['r'][2], 'mo') if dp['kind'] == 3: plt.plot(dp['r'][1], dp['r'][2], 'g*') if dp['kind'] == 4: plt.plot(dp['r'][1], dp['r'][2], 'b.') plt.title('Side View') plt.subplot(144, frameon=False) plt.xticks([]) plt.yticks([]) legend_elements = [Line2D([0], [0], marker='^', color='w', lw=4, label='Fiducial', markerfacecolor='r',markersize=14), Line2D([0], [0], marker='o', color='w', label='HPI', markerfacecolor='m', markersize=14), Line2D([0], [0], marker='*', color='w', label='EEG', markerfacecolor='g', markersize=14), Line2D([0], [0], marker='.', color='w', label='Headshape', markerfacecolor='b', markersize=14)] plt.legend(handles=legend_elements, loc='center', frameon=False) if savebase is not None: plt.tight_layout() figname = savebase.format('digitisation') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('digitisation') else: fpath = None return fpath
[docs]def plot_headmovement(raw, savebase=None): """Plot headmovement - WORK IN PROGRESS... seems v-slow atm""" chpi_amplitudes = mne.chpi.compute_chpi_amplitudes(raw) chpi_locs = mne.chpi.compute_chpi_locs(raw.info, chpi_amplitudes) head_pos = mne.chpi.compute_head_pos(raw.info, chpi_locs, verbose=False) fig = mne.viz.plot_head_positions(head_pos, mode='traces') if savebase is not None: figname = savebase.format('headpos') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig)
[docs]def plot_eog_summary(raw, savebase=None): """Plot raw EOG time series.""" # Get the raw EOG data chan_inds = mne.pick_types(raw.info, eog=True) if len(chan_inds) == 0: return None t = raw.times x = raw.get_data(chan_inds).T # Make the plot fig, ax = plt.subplots(figsize=(16, 2)) ax.plot(t, x) ax.set_xlim([t[0], t[-1]]) # Save if savebase is not None: plt.tight_layout() figname = savebase.format('EOG') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('EOG') else: fpath = None return fpath
[docs]def plot_ecg_summary(raw, savebase=None): """Plot ECG summary. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ # Get the raw ECG data chan_inds = mne.pick_types(raw.info, ecg=True) if len(chan_inds) == 0: return None t = raw.times x = raw.get_data(chan_inds).T # Make the plot fig, ax = plt.subplots(figsize=(16, 2)) ax.plot(t, x) ax.set_xlim([t[0], t[-1]]) # Save if savebase is not None: plt.tight_layout() figname = savebase.format('ECG') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('ECG') else: fpath = None return fpath
[docs]def plot_bad_ica(raw, ica, savebase): """Plot ICA characteristics for rejected components. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. ica : :py:class:`mne.preprocessing.ICA <mne.preprocessing.ICA>` MNE ICA object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ # Set MNE to only show critical messages mne.set_log_level('ERROR') exclude_uniq = np.sort(np.unique(ica.exclude))[::-1] nbad = len(exclude_uniq) # Handle the case when ica.labels_ and ica.exclude don't match if len(ica.labels_) != nbad: # Make a dummy labels_ dict ica.labels_ = {ica._ica_names[exc]: '' for exc in exclude_uniq} # Create figure fig = plt.figure(figsize=(16, 5 * np.max([nbad,1])), facecolor=[0.95] * 3) if nbad == 0: plt.subplot(111, frameon=False) plt.text(0.5, 0.5, 'No Components Rejected', ha='center', va='center') plt.xticks([]) plt.yticks([]) else: axes = [] for i in np.arange(nbad): lowerlimit = 0.1 + i / (nbad * 1.1) multiplier = nbad * 1.3 # Create axis for subplot # adapted from mne/viz/ica._create_properties_layout if 'mag' in raw.get_channel_types() and 'grad' in raw.get_channel_types(): topomap1_layout = [0.08, lowerlimit + 0.5 / multiplier, (0.3-0.08)/2+0.08, 0.45 / multiplier] topomap2_layout = [(0.3-0.08)/2+0.08, lowerlimit + 0.5 / multiplier, 0.3, 0.45 / multiplier] else: topomap1_layout = [0.08, lowerlimit + 0.5 / multiplier, 0.3, 0.45 / multiplier] axes_params = (('topomap', topomap1_layout), ('image', [0.5, lowerlimit + 0.6 / multiplier, 0.45, 0.35 / multiplier]), ('erp', [0.5, lowerlimit + 0.5 / multiplier, 0.45, 0.1 / multiplier]), ('spectrum', [0.08, lowerlimit + 0.1 / multiplier, 0.32, 0.3 / multiplier]), ('variance', [0.5, lowerlimit + 0.025 / multiplier, 0.45, 0.25 / multiplier])) axes += [[fig.add_axes(loc, label=name) for name, loc in axes_params]] ica.plot_properties(raw, picks=exclude_uniq[i], axes=axes[i], show=False, verbose=0) # Add another topo if we're dealing with two sensor types if 'mag' in raw.get_channel_types() and 'grad' in raw.get_channel_types(): ax2 = fig.add_axes(topomap2_layout, label='topomap2') kind, dropped_indices, epochs_src, data = mne.viz.ica._prepare_data_ica_properties( raw, ica, reject_by_annotation=True, reject='auto') mne.viz.ica._plot_ica_topomap(ica, exclude_uniq[i], ch_type='grad', axes=ax2, show=False) if np.any([x in ica.labels_.keys() for x in ica._ica_names]): # this is for the osl_plot_ica convention title = "".join((ica._ica_names[exclude_uniq[i]]," - ", ica.labels_[ica._ica_names[exclude_uniq[i]]].upper())) elif 'eog' in ica.labels_.keys() and 'ecg' in ica.labels_.keys(): # this is for the MNE automatic labelling convention flag_eog = exclude_uniq[i] in ica.labels_['eog'] flag_ecg = exclude_uniq[i] in ica.labels_['ecg'] title = "".join((ica._ica_names[exclude_uniq[i]]," - ", flag_eog*'EOG', flag_ecg*flag_eog*'/', flag_ecg*'ECG')) elif 'eog' in ica.labels_.keys(): flag_eog = exclude_uniq[i] in ica.labels_['eog'] title = "".join((ica._ica_names[exclude_uniq[i]]," - ", flag_eog*'EOG')) elif 'ecg' in ica.labels_.keys(): flag_ecg = exclude_uniq[i] in ica.labels_['ecg'] title = "".join((ica._ica_names[exclude_uniq[i]]," - ", flag_ecg*'ECG')) if 'mag' in raw.get_channel_types() and 'grad' in raw.get_channel_types(): title = title + '\n (mag)' ax2.set_title(title.replace('mag', 'grad')) else: # this is for if there is nothing in ica.labels_ title = None if title is not None: axes[i][0].set_title(title, fontsize=12) # check if the bads were excluded if np.isclose(ica._transform_raw(raw, 0, len(raw.times))[exclude_uniq[i],:].mean(), 0, atol=1e-9): fig.text(0.5, lowerlimit + 1 / multiplier, 'Excluded from preprocessed data', ha='center', va='center', fontsize=14, color='black') else: fig.text(0.5, lowerlimit + 1 / multiplier, 'Not yet excluded from data', ha='center', va='center', fontsize=14, color='red') if savebase is not None: figname = savebase.format('ica') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('ica') else: fpath = None return fpath
[docs]def plot_summary_bad_segs(subject_data, reportdir): """Plot summary of bad channels over subjects. Parameters ---------- subject_data : list list of data for each subject, typically generated by gen_html_data reportdir : str Path to report directory. We will save the plot in this directory. Returns ------- path : str Path to plot. """ output_file = reportdir / "summary/bad_segs.png" fig, ax = plt.subplots() for modality in ['grad', 'mag', 'eeg']: if 'bad_seg_pc_' + modality in subject_data[0].keys(): pc_segs = [] for data in subject_data: # check if key exists data if 'bad_seg_pc_' + modality in data.keys(): pc = data['bad_seg_pc_' + modality] pc_segs.append(pc) else: print('Session {} missing bad_seg_pc_{}'.format(data['filename'], modality)) pc_segs.append(0) ax.plot( range(1, len(subject_data) + 1), pc_segs, label=modality, ) ax.legend() ax.set_xlabel("Subject") ax.set_ylabel("Percentage time marked bad") fig.savefig(output_file, dpi=300) plt.close(fig) return f"summary/bad_segs.png"
[docs]def plot_summary_bad_chans(subject_data, reportdir): """Plot summary of bad channels over subjects. Parameters ---------- subject_data : list list of data for each subject, typically generated by gen_html_data reportdir : str Path to report directory. We will save the plot in this directory. Returns ------- path : str Path to plot. """ output_file = reportdir / "summary/bad_chans.png" fig, ax = plt.subplots() # For summary report: num_bads = [] for data in subject_data: num_bads.append(data['bad_chan_num']) ax.plot( range(1, len(num_bads) + 1), num_bads, ) ax.set_xlabel("Subject") ax.set_ylabel("Number of bad channels") fig.savefig(output_file, dpi=300) plt.close(fig) return f"summary/bad_chans.png"
[docs]def plot_artefact_scan(raw, savebase=None): """Plot artefact scan. Parameters ---------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` MNE Raw object. savebase : str Base string for saving figures. Returns ------- fpath : str Path to saved figure. """ events = mne.find_events(raw, min_duration=2/raw.info['sfreq']) modalities = ['mag', 'grad'] # Plot eye movements event_dict = {'moves': 1} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-1, tmax=16.5, preload=True) x = epochs.get_data() yl = np.abs(x[:,:2,:]).max() * 1.1 plt.figure(figsize=(16,5)) ax = plt.subplot(121) patches = [Rectangle((-1,-yl),1,yl*2, alpha=0.2, facecolor='grey'), Rectangle((0,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((2,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((4,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((6,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((8,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((10,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((12,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((14,-yl),2,yl*2, alpha=0.2, facecolor='grey')] for rec in patches: ax.add_artist(rec) plt.text(-0.5, yl, 'fix', ha='center') plt.text(1, yl, 'Up', ha='center') plt.text(5, yl, 'Right', ha='center') plt.text(9, yl, 'Down', ha='center') plt.text(13, yl, 'Left', ha='center') plt.plot(epochs.times, x[:,0,:].T) plt.ylim(-yl, yl*1.1) ax = plt.subplot(122) patches = [Rectangle((-1,-yl),1,yl*2, alpha=0.2, facecolor='grey'), Rectangle((0,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((2,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((4,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((6,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((8,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((10,-yl),2,yl*2, alpha=0.2, facecolor='grey'), Rectangle((12,-yl),2,yl*2, alpha=0.2, facecolor='r'), Rectangle((14,-yl),2,yl*2, alpha=0.2, facecolor='grey')] for rec in patches: ax.add_artist(rec) plt.text(-0.5, yl, 'fix', ha='center') plt.text(1, yl, 'Up', ha='center') plt.text(5, yl, 'Right', ha='center') plt.text(9, yl, 'Down', ha='center') plt.text(13, yl, 'Left', ha='center') plt.plot(epochs.times, x[:,1,:].T) plt.ylim(-yl, yl*1.1) plt.savefig(savebase.format('eyemove_eog'), dpi=300, transparent=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'eyemove_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot blinks event_dict = {'blinks': 5} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.2, tmax=2, preload=True) x = epochs.get_data() plt.figure(figsize=(16,4)) plt.subplot(121) plt.plot(epochs.times, x[:,0,:].T) plt.title(raw.info['ch_names'][0]) plt.subplot(122) plt.plot(epochs.times, x[:,1,:].T) plt.title(raw.info['ch_names'][1]) name = 'blink_eog'.format(m) plt.savefig(savebase.format(name), dpi=300, transparent=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'blink_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot swallow event_dict = {'swallow': 6} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.2, tmax=2, preload=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'swallow_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot breathe event_dict = {'breath': 7} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.5, tmax=5, preload=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'breathe_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot shrug event_dict = {'shrug': 8} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.5, tmax=3, preload=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'shrug_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot clench if len(np.where(events[:,2]==9)[0]) > 0: event_dict = {'clench': 9} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.5, tmax=3, preload=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'clench_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Plot button press event_dict = {'button_press': 257} epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.5, tmax=2, preload=True) ev = epochs.average() for m in modalities: fig = ev.plot_joint(show=False, picks=m) name = 'buttonpress_{0}'.format(m) fig.savefig(savebase.format(name), dpi=300, transparent=True) plt.close('all') # Return filenames savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name filenames = { 'plt_eyemove_grad': filebase.format('eyemove_grad'), 'plt_eyemove_mag': filebase.format('eyemove_mag'), 'plt_eyemove_eog': filebase.format('eyemove_eog'), 'plt_blink_grad': filebase.format('blink_grad'), 'plt_blink_mag': filebase.format('blink_mag'), 'plt_blink_eog': filebase.format('blink_eog'), 'plt_swallow_grad': filebase.format('swallow_grad'), 'plt_swallow_mag': filebase.format('swallow_mag'), 'plt_breathe_grad': filebase.format('breathe_grad'), 'plt_breathe_mag': filebase.format('breathe_mag'), 'plt_shrug_grad': filebase.format('shrug_grad'), 'plt_shrug_mag': filebase.format('shrug_mag'), 'plt_clench_grad': filebase.format('clench_grad'), 'plt_clench_mag': filebase.format('clench_mag'), 'plt_buttonpress_grad': filebase.format('buttonpress_grad'), 'plt_buttonpress_mag': filebase.format('buttonpress_mag'), } return filenames
[docs]def plot_events(events, event_id, savebase): if events is not None: fig = mne.viz.plot_events(events, event_id=event_id, show=False) figname = savebase.format('events') fig.savefig(figname, dpi=150, transparent=True) plt.close(fig) # Return the filename savebase = pathlib.Path(savebase) filebase = savebase.parent.name + "/" + savebase.name fpath = filebase.format('events') return fpath else: return None
[docs]def plot_custom_figures(custom_figures, savebase): plots = {} if custom_figures is not None and type(custom_figures)==dict: for key in custom_figures.keys(): if type(custom_figures[key])==plt.Figure: figname = savebase.format(f'custom_{key}') custom_figures[key].savefig(figname, dpi=150, transparent=True) plt.close(custom_figures[key]) # Return the filename savebase_path = pathlib.Path(savebase) filebase = savebase_path.parent.name + "/" + savebase_path.name plots[key] = filebase.format(f'custom_{key}') else: log_or_print(f"Custom figure {key} is not a matplotlib figure. Skipping.") if len(plots)==0: plots = None return plots