import os import atexit import functools import pickle import sys import time import warnings import numpy as np def get_txt(txt, rank): if hasattr(txt, 'write'): # Note: User-supplied object might write to files from many ranks. return txt elif rank == 0: if txt is None: return open(os.devnull, 'w') elif txt == '-': return sys.stdout else: return open(txt, 'w', 1) else: return open(os.devnull, 'w') def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): """MPI-safe version of open function. In read mode, the file is opened on all nodes. In write and append mode, the file is opened on the master only, and /dev/null is opened on all other nodes. """ if comm is None: comm = world if comm.rank > 0 and mode[0] != 'r': name = os.devnull return open(name, mode, buffering, encoding) def parprint(*args, **kwargs): """MPI-safe print - prints only from master. """ if world.rank == 0: print(*args, **kwargs) class DummyMPI: rank = 0 size = 1 def _returnval(self, a, root=-1): # MPI interface works either on numbers, in which case a number is # returned, or on arrays, in-place. if np.isscalar(a): return a if hasattr(a, '__array__'): a = a.__array__() assert isinstance(a, np.ndarray) return None def sum(self, a, root=-1): return self._returnval(a) def product(self, a, root=-1): return self._returnval(a) def broadcast(self, a, root): assert root == 0 return self._returnval(a) def barrier(self): pass class MPI: """Wrapper for MPI world object. Decides at runtime (after all imports) which one to use: * MPI4Py * GPAW * a dummy implementation for serial runs """ def __init__(self): self.comm = None def __getattr__(self, name): if self.comm is None: self.comm = _get_comm() return getattr(self.comm, name) def _get_comm(): """Get the correct MPI world object.""" if 'mpi4py' in sys.modules: return MPI4PY() if '_gpaw' in sys.modules: import _gpaw if hasattr(_gpaw, 'Communicator'): return _gpaw.Communicator() if '_asap' in sys.modules: import _asap if hasattr(_asap, 'Communicator'): return _asap.Communicator() return DummyMPI() class MPI4PY: def __init__(self, mpi4py_comm=None): if mpi4py_comm is None: from mpi4py import MPI mpi4py_comm = MPI.COMM_WORLD self.comm = mpi4py_comm @property def rank(self): return self.comm.rank @property def size(self): return self.comm.size def _returnval(self, a, b): """Behave correctly when working on scalars/arrays. Either input is an array and we in-place write b (output from mpi4py) back into a, or input is a scalar and we return the corresponding output scalar.""" if np.isscalar(a): assert np.isscalar(b) return b else: assert not np.isscalar(b) a[:] = b return None def sum(self, a, root=-1): if root == -1: b = self.comm.allreduce(a) else: b = self.comm.reduce(a, root) return self._returnval(a, b) def split(self, split_size=None): """Divide the communicator.""" # color - subgroup id # key - new subgroup rank if not split_size: split_size = self.size color = int(self.rank // (self.size / split_size)) key = int(self.rank % (self.size / split_size)) comm = self.comm.Split(color, key) return MPI4PY(comm) def barrier(self): self.comm.barrier() def abort(self, code): self.comm.Abort(code) def broadcast(self, a, root): b = self.comm.bcast(a, root=root) if self.rank == root: if np.isscalar(a): return a return return self._returnval(a, b) world = None # Check for special MPI-enabled Python interpreters: if '_gpaw' in sys.builtin_module_names: # http://wiki.fysik.dtu.dk/gpaw import _gpaw world = _gpaw.Communicator() elif '_asap' in sys.builtin_module_names: # Modern version of Asap # http://wiki.fysik.dtu.dk/asap # We cannot import asap3.mpi here, as that creates an import deadlock import _asap world = _asap.Communicator() # Check if MPI implementation has been imported already: elif '_gpaw' in sys.modules: # Same thing as above but for the module version import _gpaw try: world = _gpaw.Communicator() except AttributeError: pass elif '_asap' in sys.modules: import _asap try: world = _asap.Communicator() except AttributeError: pass elif 'mpi4py' in sys.modules: world = MPI4PY() if world is None: world = MPI() def barrier(): world.barrier() def broadcast(obj, root=0, comm=world): """Broadcast a Python object across an MPI communicator and return it.""" if comm.rank == root: string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) n = np.array([len(string)], int) else: string = None n = np.empty(1, int) comm.broadcast(n, root) if comm.rank == root: string = np.frombuffer(string, np.int8) else: string = np.zeros(n, np.int8) comm.broadcast(string, root) if comm.rank == root: return obj else: return pickle.loads(string.tobytes()) def parallel_function(func): """Decorator for broadcasting from master to slaves using MPI. Disable by passing parallel=False to the function. For a method, you can also disable the parallel behavior by giving the instance a self.serial = True. """ @functools.wraps(func) def new_func(*args, **kwargs): if (world.size == 1 or args and getattr(args[0], 'serial', False) or not kwargs.pop('parallel', True)): # Disable: return func(*args, **kwargs) ex = None result = None if world.rank == 0: try: result = func(*args, **kwargs) except Exception as x: ex = x ex, result = broadcast((ex, result)) if ex is not None: raise ex return result return new_func def parallel_generator(generator): """Decorator for broadcasting yields from master to slaves using MPI. Disable by passing parallel=False to the function. For a method, you can also disable the parallel behavior by giving the instance a self.serial = True. """ @functools.wraps(generator) def new_generator(*args, **kwargs): if (world.size == 1 or args and getattr(args[0], 'serial', False) or not kwargs.pop('parallel', True)): # Disable: for result in generator(*args, **kwargs): yield result return if world.rank == 0: try: for result in generator(*args, **kwargs): broadcast((None, result)) yield result except Exception as ex: broadcast((ex, None)) raise ex broadcast((None, None)) else: ex2, result = broadcast((None, None)) if ex2 is not None: raise ex2 while result is not None: yield result ex2, result = broadcast((None, None)) if ex2 is not None: raise ex2 return new_generator def register_parallel_cleanup_function(): """Call MPI_Abort if python crashes. This will terminate the processes on the other nodes.""" if world.size == 1: return def cleanup(sys=sys, time=time, world=world): error = getattr(sys, 'last_type', None) if error: sys.stdout.flush() sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 'Calling MPI_Abort!\n') % (world.rank, error)) sys.stderr.flush() # Give other nodes a moment to crash by themselves (perhaps # producing helpful error messages): time.sleep(3) world.abort(42) atexit.register(cleanup) def distribute_cpus(size, comm): """Distribute cpus to tasks and calculators. Input: size: number of nodes per calculator comm: total communicator object Output: communicator for this rank, number of calculators, index for this rank """ assert size <= comm.size assert comm.size % size == 0 tasks_rank = comm.rank // size r0 = tasks_rank * size ranks = np.arange(r0, r0 + size) mycomm = comm.new_communicator(ranks) return mycomm, comm.size // size, tasks_rank class ParallelModuleWrapper: def __getattr__(self, name): if name == 'rank' or name == 'size': warnings.warn('ase.parallel.{name} has been deprecated. ' 'Please use ase.parallel.world.{name} instead.' .format(name=name), FutureWarning) return getattr(world, name) return getattr(_parallel, name) _parallel = sys.modules['ase.parallel'] sys.modules['ase.parallel'] = ParallelModuleWrapper() # type: ignore