Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…arch into jparismorgan/tdb-matrix-multirange-slices
  • Loading branch information
jparismorgan committed Oct 15, 2024
2 parents aa311a0 + 9f1a046 commit 805780f
Show file tree
Hide file tree
Showing 16 changed files with 157 additions and 93 deletions.
2 changes: 1 addition & 1 deletion apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,6 @@ def centralised_kmeans(
with tiledb.scope_ctx(ctx_or_config=config):
logger = setup(config, verbose)
group = tiledb.Group(index_group_uri)
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
if training_sample_size >= partitions:
if training_source_uri:
if training_source_type is None:
Expand Down Expand Up @@ -1261,6 +1260,7 @@ def centralised_kmeans(
# raise ValueError(f"We have a training_sample_size of {training_sample_size} but {partitions} partitions - training_sample_size must be >= partitions")
centroids = np.random.rand(dimensions, partitions)

centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
logger.debug("Writing centroids to array %s", centroids_uri)
with tiledb.open(centroids_uri, mode="w", timestamp=index_timestamp) as A:
A[0:dimensions, 0:partitions] = centroids
Expand Down
36 changes: 18 additions & 18 deletions apis/python/src/tiledb/vector_search/type_erased_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,13 @@ void init_type_erased_module(py::module_& m) {
"__init__",
[](IndexVamana& instance,
const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
std::optional<TemporalPolicy> temporal_policy) {
new (&instance) IndexVamana(ctx, group_uri, temporal_policy);
new (&instance) IndexVamana(ctx, index_uri, temporal_policy);
},
py::keep_alive<1, 2>(), // IndexVamana should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("temporal_policy") = std::nullopt)
.def(
"__init__",
Expand Down Expand Up @@ -432,14 +432,14 @@ void init_type_erased_module(py::module_& m) {
"write_index",
[](IndexVamana& index,
const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
std::optional<TemporalPolicy> temporal_policy,
const std::string& storage_version) {
index.write_index(ctx, group_uri, temporal_policy, storage_version);
index.write_index(ctx, index_uri, temporal_policy, storage_version);
},
py::keep_alive<1, 2>(), // IndexVamana should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("temporal_policy") = std::nullopt,
py::arg("storage_version") = "")
.def("feature_type_string", &IndexVamana::feature_type_string)
Expand All @@ -450,34 +450,34 @@ void init_type_erased_module(py::module_& m) {
.def_static(
"clear_history",
[](const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
uint64_t timestamp) {
IndexVamana::clear_history(ctx, group_uri, timestamp);
IndexVamana::clear_history(ctx, index_uri, timestamp);
},
py::keep_alive<1, 2>(), // IndexVamana should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("timestamp"));

py::class_<IndexIVFPQ>(m, "IndexIVFPQ")
.def(
"__init__",
[](IndexIVFPQ& instance,
const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
IndexLoadStrategy index_load_strategy,
size_t memory_budget,
std::optional<TemporalPolicy> temporal_policy) {
new (&instance) IndexIVFPQ(
ctx,
group_uri,
index_uri,
index_load_strategy,
memory_budget,
temporal_policy);
},
py::keep_alive<1, 2>(), // IndexIVFPQ should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("index_load_strategy") = IndexLoadStrategy::PQ_INDEX,
py::arg("memory_budget") = 0,
py::arg("temporal_policy") = std::nullopt)
Expand Down Expand Up @@ -518,14 +518,14 @@ void init_type_erased_module(py::module_& m) {
"write_index",
[](IndexIVFPQ& index,
const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
std::optional<TemporalPolicy> temporal_policy,
const std::string& storage_version) {
index.write_index(ctx, group_uri, temporal_policy, storage_version);
index.write_index(ctx, index_uri, temporal_policy, storage_version);
},
py::keep_alive<1, 2>(), // IndexIVFPQ should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("temporal_policy") = std::nullopt,
py::arg("storage_version") = "")
.def("feature_type_string", &IndexIVFPQ::feature_type_string)
Expand All @@ -537,13 +537,13 @@ void init_type_erased_module(py::module_& m) {
.def_static(
"clear_history",
[](const tiledb::Context& ctx,
const std::string& group_uri,
const std::string& index_uri,
uint64_t timestamp) {
IndexIVFPQ::clear_history(ctx, group_uri, timestamp);
IndexIVFPQ::clear_history(ctx, index_uri, timestamp);
},
py::keep_alive<1, 2>(), // IndexIVFPQ should keep ctx alive.
py::arg("ctx"),
py::arg("group_uri"),
py::arg("index_uri"),
py::arg("timestamp"));

py::class_<IndexIVFFlat>(m, "IndexIVFFlat")
Expand Down
2 changes: 1 addition & 1 deletion src/include/api/ivf_pq_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class IndexIVFPQ {
}

uint64_t nlist() const override {
return impl_index_.nlist();
return impl_index_.partitions();
}

uint32_t num_subspaces() const override {
Expand Down
53 changes: 53 additions & 0 deletions src/include/detail/linalg/tdb_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,59 @@ std::vector<T> read_vector(
return read_vector_helper<T>(ctx, uri, 0, 0, temporal_policy, true);
}

/**
* Read the contents of a TileDB array into a std::vector.
* @tparam T Type of data element stored.
* @param ctx The TileDB context.
* @param uri The URI of the TileDB array.
* @param slices The slices to read. Each slice is a pair of [start, end] (i.e.
* they are inclusive).
* @param temporal_policy The temporal policy for the read.
* @return The vector of data.
*/
template <class T, typename Slice>
std::vector<T> read_vector(
const tiledb::Context& ctx,
const std::string& uri,
const std::vector<std::pair<Slice, Slice>>& slices,
size_t total_slices_size,
TemporalPolicy temporal_policy = {}) {
if (total_slices_size == 0) {
return {};
}
scoped_timer _{tdb_func__ + " " + std::string{uri}};

auto array_ = tiledb_helpers::open_array(
tdb_func__, ctx, uri, TILEDB_READ, temporal_policy);
auto schema_ = array_->schema();

const size_t idx = 0;
auto attr = schema_.attribute(idx);

std::string attr_name = attr.name();

// Create a subarray that reads the array up to the specified subset.
tiledb::Subarray subarray(ctx, *array_);
for (const auto& slice : slices) {
subarray.add_range(
0, static_cast<int>(slice.first), static_cast<int>(slice.second));
}

// @todo: use something non-initializing
std::vector<T> data_(total_slices_size);

tiledb::Query query(ctx, *array_);
query.set_subarray(subarray).set_data_buffer(
attr_name, data_.data(), total_slices_size);
tiledb_helpers::submit_query(tdb_func__, uri, query);
_memory_data.insert_entry(tdb_func__, total_slices_size * sizeof(T));

array_->close();
assert(tiledb::Query::Status::COMPLETE == query.query_status());

return data_;
}

template <class T>
auto sizes_to_indices(const std::vector<T>& sizes) {
std::vector<T> indices(size(sizes) + 1);
Expand Down
6 changes: 6 additions & 0 deletions src/include/index/index_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ class base_index_group {
if (opened_for_ == TILEDB_WRITE) {
set_dimensions(dimensions);
}
if (empty(this->version_)) {
this->version_ = current_storage_version;
}
if (storage_formats.find(this->version_) == storage_formats.end()) {
throw std::runtime_error("Invalid storage version: " + this->version_);
}
}

/**
Expand Down
3 changes: 0 additions & 3 deletions src/include/index/ivf_flat_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ class ivf_flat_group : public base_index_group<index_type> {
}

void create_default_impl() {
if (empty(this->version_)) {
this->version_ = current_storage_version;
}
this->init_valid_array_names();

static const int32_t tile_size{
Expand Down
3 changes: 0 additions & 3 deletions src/include/index/ivf_pq_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ class ivf_pq_group : public base_index_group<index_type> {
* Create a ready-to-use group with default arrays
****************************************************************************/
void create_default_impl() {
if (empty(this->version_)) {
this->version_ = current_storage_version;
}
this->init_valid_array_names();

static const int32_t tile_size{
Expand Down
2 changes: 1 addition & 1 deletion src/include/index/ivf_pq_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,7 @@ class ivf_pq_index {
return reassign_ratio_;
}

uint64_t nlist() const {
uint64_t partitions() const {
return num_partitions_;
}

Expand Down
3 changes: 0 additions & 3 deletions src/include/index/vamana_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ class vamana_index_group : public base_index_group<index_type> {
}

void create_default_impl() {
if (empty(this->version_)) {
this->version_ = current_storage_version;
}
this->init_valid_array_names();

static const int32_t tile_size{
Expand Down
2 changes: 1 addition & 1 deletion src/include/test/unit_api_ivf_pq_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ TEST_CASE("storage_version", "[api_ivf_pq_index]") {
// Throw with the wrong version.
CHECK_THROWS_WITH(
index.write_index(ctx, index_uri, std::nullopt, "0.4"),
"Version mismatch. Requested 0.4 but found 0.3");
"Invalid storage version: 0.4");
// Succeed without a version.
index.write_index(ctx, index_uri);
// Succeed with the same version.
Expand Down
2 changes: 1 addition & 1 deletion src/include/test/unit_api_vamana_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ TEST_CASE("storage_version", "[api_vamana_index]") {
// Throw with the wrong version.
CHECK_THROWS_WITH(
index.write_index(ctx, index_uri, std::nullopt, "0.4"),
"Version mismatch. Requested 0.4 but found 0.3");
"Invalid storage version: 0.4");
// Succeed without a version.
index.write_index(ctx, index_uri);
// Succeed with the same version.
Expand Down
2 changes: 1 addition & 1 deletion src/include/test/unit_ivf_flat_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ TEST_CASE("mismatched storage version", "[ivf_flat_group]") {
TemporalPolicy{TimeTravel, 0},
"different_version",
10),
"Version mismatch. Requested different_version but found 0.3");
"Invalid storage version: different_version");
}

TEST_CASE("clear history", "[ivf_flat_group]") {
Expand Down
2 changes: 1 addition & 1 deletion src/include/test/unit_ivf_pq_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ TEST_CASE("mismatched storage version", "[ivf_pq_group]") {
TemporalPolicy{TimeTravel, 0},
"different_version",
10),
"Version mismatch. Requested different_version but found 0.3");
"Invalid storage version: different_version");
}

TEST_CASE("clear history", "[ivf_pq_group]") {
Expand Down
Loading

0 comments on commit 805780f

Please sign in to comment.