From a2d22428ac07fb7b5af4e9ceb1cb6fa78b77b5d4 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Wed, 15 Nov 2023 13:00:13 -0500 Subject: [PATCH] Fix a jaccard distance problem that was causing #192 Signed-off-by: Alexandr Guzhva --- tests/ut/test_search.cc | 56 ++++++++++++++++++++++++++++ thirdparty/faiss/faiss/IndexFlat.cpp | 13 +++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index aea7896b3..f5e7e2ab1 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -373,6 +373,62 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { } } +// this is a special case that once triggered a problem in clustering.cpp +TEST_CASE("Test Mem Index With Binary Vector", "[float metrics][special case 1]") { + using Catch::Approx; + + const int64_t nb = 10, nq = 1; + const int64_t dim = 16; + + auto metric = GENERATE(as{}, knowhere::metric::JACCARD); + auto topk = GENERATE(as{}, 1, 1); + auto version = GenTestVersionList(); + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::HAMMING) ? 10.0 : 0.1; + json[knowhere::meta::RANGE_FILTER] = 0.0; + return json; + }; + + auto flat_gen = base_gen; + auto ivfflat_gen = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 1; + return json; + }; + + const auto train_ds = GenBinDataSet(nb, dim); + const auto query_ds = GenBinDataSet(nq, dim); + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + SECTION("Test Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); + REQUIRE(idx.Count() == nb); + auto results = idx.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + float recall = GetKNNRecall(*gt.value(), *results.value()); + REQUIRE(recall > kKnnRecallThreshold); + } +} + TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { using Catch::Approx; diff --git a/thirdparty/faiss/faiss/IndexFlat.cpp b/thirdparty/faiss/faiss/IndexFlat.cpp index e2a150f12..bb7367cd5 100644 --- a/thirdparty/faiss/faiss/IndexFlat.cpp +++ b/thirdparty/faiss/faiss/IndexFlat.cpp @@ -64,10 +64,15 @@ void IndexFlat::search( } else if (metric_type == METRIC_L2) { float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel); - } else if (is_similarity_metric(metric_type)) { - float_minheap_array_t res = {size_t(n), size_t(k), labels, distances}; - knn_extra_metrics( - x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res); + // // aguzhva: the following branch from the Faiss baseline is commented out. + // // Jaccard distance is handled differently. + // } else if (is_similarity_metric(metric_type)) { + // float_minheap_array_t res = {size_t(n), size_t(k), labels, distances}; + // knn_extra_metrics( + // x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res); + } else if (metric_type == METRIC_Jaccard) { + float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; + knn_jaccard(x, get_xb(), d, n, ntotal, &res, sel); } else { FAISS_THROW_IF_NOT(!sel); float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};