Source code for osl_ephys.utils.parallel

"""Utility functions for parallel processing.

"""

# Authors: Andrew Quinn <a.quinn@bham.ac.uk>

from functools import partial
import dask.bag as db
from dask.distributed import Client, LocalCluster, wait, default_client

# Housekeeping for logging
import logging
[docs]osl_logger = logging.getLogger(__name__)
[docs]def dask_parallel_bag(func, iter_args, func_args=None, func_kwargs=None): """A maybe more consistent alternative to ``dask_parallel``. Parameters --------- func : function The function to run in parallel. iter_args : list A list of iterables to pass to func. func_args : list, optional A list of positional arguments to pass to func. func_kwargs : dict, optional A dictionary of keyword arguments to pass to func. Returns ------- flags : list A list of return values from func. References ---------- https://docs.dask.org/en/stable/bag.html """ func_args = [] if func_args is None else func_args func_kwargs = {} if func_kwargs is None else func_kwargs # Get connection to currently active cluster client = default_client() # Print some helpful info osl_logger.info('Dask Client : {0}'.format(client.__repr__())) osl_logger.info('Dask Client dashboard link: {0}'.format(client.dashboard_link)) osl_logger.debug('Running function : {0}'.format(func.__repr__())) osl_logger.debug('User args : {0}'.format(func_args)) osl_logger.debug('User kwargs : {0}'.format(func_kwargs)) # Set kwargs - need to handle args on function call to preserve order. run_func = partial(func, **func_kwargs) osl_logger.info('Function defined : {0}'.format(run_func)) # Ensure input iter_args is list of lists if all(isinstance(aa, (list, tuple)) for aa in iter_args) is False: iter_args = [[aa] for aa in iter_args] # Add fixed positonal args if specified if func_args is not None: iter_args = [list(aa) + func_args for aa in iter_args] # Make dask bag from inputs: https://docs.dask.org/en/stable/bag.html b = db.from_sequence(iter_args) # Map iterable arguments to function using dask bag + current client bm = b.starmap(run_func) # Actually run the computation flags = bm.compute() osl_logger.info('Computation complete') return flags