Skip to content

Commit

Permalink
Merge branch 'release/1.0.3'
Browse files Browse the repository at this point in the history
  • Loading branch information
lukostaz committed Jun 7, 2019
2 parents 1aa91a4 + 8d858fa commit ac825df
Show file tree
Hide file tree
Showing 35 changed files with 392 additions and 54 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright 2019 The AmpliGraph Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,21 @@ Install from pip or conda:
**CPU-only**

```
pip install tensorflow==1.12.0
pip install tensorflow==1.13.1
or
conda install tensorflow=1.12.0
conda install tensorflow=1.13.1
```

**GPU support**

```
pip install tensorflow-gpu==1.12.0
pip install tensorflow-gpu==1.13.1
or
conda install tensorflow-gpu=1.12.0
conda install tensorflow-gpu=1.13.1
```


Expand Down Expand Up @@ -114,7 +114,7 @@ pip install -e .
```python
>> import ampligraph
>> ampligraph.__version__
'1.0.2'
'1.0.3'
```


Expand Down
9 changes: 8 additions & 1 deletion ampligraph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""Explainable Link Prediction is a library for relational learning on knowledge graphs."""
import logging.config
import pkg_resources

__version__ = '1.0.2'
__version__ = '1.0.3'
__all__ = ['datasets', 'latent_features', 'evaluation']

logging.config.fileConfig(pkg_resources.resource_filename(__name__, 'logger.conf'), disable_existing_loggers=False)
7 changes: 7 additions & 0 deletions ampligraph/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""Helper functions to load knowledge graphs."""

from .datasets import load_from_csv, load_from_rdf, load_fb15k, load_wn18, load_fb15k_237, load_from_ntriples, \
Expand Down
34 changes: 32 additions & 2 deletions ampligraph/datasets/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import pandas as pd
import os
import numpy as np
Expand Down Expand Up @@ -293,7 +300,7 @@ def _load_dataset(dataset_metadata, data_home=None, check_md5hash=False):
The location to save the dataset to. Defaults to None.
check_md5hash : boolean
If true check the md4hash of the files after they are downloaded.
If True check the md5hash of the files after they are downloaded. Defaults to False.
"""

if dataset_metadata.dataset_name is None:
Expand Down Expand Up @@ -336,6 +343,10 @@ def load_wn18(check_md5hash=False):
The dataset includes a large number of inverse relations, and its use in experiments has been deprecated.
Use WN18RR instead.
Parameters
----------
check_md5hash : bool
If ``True`` check the md5hash of the files. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -394,6 +405,9 @@ def load_wn18rr(check_md5hash=False, clean_unseen=True):
clean_unseen : bool
If ``True``, filters triples in validation and test sets that include entities not present in the training set.
check_md5hash : bool
If ``True`` check the md5hash of the datset files. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -448,6 +462,12 @@ def load_fb15k(check_md5hash=False):
The dataset includes a large number of inverse relations, and its use in experiments has been deprecated.
Use FB15k-237 instead.
Parameters
----------
check_md5hash : boolean
If ``True`` check the md5hash of the files. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -475,7 +495,7 @@ def load_fb15k(check_md5hash=False):
train_checksum='5a87195e68d7797af00e137a7f6929f2', valid_checksum='275835062bb86a86477a3c402d20b814',
test_checksum='71098693b0efcfb8ac6cd61cf3a3b505')

return _load_dataset(FB15K, data_home=None, check_md5hash=False)
return _load_dataset(FB15K, data_home=None, check_md5hash=check_md5hash)


def load_fb15k_237(check_md5hash=False, clean_unseen=True):
Expand Down Expand Up @@ -506,6 +526,9 @@ def load_fb15k_237(check_md5hash=False, clean_unseen=True):
Parameters
----------
check_md5hash : boolean
If ``True`` check the md5hash of the files. Defaults to ``False``.
clean_unseen : bool
If ``True``, filters triples in validation and test sets that include entities not present in the training set.
Expand Down Expand Up @@ -559,6 +582,13 @@ def load_yago3_10(check_md5hash=False, clean_unseen = True):
YAGO3-10 1,079,040 5,000 5,000 123,182 37
========= ========= ======= ======= ============ ===========
Parameters
----------
check_md5hash : boolean
If ``True`` check the md5hash of the files. Defaults to ``False``.
clean_unseen : bool
If ``True``, filters triples in validation and test sets that include entities not present in the training set.
Returns
-------
Expand Down
7 changes: 7 additions & 0 deletions ampligraph/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""The module includes performance metrics for neural graph embeddings models, along with model selection routines,
negatives generation, and an implementation of the learning-to-rank-based evaluation protocol used in literature."""

Expand Down
7 changes: 7 additions & 0 deletions ampligraph/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import numpy as np
import logging

Expand Down
76 changes: 55 additions & 21 deletions ampligraph/evaluation/protocol.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import numpy as np
from tqdm import tqdm

Expand All @@ -10,7 +17,7 @@
logger.setLevel(logging.DEBUG)


def train_test_split_no_unseen(X, test_size=5000, seed=0, allow_duplication=False):
def train_test_split_no_unseen(X, test_size=100, seed=0, allow_duplication=False):
"""Split into train and test sets.
This function carves out a test set that contains only entities
Expand Down Expand Up @@ -116,10 +123,14 @@ def train_test_split_no_unseen(X, test_size=5000, seed=0, allow_duplication=Fals
# in case can't find solution
if loop_count == tolerance:
if allow_duplication:
raise Exception("Not possible to split the dataset...")
raise Exception("Cannot create a test split of the desired size. "
"Some entities will not occur in both training and test set. "
"Change seed values, or set test_size to a smaller value.")
else:
raise Exception("Not possible to split the dataset. \
Maybe set allow_duplication = True can help...")
raise Exception("Cannot create a test split of the desired size. "
"Some entities will not occur in both training and test set. "
"Set allow_duplication=True, or "
"change seed values, or set test_size to a smaller value.")

logger.debug('Completed random search.')

Expand All @@ -129,6 +140,7 @@ def train_test_split_no_unseen(X, test_size=5000, seed=0, allow_duplication=Fals

return X[idx_train, :], X[idx_test, :]


def _create_unique_mappings(unique_obj, unique_rel):
obj_count = len(unique_obj)
rel_count = len(unique_rel)
Expand Down Expand Up @@ -462,9 +474,11 @@ def evaluate_performance(X, model, filter_triples=None, verbose=False, strict=Tr
Run the relational learning evaluation protocol defined in :cite:`bordes2013translating`.
It computes the ranks of each positive triple against all possible negatives created in compliance with
the local closed world assumption (LCWA), as described in :cite:`nickel2016review`.
It computes the rank of each positive triple against a number of negatives generated on the fly.
Such negatives are compliant with the local closed world assumption (LCWA),
as described in :cite:`nickel2016review`. In practice, that means only one side of the triple is corrupted
(i.e. either the subject or the object).
.. note::
When *filtered* mode is enabled (i.e. `filtered_triples` is not ``None``),
to speed up the procedure, we adopt a hashing-based strategy to handle the set difference problem.
Expand Down Expand Up @@ -496,7 +510,7 @@ def evaluate_performance(X, model, filter_triples=None, verbose=False, strict=Tr
.. hint::
When ``rank_against_ent=None``, the method will use all distinct entities in the knowledge graph ``X``
to generate negatives to rank against. If ``X`` includes more than 1 million unique
to generate negatives to rank against. If ``X`` includes more than 2.5 million unique
entities and relations, the method will return a runtime error.
To solve the problem, it is recommended to pass the desired entities to use to generate corruptions
to ``rank_against_ent``. Besides, trying to rank a positive against an extremely large number of negatives
Expand Down Expand Up @@ -524,37 +538,57 @@ def evaluate_performance(X, model, filter_triples=None, verbose=False, strict=Tr
- 's': corrupt only subject.
- 'o': corrupt only object
- 's+o': corrupt both subject and object
- 's+o': corrupt both subject and object. The same behaviour is obtained with ``use_default_protocol=True``.
.. note::
If ``corrupt_side='s+o'`` the function will return 2*n ranks.
If ``corrupt_side='s'`` or ``corrupt_side='o'``, it will return n ranks, where n is the
number of statements in X.
The first n elements of ranks are obtained against subject corruptions. From n+1 until 2n ranks are obtained
against object corruptions.
use_default_protocol: bool
Flag to indicate whether to evaluate head and tail corruptions separately (default: True).
If this is set to true, it will also ignore the ``corrupt_side`` argument and corrupt both head and tail
separately and rank triples.
Flag to indicate whether to use the standard protocol used in literature defined in
:cite:`bordes2013translating` (default: True).
If set to ``True`` it is equivalent to ``corrupt_side='s+o'``.
This corresponds to the evaluation protcol used in literature, where head and tail corruptions
are evaluated separately.
.. note::
When ``use_default_protocol=True`` the function will return 2*n ranks.
The first n elements of ranks are obtained against subject corruptions. From n+1 until 2n ranks are obtained
against object corruptions.
Returns
-------
ranks : ndarray, shape [n]
ranks : ndarray, shape [n] or [2*n]
An array of ranks of positive test triples.
When ``use_default_protocol=True`` or ``corrupt_side='s+o'``, the function returns 2*n ranks instead of n.
In that case the first n elements of ranks are obtained against subject corruptions. From n+1 until 2n ranks
are obtained against object corruptions.
Examples
--------
>>> import numpy as np
>>> from ampligraph.datasets import load_wn18
>>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.evaluation import evaluate_performance
>>> from ampligraph.evaluation import evaluate_performance, mrr_score, hits_at_n_score
>>>
>>> X = load_wn18()
>>> model = ComplEx(batches_count=10, seed=0, epochs=1, k=150, eta=10,
>>> loss='pairwise', optimizer='adagrad')
>>> model = ComplEx(batches_count=10, seed=0, epochs=10, k=150, eta=1,
>>> loss='nll', optimizer='adam')
>>> model.fit(np.concatenate((X['train'], X['valid'])))
>>>
>>> filter = np.concatenate((X['train'], X['valid'], X['test']))
>>> ranks = evaluate_performance(X['test'][:5], model=model, filter_triples=filter)
>>> ranks = evaluate_performance(X['test'][:5], model=model,
filter_triples=filter,
corrupt_side='s+o',
use_default_protocol=False)
>>> ranks
array([ 2, 4, 1, 1, 28550], dtype=int32)
[1, 582, 543, 6, 31]
>>> mrr_score(ranks)
0.55000700525394053
0.24049691297347323
>>> hits_at_n_score(ranks, n=10)
0.8
0.4
"""

logger.debug('Evaluating the performance of the embedding model.')
Expand Down
7 changes: 7 additions & 0 deletions ampligraph/latent_features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""This module includes neural graph embedding models and support functions.
Knowledge graph embedding models are neural architectures that encode concepts from a knowledge graph
Expand Down
7 changes: 7 additions & 0 deletions ampligraph/latent_features/loss_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import tensorflow as tf
import abc
import logging
Expand Down
7 changes: 7 additions & 0 deletions ampligraph/latent_features/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import numpy as np
import logging

Expand Down
16 changes: 15 additions & 1 deletion ampligraph/latent_features/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright 2019 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import numpy as np
import tensorflow as tf
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -346,6 +353,10 @@ def _load_model_from_trained_params(self):
def get_embeddings(self, entities, embedding_type='entity'):
"""Get the embeddings of entities or relations.
.. Note ::
Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
Parameters
----------
entities : array-like, dtype=int, shape=[n]
Expand Down Expand Up @@ -1053,6 +1064,9 @@ def _fn(e_s, e_p, e_o):
def get_embeddings(self, entities, type='entity'):
"""Get the embeddings of entities or relations.
.. Note ::
Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
Parameters
----------
entities : array-like, dtype=int, shape=[n]
Expand Down Expand Up @@ -1113,7 +1127,7 @@ def predict(self, X, from_idx=False, get_ranks=False):
positive_scores = self.rnd.uniform(low=0, high=1, size=len(X)).tolist()
if get_ranks:
corruption_entities = self.eval_config.get('corruption_entities', DEFAULT_CORRUPTION_ENTITIES)
if corruption_entities is None:
if corruption_entities == "all":
corruption_length = len(self.ent_to_idx)
else:
corruption_length = len(corruption_entities)
Expand Down
Loading

0 comments on commit ac825df

Please sign in to comment.