Source code for osl_ephys.preprocessing.batch

#!/usr/bin/env python

"""Tools for batch preprocessing.

"""

# Authors: Andrew Quinn <a.quinn@bham.ac.uk>
#          Chetan Gohil <chetan.gohil@psych.ox.ac.uk>
#          Mats van Es  <mats.vanes@psych.ox.ac.uk>

import argparse
import matplotlib
import matplotlib.pyplot as plt
import os
import sys
import pprint
import traceback
import re
import logging
import pickle
from pathlib import Path
from copy import deepcopy
from functools import partial, wraps
from time import localtime, strftime
from datetime import datetime
import inspect

import mne
import numpy as np
import yaml

from . import mne_wrappers, osl_wrappers
from ..utils import find_run_id, validate_outdir, process_file_inputs
from ..utils import logger as osl_logger
from ..utils.parallel import dask_parallel_bag
from ..utils.version_utils import check_version
from ..utils.misc import set_random_seed

[docs]logger = logging.getLogger(__name__)
# -------------------------------------------------------------- # Decorators # -------------------------------------------------------------- # Data importers
[docs]def import_data(infile, preload=True): """Imports data from a file. Parameters ---------- infile : str Path to file to read. File can be bti, fif, ds, meg4 or vhdr. preload : bool Should we load the data in the file? Returns ------- raw : :py:class:`mne.io.Raw <mne.io.Raw>` Data as an MNE Raw object. """ if not isinstance(infile, str): raise ValueError( "infile must be a str. Got type(infile)={0}.".format(type(infile)) ) if " " in infile: raise ValueError("filename cannot contain spaces.") logger.info("IMPORTING: {0}".format(infile)) # BTI scan if os.path.split(infile)[1] == "c,rfDC": logger.info("Detected BTI file format, using: mne.io.read_raw_bti") if os.path.isfile(os.path.join(os.path.split(infile)[0], "hs_file")): head_shape_fname = "hs_file" else: head_shape_fname = None raw = mne.io.read_raw_bti(infile, head_shape_fname=head_shape_fname, preload=preload) # FIF file elif os.path.splitext(infile)[1] == ".fif": logger.info("Detected fif file format, using: mne.io.read_raw_fif") raw = mne.io.read_raw_fif(infile, preload=preload) # EDF file elif os.path.splitext(infile)[1].lower() == ".edf": logger.info("Detected edf file format, using: mne.io.read_raw_edf") raw = mne.io.read_raw_edf(infile, preload=preload) # CTF data in ds directory elif os.path.splitext(infile)[1] == ".ds": logger.info("Detected CTF file format, using: mne.io.read_raw_ctf") raw = mne.io.read_raw_ctf(infile, preload=preload) elif os.path.splitext(infile)[1] == ".meg4": logger.info("Detected CTF file format, using: mne.io.read_raw_ctf") raw = mne.io.read_raw_ctf(os.path.dirname(infile), preload=preload) # Brainvision elif os.path.splitext(infile)[1] == ".vhdr": logger.info("Detected brainvision file format, using: mne.io.read_raw_brainvision") raw = mne.io.read_raw_brainvision(infile, preload=preload) # EEGLAB .set elif os.path.splitext(infile)[1] == ".set": logger.info("Detected EEGLAB file format, using: mne.io.read_raw_eeglab") raw = mne.io.read_raw_eeglab(infile, preload=preload) elif os.path.splitext(infile)[1] == ".con" or os.path.splitext(infile)[1] == ".sqd": logger.info("Detected Ricoh/KIT file format, using: mne.io.read_raw_kit") raw = mne.io.read_raw_kit(infile, preload=preload) elif os.path.splitext(infile)[1] == ".bdf": logger.info("Detected BDF file format, using: mne.io.read_raw_bdf") raw = mne.io.read_raw_bdf(infile, preload=preload) elif os.path.splitext(infile)[1] == ".mff": logger.info("Detected EGI file format, using mne.io.read_raw_egi") raw = mne.io.read_raw_egi(infile, preload=preload) # Curry elif os.path.splitext(infile)[1] in [".dat", ".dap", ".rs3", ".cdt", ".cdt.dpa", ".cdt.cef", ".cef"]: logger.info("Detected Curry file format, using mne.io.read_raw_curry") raw = mne.io.read_raw_curry(infile, preload=preload) # Other formats not accepted else: try: logger.info("Trying to automatically detect file type") raw = mne.io.read_raw(infile, preload=preload) except: msg = "Unable to determine file type of input {0}".format(infile) logger.error(msg) raise ValueError(msg) return raw
# -------------------------------------------------------------- # Batch processing utilities
[docs]def find_func(method, target="raw", extra_funcs=None): """Find a preprocessing function. Function priority: 1. User custom function 2. MNE/osl-ephys wrapper 3. MNE method on Raw or Epochs (specified by target) Parameters ---------- method : str Function name. target : str Type of MNE object to preprocess. Can be ``'raw'``, ``'epochs'``, ``'evoked'``, ``'power'`` or ``'itc'``. extra_funcs : list List of user-defined functions. Returns ------- function Function to preprocess an MNE object. """ func = None # 1) user custom function if extra_funcs is not None: func_ind = [ idx if (f.__name__ == method) else -1 for idx, f in enumerate(extra_funcs) ] if np.max(func_ind) > -1: func = extra_funcs[np.argmax(func_ind)] func = print_custom_func_info(func) # 2) MNE/osl-ephys Wrapper # Find osl-ephys function in local module if func is None and hasattr(osl_wrappers, "run_osl_{}".format(method)): func = getattr(osl_wrappers, "run_osl_{}".format(method)) # Find MNE function in local module if func is None and hasattr(mne_wrappers, "run_mne_{}".format(method)): func = getattr(mne_wrappers, "run_mne_{}".format(method)) # 3) MNE direct method if func is None: if target == "raw": if hasattr(mne.io.Raw, method) and callable(getattr(mne.io.Raw, method)): func = partial(mne_wrappers.run_mne_anonymous, method=method) elif target == "epochs": if hasattr(mne.Epochs, method) and callable(getattr(mne.Epochs, method)): func = partial(mne_wrappers.run_mne_anonymous, method=method) elif target in ("power", "itc"): if hasattr(mne.time_frequency.EpochsTFR, method) and callable( getattr(mne.time_frequency.EpochsTFR, method) ): func = partial(mne_wrappers.run_mne_anonymous, method=method) if func is None: logger.critical("Function not found! {}".format(method)) return func
[docs]def load_config(config): """Load config. Parameters ---------- config : str or dict Path to yaml file or string to convert to dict or a dict. Returns ------- dict Preprocessing config. """ if type(config) not in [str, dict]: raise ValueError("config must be a str or dict, got {}.".format(type(config))) if isinstance(config, str): try: # See if we have a filepath with open(config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) except (UnicodeDecodeError, FileNotFoundError, OSError): # We have a string config = yaml.load(config, Loader=yaml.FullLoader) # do some checks on the config for key in config: if config[key] == 'None': config[key] = None # Initialise missing values in config if "meta" not in config: config["meta"] = {"event_codes": None} elif "event_codes" not in config["meta"]: config["meta"]["event_codes"] = None elif "versions" not in config['meta']: config["meta"]["versions"] = None if "preproc" not in config and "group" not in config: raise KeyError("Please specify preprocessing and/or group processing steps in config.") if "preproc" in config and config["preproc"] is not None: for stage in config["preproc"]: # Check each stage is a dictionary with a single key if not isinstance(stage, dict): raise ValueError( "Preprocessing stage '{0}' is a {1} not a dict".format( stage, type(stage) ) ) if len(stage) != 1: raise ValueError( "Preprocessing stage '{0}' should only have a single key".format(stage) ) for key, val in stage.items(): # internally we want options to be an empty dict (for now at least) if val in ["null", "None", None]: stage[key] = {} for step in config["preproc"]: if config["meta"]["event_codes"] is None and "find_events" in step.values(): raise KeyError( "event_codes must be passed in config if we are finding events." ) else: config['preproc'] = None if "group" in config and config["group"] is not None: for stage in config["group"]: # Check each stage is a dictionary with a single key if not isinstance(stage, dict): raise ValueError( "Group processing stage '{0}' is a {1} not a dict".format( stage, type(stage) ) ) if len(stage) != 1: raise ValueError( "Group processing stage '{0}' should only have a single key".format(stage) ) for key, val in stage.items(): # internally we want options to be an empty dict (for now at least) if val in ["null", "None", None]: stage[key] = {} else: config['group'] = None return config
[docs]def check_config_versions(config): """Get config from a preprocessed fif file. Parameters ---------- config : dictionary or yaml string Preprocessing configuration to check. Raises ------ AssertionError Raised if package version mismatch found in 'version_assert' Warning Raised if package version mismatch found in 'version_warn' """ config = load_config(config) # Check for version and raise an error if mismatch found if 'version_assert' in config['meta']: for vers in config['meta']['version_assert']: check_version(vers, mode='assert') # Check for version and raise a warning if mismatch found if 'version_warn' in config['meta']: for vers in config['meta']['version_warn']: check_version(vers, mode='warn')
[docs]def get_config_from_fif(inst): """Get config from a preprocessed fif file. Reads the ``inst.info['description']`` field of a fif file to get the preprocessing config. Parameters ---------- inst : :py:class:`mne.io.Raw <mne.io.Raw>`, :py:class:`mne.Epochs <mne.Epochs>`, :py:class:`mne.Evoked <mne.Evoked>` Preprocessed MNE object. Returns ------- dict Preprocessing config. """ config_list = re.findall( "%% config start %%(.*?)%% config end %%", inst.info["description"], flags=re.DOTALL, ) config = [] for config_text in config_list: config.append(load_config(config_text)) return config
[docs]def append_preproc_info(dataset, config, extra_funcs=None): """Add to the config of already preprocessed data to ``inst.info['description']``. Parameters ---------- dataset : dict Preprocessed dataset. config : dict Preprocessing config. Returns ------- dict Dataset dict containing the preprocessed data edited in place. """ from .. import __version__ # here to avoid circular import if dataset["raw"].info["description"] == None: dataset["raw"].info["description"] = "" preproc_info = ( "\n\nOSL-EPHYS BATCH PROCESSING APPLIED ON " + f"{datetime.today().strftime('%d/%m/%Y %H:%M:%S')} \n" + f"VERSION: {__version__}\n" + f"%% config start %% \n{config} \n%% config end %%" ) if extra_funcs is not None: preproc_info += "\n\nCUSTOM FUNCTIONS USED:\n" for func in extra_funcs: preproc_info += f"%% extra_funcs start %% \n{inspect.getsource(func)}\n%% extra_funcs end %%" dataset["raw"].info["description"] = ( dataset["raw"].info["description"] + preproc_info ) if dataset["epochs"] is not None: if dataset["epochs"].info["description"] == None: dataset["epochs"].info["description"] = "" dataset["epochs"].info["description"] = ( dataset["epochs"].info["description"] + preproc_info ) return dataset
[docs]def write_dataset(dataset, outbase, run_id, ftype='preproc-raw', overwrite=False, skip=None): """Write preprocessed data to a file. Will write all keys in the dataset dict to disk with corresponding extensions. Parameters ---------- dataset : dict Preprocessed dataset. outbase : str Path to directory to write to. run_id : str ID for the output file. ftype: str Extension for the fif file (default ``preproc-raw``) overwrite : bool Should we overwrite if the file already exists? skip : list or None List of keys to skip writing to disk. If None, we don't skip any keys. Output ------ fif_outname : str The saved fif file name """ if skip is None: skip = [] else: [logger.info("Skip saving of dataset['{}']".format(key)) for key in skip] # Strip "_preproc-raw" or "_raw" from the run id for string in ["_preproc-raw", "_raw"]: if string in run_id: run_id = run_id.replace(string, "") if "raw" in skip: outnames = {"raw": None} else: outnames = {"raw": outbase.format(run_id=run_id, ftype=ftype, fext="fif")} if Path(outnames["raw"]).exists() and not overwrite: raise ValueError( "{} already exists. Please delete or do use overwrite=True.".format(outnames['raw']) ) logger.info(f"Saving dataset['raw'] as {outnames['raw']}") dataset["raw"].save(outnames['raw'], overwrite=overwrite) if "events" in dataset and "events" not in skip and dataset['events'] is not None: outnames['events'] = outbase.format(run_id=run_id, ftype="events", fext="npy") logger.info(f"Saving dataset['events'] as {outnames['events']}") np.save(outnames['events'], dataset["events"]) if "event_id" in dataset and "event_id" not in skip and dataset['event_id'] is not None: outnames['event_id'] = outbase.format(run_id=run_id, ftype="event-id", fext="yml") logger.info(f"Saving dataset['event_id'] as {outnames['event_id']}") yaml.dump(dataset["event_id"], open(outnames['event_id'], "w")) if "epochs" in dataset and "epochs" not in skip and dataset['epochs'] is not None: outnames['epochs'] = outbase.format(run_id=run_id, ftype="epo", fext="fif") logger.info(f"Saving dataset['epochs'] as {outnames['epochs']}") dataset["epochs"].save(outnames['epochs'], overwrite=overwrite) if "ica" in dataset and "ica" not in skip and dataset['ica'] is not None: outnames['ica'] = outbase.format(run_id=run_id, ftype="ica", fext="fif") logger.info(f"Saving dataset['ica'] as {outnames['ica']}") dataset["ica"].save(outnames['ica'], overwrite=overwrite) if "tfr" in dataset and "tfr" not in skip and dataset['tfr'] is not None: outnames['tfr'] = outbase.format(run_id=run_id, ftype="tfr", fext="fif") logger.info(f"Saving dataset['tfr'] as {outnames['tfr']}") dataset["tfr"].save(outnames['tfr'], overwrite=overwrite) if "glm" in dataset and "glm" not in skip and dataset['glm'] is not None: outnames['glm'] = outbase.format(run_id=run_id, ftype="glm", fext="pkl") logger.info(f"Saving dataset['glm'] as {outnames['glm']}") dataset["glm"].save_pkl(outnames['glm'], overwrite=overwrite) if "fig" in dataset and "fig" not in skip and dataset['fig'] is not None: keys = dataset["fig"].keys() outnames['fig'] = {} for key in keys: outnames['fig'][key] = outbase.format(run_id=run_id, ftype=key, fext="png") logger.info(f"Saving dataset['fig'][{key}] as {outnames['fig'][key]}") dataset["fig"][key].savefig(outnames['fig'][key]) # save remaining keys as pickle files for key in dataset: if key not in outnames and key not in skip: outnames[key] = outbase.format(run_id=run_id, ftype=key, fext="pkl") logger.info(f"Saving dataset['{key}'] as {outnames[key]}") if (not os.path.exists(outnames[key]) or overwrite) and key not in skip and dataset[key] is not None: with open(outnames[key], "wb") as f: pickle.dump(dataset[key], f) return outnames
[docs]def read_dataset(fif, preload=False, ftype=None): """Reads ``fif``/``npy``/``yml`` files associated with a dataset. Parameters ---------- fif : str Path to raw fif file (can be preprocessed). preload : bool Should we load the raw fif data? ftype : str Extension for the fif file (will be replaced for e.g. ``'_events.npy'`` or ``'_ica.fif'``). If ``None``, we assume the fif file is preprocessed with ``osl-ephys`` and has the extension ``'_preproc-raw'``. If this fails, we guess the extension as whatever comes after the last ``'_'``. Returns ------- dataset : dict Contains keys: ``'raw'``, ``'events'``, ``'event_id'``, ``'epochs'``, ``'ica'``. """ print("Loading dataset:") print("Reading", fif) raw = mne.io.read_raw_fif(fif, preload=preload) # Guess extension if ftype is None: logger.info("Guessing the preproc extension") if "preproc-raw" in fif: logger.info('Assuming fif file type is "preproc-raw"') ftype = "preproc-raw" else: if len(fif.split("_"))<2: logger.error("Unable to guess the fif file extension") else: logger.info('Assuming fif file type is the last "_" separated string') ftype = fif.split("_")[-1].split('.')[-2] # add extension to fif file name ftype = ftype + ".fif" events = Path(fif.replace(ftype, "events.npy")) if events.exists(): print("Reading", events) events = np.load(events) else: events = None event_id = Path(fif.replace(ftype, "event-id.yml")) if event_id.exists(): print("Reading", event_id) with open(event_id, "r") as file: event_id = yaml.load(file, Loader=yaml.Loader) else: event_id = None epochs = Path(fif.replace(ftype, "epo.fif")) if epochs.exists(): print("Reading", epochs) epochs = mne.read_epochs(epochs) else: epochs = None ica = Path(fif.replace(ftype, "ica.fif")) if ica.exists(): print("Reading", ica) ica = mne.preprocessing.read_ica(ica) else: ica = None dataset = { "raw": raw, "events": events, "event_id": event_id, "epochs": epochs, "ica": ica, } return dataset
[docs]def plot_preproc_flowchart( config, outname=None, show=False, stagecol="wheat", startcol="red", fig=None, ax=None, title=None, ): """Make a summary flowchart of a preprocessing chain. Parameters ---------- config : dict Preprocessing config to plot. outname : str Output filename. show : bool Should we show the plot? stagecol : str Stage colour. startcol : str Start colour. fig : matplotlib.figure Matplotlib figure to plot on. ax : :py:class:`matplotlib.axes <matplotlib.axes>` Matplotlib axes to plot on. title : str Title for the plot. Returns ------- fig : :py:class:`matplotlib.figure <matplotlib.figure>` ax : :py:class:`matplotlib.axes <matplotlib.axes>` """ config = load_config(config) if np.logical_or(ax == None, fig == None): fig = plt.figure(figsize=(8, 12)) plt.subplots_adjust(top=0.95, bottom=0.05) ax = plt.subplot(111, frame_on=False) ax.set_xticks([]) ax.set_yticks([]) if title == None: ax.set_title("osl-ephys Processing Recipe", fontsize=24) else: ax.set_title(title, fontsize=24) tmp_h = 1 if config["preproc"] is not None: tmp_h += 1 + len(config["preproc"]) if config["group"] is not None: tmp_h += 1 + len(config["group"]) stage_height = 1 / tmp_h box = dict(boxstyle="round", facecolor=stagecol, alpha=1, pad=0.3) startbox = dict(boxstyle="round", facecolor=startcol, alpha=1) font = { "family": "serif", "color": "k", "weight": "normal", "size": 16, } stages = [{"input": ""}] if config['preproc'] is not None: stages += [{"preproc": ""}, *config["preproc"]] if config['group'] is not None: stages += [{"group": ""}, *config["group"]] stages.append({"output": ""}) stage_str = "$\\bf{{{0}}}$ {1}" ax.arrow( 0.5, 1, 0.0, -1+0.02, fc="k", ec="k", head_width=0.045, head_length=0.035, length_includes_head=True, ) for idx, stage in enumerate(stages): method, userargs = next(iter(stage.items())) method = method.replace("_", r"\_") if method in ["input", "preproc", "group", "output"]: b = startbox else: b = box method = method + ":" ax.text( 0.5, 1 - stage_height * idx, stage_str.format(method, str(userargs)[1:-1]), ha="center", va="center", bbox=b, fontdict=font, wrap=True, ) ax.set_ylim(0, 1.05) ax.set_xlim(0.25, 0.75) if outname is not None: fig.savefig(outname, dpi=300, transparent=True) if show is True: fig.show() return fig, ax
# -------------------------------------------------------------- # Batch processing
[docs]def run_proc_chain( config, infile, subject=None, ftype='preproc-raw', outdir=None, logsdir=None, reportdir=None, ret_dataset=True, gen_report=None, overwrite=False, skip_save=None, extra_funcs=None, random_seed='auto', verbose="INFO", mneverbose="WARNING", ): """Run preprocessing for a single file. Parameters ---------- config : str or dict Preprocessing config. infile : str Path to input file. subject : str Subject ID. This will be the sub-directory in outdir. ftype: str Extension for the fif file (default ``preproc-raw``) outdir : str Output directory. logsdir : str Directory to save log files to. reportdir : str Directory to save report files to. ret_dataset : bool Should we return a dataset dict? gen_report : bool Should we generate a report? overwrite : bool Should we overwrite the output file if it already exists? skip_save: list or None (default) List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. verbose : str Level of info to print. Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``. mneverbose : str Level of info from MNE to print. Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``. Returns ------- dict or bool If ``ret_dataset=True``, a dict containing the preprocessed dataset with the following keys: ``raw``, ``ica``, ``epochs``, ``events``, ``event_id``. An empty dict is returned if preprocessing fails. If ``ret_dataset=False``, we return a flag indicating whether preprocessing was successful. """ # Get run (subject) ID run_id = subject or find_run_id(infile) name_base = "{run_id}_{ftype}.{fext}" if not ret_dataset: # Let's make sure we have an output directory outdir = outdir or os.getcwd() if outdir is not None: # We're saving the output to disk # Generate a report by default, this is overriden if the user passes # gen_report=False gen_report = True if gen_report is None else gen_report # Create output directories if they don't exist outdir = validate_outdir(outdir) logsdir = validate_outdir(logsdir or outdir / "logs") reportdir = validate_outdir(reportdir or outdir / "preproc_report") outdir = validate_outdir(outdir / run_id) else: # We're not saving the output to disk # Don't generate a report by default, this is overriden if the user passes # something for reportdir or gen_report=True gen_report = gen_report or reportdir is not None or False if gen_report: # Make sure we have a directory to write the report to reportdir = validate_outdir(reportdir or os.getcwd() + "/preproc_report") # Allow the user to create a log if they pass logsdir if logsdir is not None: logsdir = validate_outdir(logsdir) # Create output filename if outdir is not None: outbase = os.path.join(outdir, name_base) # Generate log filename if logsdir is not None: logbase = os.path.join(logsdir, name_base) logfile = logbase.format(run_id=run_id, ftype=ftype.replace("-raw", ""), fext="log") mne.utils._logging.set_log_file(logfile, overwrite=overwrite) else: logfile = None # Finish setting up loggers osl_logger.set_up(prefix=run_id, log_file=logfile, level=verbose, startup=False) mne.set_log_level(mneverbose) logger = logging.getLogger(__name__) now = strftime("%Y-%m-%d %H:%M:%S", localtime()) logger.info("{0} : Starting osl-ephys Processing".format(now)) logger.info("input : {0}".format(infile)) # Set random seed if random_seed == 'auto': set_random_seed() elif random_seed is None: pass else: set_random_seed(random_seed) # Write preprocessed data to output directory if outdir is not None: # Check for existing outputs - should be a .fif at least fifout = outbase.format(run_id=run_id, ftype=ftype, fext='fif') if os.path.exists(fifout) and (overwrite is False): logger.critical('Skipping preprocessing - existing output detected') return False # Load config if not isinstance(config, dict): config = load_config(config) # MAIN BLOCK - Run the preproc chain and catch any exceptions try: if isinstance(infile, str): raw = import_data(infile) elif (isinstance(infile, mne.io.fiff.raw.Raw) or isinstance(infile, mne.io.curry.curry.RawCurry)): raw = infile infile = raw.filenames[0] # assuming only one file here # Create a dataset dict to hold the preprocessed dataset dataset = { "raw": raw, "events": None, "epochs": None, "event_id": config["meta"]["event_codes"], "ica": None, "fig": {}, } # Do the preprocessing for stage in deepcopy(config["preproc"]): method, userargs = next(iter(stage.items())) target = userargs.get("target", "raw") # Raw is default func = find_func(method, target=target, extra_funcs=extra_funcs) # Actual function call dataset = func(dataset, userargs) # Add preprocessing info to dataset dict dataset = append_preproc_info(dataset, config, extra_funcs) outnames = {"raw": None} if outdir is not None: outnames = write_dataset(dataset, outbase, run_id, overwrite=overwrite, skip=skip_save, ftype=ftype) # Generate report data if gen_report: # Switch to non-GUI plotting backend mpl_backend = matplotlib.pyplot.get_backend() matplotlib.use('Agg') from ..report import gen_html_data, gen_html_page # avoids circular import logger.info("{0} : Generating Report".format(now)) report_data_dir = validate_outdir(reportdir / Path(outnames["raw"]).stem.replace(f"_{ftype}", "")) if 'fig' in dataset and dataset['fig'] is not None and len(dataset['fig'])>0: custom_figures = dataset['fig'] else: custom_figures = None gen_html_data( dataset["raw"], report_data_dir, ica=dataset["ica"], events=dataset["events"], event_id=dataset["event_id"], preproc_fif_filename=outnames["raw"], logsdir=logsdir, run_id=run_id, custom_figures=custom_figures, ) gen_html_page(reportdir) # Restore plotting context matplotlib.use(mpl_backend) except Exception as e: # Preprocessing failed if 'method' not in locals(): method = 'import_data' func = import_data logger.critical("**********************") logger.critical("* PROCESSING FAILED! *") logger.critical("**********************") ex_type, ex_value, ex_traceback = sys.exc_info() logger.error("{0} : {1}".format(method, func)) logger.error(ex_type) logger.error(ex_value) logger.error(traceback.print_tb(ex_traceback)) with open(logfile.replace(".log", ".error.log"), "w") as f: f.write("OSL-EPHYS PREPROCESSING CHAIN FAILED AT: {0}".format(now)) f.write("\n") f.write('Processing failed during stage : "{0}"'.format(method)) f.write(str(ex_type)) f.write("\n") f.write(str(ex_value)) f.write("\n") traceback.print_tb(ex_traceback, file=f) if ret_dataset: # We return an empty dict to indicate preproc failed # This ensures the function consistently returns one # variable type return {} else: if 'group' in config: return False, None return False now = strftime("%Y-%m-%d %H:%M:%S", localtime()) logger.info("{0} : Processing Complete".format(now)) if outnames["raw"] is not None: logger.info("Output file is {}".format(outnames["raw"])) if ret_dataset: return dataset else: if 'group' in config: return True, outnames return True
[docs]def run_proc_batch( config, files, subjects=None, ftype='preproc-raw', outdir=None, logsdir=None, reportdir=None, gen_report=True, overwrite=False, skip_save=None, extra_funcs=None, covs=None, random_seed='auto', verbose="INFO", mneverbose="WARNING", strictrun=False, dask_client=False, ): """Run batched preprocessing. This function will write output to disk (i.e. will not return the preprocessed data). Parameters ---------- config : str or dict Preprocessing config. files : str or list or mne.Raw Can be a list of Raw objects or a list of filenames (or ``.ds`` dir names if CTF data) or a path to a textfile list of filenames (or ``.ds`` dir names if CTF data). subjects : list of str Subject directory names. These are sub-directories in outdir. ftype: None or str Extension of the preprocessed fif files. Default option is `_preproc-raw`. outdir : str Output directory. logsdir : str Directory to save log files to. reportdir : str Directory to save report files to. gen_report : bool Should we generate a report? overwrite : bool Should we overwrite the output file if it exists? skip_save: list or None (default) List of keys to skip writing to disk. If None, we don't skip any keys. extra_funcs : list User-defined functions. covs : dict or pd.DataFrame Covariates to use for building the GLM design random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. verbose : str Level of info to print. Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``. mneverbose : str Level of info from MNE to print. Can be: ``'CRITICAL'``, ``'ERROR'``, ``'WARNING'``, ``'INFO'``, ``'DEBUG'`` or ``'NOTSET'``. strictrun : bool Should we ask for confirmation of user inputs before starting? dask_client : bool Indicate whether to use a previously initialised :py:class:`dask.distributed.Client <distributed.Client>` instance. Returns ------- list of bool Flags indicating whether preprocessing was successful for each input file. Notes ----- If you are using a :py:class:`dask.distributed.Client <distributed.Client>` instance, you must initialise it before calling this function. For example: >>> from dask.distributed import Client >>> client = Client(threads_per_worker=1, n_workers=4) """ if outdir is None: # Use the current working directory outdir = os.getcwd() # Validate the parent outdir - later do so for each subdirectory tmpoutdir = validate_outdir(outdir.split('{')[0]) logsdir = validate_outdir(logsdir or tmpoutdir / "logs") reportdir = validate_outdir(reportdir or tmpoutdir / "preproc_report") # Initialise Loggers mne.set_log_level(mneverbose) if strictrun and verbose not in ['INFO', 'DEBUG']: # override logger level if strictrun requested but user won't see any info... verobse = 'INFO' logfile = os.path.join(logsdir, 'batch_preproc.log') osl_logger.set_up(log_file=logfile, level=verbose, startup=False) logger.info('Starting osl-ephys Batch Processing') # Set random seed if random_seed == 'auto': random_seed = set_random_seed() elif random_seed is None: pass else: set_random_seed(random_seed) # Check through inputs and parameters infiles, good_files_outnames, good_files = process_file_inputs(files) # Specify filenames for the output data if subjects is None: subjects = good_files_outnames else: if len(subjects) != len(good_files_outnames): logger.critical( f"Number of subjects ({len(subjects)}) does not match " f"number of good files {len(good_files_outnames)}. " "Please fix the subjects list or pass subjects=None." ) if strictrun and click.confirm('Is this correct set of inputs?') is False: logger.critical('Stopping : User indicated incorrect number of input files') sys.exit(1) else: if strictrun: logger.info('User confirms input files') logger.info('Outputs saving to: {0}'.format(outdir)) if strictrun and click.confirm('Is this correct output directory?') is False: logger.critical('Stopping : User indicated incorrect output directory') sys.exit(1) else: if strictrun: logger.info('User confirms output directory') config = load_config(config) config_str = pprint.PrettyPrinter().pformat(config) logger.info('Running config\n {0}'.format(config_str)) if strictrun and click.confirm('Is this the correct config?') is False: logger.critical('Stopping : User indicated incorrect preproc config') sys.exit(1) else: if strictrun: logger.info('User confirms input config') if config['preproc'] is not None: # Create partial function with fixed options pool_func = partial( run_proc_chain, outdir=outdir, ftype=ftype, logsdir=logsdir, reportdir=reportdir, ret_dataset=False, gen_report=gen_report, overwrite=overwrite, skip_save=skip_save, extra_funcs=extra_funcs, random_seed=random_seed, ) # Loop through input files to generate arguments for run_proc_chain args = [] for infile, subject in zip(infiles, subjects): args.append((config, infile, subject)) # Actually run the processes if dask_client: proc_flags = dask_parallel_bag(pool_func, args) else: proc_flags = [pool_func(*aa) for aa in args] if isinstance(proc_flags[0], tuple): group_inputs = [flag[1] for flag in proc_flags] proc_flags = [flag[0] for flag in proc_flags] osl_logger.set_up(log_file=logfile, level=verbose, startup=False) logger.info("Processed {0}/{1} files successfully".format( np.sum(proc_flags), len(proc_flags))) # Generate a report if gen_report and len(infiles) > 0: from ..report import preproc_report # avoids circular import preproc_report.gen_html_page(reportdir) else: group_inputs = [{"raw": infile} for infile in infiles] proc_flags = [os.path.exists(sub) for sub in infiles] osl_logger.set_up(log_file=logfile, level=verbose, startup=False) logger.info("No preprocessing steps specified. Skipping preprocessing.") # start group processing custom_figures = None if config['group'] is not None: logger.info("Starting Group Processing") logger.info( "Valid input files {0}/{1}".format( np.sum(proc_flags), len(proc_flags) ) ) dataset = {} skip_save=[] for key in group_inputs[0]: dataset[key] = [group_inputs[i][key] for i in range(len(group_inputs))] skip_save.append(key) if covs is not None: dataset['covs'] = covs dataset['fig'] = {} for stage in deepcopy(config["group"]): method, userargs = next(iter(stage.items())) # make sure the function always knows it's a group processing userargs['run_on_group'] = True target = userargs.get("target", "raw") # Raw is default # skip.append(stage if userargs.get("skip_save") is True else None) # skip saving this stage to disk func = find_func(method, target=target, extra_funcs=extra_funcs) # Actual function call dataset = func(dataset, userargs) outbase = os.path.join(outdir, "{ftype}.{fext}") outnames = write_dataset(dataset, outbase, '', ftype='', overwrite=overwrite, skip=skip_save) custom_figures = dataset.get('fig', None) # rerun the summary report if gen_report: from ..report import preproc_report # avoids circular import if preproc_report.gen_html_summary(reportdir, logsdir, custom_figures=custom_figures): logger.info("******************************" + "*" * len(str(reportdir))) logger.info(f"* REMEMBER TO CHECK REPORT: {reportdir} *") logger.info("******************************" + "*" * len(str(reportdir))) # Return flags return proc_flags
# ---------------------------------------------------------- # Main CLI user function
[docs]def main(argv=None): """Main function for command line interface. Parameters ---------- argv : list Command line arguments. """ if argv is None: argv = sys.argv[1:] parser = argparse.ArgumentParser(description="Batch preprocess some fif files.") parser.add_argument("config", type=str, help="yaml defining preproc") parser.add_argument( "files", type=str, help="plain text file containing full paths to files to be processed", ) parser.add_argument( "--outdir", type=str, default=None, help="Path to output directory to save data in", ) parser.add_argument( "--logsdir", type=str, default=None, help="Path to logs directory" ) parser.add_argument( "--reportdir", type=str, default=None, help="Path to report directory" ) parser.add_argument( "--gen_report", type=bool, default=True, help="Should we generate a report?" ) parser.add_argument( "--overwrite", action="store_true", default=False, help="Overwrite previous output files if they're in the way", ) parser.add_argument( "--verbose", type=str, default="INFO", help="Set the logging level for osl-ephys functions", ) parser.add_argument( "--mneverbose", type=str, default="WARNING", help="Set the logging level for MNE functions", ) parser.add_argument( "--strictrun", action="store_true", help="Will ask the user for confirmation before starting", ) parser.usage = parser.format_help() args = parser.parse_args(argv) run_proc_batch(**vars(args))
if __name__ == "__main__": main()