Skip to content

Commit

Permalink
Neo4j Publisher to support desired state of relation (#69)
Browse files Browse the repository at this point in the history
* [AMD-120] Add relation pre-processor in Neo4jPublisher

* Update

* Added DeleteRelationPreprocessor

* Added DeleteRelationPreprocessor

* Update

* Update
  • Loading branch information
jinhyukchang authored Jun 5, 2019
1 parent 014690e commit edce3cb
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 49 deletions.
120 changes: 81 additions & 39 deletions databuilder/publisher/neo4j_csv_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from typing import Set, List # noqa: F401

from databuilder.publisher.base_publisher import Publisher
from databuilder.publisher.neo4j_preprocessor import NoopRelationPreprocessor


# Config keys
# A directory that contains CSV files for nodes
Expand All @@ -23,6 +25,8 @@
NEO4J_END_POINT_KEY = 'neo4j_endpoint'
# A transaction size that determines how often it commits.
NEO4J_TRANSCATION_SIZE = 'neo4j_transaction_size'
# A progress report frequency that determines how often it report the progress.
NEO4J_PROGRESS_REPORT_FREQUENCY = 'neo4j_progress_report_frequency'
# A boolean flag to make it fail if relationship is not created
NEO4J_RELATIONSHIP_CREATION_CONFIRM = 'neo4j_relationship_creation_confirm'

Expand All @@ -40,6 +44,8 @@
# Neo4j property name for published tag
PUBLISHED_TAG_PROPERTY_NAME = 'published_tag'

RELATION_PREPROCESSOR = 'relation_preprocessor'

# CSV HEADER
# A header with this suffix will be pass to Neo4j statement without quote
UNQUOTED_SUFFIX = ':UNQUOTED'
Expand Down Expand Up @@ -69,8 +75,10 @@
RELATION_TYPE, RELATION_REVERSE_TYPE}

DEFAULT_CONFIG = ConfigFactory.from_dict({NEO4J_TRANSCATION_SIZE: 500,
NEO4J_PROGRESS_REPORT_FREQUENCY: 500,
NEO4J_RELATIONSHIP_CREATION_CONFIRM: False,
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50})
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
RELATION_PREPROCESSOR: NoopRelationPreprocessor()})

NODE_MERGE_TEMPLATE = Template("""MERGE (node:$LABEL {key: '${KEY}'})
ON CREATE SET ${create_prop_body}
Expand Down Expand Up @@ -107,6 +115,8 @@ def init(self, conf):
# type: (ConfigTree) -> None
conf = conf.with_fallback(DEFAULT_CONFIG)

self._count = 0 # type: int
self._progress_report_frequency = conf.get_int(NEO4J_PROGRESS_REPORT_FREQUENCY)
self._node_files = self._list_files(conf, NODE_FILES_DIR)
self._node_files_iter = iter(self._node_files)

Expand All @@ -129,6 +139,8 @@ def init(self, conf):
if not self.publish_tag:
raise Exception('{} should not be empty'.format(JOB_PUBLISH_TAG))

self._relation_preprocessor = conf.get(RELATION_PREPROCESSOR)

LOGGER.info('Publishing Node csv files {}, and Relation CSV files {}'
.format(self._node_files, self._relation_files))

Expand All @@ -146,7 +158,7 @@ def _list_files(self, conf, path_key):
path = conf.get_string(path_key)
return [join(path, f) for f in listdir(path) if isfile(join(path, f))]

def publish_impl(self):
def publish_impl(self): # noqa: C901
# type: () -> None
"""
Publishes Nodes first and then Relations
Expand All @@ -160,23 +172,33 @@ def publish_impl(self):
self._create_indices(node_file=node_file)

LOGGER.info('Publishing Node files: {}'.format(self._node_files))
while True:
try:
node_file = next(self._node_files_iter)
self._publish_node(node_file)
except StopIteration:
break

LOGGER.info('Publishing Relationship files: {}'.format(self._relation_files))
while True:
try:
relation_file = next(self._relation_files_iter)
self._publish_relation(relation_file)
except StopIteration:
break

# TODO: Add statsd support
LOGGER.info('Successfully published. Elapsed: {} seconds'.format(time.time() - start))
try:
tx = self._session.begin_transaction()
while True:
try:
node_file = next(self._node_files_iter)
tx = self._publish_node(node_file, tx=tx)
except StopIteration:
break

LOGGER.info('Publishing Relationship files: {}'.format(self._relation_files))
while True:
try:
relation_file = next(self._relation_files_iter)
tx = self._publish_relation(relation_file, tx=tx)
except StopIteration:
break

tx.commit()
LOGGER.info('Committed total {} statements'.format(self._count))

# TODO: Add statsd support
LOGGER.info('Successfully published. Elapsed: {} seconds'.format(time.time() - start))
except Exception as e:
LOGGER.exception('Failed to publish. Rolling back.')
if not tx.closed():
tx.rollback()
raise e

def get_scope(self):
# type: () -> str
Expand All @@ -200,8 +222,8 @@ def _create_indices(self, node_file):

LOGGER.info('Indices have been created.')

def _publish_node(self, node_file):
# type: (str) -> None
def _publish_node(self, node_file, tx):
# type: (str, Transaction) -> Transaction
"""
Iterate over the csv records of a file, each csv record transform to Merge statement and will be executed.
All nodes should have a unique key, and this method will try to create unique index on the LABEL when it sees
Expand All @@ -218,14 +240,12 @@ def _publish_node(self, node_file):
:param node_file:
:return:
"""
tx = self._session.begin_transaction()

with open(node_file, 'r') as node_csv:
for count, node_record in enumerate(csv.DictReader(node_csv)):
stmt = self.create_node_merge_statement(node_record=node_record)
tx = self._execute_statement(stmt, tx, count)

tx.commit()
LOGGER.info('Committed {} records'.format(count + 1))
tx = self._execute_statement(stmt, tx)
return tx

def is_create_only_node(self, node_record):
# type: (dict) -> bool
Expand Down Expand Up @@ -257,8 +277,8 @@ def create_node_merge_statement(self, node_record):

return NODE_MERGE_TEMPLATE.substitute(params)

def _publish_relation(self, relation_file):
# type: (str) -> None
def _publish_relation(self, relation_file, tx):
# type: (str, Transaction) -> Transaction
"""
Creates relation between two nodes.
(In Amundsen, all relation is bi-directional)
Expand All @@ -273,15 +293,33 @@ def _publish_relation(self, relation_file):
:return:
"""

tx = self._session.begin_transaction()
if self._relation_preprocessor.is_perform_preprocess():
LOGGER.info('Pre-processing relation with {}'.format(self._relation_preprocessor))

count = 0
with open(relation_file, 'r') as relation_csv:
for rel_record in csv.DictReader(relation_csv):
stmt, params = self._relation_preprocessor.preprocess_cypher(
start_label=rel_record[RELATION_START_LABEL],
end_label=rel_record[RELATION_END_LABEL],
start_key=rel_record[RELATION_START_KEY],
end_key=rel_record[RELATION_END_KEY],
relation=rel_record[RELATION_TYPE],
reverse_relation=rel_record[RELATION_REVERSE_TYPE])

if stmt:
tx = self._execute_statement(stmt, tx=tx, params=params)
count += 1

LOGGER.info('Executed pre-processing Cypher statement {} times'.format(count))

with open(relation_file, 'r') as relation_csv:
for count, rel_record in enumerate(csv.DictReader(relation_csv)):
stmt = self.create_relationship_merge_statement(rel_record=rel_record)
tx = self._execute_statement(stmt, tx, count,
tx = self._execute_statement(stmt, tx,
expect_result=self._confirm_rel_created)

tx.commit()
LOGGER.info('Committed {} records'.format(count + 1))
return tx

def create_relationship_merge_statement(self, rel_record):
# type: (dict) -> str
Expand Down Expand Up @@ -352,9 +390,9 @@ def _create_props_body(self,
def _execute_statement(self,
stmt,
tx,
count,
params=None,
expect_result=False):
# type: (str, Transaction, int, bool) -> Transaction
# type: (str, Transaction, bool) -> Transaction

"""
Executes statement against Neo4j. If execution fails, it rollsback and raise exception.
Expand All @@ -367,20 +405,24 @@ def _execute_statement(self,
"""
try:
if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug('Executing statement: {}'.format(stmt))
LOGGER.debug('Executing statement: {} with params {}'.format(stmt, params))

if six.PY2:
result = tx.run(unicode(stmt, errors='ignore')) # noqa
result = tx.run(unicode(stmt, errors='ignore'), parameters=params) # noqa
else:
result = tx.run(str(stmt).encode('utf-8', 'ignore'))
result = tx.run(str(stmt).encode('utf-8', 'ignore'), parameters=params)
if expect_result and not result.single():
raise RuntimeError('Failed to executed statement: {}'.format(stmt))

if count > 1 and count % self._transaction_size == 0:
self._count += 1
if self._count > 1 and self._count % self._transaction_size == 0:
tx.commit()
LOGGER.info('Committed {} records so far'.format(count))
LOGGER.info('Committed {} statements so far'.format(self._count))
return self._session.begin_transaction()

if self._count > 1 and self._count % self._progress_report_frequency == 0:
LOGGER.info('Processed {} statements so far'.format(self._count))

return tx
except Exception as e:
LOGGER.exception('Failed to execute Cypher query')
Expand Down
Loading

0 comments on commit edce3cb

Please sign in to comment.