Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate improvements to the get-unique-points algorithm #781

Closed
khurram-ghani opened this issue Aug 24, 2023 · 1 comment
Closed

Investigate improvements to the get-unique-points algorithm #781

khurram-ghani opened this issue Aug 24, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@khurram-ghani
Copy link
Collaborator

Describe the feature you'd like
Investigate whether this algorithm can be parallelised. There is a previous parallel but incomplete implementation and the current sequential one. See this PR for context.

@khurram-ghani khurram-ghani added the enhancement New feature or request label Aug 24, 2023
@j-wilson
Copy link

j-wilson commented Dec 5, 2023

This can be done in O(N + M^2) time and space by using a bloom filter, where N is the number of points and M is the number of duplicate points. Roughly:

  1. Quantized each point to the desired precision.
  2. Convert each quantized point to a string type.
  3. Hash each stringified+quantized point.
  4. Test for equality within each bin where a collision occured.

Here is a naive implementation with O(N^2) time and space complexity that computes all pairwise distances and returns the indices of each unique row

@tf.function
def get_unique_rows_dense(matrix: tf.Tensor, precision: Optional[int] = None) -> tf.Tensor:
    matrix = (
        matrix
        if precision is None
        else tf.math.round(10 ** precision * matrix)
    )
    sq_norms = tf.reduce_sum(tf.square(matrix), axis=-1, keepdims=True)
    sq_dists = (
        sq_norms 
        + tf.transpose(sq_norms) 
        - tf.matmul(2 * matrix, matrix, transpose_b=True)
    )
    argmin = tf.argmin(sq_dists, axis=-1) 
    unique = tf.where(argmin == tf.range(tf.shape(matrix)[0], dtype=argmin.dtype))
    return tf.squeeze(unique, axis=-1)

And, here is a fancy implementation using the approach suggested above

@tf.function
def get_unique_rows(
    matrix: tf.Tensor, precision: Optional[int] = None, **kwargs: Any,
) -> tf.Tensor:
    matrix = (
        matrix
        if precision is None
        else tf.math.round(10 ** precision * matrix)
    )
    strings = tf.strings.reduce_join(tf.strings.as_string(matrix), axis=-1)
    bin_ids = tf.strings.to_hash_bucket_fast(strings, 2 ** 63 - 1)
    unique_ids, membership = tf.unique(bin_ids, out_idx=tf.int64)

    nrows = tf.shape(matrix)[0]
    nbins = tf.size(unique_ids)  # number of occupied bins
    if nbins == nrows:
        return tf.range(nrows, dtype=tf.int64)

    def deduplicate(k):
        indices = tf.squeeze(tf.where(membership == k), axis=-1)
        rows = tf.gather(matrix, indices)
        return tf.gather(indices, get_unique_rows_dense(rows))

    bins = tf.range(nbins, dtype=tf.int64)
    bin_counts = tf.math.bincount(tf.cast(membership, tf.int32))
    collisions = bin_counts > 1

    # Handle single occupancy bins
    singletons = tf.where(tf.reduce_any(membership[:, None] == bins[~collisions], -1))

    # Resolve collisions
    deduplications = tf.map_fn(deduplicate, bins[collisions], **kwargs)

    return tf.squeeze(tf.concat([singletons, deduplications], axis=0), axis=-1)

Running on my laptop with X = tf.random.uniform(shape=[16384, 4], dtype=tf.float64) gives

%timeit get_unique_rows_dense(X, precision=1)
%timeit get_unique_rows(X, precision=1)
728 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
177 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and

%timeit get_unique_rows_dense(X, precision=None)
%timeit get_unique_rows(X, precision=None)
713 ms ± 9.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
20.7 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

In the latter case, the additional speedup occurs because no points hash to the same bin (hence nbins == nrows).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants