Skip to content
This repository has been archived by the owner on Oct 30, 2020. It is now read-only.

Commit

Permalink
Merge pull request #49 from karlhigley/fix/minhash-bands
Browse files Browse the repository at this point in the history
Fix signature element grouping in `BandingCollisionStrategy`
  • Loading branch information
karlhigley authored Jul 4, 2016
2 parents 449b3bc + e6f8df7 commit 666eb07
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ private[neighbors] class BandingCollisionStrategy(
*/
def apply(hashTables: RDD[_ <: HashTableEntry[_]]): RDD[(Product, Point)] = {
val bandEntries = hashTables.flatMap(entry => {
val banded = entry.sigElements.grouped(bands).zipWithIndex
val elements = entry.sigElements
val banded = elements.grouped(elements.size / bands).zipWithIndex
banded.map {
case (bandSig, band) => {
case (bandSig, bandNum) => {
// Arrays are mutable and can't be used in RDD keys
// Use a hash value (i.e. an int) as a substitute
val bandSigHash = MurmurHash3.arrayHash(bandSig)
val key = (entry.table, band, bandSigHash).asInstanceOf[Product]
val key = (entry.table, bandNum, bandSigHash).asInstanceOf[Product]
(key, (entry.id, entry.point))
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.github.karlhigley.spark.neighbors

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.SparseVector

class CollisionStrategySuite extends FunSuite with TestSparkContext {
val numPoints = 1000
val dimensions = 100
val density = 0.5

var points: RDD[(Long, SparseVector)] = _

override def beforeAll() {
super.beforeAll()
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
points = sc.parallelize(localPoints).zipWithIndex.map(_.swap)
}

test("SimpleCollisionStrategy produces the correct number of tuples") {
val ann =
new ANN(dimensions, "cosine")
.setTables(1)
.setSignatureLength(8)

val model = ann.train(points)

val hashTables = model.hashTables
val collidable = model.collisionStrategy(hashTables)

assert(collidable.count() == numPoints)
}

test("BandingCollisionStrategy produces the correct number of tuples") {
val numBands = 4

val ann =
new ANN(dimensions, "jaccard")
.setTables(1)
.setSignatureLength(8)
.setBands(numBands)
.setPrimeModulus(739)

val model = ann.train(points)

val hashTables = model.hashTables
val collidable = model.collisionStrategy(hashTables)

assert(collidable.count() == numPoints * numBands)
}

}

0 comments on commit 666eb07

Please sign in to comment.