Skip to content

Commit

Permalink
added some docs + tests and refactored some code. Also removed Delaye…
Browse files Browse the repository at this point in the history
…d class.
  • Loading branch information
Matthew Chan committed Oct 26, 2017
1 parent cdb2f94 commit 3db72d9
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 64 deletions.
111 changes: 111 additions & 0 deletions decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from functools import wraps


#
# Decorators
#

def onetime(varname: str):
"""
Prevents basic Python types from being changed after setting.
Numpy arrays also need the read-only flag set on them.
For use with Python setters, must be decorated before (closer to the function) @setter is
called.
Parameters
----------
varname
Name of the variable being tested. It does not need to be defined at interpretation
time.
Returns
-------
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if getattr(self, varname, None) is not None:
print("Trying to set a one-time attribute {varname}. Ignored")
return
func(self, *args, **kwargs)

return wrapper

return decorator


def cache(varname: str):
"""
Returns a variable if it is already defined. Otherwise, it calls the code in the property.
Must be decorated before (closer to the function) @property.
Parameters
----------
varname
Name of variable to cache. It does not need to be defined at interpretation.
Returns
-------
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
val = getattr(self, varname, None)
if val is None:
val = func(self, *args, **kwargs)
setattr(self, varname, val)
return val

return wrapper

return decorator


def delayed(func):
"""
Safety check to make sure the instance has the second stage of instantiation.
The function decorated will return AttributeError if the function with @finalize has not
been called yet.
Parameters
----------
func
Returns
-------
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
if not getattr(self, "init_finished", False):
print("Instance must finalize instantiation before calling compute functions.")
raise AttributeError
return func(self, *args, **kwargs)

return wrapper


def finalize(func):
"""
When the function decorated completes, the instance will be marked as fully instantiated.
Parameters
----------
func
Returns
-------
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.init_finished = True

return wrapper
80 changes: 16 additions & 64 deletions example_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,25 @@
import grid
import iodata
import numpy as np
from functools import wraps
from abc import ABCMeta
from meanfield import *


#
# Decorators
#

def onetime(varname):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if getattr(self, varname) is not None:
print("Trying to set a one-time attribute {varname}. Ignored")
return
func(self, *args, **kwargs)

return wrapper

return decorator


def cache(varname):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if getattr(self, varname) is None:
val = func(self, *args, **kwargs)
setattr(self, varname, val)
return getattr(self, varname)

return wrapper

return decorator


def delayed(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.init_finished:
print("Instance must finalize instantiation before calling compute functions.")
raise AttributeError
return func(self, *args, **kwargs)

return wrapper


def finalize(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.init_finished = True

return wrapper


#
# Options classes
#

class Options:
class Options(metaclass=ABCMeta):
"""
Users instantiate child classes of this class to specify which options they want
for calculations.
"""

# __init__ is for options. finish_init is for inputs used in calculations.
def finish_init(self):
raise NotImplementedError


class DelayedInit:
init_finished = False


class Molecule(Options):
def __init__(self, coords, atomic_numbers, pseudo_numbers=None, charge=None, multiplicity=None):
self._coords = coords
Expand All @@ -90,6 +39,9 @@ def __init__(self, coords, atomic_numbers, pseudo_numbers=None, charge=None, mul

self._multiplicity = multiplicity or default_multiplicity

def finish_init(self):
pass

@classmethod
def from_file(cls, filename):
# TODO: load from file
Expand Down Expand Up @@ -143,7 +95,7 @@ def nelec_beta(self):
return self.nelec[1]


class Basis(Options, DelayedInit):
class Basis(Options):
def __init__(self, bset=None):
"""Specify basis set options.
bset can be a string containing the basis set name, or a nested tuple like ((int, string),)
Expand Down Expand Up @@ -243,7 +195,7 @@ def make_guess(self):
guess_core_hamiltonian(self._olp, self._one, *self._orbs)


class Grid(Options, DelayedInit):
class Grid(Options):
def __init__(self, accuracy='coarse'):
self._accuracy = accuracy

Expand All @@ -260,7 +212,7 @@ def grid(self):
return self._grid


class Orbitals(Options, DelayedInit):
class Orbitals(Options):
def __init__(self, spin="U"):
self._spin = spin

Expand Down Expand Up @@ -316,7 +268,7 @@ def __init__(self, *args, **kwargs):
# Compute classes
#

class Method:
class Method(Options):
def __init__(self):
self._ham = None
self._basis = None
Expand Down Expand Up @@ -349,7 +301,7 @@ def _calculate_energy(self):
return self._ham.energy


class HF(Options, Method, DelayedInit):
class HF(Method):
@finalize
def finish_init(self, coords, basis, scf, occ_model, orb, grid):
super().finish_init(coords, basis, scf)
Expand Down Expand Up @@ -379,7 +331,7 @@ def finish_init(self, coords, basis, scf, occ_model, orb, grid):
raise NotImplementedError


class DFT(Options, Method, DelayedInit):
class DFT(Method):
def __init__(self, xc=None, x=None, c=None, frac=None):
if xc and (x or c or frac):
print("Cannot specify xc and also x or c functionals or exchange fraction")
Expand Down
52 changes: 52 additions & 0 deletions test/test_example_glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from numpy.testing import assert_raises

from decorators import onetime, cache, delayed, finalize


def test_onetime():
class A:
_x = 5

@property
def x(self):
return self._x

@x.setter
@onetime("_x")
def x(self, y):
self._x = y

a = A()
a.x = 10
assert a.x == 5


def test_cache():
class A:
_x = 5

@property
@cache("_x")
def x(self):
return 10

a = A()
assert a.x == 5


def test_delayed_finalize():
class A:
@finalize
def finish(self):
pass

@delayed
def do(self):
pass

a = A()
with assert_raises(AttributeError):
a.do()

a.finish()
a.do()

0 comments on commit 3db72d9

Please sign in to comment.