diff --git a/.github/workflows/pytest_postgres.yml b/.github/workflows/pytest_postgres.yml index 7a2e42b..84054c6 100644 --- a/.github/workflows/pytest_postgres.yml +++ b/.github/workflows/pytest_postgres.yml @@ -33,18 +33,8 @@ jobs: python -m pip install pytest-mock python -m pip install pytest-cov python -m pip install . - playwright install - - name: Run pytest and Generate coverage report + - name: Run pytest for postgres shell: bash -l {0} run: | - python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings --cov=./ --cov-report=xml:coverage.xml - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml - flags: unittests - env_vars: OS,PYTHON - name: codecov-umbrella - fail_ci_if_error: false - verbose: true + python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings + diff --git a/elm/web/rhub.py b/elm/web/rhub.py index 5b01ba6..d56d1fe 100644 --- a/elm/web/rhub.py +++ b/elm/web/rhub.py @@ -585,26 +585,23 @@ def authors(self): """ pa = self.get('personAssociations') + if not pa: + return None + authors = [] for r in pa: name = r.get('name') - if name: - first = name.get('firstName') - last = name.get('lastName') + if not name: + continue - if first and last: - full = first + ' ' + last - elif first: - full = first - elif last: - full = last - else: - full = None + first = name.get('firstName') + last = name.get('lastName') - if full: - authors.append(full) + full = " ".join(filter(bool, [first, last])) + + authors.append(full) out = ', '.join(authors) @@ -660,11 +657,15 @@ def abstract(self): """ abstract = self.get('abstract') - if abstract: - text = abstract.get('text')[0] - value = text.get('value') - else: - value = None + if not abstract: + return None + + text = abstract.get('text') + + if not text: + return None + + value = text[0].get('value') return value diff --git a/elm/wizard.py b/elm/wizard.py index d65f754..840645e 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -450,29 +450,28 @@ def __init__(self, db_host, db_port, db_name, self.psycopg2 = try_import('psycopg2') if meta_columns is None: - self.meta_columns = ['title', 'url'] + self.meta_columns = ['title', 'url', 'id'] else: self.meta_columns = meta_columns + assert 'id' in self.meta_columns, "Please include the 'id' column!" + if cursor is None: - self.db_host = db_host - self.db_port = db_port - self.db_name = db_name - self.db_schema = db_schema - self.db_table = db_table - self.db_user = os.getenv("EWIZ_DB_USER") - self.db_password = os.getenv('EWIZ_DB_PASSWORD') - assert self.db_user is not None, "Must set EWIZ_DB_USER!" - assert self.db_password is not None, "Must set EWIZ_DB_PASSWORD!" - self.db_kwargs = dict(user=self.db_user, password=self.db_password, - host=self.db_host, port=self.db_port, - database=self.db_name) + db_user = os.getenv("EWIZ_DB_USER") + db_password = os.getenv('EWIZ_DB_PASSWORD') + assert db_user is not None, "Must set EWIZ_DB_USER!" + assert db_password is not None, "Must set EWIZ_DB_PASSWORD!" + self.db_kwargs = dict(user=db_user, password=db_password, + host=db_host, port=db_port, + database=db_name) self.conn = self.psycopg2.connect(**self.db_kwargs) self.cursor = self.conn.cursor() else: self.cursor = cursor + self.db_schema = db_schema + self.db_table = db_table self.tag = tag self.probes = probes @@ -639,20 +638,20 @@ def _format_refs(self, refs, ids): ref_list = [] for item in refs: - ref_dict = {} - for icol, col in enumerate(self.meta_columns): - value = item[icol] - value = str(value).replace(chr(34), '') - ref_dict[col] = value + ref_dict = {col: str(value).replace(chr(34), '') + for col, value in zip(self.meta_columns, item)} ref_list.append(ref_dict) - seen = set() + assert len(ref_list) > 0, ("The Wizard did not return any " + "references. Please check your database " + "connection or query.") + unique_ref_list = [] for ref_dict in ref_list: - if str(ref_dict) not in seen: - seen.add(str(ref_dict)) - unique_ref_list.append(ref_dict) + if any(ref_dict == d for d in unique_ref_list): + continue + unique_ref_list.append(ref_dict) ref_list = unique_ref_list if 'id' in ref_list[0]: @@ -703,7 +702,7 @@ def make_ref_list(self, ids): raise RuntimeError(msg) from exc else: conn.commit() - refs = cursor.fetchall() + refs = cursor.fetchall() ref_list = self._format_refs(refs, ids) diff --git a/tests/test_wizard_postgres.py b/tests/test_wizard_postgres.py index 63a5c6a..81b166d 100644 --- a/tests/test_wizard_postgres.py +++ b/tests/test_wizard_postgres.py @@ -5,6 +5,7 @@ import ast import json from io import BytesIO +import numpy as np from elm import TEST_DATA_DIR from elm.wizard import EnergyWizardPostgres @@ -21,6 +22,9 @@ QUERY_TUPLE = ast.literal_eval(QUERY_TEXT) REF_TUPLE = ast.literal_eval(REF_TEXT) +os.environ["EWIZ_DB_USER"] = "user" +os.environ["EWIZ_DB_PASSWORD"] = "password" + class Cursor: """Dummy class for mocking database cursor objects""" @@ -66,16 +70,12 @@ def invoke_model(self, **kwargs): # pylint: disable=unused-argument def test_postgres(mocker): """Test to ensure correct response vector db.""" - os.environ["EWIZ_DB_USER"] = "user" - os.environ["EWIZ_DB_PASSWORD"] = "password" - mock_conn_cm = mocker.MagicMock() mock_conn = mock_conn_cm.__enter__.return_value mock_conn.cursor.return_value = Cursor() mock_connect = mocker.patch('psycopg2.connect') mock_connect.return_value = mock_conn_cm - wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', db_name='Dummy', db_schema='Dummy', db_table='Dummy', @@ -95,3 +95,84 @@ def test_postgres(mocker): assert 'title' in str(ref_list) assert 'url' in str(ref_list) assert 'research-hub.nrel.gov' in str(ref_list) + + +def test_ref_replace(mocker): + """Test to ensure removal of double quotes from references.""" + mock_conn_cm = mocker.MagicMock() + mock_conn = mock_conn_cm.__enter__.return_value + mock_conn.cursor.return_value = Cursor() + + mock_connect = mocker.patch('psycopg2.connect') + mock_connect.return_value = mock_conn_cm + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [(chr(34), 'test.com', '5a'), + ('remove "double" quotes', 'test_2.com', '7b')] + + ids = np.array(['7b', '5a']) + + out = wizard._format_refs(refs, ids) + + assert len(out) > 1 + for i in out: + assert json.loads(i) + + +def test_ids(mocker): + """Test to ensure only records with valid ids are returned.""" + mock_conn_cm = mocker.MagicMock() + mock_conn = mock_conn_cm.__enter__.return_value + mock_conn.cursor.return_value = Cursor() + + mock_connect = mocker.patch('psycopg2.connect') + mock_connect.return_value = mock_conn_cm + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [('title', 'test.com', '5a'), + ('title2', 'test_2.com', '7b')] + + ids = np.array(['7c', '5a']) + + out = wizard._format_refs(refs, ids) + + assert len(out) == 1 + assert '7b' not in out + + +def test_sorted_refs(mocker): + """Test to ensure references are sorted in same order as ids.""" + mock_conn_cm = mocker.MagicMock() + mock_conn = mock_conn_cm.__enter__.return_value + mock_conn.cursor.return_value = Cursor() + + mock_connect = mocker.patch('psycopg2.connect') + mock_connect.return_value = mock_conn_cm + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [('title', 'test.com', '5a'), + ('title2', 'test_2.com', '7b')] + + ids = np.array(['7b', '5a']) + + expected = ['{"title": "title2", "url": "test_2.com", "id": "7b"}', + '{"title": "title", "url": "test.com", "id": "5a"}'] + + out = wizard._format_refs(refs, ids) + + assert expected == out