Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinHashLSH, top-k results, parallel processing enhancement from kyao #42

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ scipy>=1.1.0
matplotlib>=2.0.0
dask>=0.19.2
distributed>=1.23
datasketch>=1.5.3
pyrallel.lib
3 changes: 2 additions & 1 deletion rltk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
__version__ = '2.0.0-a020'

from rltk.record import Record, AutoGeneratedRecord,\
cached_property, generate_record_property_cache, validate_record, remove_raw_object, set_id
cached_property, generate_record_property_cache, validate_record, remove_raw_object, set_id, \
PrioritizedRecord, BoundedSizeRecordHeap
from rltk.dataset import Dataset
from rltk.io import *
from rltk.similarity import *
Expand Down
1 change: 1 addition & 0 deletions rltk/blocking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rltk.blocking.canopy_block_generator import CanopyBlockGenerator
from rltk.blocking.sorted_neighbourhood_block_generator import SortedNeighbourhoodBlockGenerator
from rltk.blocking.blocking_helper import BlockingHelper
from rltk.blocking.block_utils import ngram, generate_minhash_blocking_keys

Blocker = BlockGenerator
HashBlocker = HashBlockGenerator
Expand Down
32 changes: 32 additions & 0 deletions rltk/blocking/block_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import hashlib
from typing import List, Set

from datasketch import MinHash, MinHashLSH

def ngram(n: int, s: str, sep: str = ' ', padded: bool = False) -> List[str]:
"""Generate sequence of n-grams from string"""
if len(s) == 0:
return []
if padded:
pad = sep * (n - 1)
s = pad + s + pad
s = s.split(' ')
s = sep.join(s)
if len(s) < n:
return [s]
return [s[i:i + n] for i in range(len(s) - n + 1)]

def generate_minhash_blocking_keys(
tokens: List[str], num_perm: int, threshold: float, key_len: int = 10) -> Set[str]:
"""Generate blocking keys using MinHash Locality Sensitive Hashing"""
m = MinHash(num_perm=num_perm)
for d in tokens:
m.update(d.encode('utf8'))
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
lsh.insert("m", m)

keys = set()
for hashtable in lsh.hashtables:
byte_key = list(hashtable._dict.keys())[0]
keys.add(hashlib.sha1(byte_key).hexdigest()[:key_len])
return keys
17 changes: 16 additions & 1 deletion rltk/blocking/token_block_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from multiprocessing import Pool
from typing import Callable, TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -13,7 +14,8 @@ class TokenBlockGenerator(BlockGenerator):
"""

def block(self, dataset, function_: Callable = None, property_: str = None,
block: Block = None, block_black_list: BlockBlackList = None, base_on: Block = None):
block: Block = None, block_black_list: BlockBlackList = None, base_on: Block = None,
processes: int = 1, chunk_size: int = 100):
"""
The return of `property_` or `function_` should be list or set.
"""
Expand All @@ -36,6 +38,19 @@ def block(self, dataset, function_: Callable = None, property_: str = None,
if block_black_list:
block_black_list.add(v, block)

elif processes > 1 and function_:
with Pool(processes) as p:
for r, value in zip(dataset, p.imap(function_, dataset, chunk_size)):
if not isinstance(value, list) and not isinstance(value, set):
raise ValueError('Return of the function or property should be a list')
for v in value:
if not isinstance(v, str):
raise ValueError('Elements in return list should be string')
if block_black_list and block_black_list.has(v):
continue
block.add(v, dataset.id, r.id)
if block_black_list:
block_black_list.add(v, block)
else:
for r in dataset:
value = function_(r) if function_ else getattr(r, property_)
Expand Down
61 changes: 53 additions & 8 deletions rltk/record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from typing import Callable
from dataclasses import dataclass, field
from heapq import heappush, heappushpop, nlargest
from typing import Callable, List


# Record ID should be string
Expand All @@ -12,7 +14,7 @@
class Record(object):
"""
Record representation. Properties should be defined for further usage.

Args:
raw_object (dict): Raw data which will be used to create properties.
"""
Expand All @@ -32,7 +34,7 @@ def id(self):
def __eq__(self, other):
"""
Only if both instances have the same class and id.

Returns:
bool: Equal or not.
"""
Expand Down Expand Up @@ -84,7 +86,7 @@ def remove_raw_object(cls):
def generate_record_property_cache(obj):
"""
Generate final value on all cached_property decorated methods.

Args:
obj (Record): Record instance.
"""
Expand All @@ -101,10 +103,10 @@ def generate_record_property_cache(obj):
def validate_record(obj):
"""
Property validator of record instance.

Args:
obj (Record): Record instance.

Raises:
TypeError: if id is not valid
"""
Expand All @@ -117,10 +119,10 @@ def validate_record(obj):
def get_property_names(cls: type):
"""
Get keys of property and cached_property from a record class.

Args:
cls (type): Record class

Returns:
list: Property names in class
"""
Expand Down Expand Up @@ -193,3 +195,46 @@ def id(self):
if function_:
id_ = function_(id_)
return id_

@dataclass(order=True)
class PrioritizedRecord:
priority: float
record: Record=field(compare=False)


class BoundedSizeRecordHeap:
"""Maintain the highest priority records"""
def __init__(self, size: int = 15):
"""
Args:
size: max size of the heap. Prefer sizes that are powers of two minus one (2^n-1).
"""
self._size = size
self._heap = []
self._ids = set()

def push(self, item: PrioritizedRecord, *, debug=False):
"""Add a record"""
if item in self:
return
if len(self._heap) < self._size:
heappush(self._heap, item)
self._ids.add(item.record.id)
elif item > self._heap[0]:
popped = heappushpop(self._heap, item)
if debug:
print(f'Remove: {popped.priority:5.2} {popped.record.id:8} {popped.record.value}')
if popped.record.id in self._ids:
self._ids.remove(popped.record.id)

def nlargest(self, n: int) -> List[PrioritizedRecord]:
"""Return high priority records"""
return nlargest(n, self._heap)

def __contains__(self, item):
if isinstance(item, PrioritizedRecord):
return item.record.id in self._ids
elif isinstance(item, Record):
return item.id in self._ids
else:
return item in self._ids