Skip to content

Commit

Permalink
Added test for using script to sum up a float vector
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz committed Jan 29, 2020
1 parent eed6db2 commit d8ccda1
Showing 1 changed file with 38 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ class ElastiKnnVectorFieldMapperSuite
} yield Succeeded
}

test("index and retrieve float vectors") {

val indexName = "test-index-retrieve-float-vectors"
val ekvs: Seq[ElastiKnnVector] = FloatVector.randoms(10, 5).map(ElastiKnnVector(_))

def index(ekvs: Seq[ElastiKnnVector], indexName: String): Future[Unit] =
for {
_ <- client.execute(deleteIndex(indexName))

Expand All @@ -60,6 +56,11 @@ class ElastiKnnVectorFieldMapperSuite
indexRes <- client.execute(bulk(indexReqs).refresh(RefreshPolicy.IMMEDIATE))
_ = indexRes.shouldBeSuccess
_ = indexRes.result.errors shouldBe false
} yield ()

def testIndexRetrieve(ekvs: Seq[ElastiKnnVector], indexName: String): Future[Assertion] =
for {
_ <- index(ekvs, indexName)

getRes <- client.execute(search(indexName).query(matchAllQuery()))
_ = getRes.shouldBeSuccess
Expand All @@ -82,6 +83,38 @@ class ElastiKnnVectorFieldMapperSuite
}
}

test("index and retrieve float vectors") {
testIndexRetrieve(FloatVector.randoms(10, 9).map(ElastiKnnVector(_)), "test-index-retrieve-float-vectors")
}

test("index and retrieve sparse bool vectors") {
testIndexRetrieve(SparseBoolVector.randoms(10, 9).map(ElastiKnnVector(_)), "test-index-retrieve-bool-vectors")
}

test("index float vectors and use script to sum them") {
val indexName = "test-index-script-sum"
val ekvs = FloatVector.randoms(10, 9).map(ElastiKnnVector(_))
for {
_ <- index(ekvs, indexName)
searchReq = search(indexName).query(
scriptScoreQuery(requests.script.Script(
"""
|def vec = doc[params.field];
|double sum = 0.0;
|for (n in vec) sum += n;
|return sum;
|""".stripMargin,
params = Map("field" -> fieldName)
)))
searchRes <- client.execute(searchReq)
_ = searchRes.shouldBeSuccess
scores = searchRes.result.hits.hits.map(_.score).sorted
correct = ekvs.flatMap(_.vector.floatVector).map(_.values.sum.toFloat).sorted
_ = scores should have length correct.length
} yield
forAll(scores.zip(correct)) {
case (a, b) => a shouldBe (b +- 1e-5f)
}
}

test("index and script search float vectors") {
Expand Down

0 comments on commit d8ccda1

Please sign in to comment.