diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 33edb70bf..aea7896b3 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -32,9 +32,9 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { const int64_t nb = 1000, nq = 10; const int64_t dim = 128; - const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto topk = GENERATE(as{}, 5, 120); auto version = GenTestVersionList(); auto base_gen = [=]() { @@ -89,7 +89,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::Json json = base_gen(); json[knowhere::indexparam::HNSW_M] = 128; json[knowhere::indexparam::EFCONSTRUCTION] = 200; - json[knowhere::indexparam::EF] = 64; + json[knowhere::indexparam::EF] = 200; return json; }; @@ -270,9 +270,9 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { const int64_t nb = 1000, nq = 10; const int64_t dim = 1024; - const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD); + auto topk = GENERATE(as{}, 5, 120); auto version = GenTestVersionList(); auto base_gen = [=]() { knowhere::Json json; @@ -288,7 +288,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { auto ivfflat_gen = [base_gen]() { knowhere::Json json = base_gen(); json[knowhere::indexparam::NLIST] = 16; - json[knowhere::indexparam::NPROBE] = 8; + json[knowhere::indexparam::NPROBE] = 14; return json; }; @@ -296,7 +296,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { knowhere::Json json = base_gen(); json[knowhere::indexparam::HNSW_M] = 128; json[knowhere::indexparam::EFCONSTRUCTION] = 200; - json[knowhere::indexparam::EF] = 64; + json[knowhere::indexparam::EF] = 200; return json; }; @@ -377,11 +377,11 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { using Catch::Approx; const int64_t nb = 1000, nq = 10; - const int64_t topk = 5; auto dim = GENERATE(as{}, 8, 16, 32, 64, 128, 256, 512, 160); auto version = GenTestVersionList(); auto metric = GENERATE(as{}, knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); + auto topk = GENERATE(as{}, 5, 100); auto base_gen = [=]() { knowhere::Json json; @@ -441,11 +441,17 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { auto code_size = dim / 8; for (int64_t i = 0; i < nq; i++) { const uint8_t* query_vector = (const uint8_t*)query_ds->GetTensor() + i * code_size; - std::vector ids_v(ids + i * topk, ids + (i + 1) * topk); - auto ds = GenIdsDataSet(topk, ids_v); + // filter out -1 when the result num less than topk + int64_t real_topk = 0; + for (; real_topk < topk; real_topk++) { + if (ids[i * topk + real_topk] < 0) + break; + } + std::vector ids_v(ids + i * topk, ids + i * topk + real_topk); + auto ds = GenIdsDataSet(real_topk, ids_v); auto gv_res = idx.GetVectorByIds(*ds); REQUIRE(gv_res.has_value()); - for (int64_t j = 0; j < topk; j++) { + for (int64_t j = 0; j < real_topk; j++) { const uint8_t* res_vector = (const uint8_t*)gv_res.value()->GetTensor() + j * code_size; if (metric == knowhere::metric::SUPERSTRUCTURE) { REQUIRE(faiss::is_subset(res_vector, query_vector, code_size));