Skip to content

Commit

Permalink
knowhere support multi data type
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>

Add (b)float16 interface for pyknowhere

Signed-off-by: Writer-X <[email protected]>
  • Loading branch information
cqy123456 committed Jan 4, 2024
1 parent 3aee32f commit 92c0aeb
Show file tree
Hide file tree
Showing 45 changed files with 1,604 additions and 399 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
CIBW_ARCHS: ${{ matrix.arch }}
CIBW_BEFORE_ALL_LINUX: "bash scripts/python_deps.sh && rm -rf build && mkdir build && cd build && conan install .. --build=missing -o with_diskann=True -s compiler.libcxx=libstdc++11 -s build_type=Release && conan build .. && cd -"
# CIBW_BEFORE_ALL_MACOS: "bash scripts/python_deps.sh && rm -rf build && mkdir build && cd build && CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ conan install .. --build=missing -s build_type=Release && CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ conan build .. && cd -"
CIBW_BEFORE_BUILD: "pip3 install pytest numpy faiss-cpu"
CIBW_BEFORE_BUILD: "pip3 install pytest numpy faiss-cpu bfloat16"
# CIBW_ENVIRONMENT_MACOS: >
# _PYTHON_HOST_PLATFORM=macosx-10.15-${{ matrix.arch }}
# CIBW_BEFORE_BUILD: "bash scripts/python_deps.sh && pip3 install pytest numpy faiss-cpu"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ jobs:
sudo apt update \
&& sudo apt install -y cmake g++ gcc libopenblas-dev libaio-dev libcurl4-openssl-dev libevent-dev libgflags-dev python3 python3-pip python3-setuptools \
&& pip3 install conan==1.58.0 pytest faiss-cpu numpy wheel \
&& pip3 install bfloat16 \
&& conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local
- name: Build
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ install dependency:

```
sudo apt install swig python3-dev
pip3 install bfloat16
```

after build knowhere:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ TEST_F(Benchmark_float, TEST_DISKANN) {
std::shared_ptr<knowhere::FileManager> file_manager = std::make_shared<knowhere::LocalFileManager>();
auto diskann_index_pack = knowhere::Pack(file_manager);

index_ = knowhere::IndexFactory::Instance().Create(
index_ = knowhere::IndexFactory::Instance().Create<float>(
index_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ TEST_F(Benchmark_float_bitset, TEST_DISKANN) {
auto diskann_index_pack = knowhere::Pack(file_manager);

auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version, diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
index_.Build(*ds_ptr, conf);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_range_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ TEST_F(Benchmark_float_range_bitset, TEST_DISKANN) {
auto diskann_index_pack = knowhere::Pack(file_manager);

auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version, diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
index_.Build(*ds_ptr, conf);
Expand Down
4 changes: 2 additions & 2 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
create_index(const std::string& index_file_name, const knowhere::Json& conf) {
auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str());
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version);

try {
printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str());
Expand All @@ -120,7 +120,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {

std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index";
printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str());
golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version);
golden_index_ = knowhere::IndexFactory::Instance().Create<float>(golden_index_type_, version);

try {
printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str());
Expand Down
4 changes: 4 additions & 0 deletions include/knowhere/comp/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
#include "knowhere/bitsetview.h"
#include "knowhere/dataset.h"
#include "knowhere/factory.h"
#include "knowhere/operands.h"

namespace knowhere {

class BruteForce {
public:
template <typename DataType>
static expected<DataSetPtr>
Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset);

template <typename DataType>
static Status
SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset);

template <typename DataType>
static expected<DataSetPtr>
RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);
Expand Down
3 changes: 1 addition & 2 deletions include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace knowhere {

class DataSet {
class DataSet : public std::enable_shared_from_this<const DataSet> {
public:
typedef std::variant<const float*, const size_t*, const int64_t*, const void*, int64_t, std::string, std::any> Var;
DataSet() = default;
Expand Down Expand Up @@ -227,7 +227,6 @@ class DataSet {
bool is_owner = true;
};
using DataSetPtr = std::shared_ptr<DataSet>;

inline DataSetPtr
GenDataSet(const int64_t nb, const int64_t dim, const void* xb) {
auto ret_ds = std::make_shared<DataSet>();
Expand Down
38 changes: 34 additions & 4 deletions include/knowhere/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,53 @@
namespace knowhere {
class IndexFactory {
public:
template <typename DataType>
Index<IndexNode>
Create(const std::string& name, const int32_t& version, const Object& object = nullptr);
template <typename DataType>
const IndexFactory&
Register(const std::string& name, std::function<Index<IndexNode>(const int32_t& version, const Object&)> func);
static IndexFactory&
Instance();

private:
typedef std::map<std::string, std::function<Index<IndexNode>(const int32_t&, const Object&)>> FuncMap;
struct FunMapValueBase {
virtual ~FunMapValueBase() = default;
};
template <typename T1>
struct FunMapValue : FunMapValueBase {
public:
FunMapValue(std::function<T1(const int32_t&, const Object&)>& input) : fun_value(input) {
}
std::function<T1(const int32_t&, const Object&)> fun_value;
};
typedef std::map<std::string, std::unique_ptr<FunMapValueBase>> FuncMap;
IndexFactory();
static FuncMap&
MapInstance();
template <typename DataType>
std::string
GetMapKey(const std::string& name);
};

#define KNOWHERE_CONCAT(x, y) x##y
#define KNOWHERE_REGISTER_GLOBAL(name, func) \
const IndexFactory& KNOWHERE_CONCAT(index_factory_ref_, name) = IndexFactory::Instance().Register(#name, func)
#define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y
#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \
const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register<data_type>(#name, func)
#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \
KNOWHERE_REGISTER_GLOBAL( \
name, \
[](const int32_t& version, const Object& object) { \
return (Index<index_node<data_type, ##__VA_ARGS__>>::Create(version, object)); \
}, \
data_type)
#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, ...) \
KNOWHERE_REGISTER_GLOBAL( \
name, \
[](const int32_t& version, const Object& object) { \
return (Index<IndexNodeDataMockWrapper<data_type>>::Create( \
std::make_unique<index_node<MockData<data_type>::type, ##__VA_ARGS__>>(version, object))); \
}, \
data_type)
} // namespace knowhere

#endif /* INDEX_FACTORY_H */
1 change: 0 additions & 1 deletion include/knowhere/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "knowhere/index_node.h"

namespace knowhere {

template <typename T1>
class Index {
public:
Expand Down
2 changes: 1 addition & 1 deletion include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#include "knowhere/object.h"
#include "knowhere/operands.h"
#include "knowhere/version.h"

namespace knowhere {

class IndexNode : public Object {
public:
IndexNode(const int32_t ver) : version_(ver) {
Expand Down
100 changes: 100 additions & 0 deletions include/knowhere/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#ifndef INDEX_NODE_DATA_MOCK_WRAPPER_H
#define INDEX_NODE_DATA_MOCK_WRAPPER_H

#include "knowhere/index_node.h"
namespace knowhere {

template <typename DataType>
class IndexNodeDataMockWrapper : public IndexNode {
public:
IndexNodeDataMockWrapper(std::unique_ptr<IndexNode> index_node) : index_node_(std::move(index_node)) {
}

Status
Build(const DataSet& dataset, const Config& cfg) override;

Status
Train(const DataSet& dataset, const Config& cfg) override;

Status
Add(const DataSet& dataset, const Config& cfg) override;
expected<DataSetPtr>
Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override;

expected<DataSetPtr>
RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override;

expected<std::vector<std::shared_ptr<iterator>>>
AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override;

expected<DataSetPtr>
GetVectorByIds(const DataSet& dataset) const override;

bool
HasRawData(const std::string& metric_type) const override {
return index_node_->HasRawData(metric_type);
}

expected<DataSetPtr>
GetIndexMeta(const Config& cfg) const override {
return index_node_->GetIndexMeta(cfg);
}

Status
Serialize(BinarySet& binset) const override {
return index_node_->Serialize(binset);
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
return index_node_->Deserialize(binset, config);
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
return index_node_->DeserializeFromFile(filename, config);
}

std::unique_ptr<BaseConfig>
CreateConfig() const override {
return index_node_->CreateConfig();
}

int64_t
Dim() const override {
return index_node_->Dim();
}

int64_t
Size() const override {
return index_node_->Size();
}

int64_t
Count() const override {
return index_node_->Count();
}

std::string
Type() const override {
return index_node_->Type();
}

private:
std::unique_ptr<IndexNode> index_node_;
};

} // namespace knowhere

#endif /* INDEX_NODE_DATA_MOCK_WRAPPER_H */
Loading

0 comments on commit 92c0aeb

Please sign in to comment.