Source code for osl_ephys.source_recon.batch

"""Batch processing for source reconstruction.

"""

# Authors: Chetan Gohil <chetan.gohil@psych.ox.ac.uk>
#          Mats van Es <mats.vanes@psych.ox.ac.uk>           

import os
import sys
import traceback
import pprint
import inspect

from copy import deepcopy
from time import localtime, strftime
from functools import partial
from dask.distributed import Variable, Queue

import numpy as np
import yaml
import mne

from . import rhino, wrappers, freesurfer_utils
from ..report import src_report
from ..utils import logger as osl_logger
from ..utils import validate_outdir, find_run_id, parallel
from ..utils.misc import set_random_seed

import logging
[docs]logger = logging.getLogger(__name__)
[docs]def load_config(config): """Load config. Parameters ---------- config : str or dict Path to yaml file or str to convert to dict or a dict. Returns ------- config : dict Source reconstruction config. """ if type(config) not in [str, dict]: raise ValueError("config must be a str or dict, got {}.".format(type(config))) if isinstance(config, str): try: # See if we have a filepath with open(config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) except (UnicodeDecodeError, FileNotFoundError, OSError): # We have a str config = yaml.load(config, Loader=yaml.FullLoader) # Validation if "source_recon" not in config: raise ValueError("source_recon must be included in the config.") return config
[docs]def find_func(method, extra_funcs): """Find a source reconstruction function. Parameters ---------- method : str Function name. extra_funcs : list of functions Custom functions. Returns ------- func : function Function to use. """ func = None # Look in custom functions if extra_funcs is not None: func_ind = [idx if (f.__name__ == method) else -1 for idx, f in enumerate(extra_funcs) ] if np.max(func_ind) > -1: func = extra_funcs[np.argmax(func_ind)] # Look in osl_ephys.source_recon.wrappers if func is None and hasattr(wrappers, method): func = getattr(wrappers, method) return func
[docs]def run_src_chain( config, outdir, subject, preproc_file=None, smri_file=None, epoch_file=None, surface_extraction_method='fsl', logsdir=None, reportdir=None, gen_report=True, verbose="INFO", mneverbose="WARNING", extra_funcs=None, random_seed='auto', ): """Source reconstruction. Parameters ---------- config : str or dict Source reconstruction config. outdir : str Source reconstruction directory. subject : str Subject name. surface_extraction_method : str Can be 'fsl' or 'freesurfer'. preproc_file : str Preprocessed fif file. smri_file : str Structural MRI file. epoch_file : str Epoched fif file. logsdir : str Directory to save log files to. reportdir : str Directory to save report files to. gen_report : bool Should we generate a report? verbose : str Level of verbose. mneverbose : str Level of MNE verbose. extra_funcs : list of functions Custom functions. random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. Returns ------- flag : bool Flag indicating whether source reconstruction was successful. """ if surface_extraction_method == 'fsl': rhino.fsl_utils.check_fsl() elif surface_extraction_method == 'freesurfer': freesurfer_utils.check_freesurfer() # Directories outdir = validate_outdir(outdir) logsdir = validate_outdir(logsdir or outdir / "logs") reportdir = validate_outdir(reportdir or outdir / "src_report") # Use the subject ID for the run ID run_id = subject # Generate log filename name_base = "{run_id}_{ftype}.{fext}" logbase = os.path.join(logsdir, name_base) logfile = logbase.format(run_id=run_id, ftype="src", fext="log") mne.utils._logging.set_log_file(logfile) # Finish setting up loggers osl_logger.set_up(prefix=run_id, log_file=logfile, level=verbose, startup=False) mne.set_log_level(mneverbose) logger = logging.getLogger(__name__) now = strftime("%Y-%m-%d %H:%M:%S", localtime()) logger.info("{0} : Starting osl-ephys Processing".format(now)) logger.info("input : {0}".format(outdir / subject)) # Set random seed if random_seed == 'auto': set_random_seed() elif random_seed is None: pass else: set_random_seed(random_seed) # Load config if not isinstance(config, dict): config = load_config(config) # Check what files are in the output directory preproc_filename = f"{outdir}/{subject}/{subject}_preproc-raw.fif" epoch_filename = f"{outdir}/{subject}/{subject}_epo.fif" if os.path.exists(preproc_filename) and os.path.exists(epoch_filename): if preproc_file is None and epoch_file is None: raise ValueError( "Both preproc and epoch fif files found. " "Please pass preproc_file=True or epoch_file=True." ) elif os.path.exists(preproc_filename): preproc_file = preproc_filename elif os.path.exists(epoch_filename): epoch_file = epoch_filename # Validation doing_coreg = ( any(["compute_surfaces" in method for method in config["source_recon"]]) or any(["coregister" in method for method in config["source_recon"]]) ) if doing_coreg and smri_file is None: raise ValueError("smri_file must be passed if we're doing coregistration.") # MAIN BLOCK - Run source reconstruction and catch any exceptions try: for stage in deepcopy(config["source_recon"]): method, userargs = next(iter(stage.items())) func = find_func(method, extra_funcs=extra_funcs) if func is None: avail_funcs = inspect.getmembers(wrappers, inspect.isfunction) avail_names = [name for name, _ in avail_funcs] if method not in avail_names: raise NotImplementedError( f"{method} not available.\n" + "Please pass via extra_funcs " + f"or use available functions: {avail_names}." ) def wrapped_func(**kwargs): sig = inspect.signature(func) args = [param for param in sig.parameters.keys() if param != 'kwargs'] defaults = [param.default for param in sig.parameters.values() if param.default is not inspect.Parameter.empty] args_with_defaults = args[-len(defaults):] if defaults else [] kwargs_to_pass = {} for a in args: if a in kwargs: kwargs_to_pass[a] = kwargs[a] elif a not in args_with_defaults: raise ValueError(f"{a} needs to be passed to {func.__name__}") return func(**kwargs_to_pass) wrapped_func( outdir=outdir, subject=subject, surface_extraction_method=surface_extraction_method, preproc_file=preproc_file, smri_file=smri_file, epoch_file=epoch_file, reportdir=reportdir, logsdir=logsdir, **userargs, ) except Exception as e: logger.critical("*********************************") logger.critical("* SOURCE RECONSTRUCTION FAILED! *") logger.critical("*********************************") ex_type, ex_value, ex_traceback = sys.exc_info() logger.error("{0} : {1}".format(method, func)) logger.error(ex_type) logger.error(ex_value) logger.error(traceback.print_tb(ex_traceback)) with open(logfile.replace(".log", ".error.log"), "w") as f: f.write("osl-ephys SOURCE RECON failed at: {0}".format(now)) f.write("\n") f.write('Processing failed during stage : "{0}"'.format(method)) f.write(str(ex_type)) f.write("\n") f.write(str(ex_value)) f.write("\n") traceback.print_tb(ex_traceback, file=f) return False if gen_report: # Generate data and individual HTML data for the report src_report.gen_html_data(config, outdir, subject, reportdir, extra_funcs=extra_funcs) # Generate individual subject HTML report src_report.gen_html_page(reportdir) return True
[docs]def run_src_batch( config, outdir, subjects, preproc_files=None, smri_files=None, epoch_files=None, surface_extraction_method='fsl', logsdir=None, reportdir=None, gen_report=True, verbose="INFO", mneverbose="WARNING", extra_funcs=None, dask_client=False, random_seed='auto', ): """Batch source reconstruction. Parameters ---------- config : str or dict Source reconstruction config. outdir : str Source reconstruction directory. subjects : list of str Subject names. surface_extraction_method : str Can be 'fsl' or 'freesurfer'. preproc_files : list of str Preprocessed fif files. smri_files : list of str or str Structural MRI files. Can be 'standard' to use MNI152_T1_2mm.nii for the structural. epoch_files : list of str Epoched fif file. logsdir : str Directory to save log files to. reportdir : str Directory to save report files to. gen_report : bool Should we generate a report? verbose : str Level of verbose. mneverbose : str Level of MNE verbose. extra_funcs : list of functions Custom functions. dask_client : bool Are we using a dask client? random_seed : 'auto' (default), int or None Random seed to set. If 'auto', a random seed will be generated. Random seeds are set for both Python and NumPy. If None, no random seed is set. Returns ------- flags : list of bool Flags indicating whether coregistration was successful. """ if surface_extraction_method == 'fsl': rhino.fsl_utils.check_fsl() elif surface_extraction_method == 'freesurfer': freesurfer_utils.check_freesurfer() # Directories outdir = validate_outdir(outdir) logsdir = validate_outdir(logsdir or outdir / "logs") reportdir = validate_outdir(reportdir or outdir / "src_report") # Initialise Loggers mne.set_log_level(mneverbose) logfile = os.path.join(logsdir, 'batch_src.log') osl_logger.set_up(log_file=logfile, level=verbose, startup=False) logger.info('Starting osl-ephys Batch Source Reconstruction') # Set random seed if random_seed == 'auto': random_seed = set_random_seed() elif random_seed is None: pass else: set_random_seed(random_seed) # Load config config = load_config(config) config_str = pprint.PrettyPrinter().pformat(config) logger.info('Running config\n {0}'.format(config_str)) # Number of files (subjects) to process n_subjects = len(subjects) # Validation if preproc_files is not None and epoch_files is not None: raise ValueError("Please pass either preproc_file or epoch_files, not both.") if preproc_files and epoch_files: raise ValueError( "Cannot pass both preproc_files=True and epoch_files=True. " "Please only pass one of these." ) if isinstance(preproc_files, list): n_files = len(preproc_files) if n_subjects != n_files: raise ValueError(f"Got {n_subjects} subjects and {n_files} preproc_files.") elif isinstance(epoch_files, list): n_files = len(epoch_files) if n_subjects != n_files: raise ValueError(f"Got {n_subjects} subjects and {n_files} epoch_files.") else: # Check what files are in the output directory preproc_files_list = [] epoch_files_list = [] for subject in subjects: preproc_file = f"{outdir}/{subject}/{subject}_preproc-raw.fif" epoch_file = f"{outdir}/{subject}/{subject}_epo.fif" if os.path.exists(preproc_file) and os.path.exists(epoch_file): if preproc_files is None and epoch_files is None: raise ValueError( "Both preproc and epoch fif files found. " "Please pass preproc_files=True or epoch_files=True." ) elif os.path.exists(preproc_file): preproc_files_list.append(preproc_file) elif os.path.exists(epoch_file): epoch_files_list.append(epoch_file) if len(preproc_files_list) > 0: preproc_files = preproc_files_list elif len(epoch_files_list) > 0: epoch_files = epoch_files_list doing_coreg = ( any(["compute_surfaces" in method for method in config["source_recon"]]) or any(["coregister" in method for method in config["source_recon"]]) ) if doing_coreg and smri_files is None: raise ValueError("smri_files must be passed if we are coregistering.") elif smri_files is None or isinstance(smri_files, str): smri_files = [smri_files] * n_subjects if preproc_files is None: preproc_files = [None] * n_subjects if epoch_files is None: epoch_files = [None] * n_subjects # Create partial function with fixed options pool_func = partial( run_src_chain, surface_extraction_method=surface_extraction_method, logsdir=logsdir, reportdir=reportdir, gen_report=gen_report, verbose=verbose, mneverbose=mneverbose, extra_funcs=extra_funcs, random_seed=random_seed, ) # Loop through input files to generate arguments for run_coreg_chain args = [] for subject, preproc_file, smri_file, epoch_file, in zip(subjects, preproc_files, smri_files, epoch_files): args.append((config, outdir, subject, preproc_file, smri_file, epoch_file)) # Actually run the processes if dask_client: flags = parallel.dask_parallel_bag(pool_func, args) else: flags = [pool_func(*aa) for aa in args] osl_logger.set_up(log_file=logfile, level=verbose, startup=False) logger.info("Processed {0}/{1} files successfully".format(int(np.sum(flags)), len(flags))) if gen_report and int(np.sum(flags)) > 0: # Generate individual subject HTML report src_report.gen_html_page(reportdir) # Generate a summary report if src_report.gen_html_summary(reportdir): logger.info("******************************" + "*" * len(str(reportdir))) logger.info(f"* REMEMBER TO CHECK REPORT: {reportdir} *") logger.info("******************************" + "*" * len(str(reportdir))) return flags