Skip to content
This repository has been archived by the owner on Oct 10, 2023. It is now read-only.

Commit

Permalink
Update module_scraper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dstrande authored Aug 11, 2023
1 parent 7d522f9 commit 016488a
Showing 1 changed file with 47 additions and 52 deletions.
99 changes: 47 additions & 52 deletions flojoy/module_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class FlojoyWrapper:
"""Class for creating a Flojoy wrapper of NumPy and SciPy functions.
Run the script from the nodes directory.
"""

INPUT_DEFAULT_DTYPE = "np.ndarray"
FORBIDDEN_OPTIONAL_ARGS = ["x", "data", "kwargs", "comparator", "a", "A"]
FORBIDDEN_TYPES = [
Expand All @@ -19,7 +20,7 @@ class FlojoyWrapper:
"sequence",
"... array_like",
"number or ndarray or sequence",
"complex"
"complex",
]
FORBIDDEN_RETURN = [
"tuple of ndarrays",
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(self, func, parameters, module, argument_names):
self.data += f"{arg}: {dtype} = 2,\n\t"
elif dtype == "float" and def_val == "":
self.data += f"{arg}: {dtype} = 0.1,\n\t"
elif dtype not in ['int', 'float', 'None'] and def_val == "":
elif dtype not in ["int", "float", "None"] and def_val == "":
self.data += f"{arg}: {dtype},\n\t"
else:
self.data = ""
Expand All @@ -122,7 +123,7 @@ def __init__(self, func, parameters, module, argument_names):
return

# mvsdist and bayes_mvs both are incompatiable currently.
if 'mvs' in self.name:
if "mvs" in self.name:
self.data = ""
return

Expand Down Expand Up @@ -207,15 +208,15 @@ def write_func_args(self):
self.data += "\t" if idk < len(self.arguments) - 1 else ""

def gen_return_options(self):
self.return_options = re.findall(r' (.*?) : ', self.return_doc)
self.return_options = re.findall(r" (.*?) : ", self.return_doc)

def count_returns(self):
self.decomp_return = False
search = ' Returns\n -------'
search = " Returns\n -------"
find1 = self.org_docs.find(search)
find2 = self.org_docs.find('-----', find1 + len(search))
find2 = self.org_docs.find("-----", find1 + len(search))
self.return_doc = self.org_docs[find1:find2]
self.return_count = self.return_doc.count(' : ')
self.return_count = self.return_doc.count(" : ")
if self.return_count > 1:
self.decomp_return = True

Expand All @@ -237,16 +238,11 @@ def write_wrapper(self, mtype):

if self.module.__name__ == "numpy.linalg":
self.data = self.data.replace(
"OrderedPair | Matrix | Scalar",
"Matrix | Scalar"
)
self.data = self.data.replace(
"OrderedPair | Matrix",
"Matrix"
"OrderedPair | Matrix | Scalar", "Matrix | Scalar"
)
self.data = self.data.replace("OrderedPair | Matrix", "Matrix")
self.data = self.data.replace(
"import OrderedPair, flojoy,",
"import flojoy,"
"import OrderedPair, flojoy,", "import flojoy,"
)

self.data += f"\tresult = {self.module.__name__}.{self.name}(\n\t\t\t" + (
Expand Down Expand Up @@ -322,45 +318,49 @@ def custom_params(self):
This corrects those nodes with node_replace.txt.
"""
nodename = self.name.upper()
filename = f'{os.path.dirname(__file__)}/scraper/node_replace.txt'
replace = np.loadtxt(filename, delimiter='\t', dtype=str, skiprows=1).T
filename = f"{os.path.dirname(__file__)}/scraper/node_replace.txt"
replace = np.loadtxt(filename, delimiter="\t", dtype=str, skiprows=1).T
if nodename in replace[0]:
# print(repr(self.data[272:311]))
index = list(replace[0]).index(nodename)
to_replace = replace[1][index].replace('/n/t', '\n\t')
replacement = replace[2][index].replace('/n/t', '\n\t')
to_replace = replace[1][index].replace("/n/t", "\n\t")
replacement = replace[2][index].replace("/n/t", "\n\t")
self.data = self.data.replace(to_replace, replacement)

def write_test_script(self):
nodename = self.name.upper()

self.test_script = 'import numpy as np\n'
self.test_script += 'from flojoy import OrderedPair, Matrix, Scalar\n\n'
self.test_script += f'def test_{nodename}(mock_flojoy_decorator):'
self.test_script += f'\n\timport {nodename}\n\n\t'
self.test_script = "import numpy as np\n"
self.test_script += "from flojoy import OrderedPair, Matrix, Scalar\n\n"
self.test_script += f"def test_{nodename}(mock_flojoy_decorator):"
self.test_script += f"\n\timport {nodename}\n\n\t"

if self.module.__name__ == "numpy.linalg":
self.test_script += 'array1 = np.eye(5)\n\t'
self.test_script += 'array2 = np.eye(9)\n\t'
self.test_script += 'array2.shape = (3, 3, 3, 3)\n\n\t'
self.test_script += 'try:\n\t\t'
self.test_script += 'element_a = Matrix(m=array1)\n\t\t'
self.test_script += f'res = {nodename}.{nodename}(default=element_a)\n\t'
self.test_script += 'except np.linalg.LinAlgError:\n\t\t'
self.test_script += 'element_a = Matrix(m=array2)\n\t\t'
self.test_script += f'res = {nodename}.{nodename}(default=element_a)\n\n'
self.test_script += "array1 = np.eye(5)\n\t"
self.test_script += "array2 = np.eye(9)\n\t"
self.test_script += "array2.shape = (3, 3, 3, 3)\n\n\t"
self.test_script += "try:\n\t\t"
self.test_script += "element_a = Matrix(m=array1)\n\t\t"
self.test_script += f"res = {nodename}.{nodename}(default=element_a)\n\t"
self.test_script += "except np.linalg.LinAlgError:\n\t\t"
self.test_script += "element_a = Matrix(m=array2)\n\t\t"
self.test_script += f"res = {nodename}.{nodename}(default=element_a)\n\n"

else:
self.test_script += 'element_a = OrderedPair(x='
self.test_script += 'np.ones(50), y=np.arange(1, 51))\n\t'
self.test_script += f'res = {nodename}.{nodename}(default=element_a)\n\t'
self.test_script += "element_a = OrderedPair(x="
self.test_script += "np.ones(50), y=np.arange(1, 51))\n\t"
self.test_script += f"res = {nodename}.{nodename}(default=element_a)\n\t"

self.test_script += '\n\n\t# check that the outputs are one of the correct types.'
self.test_script += '\n\tassert isinstance(res, Scalar | OrderedPair | Matrix)\n'
self.test_script += (
"\n\n\t# check that the outputs are one of the correct types."
)
self.test_script += (
"\n\tassert isinstance(res, Scalar | OrderedPair | Matrix)\n"
)

# Some nodes tests require custom inputs for testing.
filename = f'{os.path.dirname(__file__)}/scraper/test_replace.txt'
replace = np.loadtxt(filename, delimiter='\t', dtype=str, skiprows=1).T
filename = f"{os.path.dirname(__file__)}/scraper/test_replace.txt"
replace = np.loadtxt(filename, delimiter="\t", dtype=str, skiprows=1).T
if nodename in replace[0]:
index = list(replace[0]).index(nodename)
to_replace = replace[1][index]
Expand Down Expand Up @@ -425,27 +425,20 @@ def scrape_function(func):
func, default_optional_params, submodule, all_arg_names
)
fw.write_wrapper(f"{module.upper()}_{submodule_name.upper()}")
if (
fw.data != ""
and "NoneType" not in fw.data
):
if fw.data != "" and "NoneType" not in fw.data:
try:
valid = ast.parse(fw.data) # Test node script.
this_nodes_directory = Path(NODE_DIR / f"{fw.name.upper()}")
this_nodes_directory.mkdir(exist_ok=True)

# Write node script.
script_name = f"{fw.name.upper()}.py"
with open(
this_nodes_directory / script_name, "w"
) as fh:
with open(this_nodes_directory / script_name, "w") as fh:
fh.write(fw.data)

# Write testing script.
test_name = f"{fw.name.upper()}_test_.py"
with open(
this_nodes_directory / test_name, "w"
) as fh:
with open(this_nodes_directory / test_name, "w") as fh:
fh.write(fw.test_script)

valids.append(fw.name)
Expand All @@ -462,9 +455,11 @@ def scrape_function(func):
# ) as fh:
# fh.write(fw.data)

print('invalids: ', invalids)
print('valids: ', valids)
print(f'Created nodes for {len(valids)} out of {len(valids) + len(invalids)} functions.')
print("invalids: ", invalids)
print("valids: ", valids)
print(
f"Created nodes for {len(valids)} out of {len(valids) + len(invalids)} functions."
)
with open(NODE_DIR / "__init__.py", "w") as fh:
functions = "__all__ = ["
for name in valids:
Expand Down

0 comments on commit 016488a

Please sign in to comment.