Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Effects of Bit Truncation on MinhashLSH? #237

Open
123epsilon opened this issue Mar 15, 2024 · 16 comments
Open

Question: Effects of Bit Truncation on MinhashLSH? #237

123epsilon opened this issue Mar 15, 2024 · 16 comments

Comments

@123epsilon
Copy link
Contributor

Maybe a bit of a dumb question but I'm a little confused by the _insert method in the MinhashLSH class:

def _insert(
        self,
        key: Hashable,
        minhash: Union[MinHash, WeightedMinHash],
        check_duplication: bool = True,
        buffer: bool = False,
    ):
        if len(minhash) != self.h:
            raise ValueError(
                "Expecting minhash with length %d, got %d" % (self.h, len(minhash))
            )
        if self.prepickle:
            key = pickle.dumps(key)
        if check_duplication and key in self.keys:
            raise ValueError("The given key already exists")
        Hs = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges]
        self.keys.insert(key, *Hs, buffer=buffer)
        for H, hashtable in zip(Hs, self.hashtables):
            hashtable.insert(H, key, buffer=buffer)

It seems like this is iterating over the bands in the Minhash table, and then applies self._H to each band before inserting this list as a key-value pair into the given storage backend. However, self._H is set in __init__ as either:

if hashfunc:
            self._H = self._hashed_byteswap
        else:
            self._H = self._byteswap

But hashfunc is None by default. It seems that before we insert a given band into the storage backend we are only swapping the bytes as opposed to applying a locally-sensitive hash function. Am I missing something here?

Basically, I'd like to understand 1) where the LSH hashing logic is located in this class and 2) be able to determine the range of values that the LSH hash can map onto (i.e. for a given band what is the range of integer values it can take on after applying datasketch's native LSH logic).

@ekzhu
Copy link
Owner

ekzhu commented Mar 16, 2024

  1. where the LSH hashing logic is located in this class

It is inside MinHash.

  1. be able to determine the range of values that the LSH hash can map onto (i.e. for a given band what is the range of integer values it can take on after applying datasketch's native LSH logic).

For a given band, the hash table in that band doesn't use integer key. It's a concatenation of multiple hash values in the MinHash.

@123epsilon
Copy link
Contributor Author

123epsilon commented Mar 16, 2024

For a given band, the hash table in that band doesn't use integer key. It's a concatenation of multiple hash values in the MinHash.

Thanks, I just meant the size of the hashvalues for each band, which I believe should be 32 bits right? We use sha1_32 to hash them, and represent them as 64 bit values to avoid overflows I assume but the actual values are 32 bit integers right? Meaning that if we have b bands of size r then each band will have r 32 bit integers.

Also, would you anticipate any issues if one were to truncate them to be 16 bits instead?

@ekzhu
Copy link
Owner

ekzhu commented Mar 16, 2024

We use sha1_32 to hash them, and represent them as 64 bit values to avoid overflows I assume but the actual values are 32 bit integers right? Meaning that if we have b bands of size r then each band will have r 32 bit integers.

This is correct.

Also, would you anticipate any issues if one were to truncate them to be 16 bits instead?

More false positives due to less bits required to make a collision. Would love to see a benchmark result on # of bits vs accuracy. See 'benchmarks' directory for various benchmarks.

@123epsilon
Copy link
Contributor Author

I got some interesting results:

image

Note: 32-bit and 64-bit accuracy were exactly the same, hence why you can't see the red line on the accuracy graph.

Also note: to get a better idea of performance I ran each performance trial 5 times and took the average time, that is what is plotted on the right-hand side. Otherwise, the setup is exactly the same as in minhash_benchmark.py

The code for this benchmark is here: https://github.com/123epsilon/datasketch/blob/8abf5cc72eebd4198a373c4f65ad8840114124c0/benchmark/sketches/truncate_minhash_benchmark.py

Very interesting that even using a 16-bit hash can achieve virtually the same accuracy as 32/64 bits and with better performance - at least on this benchmark. A question about performance, I'm not super familiar with hashing theory and implementation but do you think there are any tweaks I could make to the minhash class in order to improve performance when we truncate the number of bits we use to represent a hashvalue?

@ekzhu
Copy link
Owner

ekzhu commented Mar 18, 2024

Thanks for the results! Would love for this benchmark to be added. Can you submit a PR for this?

There is a paper on b-bits minwise hasing: https://arxiv.org/pdf/0910.3349.pdf. Have you tried bits between 8 and 16? I think for lower bits, it is important to use the unbiased estimator from the paper, not the standard one for MinHash.

We have implemented b-bit minhash here: https://github.com/ekzhu/datasketch/blob/master/datasketch/b_bit_minhash.py though it is not an optimized implementation. Would be great to combine this with the standard MinHash class.

@123epsilon
Copy link
Contributor Author

@ekzhu Actually I'm suspicious of the results I shared - I inspected some of the hashvalues for the 16-bit case and found that they were still 64 bit integers after applying the modular arithmetic in the update function, I believe this is because of the hard-coded values for _mersenne_prime, _hash_range, and _max_hash in minhash.py. All I was doing was taking the output of sha1 and truncating it, using that as a hashfunc (similarly to how you've implemented sha1_hash32).

I'm kind of guessing here, but I experimented by parameterizing Minhash with those values and substituting those variables as follows (feedback on this would be nice):

def get_params(num_bits):
	if num_bits == 8:
		return {
			"mersenne_prime": np.uint64((1 << 13) - 1),
			"hash_range": np.uint64((1 << 8) - 1),
			"max_hash": 1 << 8
		}
	elif num_bits == 16:
		return {
			"mersenne_prime": np.uint64((1 << 31) - 1),
			"hash_range": np.uint64((1 << 16) - 1),
			"max_hash": 1 << 16
		}
	elif num_bits == 32:
		return {} # default, we don't need to specify params
	elif num_bits == 64:
		return {}

Then this shows degraded results - perhaps for the reasons you cited above:

image

Here 32 and 64 bits and 8 and 16 bits each perform the same in terms of error, with marked degradation in 16-bit. Do you think this has to do with the biased estimation problem you alluded to above? Is the correct step to use bBitMinHash instead?

@123epsilon
Copy link
Contributor Author

Ok so instead of truncating the hash values as I did above, I used the default MinHash configuration and instead truncated the hashvalues precision by initializing a bBitMinhash using a MinHash object. I verified that the hashvalues occupy the expected range as a sanity check and I got these results:

image

I'm a little surprised that everything above 8 bits performs basically the same, as expected as we increase k the variance decreases greatly. One open question is whether this same trend holds well for whole documents since we would be cramming a much larger set of entities into a limited number of hash buckets - I might try to benchmark this against a real text dataset (or a manageable subset of it).

On a side note: the current b_bit_minhash_benchmark.py uses a lot of python 2 functionality that is no longer supported (in particular the use of the pyhash module) - if you want I can replace the contents of that file with the benchmark I ran above.

@ekzhu
Copy link
Owner

ekzhu commented Mar 23, 2024

On a side note: the current b_bit_minhash_benchmark.py uses a lot of python 2 functionality that is no longer supported (in particular the use of the pyhash module) - if you want I can replace the contents of that file with the benchmark I ran above.

Yes please. We can also make it much more performant.

Thanks! More benchmark on longer docs would be useful. E.g we can use Wikipedia corpus.

@123epsilon 123epsilon changed the title Question: where is LSH hashing happening in MinhashLSH? Question: Effects of Bit Truncation on MinhashLSH? Mar 23, 2024
@123epsilon
Copy link
Contributor Author

@ekzhu Ok! Just ran a great benchmark on bBitMinHash against the Wikipedia-Simple text dataset, 20220301.simple. Instead of randomly sampling two documents to compare (where the average Jaccard Similarity was something like 0.06), we instead do 30 trials for each permutation and number of bits pair where we sample one document, doc1, randomly and then treat the set s1 as the entire doc1. Then we treat the second set s2 as a random contiguous subset of doc1. That is:

i1 = random.randint(0, N_DOCS)

overlap = random.uniform()
doc1 = wiki_data['train'][i1]['text']
# generate a random overlapping region of text given a start point
# this isn't perfect since we could be cutting off the start/ending words in the region
# but that won't affect the estimation too much
overlap_size = int(len(doc1)*overlap)
overlap_start = random.randint(0, len(doc1)-overlap_size)
doc2 = wiki_data['train'][i1]['text'][overlap_start:overlap_start+overlap_size]
m1, s1 = _run_acc(doc1, num_perm, num_bits)
m2, s2 = _run_acc(doc2, num_perm, num_bits)

In practice I saw that this gave us a much better range of true similarity scores to compare the MinHash implementation against. I used the huggingface datasets package to load the Wikipedia data.

These are the results:

image

I'm pleasantly surprised that the results are so good even for very few bits in the representation. What are your thoughts? Is there anything suspicious about these results?

@123epsilon
Copy link
Contributor Author

@ekzhu Ok, one more benchmark :)

I was interested in assessing the agreement between LSH and LSH using truncated minhash signatures.
Procedure: Take 60% of the documents from wikipedia-20220301.simple and insert them into LSH indices, truncating their representations using bBitMinHashing for num_bits between 1 and 32 bits. Then query the remaining 40% of documents against those indices and measure the percent agreement between the truncated indices and the original indices in terms of whether they report that the query document is a duplicate (True/False). (Another idea might be to assess top-k agreement but I haven't done that yet).

image

I guess this is expected, but I had a key question - does the current LSH implementation take advantage of this truncation in any way to make the actual index smaller? Or is that expected to happen implicitly if we construct a prefix tree where most entries have all leading zero bytes due to truncation (though the byteswapping would prevent this from being used right?)?

@ekzhu
Copy link
Owner

ekzhu commented Apr 4, 2024

does the current LSH implementation take advantage of this truncation in any way to make the actual index smaller?

Currently the MinHashLSH is agnostic to the byte size of hash values. So, if the hashvalues attribute of the input object has smaller byte size for hash value, the resulting concatenated hash key for the LSH hash table will be smaller, and results in smaller index. However MinHashLSH will not attempt to truncate zeros off the hash values after byteswapping it.

What might be useful as an optimization is:

  1. improve bBitMinHash implementation to make it really memory efficient by using a bit array
  2. implement "band view" for bBitMinHash so we can access the b-bit hash values without expanding them to 32/64-bit integers. E.g., output a compacted array.
  3. MinHashLSH can use the "band view" to access the b-bit hash values for a specific band.

I'm pleasantly surprised that the results are so good even for very few bits in the representation. What are your thoughts? Is there anything suspicious about these results?

The result is expected as b-bit MinHash has been proved to work years ago. Perhaps it is a time for us to do the engineering work to make it usable.

@123epsilon
Copy link
Contributor Author

@ekzhu Ok, I've been playing around with this, how do you see this being implemented? Are there libraries/docs I can refer to? By experiment, I've found that just changing the numpy datatype results in the resulting bytestring being shortened as expected (cutting off leading zeros). Which should be helpful for reducing the hashtable size.

Compacted arrays from the builtin array module are also feasible, though a lot of the datasketch code makes the assumption that we are using numpy arrays so some other changes would be required in lsh for instance.

I'm also not sure how to profile the memory usage exactly, I've been using tracemalloc to measure peak consumption but that doesn't really show much difference between those two approaches just for the bBitMinHash by itself, though I'll try to measure for the lsh index as well.

@ekzhu
Copy link
Owner

ekzhu commented Apr 17, 2024

By experiment, I've found that just changing the numpy datatype results in the resulting bytestring being shortened as expected (cutting off leading zeros). Which should be helpful for reducing the hashtable size.

Can you give an example how to do this?

Compacted arrays from the builtin array module are also feasible, though a lot of the datasketch code makes the assumption that we are using numpy arrays so some other changes would be required in lsh for instance.

Right. Perhaps this can be a future step as it requires addressing both LSH and MinHash.

I'm also not sure how to profile the memory usage exactly, I've been using tracemalloc to measure peak consumption but that doesn't really show much difference between those two approaches just for the bBitMinHash by itself, though I'll try to measure for the lsh index as well.

Probably most of the memory saving will come from the LSH index having smaller keys. You can also use the Redis storage backend to get a snapshot of the index in Redis and see how large it is.

@123epsilon
Copy link
Contributor Author

Can you give an example how to do this?

>>> import numpy as np
>>> l = [1,2,3,4,5]
>>> np.array(l, dtype=np.uint32).tobytes()
b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00'
>>> np.array(l, dtype=np.uint16).tobytes()
b'\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00'
>>> np.array(l, dtype=np.uint8).tobytes()
b'\x01\x02\x03\x04\x05'

Probably most of the memory saving will come from the LSH index having smaller keys. You can also use the Redis storage backend to get a snapshot of the index in Redis and see how large it is.

Ok, I'll evaluate it that way then - for now using the numpy datatype manipulation I showed above.

@ekzhu
Copy link
Owner

ekzhu commented Apr 17, 2024

Thanks! I see the byte size truncation happens at data type size boundary. In that case it might make sense to restrict the b-bit minhash to use b = {64, 32, 16, 8} to maximize efficiency.

@123epsilon
Copy link
Contributor Author

Just an update on this - I was out for an internship over the summer so I paused work on this. The relevant code already exists and probably provides the intended enhancement, I just need to finish the profiling code to get the memory numbers from redis.

master...123epsilon:datasketch:optimize_bbit_minhash

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants