Better way to set number of threads used by NumPy


Background

When NumPy is linked against multithreaded implementations of BLAS (like MKL or OpenBLAS), the computationally intensive parts of a program run on multiple cores (sometimes all cores) automatically.

This is bad when:

  • you are sharing resources
  • you know of a better way to parallelize your program.

In these cases it is reasonable to restrict the number of threads used by MKL/OpenBLAS to 1, and parallelize your program manually.

My solution below involves loading the libraries at runtime and calling the corresponding C functions from Python.

Questions

  1. Are there any best/better practices in solving this problem?
  2. What are the pitfalls of my approach?
  3. Please comment on code quality in general.

Example of use

import numpy  # this uses however many threads MKL/OpenBLAS uses result = numpy.linalg.svd(matrix)   # this uses one thread with single_threaded(numpy):     result = numpy.linalg.svd(matrix) 

Implementation

  1. Imports

    import subprocess import re import sys import os import glob import warnings import ctypes 
  2. Class BLAS, abstracting a BLAS library with methods to get and set the number of threads:

    class BLAS:     def __init__(self, cdll, kind):          if kind not in (MKL, OPENBLAS):             raise ValueError(f'kind must be {MKL} or {OPENBLAS}, got {kind} instead.')          self.kind = kind         self.cdll = cdll          if kind == MKL:             self.get_n_threads = cdll.MKL_Get_Max_Threads             self.set_n_threads = cdll.MKL_Set_Num_Threads         else:             self.get_n_threads = cdll.openblas_get_num_threads             self.set_n_threads = cdll.openblas_set_num_threads 
  3. Function get_blas, returning a BLAS object given an imported NumPy module.

    def get_blas(numpy_module):      LDD = 'ldd'     LDD_PATTERN = r'^\t(?P<lib>.*{}.*) => (?P<path>.*) \(0x.*$  '      NUMPY_PATH = os.path.join(numpy_module.__path__[0], 'core')     MULTIARRAY_PATH = glob.glob(os.path.join(NUMPY_PATH, 'multiarray*.so'))[0]      ldd_result = subprocess.run(         args=[LDD, MULTIARRAY_PATH],          check=True,         stdout=subprocess.PIPE,          universal_newlines=True     )      output = ldd_result.stdout      if MKL in output:         kind = MKL     elif OPENBLAS in output:         kind = OPENBLAS     else:         return None      pattern = LDD_PATTERN.format(kind)     match = re.search(pattern, output, flags=re.MULTILINE)      if match:         lib = ctypes.CDLL(match.groupdict()['path'])         return BLAS(lib, kind)     else:         return None 
  4. Context manager single_threaded, that takes an imported NumPy module, sets number of threads to 1 on enter, resets to previous value on exit.

    class single_threaded:     def __init__(self, numpy_module):         self.blas = get_blas(numpy_module)      def __enter__(self):         if self.blas is not None:             self.old_n_threads = self.blas.get_n_threads()             self.blas.set_n_threads(1)         else:             warnings.warn(                 'No MKL/OpenBLAS found, assuming NumPy is single-threaded.'             )      def __exit__(self, *args):         if self.blas is not None:             self.blas.set_n_threads(self.old_n_threads)             if self.blas.get_n_threads() != self.old_n_threads:                 message = (                     f'Failed to reset {self.blas.kind} '                     f'to {self.old_n_threads} threads (previous value).'                 )                 raise RuntimeError(message)