You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
643 lines
17 KiB
643 lines
17 KiB
import errno |
|
import functools |
|
import os |
|
import io |
|
import pickle |
|
import sys |
|
import time |
|
import string |
|
import warnings |
|
from importlib import import_module |
|
from math import sin, cos, radians, atan2, degrees |
|
from contextlib import contextmanager, ExitStack |
|
from math import gcd |
|
from pathlib import PurePath, Path |
|
import re |
|
|
|
import numpy as np |
|
|
|
from ase.formula import formula_hill, formula_metal |
|
|
|
__all__ = ['exec_', 'basestring', 'import_module', 'seterr', 'plural', |
|
'devnull', 'gcd', 'convert_string_to_fd', 'Lock', |
|
'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens', |
|
'hsv2rgb', 'hsv', 'pickleload', 'FileNotFoundError', |
|
'formula_hill', 'formula_metal', 'PurePath', 'xwopen', |
|
'tokenize_version'] |
|
|
|
|
|
def tokenize_version(version_string: str): |
|
"""Parse version string into a tuple for version comparisons. |
|
|
|
Usage: tokenize_version('3.8') < tokenize_version('3.8.1'). |
|
""" |
|
tokens = [] |
|
for component in version_string.split('.'): |
|
match = re.match(r'(\d*)(.*)', component) |
|
assert match is not None, f'Cannot parse component {component}' |
|
number_str, tail = match.group(1, 2) |
|
try: |
|
number = int(number_str) |
|
except ValueError: |
|
number = -1 |
|
tokens += [number, tail] |
|
return tuple(tokens) |
|
|
|
|
|
# Python 2+3 compatibility stuff (let's try to remove these things): |
|
basestring = str |
|
pickleload = functools.partial(pickle.load, encoding='bytes') |
|
|
|
|
|
def deprecated(msg, category=FutureWarning): |
|
"""Return a decorator deprecating a function. |
|
|
|
Use like @deprecated('warning message and explanation').""" |
|
def deprecated_decorator(func): |
|
@functools.wraps(func) |
|
def deprecated_function(*args, **kwargs): |
|
warning = msg |
|
if not isinstance(warning, Warning): |
|
warning = category(warning) |
|
warnings.warn(warning) |
|
return func(*args, **kwargs) |
|
return deprecated_function |
|
return deprecated_decorator |
|
|
|
|
|
@contextmanager |
|
def seterr(**kwargs): |
|
"""Set how floating-point errors are handled. |
|
|
|
See np.seterr() for more details. |
|
""" |
|
old = np.seterr(**kwargs) |
|
try: |
|
yield |
|
finally: |
|
np.seterr(**old) |
|
|
|
|
|
def plural(n, word): |
|
"""Use plural for n!=1. |
|
|
|
>>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg') |
|
('0 eggs', '1 egg', '2 eggs') |
|
""" |
|
if n == 1: |
|
return '1 ' + word |
|
return '%d %ss' % (n, word) |
|
|
|
|
|
class DevNull: |
|
encoding = 'UTF-8' |
|
closed = False |
|
|
|
_use_os_devnull = deprecated('use open(os.devnull) instead', |
|
DeprecationWarning) |
|
# Deprecated for ase-3.21.0. Change to futurewarning later on. |
|
|
|
@_use_os_devnull |
|
def write(self, string): |
|
pass |
|
|
|
@_use_os_devnull |
|
def flush(self): |
|
pass |
|
|
|
@_use_os_devnull |
|
def seek(self, offset, whence=0): |
|
return 0 |
|
|
|
@_use_os_devnull |
|
def tell(self): |
|
return 0 |
|
|
|
@_use_os_devnull |
|
def close(self): |
|
pass |
|
|
|
@_use_os_devnull |
|
def isatty(self): |
|
return False |
|
|
|
@_use_os_devnull |
|
def read(self, n=-1): |
|
return '' |
|
|
|
|
|
devnull = DevNull() |
|
|
|
|
|
@deprecated('convert_string_to_fd does not facilitate proper resource ' |
|
'management. ' |
|
'Please use e.g. ase.utils.IOContext class instead.') |
|
def convert_string_to_fd(name, world=None): |
|
"""Create a file-descriptor for text output. |
|
|
|
Will open a file for writing with given name. Use None for no output and |
|
'-' for sys.stdout. |
|
""" |
|
if world is None: |
|
from ase.parallel import world |
|
if name is None or world.rank != 0: |
|
return open(os.devnull, 'w') |
|
if name == '-': |
|
return sys.stdout |
|
if isinstance(name, (str, PurePath)): |
|
return open(str(name), 'w') # str for py3.5 pathlib |
|
return name # we assume name is already a file-descriptor |
|
|
|
|
|
# Only Windows has O_BINARY: |
|
CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0) |
|
|
|
|
|
@contextmanager |
|
def xwopen(filename, world=None): |
|
"""Create and open filename exclusively for writing. |
|
|
|
If master cpu gets exclusive write access to filename, a file |
|
descriptor is returned (a dummy file descriptor is returned on the |
|
slaves). If the master cpu does not get write access, None is |
|
returned on all processors.""" |
|
|
|
fd = opencew(filename, world) |
|
try: |
|
yield fd |
|
finally: |
|
if fd is not None: |
|
fd.close() |
|
|
|
|
|
#@deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak') |
|
def opencew(filename, world=None): |
|
return _opencew(filename, world) |
|
|
|
|
|
def _opencew(filename, world=None): |
|
if world is None: |
|
from ase.parallel import world |
|
|
|
closelater = [] |
|
|
|
def opener(file, flags): |
|
return os.open(file, flags | CEW_FLAGS) |
|
|
|
try: |
|
error = 0 |
|
if world.rank == 0: |
|
try: |
|
fd = open(filename, 'wb', opener=opener) |
|
except OSError as ex: |
|
error = ex.errno |
|
else: |
|
closelater.append(fd) |
|
else: |
|
fd = open(os.devnull, 'wb') |
|
closelater.append(fd) |
|
|
|
# Synchronize: |
|
error = world.sum(error) |
|
if error == errno.EEXIST: |
|
return None |
|
if error: |
|
raise OSError(error, 'Error', filename) |
|
|
|
return fd |
|
except BaseException: |
|
for fd in closelater: |
|
fd.close() |
|
raise |
|
|
|
|
|
def opencew_text(*args, **kwargs): |
|
fd = opencew(*args, **kwargs) |
|
if fd is None: |
|
return None |
|
return io.TextIOWrapper(fd) |
|
|
|
|
|
class Lock: |
|
def __init__(self, name='lock', world=None, timeout=float('inf')): |
|
self.name = str(name) |
|
self.timeout = timeout |
|
if world is None: |
|
from ase.parallel import world |
|
self.world = world |
|
|
|
def acquire(self): |
|
dt = 0.2 |
|
t1 = time.time() |
|
while True: |
|
fd = opencew(self.name, self.world) |
|
if fd is not None: |
|
self.fd = fd |
|
break |
|
time_left = self.timeout - (time.time() - t1) |
|
if time_left <= 0: |
|
raise TimeoutError |
|
time.sleep(min(dt, time_left)) |
|
dt *= 2 |
|
|
|
def release(self): |
|
self.world.barrier() |
|
# Important to close fd before deleting file on windows |
|
# as a WinError would otherwise be raised. |
|
self.fd.close() |
|
if self.world.rank == 0: |
|
os.remove(self.name) |
|
self.world.barrier() |
|
|
|
def __enter__(self): |
|
self.acquire() |
|
|
|
def __exit__(self, type, value, tb): |
|
self.release() |
|
|
|
|
|
class OpenLock: |
|
def acquire(self): |
|
pass |
|
|
|
def release(self): |
|
pass |
|
|
|
def __enter__(self): |
|
pass |
|
|
|
def __exit__(self, type, value, tb): |
|
pass |
|
|
|
|
|
def search_current_git_hash(arg, world=None): |
|
"""Search for .git directory and current git commit hash. |
|
|
|
Parameters: |
|
|
|
arg: str (directory path) or python module |
|
.git directory is searched from the parent directory of |
|
the given directory or module. |
|
""" |
|
if world is None: |
|
from ase.parallel import world |
|
if world.rank != 0: |
|
return None |
|
|
|
# Check argument |
|
if isinstance(arg, str): |
|
# Directory path |
|
dpath = arg |
|
else: |
|
# Assume arg is module |
|
dpath = os.path.dirname(arg.__file__) |
|
# dpath = os.path.abspath(dpath) |
|
# in case this is just symlinked into $PYTHONPATH |
|
dpath = os.path.realpath(dpath) |
|
dpath = os.path.dirname(dpath) # Go to the parent directory |
|
git_dpath = os.path.join(dpath, '.git') |
|
if not os.path.isdir(git_dpath): |
|
# Replace this 'if' with a loop if you want to check |
|
# further parent directories |
|
return None |
|
HEAD_file = os.path.join(git_dpath, 'HEAD') |
|
if not os.path.isfile(HEAD_file): |
|
return None |
|
with open(HEAD_file, 'r') as fd: |
|
line = fd.readline().strip() |
|
if line.startswith('ref: '): |
|
ref = line[5:] |
|
ref_file = os.path.join(git_dpath, ref) |
|
else: |
|
# Assuming detached HEAD state |
|
ref_file = HEAD_file |
|
if not os.path.isfile(ref_file): |
|
return None |
|
with open(ref_file, 'r') as fd: |
|
line = fd.readline().strip() |
|
if all(c in string.hexdigits for c in line): |
|
return line |
|
return None |
|
|
|
|
|
def rotate(rotations, rotation=np.identity(3)): |
|
"""Convert string of format '50x,-10y,120z' to a rotation matrix. |
|
|
|
Note that the order of rotation matters, i.e. '50x,40z' is different |
|
from '40z,50x'. |
|
""" |
|
|
|
if rotations == '': |
|
return rotation.copy() |
|
|
|
for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1]))) |
|
for s in rotations.split(',')]: |
|
s = sin(a) |
|
c = cos(a) |
|
if i == 0: |
|
rotation = np.dot(rotation, [(1, 0, 0), |
|
(0, c, s), |
|
(0, -s, c)]) |
|
elif i == 1: |
|
rotation = np.dot(rotation, [(c, 0, -s), |
|
(0, 1, 0), |
|
(s, 0, c)]) |
|
else: |
|
rotation = np.dot(rotation, [(c, s, 0), |
|
(-s, c, 0), |
|
(0, 0, 1)]) |
|
return rotation |
|
|
|
|
|
def givens(a, b): |
|
"""Solve the equation system:: |
|
|
|
[ c s] [a] [r] |
|
[ ] . [ ] = [ ] |
|
[-s c] [b] [0] |
|
""" |
|
sgn = np.sign |
|
if b == 0: |
|
c = sgn(a) |
|
s = 0 |
|
r = abs(a) |
|
elif abs(b) >= abs(a): |
|
cot = a / b |
|
u = sgn(b) * (1 + cot**2)**0.5 |
|
s = 1. / u |
|
c = s * cot |
|
r = b * u |
|
else: |
|
tan = b / a |
|
u = sgn(a) * (1 + tan**2)**0.5 |
|
c = 1. / u |
|
s = c * tan |
|
r = a * u |
|
return c, s, r |
|
|
|
|
|
def irotate(rotation, initial=np.identity(3)): |
|
"""Determine x, y, z rotation angles from rotation matrix.""" |
|
a = np.dot(initial, rotation) |
|
cx, sx, rx = givens(a[2, 2], a[1, 2]) |
|
cy, sy, ry = givens(rx, a[0, 2]) |
|
cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1], |
|
cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1])) |
|
x = degrees(atan2(sx, cx)) |
|
y = degrees(atan2(-sy, cy)) |
|
z = degrees(atan2(sz, cz)) |
|
return x, y, z |
|
|
|
|
|
def pbc2pbc(pbc): |
|
newpbc = np.empty(3, bool) |
|
newpbc[:] = pbc |
|
return newpbc |
|
|
|
|
|
def hsv2rgb(h, s, v): |
|
"""http://en.wikipedia.org/wiki/HSL_and_HSV |
|
|
|
h (hue) in [0, 360[ |
|
s (saturation) in [0, 1] |
|
v (value) in [0, 1] |
|
|
|
return rgb in range [0, 1] |
|
""" |
|
if v == 0: |
|
return 0, 0, 0 |
|
if s == 0: |
|
return v, v, v |
|
|
|
i, f = divmod(h / 60., 1) |
|
p = v * (1 - s) |
|
q = v * (1 - s * f) |
|
t = v * (1 - s * (1 - f)) |
|
|
|
if i == 0: |
|
return v, t, p |
|
elif i == 1: |
|
return q, v, p |
|
elif i == 2: |
|
return p, v, t |
|
elif i == 3: |
|
return p, q, v |
|
elif i == 4: |
|
return t, p, v |
|
elif i == 5: |
|
return v, p, q |
|
else: |
|
raise RuntimeError('h must be in [0, 360]') |
|
|
|
|
|
def hsv(array, s=.9, v=.9): |
|
array = (array + array.min()) * 359. / (array.max() - array.min()) |
|
result = np.empty((len(array.flat), 3)) |
|
for rgb, h in zip(result, array.flat): |
|
rgb[:] = hsv2rgb(h, s, v) |
|
return np.reshape(result, array.shape + (3,)) |
|
|
|
|
|
# This code does the same, but requires pylab |
|
# def cmap(array, name='hsv'): |
|
# import pylab |
|
# a = (array + array.min()) / array.ptp() |
|
# rgba = getattr(pylab.cm, name)(a) |
|
# return rgba[:-1] # return rgb only (not alpha) |
|
|
|
|
|
def longsum(x): |
|
"""128-bit floating point sum.""" |
|
return float(np.asarray(x, dtype=np.longdouble).sum()) |
|
|
|
|
|
@contextmanager |
|
def workdir(path, mkdir=False): |
|
"""Temporarily change, and optionally create, working directory.""" |
|
path = Path(path) |
|
if mkdir: |
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
olddir = os.getcwd() |
|
os.chdir(str(path)) # py3.6 allows chdir(path) but we still need 3.5 |
|
try: |
|
yield # Yield the Path or dirname maybe? |
|
finally: |
|
os.chdir(olddir) |
|
|
|
|
|
class iofunction: |
|
"""Decorate func so it accepts either str or file. |
|
|
|
(Won't work on functions that return a generator.)""" |
|
def __init__(self, mode): |
|
self.mode = mode |
|
|
|
def __call__(self, func): |
|
@functools.wraps(func) |
|
def iofunc(file, *args, **kwargs): |
|
openandclose = isinstance(file, (str, PurePath)) |
|
fd = None |
|
try: |
|
if openandclose: |
|
fd = open(str(file), self.mode) |
|
else: |
|
fd = file |
|
obj = func(fd, *args, **kwargs) |
|
return obj |
|
finally: |
|
if openandclose and fd is not None: |
|
# fd may be None if open() failed |
|
fd.close() |
|
return iofunc |
|
|
|
|
|
def writer(func): |
|
return iofunction('w')(func) |
|
|
|
|
|
def reader(func): |
|
return iofunction('r')(func) |
|
|
|
|
|
# The next two functions are for hotplugging into a JSONable class |
|
# using the jsonable decorator. We are supposed to have this kind of stuff |
|
# in ase.io.jsonio, but we'd rather import them from a 'basic' module |
|
# like ase/utils than one which triggers a lot of extra (cyclic) imports. |
|
|
|
def write_json(self, fd): |
|
"""Write to JSON file.""" |
|
from ase.io.jsonio import write_json as _write_json |
|
_write_json(fd, self) |
|
|
|
|
|
@classmethod # type: ignore |
|
def read_json(cls, fd): |
|
"""Read new instance from JSON file.""" |
|
from ase.io.jsonio import read_json as _read_json |
|
obj = _read_json(fd) |
|
assert type(obj) is cls |
|
return obj |
|
|
|
|
|
def jsonable(name): |
|
"""Decorator for facilitating JSON I/O with a class. |
|
|
|
Pokes JSON-based read and write functions into the class. |
|
|
|
In order to write an object to JSON, it needs to be a known simple type |
|
(such as ndarray, float, ...) or implement todict(). If the class |
|
defines a string called ase_objtype, the decoder will want to convert |
|
the object back into its original type when reading.""" |
|
def jsonableclass(cls): |
|
cls.ase_objtype = name |
|
if not hasattr(cls, 'todict'): |
|
raise TypeError('Class must implement todict()') |
|
|
|
# We may want the write and read to be optional. |
|
# E.g. a calculator might want to be JSONable, but not |
|
# that .write() produces a JSON file. |
|
# |
|
# This is mostly for 'lightweight' object IO. |
|
cls.write = write_json |
|
cls.read = read_json |
|
return cls |
|
return jsonableclass |
|
|
|
|
|
class ExperimentalFeatureWarning(Warning): |
|
pass |
|
|
|
|
|
def experimental(func): |
|
"""Decorator for functions not ready for production use.""" |
|
@functools.wraps(func) |
|
def expfunc(*args, **kwargs): |
|
warnings.warn('This function may change or misbehave: {}()' |
|
.format(func.__qualname__), |
|
ExperimentalFeatureWarning) |
|
return func(*args, **kwargs) |
|
return expfunc |
|
|
|
|
|
def lazymethod(meth): |
|
"""Decorator for lazy evaluation and caching of data. |
|
|
|
Example:: |
|
|
|
class MyClass: |
|
|
|
@lazymethod |
|
def thing(self): |
|
return expensive_calculation() |
|
|
|
The method body is only executed first time thing() is called, and |
|
its return value is stored. Subsequent calls return the cached |
|
value.""" |
|
name = meth.__name__ |
|
|
|
@functools.wraps(meth) |
|
def getter(self): |
|
try: |
|
cache = self._lazy_cache |
|
except AttributeError: |
|
cache = self._lazy_cache = {} |
|
|
|
if name not in cache: |
|
cache[name] = meth(self) |
|
return cache[name] |
|
return getter |
|
|
|
|
|
def atoms_to_spglib_cell(atoms): |
|
"""Convert atoms into data suitable for calling spglib.""" |
|
return (atoms.get_cell(), |
|
atoms.get_scaled_positions(), |
|
atoms.get_atomic_numbers()) |
|
|
|
|
|
def warn_legacy(feature_name): |
|
warnings.warn( |
|
f'The {feature_name} feature is untested and ASE developers do not ' |
|
'know whether it works or how to use it. Please rehabilitate it ' |
|
'(by writing unittests) or it may be removed.', |
|
FutureWarning) |
|
|
|
|
|
def lazyproperty(meth): |
|
"""Decorator like lazymethod, but making item available as a property.""" |
|
return property(lazymethod(meth)) |
|
|
|
|
|
class IOContext: |
|
@lazyproperty |
|
def _exitstack(self): |
|
return ExitStack() |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, *args): |
|
self.close() |
|
|
|
def closelater(self, fd): |
|
return self._exitstack.enter_context(fd) |
|
|
|
def close(self): |
|
self._exitstack.close() |
|
|
|
def openfile(self, file, comm=None, mode='w'): |
|
from ase.parallel import world |
|
if comm is None: |
|
comm = world |
|
|
|
if hasattr(file, 'close'): |
|
return file # File already opened, not for us to close. |
|
|
|
if file is None or comm.rank != 0: |
|
return self.closelater(open(os.devnull, mode=mode)) |
|
|
|
if file == '-': |
|
return sys.stdout |
|
|
|
return self.closelater(open(file, mode=mode))
|
|
|