diff --git a/cpp/Enums.h b/cpp/Enums.h new file mode 100644 index 00000000..95d11b74 --- /dev/null +++ b/cpp/Enums.h @@ -0,0 +1,50 @@ +#pragma once + +/** + * The space (i.e. distance metric) to use for searching. + */ +enum SpaceType : unsigned char { + Euclidean = 0, + InnerProduct = 1, + Cosine = 2, +}; + +/** + * The datatype used to use when storing vectors on disk. + * Affects precision and memory usage. + */ +enum class StorageDataType : unsigned char { + Float8 = 1 << 4, + Float32 = 2 << 4, + + // An 8-bit floating point format that uses + // four bits for exponent, 3 bits for mantissa, + // allowing representation of values from 2e-9 to 448. + E4M3 = 3 << 4, +}; + +inline const std::string toString(StorageDataType sdt) { + switch (sdt) { + case StorageDataType::Float8: + return "Float8"; + case StorageDataType::Float32: + return "Float32"; + case StorageDataType::E4M3: + return "E4M3"; + default: + return "Unknown storage data type (value " + std::to_string((int)sdt) + ")"; + } +} + +inline const std::string toString(SpaceType space) { + switch (space) { + case SpaceType::Euclidean: + return "Euclidean"; + case SpaceType::Cosine: + return "Cosine"; + case SpaceType::InnerProduct: + return "InnerProduct"; + default: + return "Unknown space type (value " + std::to_string((int)space) + ")"; + } +} \ No newline at end of file diff --git a/cpp/Index.h b/cpp/Index.h index 9aa678a8..f6fe581c 100644 --- a/cpp/Index.h +++ b/cpp/Index.h @@ -25,33 +25,11 @@ #include #include +#include "Enums.h" #include "StreamUtils.h" #include "array_utils.h" #include "hnswlib.h" -/** - * The space (i.e. distance metric) to use for searching. - */ -enum SpaceType { - Euclidean, - InnerProduct, - Cosine, -}; - -/** - * The datatype used to use when storing vectors on disk. - * Affects precision and memory usage. - */ -enum class StorageDataType { - Float8, - Float32, - - // An 8-bit floating point format that uses - // four bits for exponent, 3 bits for mantissa, - // allowing representation of values from 2e-9 to 448. - E4M3, -}; - /** * A C++ wrapper class for a Voyager index, which accepts * and returns floating-point data. diff --git a/cpp/Metadata.h b/cpp/Metadata.h new file mode 100644 index 00000000..67d1bb25 --- /dev/null +++ b/cpp/Metadata.h @@ -0,0 +1,123 @@ +#pragma once +/*- + * -\-\- + * voyager + * -- + * Copyright (C) 2016 - 2023 Spotify AB + * + * This file is heavily based on hnswlib (https://github.com/nmslib/hnswlib, + * Apache 2.0-licensed, no copyright author listed) + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +#include "Enums.h" +#include "StreamUtils.h" + +namespace voyager { +namespace Metadata { +/** + * @brief A basic metadata class that stores the number of dimensions, + * the SpaceType, StorageDataType, and number of dimensions. + */ +class V1 { +public: + V1(int numDimensions, SpaceType spaceType, StorageDataType storageDataType) + : numDimensions(numDimensions), spaceType(spaceType), + storageDataType(storageDataType) {} + + V1() {} + virtual ~V1() {} + + int version() const { return 1; } + + int getNumDimensions() { return numDimensions; } + + StorageDataType getStorageDataType() { return storageDataType; } + + SpaceType getSpaceType() { return spaceType; } + + void setNumDimensions(int newNumDimensions) { + numDimensions = newNumDimensions; + } + + void setStorageDataType(StorageDataType newStorageDataType) { + storageDataType = newStorageDataType; + } + + void setSpaceType(SpaceType newSpaceType) { spaceType = newSpaceType; } + + virtual void serializeToStream(std::shared_ptr stream) { + stream->write("VOYA", 4); + writeBinaryPOD(stream, version()); + writeBinaryPOD(stream, numDimensions); + writeBinaryPOD(stream, spaceType); + writeBinaryPOD(stream, storageDataType); + }; + + virtual void loadFromStream(std::shared_ptr stream) { + // Version has already been loaded before we get here! + readBinaryPOD(stream, numDimensions); + readBinaryPOD(stream, spaceType); + readBinaryPOD(stream, storageDataType); + }; + +private: + int numDimensions; + SpaceType spaceType; + StorageDataType storageDataType; +}; + +static std::unique_ptr +loadFromStream(std::shared_ptr inputStream) { + uint32_t header = inputStream->peek(); + if (header != 'AYOV') { + return nullptr; + } + + // Actually read instead of just peeking: + inputStream->read((char *)&header, sizeof(header)); + + int version; + readBinaryPOD(inputStream, version); + + switch (version) { + case 1: { + std::unique_ptr metadata = std::make_unique(); + metadata->loadFromStream(inputStream); + return metadata; + } + default: { + std::stringstream stream; + stream << std::hex << version; + std::string resultAsHex(stream.str()); + + std::string error = "Unable to parse version of Voyager index file; found " + "unsupported version \"0x" + + resultAsHex + "\"."; + + if (version < 20) { + error += " A newer version of the Voyager library may be able to read " + "this index."; + } else { + error += " This index may be corrupted (or not a Voyager index)."; + } + + throw std::domain_error(error); + } + } +}; + +} // namespace Metadata +}; // namespace voyager \ No newline at end of file diff --git a/cpp/StreamUtils.h b/cpp/StreamUtils.h index 9ead5802..c3d74148 100644 --- a/cpp/StreamUtils.h +++ b/cpp/StreamUtils.h @@ -21,6 +21,7 @@ #pragma once #include #include +#include #include #include #include @@ -41,11 +42,12 @@ class InputStream { virtual bool advanceBy(long long numBytes) { return setPosition(getPosition() + numBytes); } + virtual uint32_t peek() = 0; }; class FileInputStream : public InputStream { public: - FileInputStream(const std::string &filename) { + FileInputStream(const std::string &filename) : filename(filename) { handle = fopen(filename.c_str(), "r"); if (!handle) { throw std::runtime_error("Failed to open file for reading: " + filename); @@ -74,6 +76,19 @@ class FileInputStream : public InputStream { virtual bool advanceBy(long long bytes) { return fseek(handle, bytes, SEEK_CUR) == 0; } + virtual uint32_t peek() { + uint32_t result = 0; + long long lastPosition = getPosition(); + if (read((char *)&result, sizeof(result)) == sizeof(result)) { + setPosition(lastPosition); + return result; + } else { + throw std::runtime_error( + "Failed to peek " + std::to_string(sizeof(result)) + + " bytes from file \"" + filename + "\" at index " + + std::to_string(lastPosition) + "."); + } + } virtual ~FileInputStream() { if (handle) { @@ -85,6 +100,7 @@ class FileInputStream : public InputStream { protected: FileInputStream() {} FILE *handle = nullptr; + std::string filename; private: bool isRegularFile = false; @@ -143,4 +159,22 @@ class MemoryOutputStream : public OutputStream { private: std::ostringstream outputStream; -}; \ No newline at end of file +}; + +template +static void writeBinaryPOD(std::shared_ptr out, const T &podRef) { + if (!out->write((char *)&podRef, sizeof(T))) { + throw std::runtime_error("Failed to write " + std::to_string(sizeof(T)) + + " bytes to stream!"); + } +} + +template +static void readBinaryPOD(std::shared_ptr in, T &podRef) { + long long bytesRead = in->read((char *)&podRef, sizeof(T)); + if (bytesRead != sizeof(T)) { + throw std::runtime_error("Failed to read " + std::to_string(sizeof(T)) + + " bytes from stream! Got " + + std::to_string(bytesRead) + "."); + } +} \ No newline at end of file diff --git a/cpp/TypedIndex.h b/cpp/TypedIndex.h index 114b6b8c..510d7288 100644 --- a/cpp/TypedIndex.h +++ b/cpp/TypedIndex.h @@ -29,7 +29,9 @@ #include #include "E4M3.h" +#include "Enums.h" #include "Index.h" +#include "Metadata.h" #include "array_utils.h" #include "hnswlib.h" #include "std_utils.h" @@ -99,6 +101,7 @@ class TypedIndex : public Index { hnswlib::labeltype currentLabel; std::unique_ptr> algorithmImpl; std::unique_ptr> spaceImpl; + std::unique_ptr metadata; public: /** @@ -107,7 +110,10 @@ class TypedIndex : public Index { TypedIndex(const SpaceType space, const int dimensions, const size_t M = 12, const size_t efConstruction = 200, const size_t randomSeed = 1, const size_t maxElements = 1) - : space(space), dimensions(dimensions) { + : space(space), dimensions(dimensions), + metadata(std::make_unique( + dimensions, space, getStorageDataType())) { + switch (space) { case Euclidean: spaceImpl = std::make_unique< @@ -168,6 +174,18 @@ class TypedIndex : public Index { currentLabel = algorithmImpl->cur_element_count; } + /** + * Load an index from the given input stream, interpreting + * it as the given Space and number of dimensions. + */ + TypedIndex(std::unique_ptr metadata, + std::shared_ptr inputStream, bool searchOnly = false) + : TypedIndex(metadata->getSpaceType(), metadata->getNumDimensions()) { + algorithmImpl = std::make_unique>( + spaceImpl.get(), inputStream, 0, searchOnly); + currentLabel = algorithmImpl->cur_element_count; + } + int getNumDimensions() const { return dimensions; } SpaceType getSpace() const { return space; } @@ -215,7 +233,7 @@ class TypedIndex : public Index { * Save this index to the provided file path on disk. */ void saveIndex(const std::string &pathToIndex) { - algorithmImpl->saveIndex(pathToIndex); + saveIndex(std::make_shared(pathToIndex)); } /** @@ -224,6 +242,7 @@ class TypedIndex : public Index { * TypedIndex constructor to reload this index. */ void saveIndex(std::shared_ptr outputStream) { + metadata->serializeToStream(outputStream); algorithmImpl->saveIndex(outputStream); } @@ -572,3 +591,44 @@ class TypedIndex : public Index { size_t getM() const { return algorithmImpl->M_; } }; + +std::unique_ptr +loadTypedIndexFromStream(std::shared_ptr inputStream) { + std::unique_ptr metadata = + voyager::Metadata::loadFromStream(inputStream); + + if (!metadata) { + throw std::domain_error( + "The provided file contains no Voyager parameter metadata. Please " + "specify the number of dimensions, SpaceType, and StorageDataType that " + "this index contains."); + } else if (voyager::Metadata::V1 *v1 = + dynamic_cast(metadata.get())) { + // We have enough information to create a TypedIndex! + switch (v1->getStorageDataType()) { + case StorageDataType::Float32: + return std::make_unique>( + std::unique_ptr( + (voyager::Metadata::V1 *)metadata.release()), + inputStream); + break; + case StorageDataType::Float8: + return std::make_unique>>( + std::unique_ptr( + (voyager::Metadata::V1 *)metadata.release()), + inputStream); + break; + case StorageDataType::E4M3: + return std::make_unique>( + std::unique_ptr( + (voyager::Metadata::V1 *)metadata.release()), + inputStream); + break; + default: + throw std::domain_error("Unknown storage data type: " + + std::to_string((int)v1->getStorageDataType())); + } + } else { + throw std::domain_error("Unknown Voyager metadata format."); + } +} \ No newline at end of file diff --git a/cpp/hnswalg.h b/cpp/hnswalg.h index 7128569e..7f81fae1 100644 --- a/cpp/hnswalg.h +++ b/cpp/hnswalg.h @@ -712,8 +712,15 @@ class HierarchicalNSW : public AlgorithmInterface { if (inputStream->isSeekable()) { totalFileSize = inputStream->getTotalLength(); } - readBinaryPOD(inputStream, offsetLevel0_); + if (totalFileSize > 0 && offsetLevel0_ > totalFileSize) { + throw std::domain_error("Index appears to contain corrupted data; level " + "0 offset parameter (" + + std::to_string(offsetLevel0_) + + ") exceeded size of index file (" + + std::to_string(totalFileSize) + ")."); + } + readBinaryPOD(inputStream, max_elements_); readBinaryPOD(inputStream, cur_element_count); @@ -727,6 +734,15 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(inputStream, maxlevel_); readBinaryPOD(inputStream, enterpoint_node_); + if (enterpoint_node_ >= cur_element_count) { + throw std::runtime_error( + "Index seems to be corrupted or unsupported. " + "Entry point into HNSW data structure was at element index " + + std::to_string(enterpoint_node_) + ", but only " + + std::to_string(cur_element_count) + + " elements are present in the index."); + } + readBinaryPOD(inputStream, maxM_); readBinaryPOD(inputStream, maxM0_); readBinaryPOD(inputStream, M_); @@ -763,18 +779,37 @@ class HierarchicalNSW : public AlgorithmInterface { if (inputStream->getPosition() < 0 || inputStream->getPosition() >= totalFileSize) { throw std::runtime_error( - "Index seems to be corrupted or unsupported"); + "Index seems to be corrupted or unsupported. Seeked to " + + std::to_string(position + + (cur_element_count * size_data_per_element_) + + (sizeof(unsigned int) * i)) + + " bytes to read linked list, but resulting stream position was " + + std::to_string(inputStream->getPosition()) + + " (of total file size " + std::to_string(totalFileSize) + + " bytes)."); } unsigned int linkListSize; readBinaryPOD(inputStream, linkListSize); if (linkListSize != 0) { + if (inputStream->getPosition() + linkListSize > totalFileSize) { + throw std::runtime_error( + "Index seems to be corrupted or unsupported. Advancing to the " + "next linked list requires " + + std::to_string(linkListSize) + + " additional bytes (from position " + + std::to_string(inputStream->getPosition()) + + "), but index data only has " + std::to_string(totalFileSize) + + " bytes in total."); + } inputStream->advanceBy(linkListSize); } } if (inputStream->getPosition() != totalFileSize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); + throw std::runtime_error( + "Index seems to be corrupted or unsupported. After reading all " + "linked lists, extra data remained at the end of the index."); inputStream->setPosition(position); } @@ -885,6 +920,15 @@ class HierarchicalNSW : public AlgorithmInterface { } } + if (enterpoint_node_ > 0 && enterpoint_node_ != -1 && + !linkLists_[enterpoint_node_]) { + throw std::runtime_error( + "Index seems to be corrupted or unsupported. " + "Entry point into HNSW data structure was at element index " + + std::to_string(enterpoint_node_) + + ", but no linked list was present at that index."); + } + for (size_t i = 0; i < cur_element_count; i++) { if (isMarkedDeleted(i)) num_deleted_ += 1; @@ -903,7 +947,8 @@ class HierarchicalNSW : public AlgorithmInterface { tableint label_c; auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); + throw std::runtime_error("Label " + std::to_string(label) + + " not found in index."); } label_c = search->second; diff --git a/cpp/hnswlib.h b/cpp/hnswlib.h index 3f291a62..c1c876f4 100644 --- a/cpp/hnswlib.h +++ b/cpp/hnswlib.h @@ -72,24 +72,6 @@ template class pairGreater { bool operator()(const T &p1, const T &p2) { return p1.first > p2.first; } }; -template -static void writeBinaryPOD(std::shared_ptr out, const T &podRef) { - if (!out->write((char *)&podRef, sizeof(T))) { - throw std::runtime_error("Failed to write " + std::to_string(sizeof(T)) + - " bytes to stream!"); - } -} - -template -static void readBinaryPOD(std::shared_ptr in, T &podRef) { - long long bytesRead = in->read((char *)&podRef, sizeof(T)); - if (bytesRead != sizeof(T)) { - throw std::runtime_error("Failed to read " + std::to_string(sizeof(T)) + - " bytes from stream! Got " + - std::to_string(bytesRead) + "."); - } -} - template class AlgorithmInterface { public: virtual void addPoint(const data_t *datapoint, labeltype label) = 0; diff --git a/java/JavaInputStream.h b/java/JavaInputStream.h index e0e267d5..ac4d8d8f 100644 --- a/java/JavaInputStream.h +++ b/java/JavaInputStream.h @@ -19,7 +19,9 @@ */ #include +#include #include +#include class JavaInputStream : public InputStream { public: @@ -65,6 +67,16 @@ class JavaInputStream : public InputStream { std::to_string(bufferSize)); } + if (peekValue.size()) { + long long bytesToCopy = + std::min(bytesToRead, (long long)peekValue.size()); + std::memcpy(buffer, peekValue.data(), bytesToCopy); + for (int i = 0; i < bytesToCopy; i++) + peekValue.erase(peekValue.begin()); + bytesRead += bytesToCopy; + buffer += bytesToCopy; + } + while (bytesRead < bytesToRead) { int readResult = env->CallIntMethod( inputStream, readMethod, byteArray, 0, @@ -109,8 +121,27 @@ class JavaInputStream : public InputStream { virtual ~JavaInputStream() {} + virtual uint32_t peek() { + uint32_t result = 0; + long long lastPosition = getPosition(); + if (read((char *)&result, sizeof(result)) == sizeof(result)) { + char *resultAsCharacters = (char *)&result; + peekValue.push_back(resultAsCharacters[0]); + peekValue.push_back(resultAsCharacters[1]); + peekValue.push_back(resultAsCharacters[2]); + peekValue.push_back(resultAsCharacters[3]); + return result; + } else { + throw std::runtime_error("Failed to peek " + + std::to_string(sizeof(result)) + + " bytes from JavaInputStream at index " + + std::to_string(lastPosition) + "."); + } + } + private: JNIEnv *env; jobject inputStream; + std::vector peekValue; long long bytesRead = 0; }; \ No newline at end of file diff --git a/java/com_spotify_voyager_jni_Index.cpp b/java/com_spotify_voyager_jni_Index.cpp index c2e2a2a9..0cc7c61c 100644 --- a/java/com_spotify_voyager_jni_Index.cpp +++ b/java/com_spotify_voyager_jni_Index.cpp @@ -21,6 +21,7 @@ #include "com_spotify_voyager_jni_Index.h" #include "JavaInputStream.h" #include "JavaOutputStream.h" +#include #include #include @@ -697,31 +698,56 @@ void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_io_OutputStream_2( // Load Index //////////////////////////////////////////////////////////////////////////////////////////////////// // TODO: Convert these to static methods -void Java_com_spotify_voyager_jni_Index_nativeLoadFromFile( +void Java_com_spotify_voyager_jni_Index_nativeLoadFromFileWithParameters( JNIEnv *env, jobject self, jstring filename, jobject spaceType, jint numDimensions, jobject storageDataType) { try { + auto inputStream = + std::make_shared(toString(env, filename)); + std::unique_ptr metadata = + voyager::Metadata::loadFromStream(inputStream); + + if (metadata) { + if (metadata->getStorageDataType() != + toStorageDataType(env, storageDataType)) { + throw std::domain_error( + "Provided storage data type (" + + toString(toStorageDataType(env, storageDataType)) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); + } + if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { + throw std::domain_error( + "Provided space type (" + toString(toSpaceType(env, spaceType)) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); + } + if (metadata->getNumDimensions() != numDimensions) { + throw std::domain_error( + "Provided number of dimensions (" + std::to_string(numDimensions) + + ") does not match the number of dimensions used in this file (" + + std::to_string(metadata->getNumDimensions()) + ")."); + } + } + switch (toStorageDataType(env, storageDataType)) { case StorageDataType::Float32: - setHandle( - env, self, - new TypedIndex( - std::make_shared(toString(env, filename)), - toSpaceType(env, spaceType), numDimensions)); + setHandle(env, self, + new TypedIndex(inputStream, + toSpaceType(env, spaceType), + numDimensions)); break; case StorageDataType::Float8: setHandle( env, self, new TypedIndex>( - std::make_shared(toString(env, filename)), - toSpaceType(env, spaceType), numDimensions)); + inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::E4M3: - setHandle( - env, self, - new TypedIndex( - std::make_shared(toString(env, filename)), - toSpaceType(env, spaceType), numDimensions)); + setHandle(env, self, + new TypedIndex(inputStream, + toSpaceType(env, spaceType), + numDimensions)); break; } } catch (std::exception const &e) { @@ -731,28 +757,55 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromFile( } } -void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStream( - JNIEnv *env, jobject self, jobject inputStream, jobject spaceType, +void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStreamWithParameters( + JNIEnv *env, jobject self, jobject jInputStream, jobject spaceType, jint numDimensions, jobject storageDataType) { try { + auto inputStream = std::make_shared(env, jInputStream); + std::unique_ptr metadata = + voyager::Metadata::loadFromStream(inputStream); + + if (metadata) { + if (metadata->getStorageDataType() != + toStorageDataType(env, storageDataType)) { + throw std::domain_error( + "Provided storage data type (" + + toString(toStorageDataType(env, storageDataType)) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); + } + if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { + throw std::domain_error( + "Provided space type (" + toString(toSpaceType(env, spaceType)) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); + } + if (metadata->getNumDimensions() != numDimensions) { + throw std::domain_error( + "Provided number of dimensions (" + std::to_string(numDimensions) + + ") does not match the number of dimensions used in this file (" + + std::to_string(metadata->getNumDimensions()) + ")."); + } + } + switch (toStorageDataType(env, storageDataType)) { case StorageDataType::Float32: setHandle(env, self, - new TypedIndex( - std::make_shared(env, inputStream), - toSpaceType(env, spaceType), numDimensions)); + new TypedIndex(inputStream, + toSpaceType(env, spaceType), + numDimensions)); break; case StorageDataType::Float8: - setHandle(env, self, - new TypedIndex>( - std::make_shared(env, inputStream), - toSpaceType(env, spaceType), numDimensions)); + setHandle( + env, self, + new TypedIndex>( + inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::E4M3: setHandle(env, self, - new TypedIndex( - std::make_shared(env, inputStream), - toSpaceType(env, spaceType), numDimensions)); + new TypedIndex(inputStream, + toSpaceType(env, spaceType), + numDimensions)); break; } } catch (std::exception const &e) { diff --git a/java/com_spotify_voyager_jni_Index.h b/java/com_spotify_voyager_jni_Index.h index 1dcbc744..25587917 100644 --- a/java/com_spotify_voyager_jni_Index.h +++ b/java/com_spotify_voyager_jni_Index.h @@ -18,23 +18,40 @@ JNIEXPORT void JNICALL Java_com_spotify_voyager_jni_Index_nativeConstructor( /* * Class: com_spotify_voyager_jni_Index - * Method: nativeLoadFromFile + * Method: nativeLoadFromFileWithParameters * Signature: * (Ljava/lang/String;Lcom/spotify/voyager/jni/Index/SpaceType;ILcom/spotify/voyager/jni/Index/StorageDataType;)V */ -JNIEXPORT void JNICALL Java_com_spotify_voyager_jni_Index_nativeLoadFromFile( +JNIEXPORT void JNICALL +Java_com_spotify_voyager_jni_Index_nativeLoadFromFileWithParameters( JNIEnv *, jobject, jstring, jobject, jint, jobject); /* * Class: com_spotify_voyager_jni_Index - * Method: nativeLoadFromInputStream + * Method: nativeLoadFromFile + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_com_spotify_voyager_jni_Index_nativeLoadFromFile( + JNIEnv *, jobject, jstring); + +/* + * Class: com_spotify_voyager_jni_Index + * Method: nativeLoadFromInputStreamWithParameters * Signature: * (Ljava/io/InputStream;Lcom/spotify/voyager/jni/Index/SpaceType;ILcom/spotify/voyager/jni/Index/StorageDataType;)V */ JNIEXPORT void JNICALL +Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStreamWithParameters( + JNIEnv *, jobject, jobject, jobject, jint, jobject); + +/* + * Class: com_spotify_voyager_jni_Index + * Method: nativeLoadFromInputStream + * Signature: (Ljava/io/InputStream;)V + */ +JNIEXPORT void JNICALL Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStream(JNIEnv *, jobject, - jobject, jobject, - jint, jobject); + jobject); /* * Class: com_spotify_voyager_jni_Index diff --git a/java/src/main/java/com/spotify/voyager/jni/Index.java b/java/src/main/java/com/spotify/voyager/jni/Index.java index e48baee1..13f4bb65 100644 --- a/java/src/main/java/com/spotify/voyager/jni/Index.java +++ b/java/src/main/java/com/spotify/voyager/jni/Index.java @@ -234,7 +234,24 @@ public Index( public static Index load( String filename, SpaceType space, int numDimensions, StorageDataType storageDataType) { Index index = new Index(); - index.nativeLoadFromFile(filename, space, numDimensions, storageDataType); + index.nativeLoadFromFileWithParameters(filename, space, numDimensions, storageDataType); + return index; + } + + /** + * Load a Voyager index file and create a new {@link Index} initialized with the data in that + * file. + * + * @param filename A filename to load. + * @return An {@link Index} whose contents have been initialized with the data provided by the + * file. + * @throws RuntimeException if the index cannot be loaded from the file, the file contains invalid + * data, or the file contains an older version of the Voyager file format that requires + * additional arguments to be provided. + */ + public static Index load(String filename) { + Index index = new Index(); + index.nativeLoadFromFile(filename); return index; } @@ -258,7 +275,26 @@ public static Index load( int numDimensions, StorageDataType storageDataType) { Index index = new Index(); - index.nativeLoadFromInputStream(inputStream, space, numDimensions, storageDataType); + index.nativeLoadFromInputStreamWithParameters( + inputStream, space, numDimensions, storageDataType); + return index; + } + + /** + * Interpret the contents of a {@code java.io.InputStream} as the contents of a Voyager index file + * and create a new {@link Index} initialized with the data provided by that stream. + * + * @param inputStream A {@link java.io.InputStream} that will provide the contents of a Voyager + * index. + * @return An {@link Index} whose contents have been initialized with the data provided by the + * input stream. + * @throws RuntimeException if the index cannot be loaded from the stream, or the stream contains + * invalid data, or the file contains an older version of the Voyager file format that + * requires additional arguments to be provided. + */ + public static Index load(InputStream inputStream) { + Index index = new Index(); + index.nativeLoadFromInputStream(inputStream); return index; } @@ -295,12 +331,16 @@ private native void nativeConstructor( long maxElements, StorageDataType storageDataType); - private native void nativeLoadFromFile( + private native void nativeLoadFromFileWithParameters( String filename, SpaceType space, int numDimensions, StorageDataType storageDataType); - private native void nativeLoadFromInputStream( + private native void nativeLoadFromFile(String filename); + + private native void nativeLoadFromInputStreamWithParameters( InputStream inputStream, SpaceType space, int numDimensions, StorageDataType storageDataType); + private native void nativeLoadFromInputStream(InputStream inputStream); + private native void nativeDestructor(); /** diff --git a/python/bindings.cpp b/python/bindings.cpp index b04c2ca9..fb5d68ad 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -905,20 +905,44 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). const StorageDataType storageDataType) -> std::shared_ptr { py::gil_scoped_release release; + auto inputStream = std::make_shared(filename); + std::unique_ptr metadata = + voyager::Metadata::loadFromStream(inputStream); + + if (metadata) { + if (metadata->getStorageDataType() != storageDataType) { + throw std::domain_error( + "Provided storage data type (" + toString(storageDataType) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); + } + if (metadata->getSpaceType() != space) { + throw std::domain_error( + "Provided space type (" + toString(space) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); + } + if (metadata->getNumDimensions() != num_dimensions) { + throw std::domain_error( + "Provided number of dimensions (" + + std::to_string(num_dimensions) + + ") does not match the number of dimensions used in this file " + "(" + + std::to_string(metadata->getNumDimensions()) + ")."); + } + } + switch (storageDataType) { case StorageDataType::E4M3: - return std::make_shared>( - std::make_shared(filename), space, - num_dimensions); + return std::make_shared>(inputStream, space, + num_dimensions); case StorageDataType::Float8: return std::make_shared< - TypedIndex>>( - std::make_shared(filename), space, - num_dimensions); + TypedIndex>>(inputStream, space, + num_dimensions); case StorageDataType::Float32: - return std::make_shared>( - std::make_shared(filename), space, - num_dimensions); + return std::make_shared>(inputStream, space, + num_dimensions); default: throw std::runtime_error("Unknown storage data type received!"); } @@ -926,6 +950,16 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). py::arg("filename"), py::arg("space"), py::arg("num_dimensions"), py::arg("storage_data_type") = StorageDataType::Float32, LOAD_DOCSTRING); + index.def_static( + "load", + [](const std::string filename) -> std::shared_ptr { + py::gil_scoped_release release; + + return loadTypedIndexFromStream( + std::make_shared(filename)); + }, + py::arg("filename"), LOAD_DOCSTRING); + index.def_static( "load", [](const py::object filelike, const SpaceType space, @@ -941,6 +975,32 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). auto inputStream = std::make_shared(filelike); py::gil_scoped_release release; + std::unique_ptr metadata = + voyager::Metadata::loadFromStream(inputStream); + + if (metadata) { + if (metadata->getStorageDataType() != storageDataType) { + throw std::domain_error( + "Provided storage data type (" + toString(storageDataType) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); + } + if (metadata->getSpaceType() != space) { + throw std::domain_error( + "Provided space type (" + toString(space) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); + } + if (metadata->getNumDimensions() != num_dimensions) { + throw std::domain_error( + "Provided number of dimensions (" + + std::to_string(num_dimensions) + + ") does not match the number of dimensions used in this file " + "(" + + std::to_string(metadata->getNumDimensions()) + ")."); + } + } + switch (storageDataType) { case StorageDataType::E4M3: return std::make_shared>(inputStream, space, @@ -958,4 +1018,21 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). }, py::arg("file_like"), py::arg("space"), py::arg("num_dimensions"), py::arg("storage_data_type") = StorageDataType::Float32, LOAD_DOCSTRING); + + index.def_static( + "load", + [](const py::object filelike) -> std::shared_ptr { + if (!isReadableFileLike(filelike)) { + throw py::type_error( + "Expected either a filename or a file-like object (with " + "read, seek, seekable, and tell methods), but received: " + + filelike.attr("__repr__")().cast()); + } + + auto inputStream = std::make_shared(filelike); + py::gil_scoped_release release; + + return loadTypedIndexFromStream(inputStream); + }, + py::arg("file_like"), LOAD_DOCSTRING); } diff --git a/python/src/PythonInputStream.h b/python/src/PythonInputStream.h index 5af041d3..6a1d8dd8 100644 --- a/python/src/PythonInputStream.h +++ b/python/src/PythonInputStream.h @@ -86,6 +86,16 @@ class PythonInputStream : public InputStream, PythonFileLike { long long bytesRead = 0; + if (peekValue.size()) { + long long bytesToCopy = + std::min(bytesToRead, (long long)peekValue.size()); + std::memcpy(buffer, peekValue.data(), bytesToCopy); + for (int i = 0; i < bytesToCopy; i++) + peekValue.erase(peekValue.begin()); + bytesRead += bytesToCopy; + buffer += bytesToCopy; + } + while (bytesRead < bytesToRead) { auto readResult = fileLike.attr("read")( std::min(MAX_BUFFER_SIZE, bytesToRead - bytesRead)); @@ -150,13 +160,13 @@ class PythonInputStream : public InputStream, PythonFileLike { return true; } - return fileLike.attr("tell")().cast() == getTotalLength(); + return getPosition() == getTotalLength(); } long long getPosition() { py::gil_scoped_acquire acquire; - return fileLike.attr("tell")().cast(); + return fileLike.attr("tell")().cast() - peekValue.size(); } bool setPosition(long long pos) { @@ -166,10 +176,29 @@ class PythonInputStream : public InputStream, PythonFileLike { fileLike.attr("seek")(pos); } - return fileLike.attr("tell")().cast() == pos; + return getPosition() == pos; + } + + uint32_t peek() { + uint32_t result = 0; + long long lastPosition = getPosition(); + if (read((char *)&result, sizeof(result)) == sizeof(result)) { + char *resultAsCharacters = (char *)&result; + peekValue.push_back(resultAsCharacters[0]); + peekValue.push_back(resultAsCharacters[1]); + peekValue.push_back(resultAsCharacters[2]); + peekValue.push_back(resultAsCharacters[3]); + return result; + } else { + throw std::runtime_error("Failed to peek " + + std::to_string(sizeof(result)) + + " bytes from file-like object at index " + + std::to_string(lastPosition) + "."); + } } private: long long totalLength = -1; + std::vector peekValue; bool lastReadWasSmallerThanExpected = false; }; \ No newline at end of file diff --git a/python/tests/indices/v0/cosine_128dim_e4m3.hnsw b/python/tests/indices/v0/cosine_128dim_e4m3.hnsw new file mode 100644 index 00000000..a31b9c93 Binary files /dev/null and b/python/tests/indices/v0/cosine_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/cosine_128dim_float32.hnsw b/python/tests/indices/v0/cosine_128dim_float32.hnsw new file mode 100644 index 00000000..938a96ec Binary files /dev/null and b/python/tests/indices/v0/cosine_128dim_float32.hnsw differ diff --git a/python/tests/indices/v0/cosine_128dim_float8.hnsw b/python/tests/indices/v0/cosine_128dim_float8.hnsw new file mode 100644 index 00000000..efb5e910 Binary files /dev/null and b/python/tests/indices/v0/cosine_128dim_float8.hnsw differ diff --git a/python/tests/indices/v0/cosine_16dim_e4m3.hnsw b/python/tests/indices/v0/cosine_16dim_e4m3.hnsw new file mode 100644 index 00000000..9cb475ad Binary files /dev/null and b/python/tests/indices/v0/cosine_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/cosine_16dim_float32.hnsw b/python/tests/indices/v0/cosine_16dim_float32.hnsw new file mode 100644 index 00000000..c205eaec Binary files /dev/null and b/python/tests/indices/v0/cosine_16dim_float32.hnsw differ diff --git a/python/tests/indices/v0/cosine_16dim_float8.hnsw b/python/tests/indices/v0/cosine_16dim_float8.hnsw new file mode 100644 index 00000000..d09cb9dd Binary files /dev/null and b/python/tests/indices/v0/cosine_16dim_float8.hnsw differ diff --git a/python/tests/indices/v0/cosine_4dim_e4m3.hnsw b/python/tests/indices/v0/cosine_4dim_e4m3.hnsw new file mode 100644 index 00000000..0a12066a Binary files /dev/null and b/python/tests/indices/v0/cosine_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/cosine_4dim_float32.hnsw b/python/tests/indices/v0/cosine_4dim_float32.hnsw new file mode 100644 index 00000000..8728101b Binary files /dev/null and b/python/tests/indices/v0/cosine_4dim_float32.hnsw differ diff --git a/python/tests/indices/v0/cosine_4dim_float8.hnsw b/python/tests/indices/v0/cosine_4dim_float8.hnsw new file mode 100644 index 00000000..930e14d7 Binary files /dev/null and b/python/tests/indices/v0/cosine_4dim_float8.hnsw differ diff --git a/python/tests/indices/v0/euclidean_128dim_e4m3.hnsw b/python/tests/indices/v0/euclidean_128dim_e4m3.hnsw new file mode 100644 index 00000000..bd602626 Binary files /dev/null and b/python/tests/indices/v0/euclidean_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/euclidean_128dim_float32.hnsw b/python/tests/indices/v0/euclidean_128dim_float32.hnsw new file mode 100644 index 00000000..b6b03390 Binary files /dev/null and b/python/tests/indices/v0/euclidean_128dim_float32.hnsw differ diff --git a/python/tests/indices/v0/euclidean_128dim_float8.hnsw b/python/tests/indices/v0/euclidean_128dim_float8.hnsw new file mode 100644 index 00000000..ffecbb2f Binary files /dev/null and b/python/tests/indices/v0/euclidean_128dim_float8.hnsw differ diff --git a/python/tests/indices/v0/euclidean_16dim_e4m3.hnsw b/python/tests/indices/v0/euclidean_16dim_e4m3.hnsw new file mode 100644 index 00000000..db80a128 Binary files /dev/null and b/python/tests/indices/v0/euclidean_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/euclidean_16dim_float32.hnsw b/python/tests/indices/v0/euclidean_16dim_float32.hnsw new file mode 100644 index 00000000..fafea806 Binary files /dev/null and b/python/tests/indices/v0/euclidean_16dim_float32.hnsw differ diff --git a/python/tests/indices/v0/euclidean_16dim_float8.hnsw b/python/tests/indices/v0/euclidean_16dim_float8.hnsw new file mode 100644 index 00000000..ad432d1c Binary files /dev/null and b/python/tests/indices/v0/euclidean_16dim_float8.hnsw differ diff --git a/python/tests/indices/v0/euclidean_4dim_e4m3.hnsw b/python/tests/indices/v0/euclidean_4dim_e4m3.hnsw new file mode 100644 index 00000000..2bf668b4 Binary files /dev/null and b/python/tests/indices/v0/euclidean_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/euclidean_4dim_float32.hnsw b/python/tests/indices/v0/euclidean_4dim_float32.hnsw new file mode 100644 index 00000000..859f4a3d Binary files /dev/null and b/python/tests/indices/v0/euclidean_4dim_float32.hnsw differ diff --git a/python/tests/indices/v0/euclidean_4dim_float8.hnsw b/python/tests/indices/v0/euclidean_4dim_float8.hnsw new file mode 100644 index 00000000..17d02b63 Binary files /dev/null and b/python/tests/indices/v0/euclidean_4dim_float8.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_128dim_e4m3.hnsw b/python/tests/indices/v0/innerproduct_128dim_e4m3.hnsw new file mode 100644 index 00000000..bd602626 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_128dim_float32.hnsw b/python/tests/indices/v0/innerproduct_128dim_float32.hnsw new file mode 100644 index 00000000..b6b03390 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_128dim_float32.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_128dim_float8.hnsw b/python/tests/indices/v0/innerproduct_128dim_float8.hnsw new file mode 100644 index 00000000..ffecbb2f Binary files /dev/null and b/python/tests/indices/v0/innerproduct_128dim_float8.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_16dim_e4m3.hnsw b/python/tests/indices/v0/innerproduct_16dim_e4m3.hnsw new file mode 100644 index 00000000..db80a128 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_16dim_float32.hnsw b/python/tests/indices/v0/innerproduct_16dim_float32.hnsw new file mode 100644 index 00000000..fafea806 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_16dim_float32.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_16dim_float8.hnsw b/python/tests/indices/v0/innerproduct_16dim_float8.hnsw new file mode 100644 index 00000000..ad432d1c Binary files /dev/null and b/python/tests/indices/v0/innerproduct_16dim_float8.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_4dim_e4m3.hnsw b/python/tests/indices/v0/innerproduct_4dim_e4m3.hnsw new file mode 100644 index 00000000..2bf668b4 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_4dim_float32.hnsw b/python/tests/indices/v0/innerproduct_4dim_float32.hnsw new file mode 100644 index 00000000..859f4a3d Binary files /dev/null and b/python/tests/indices/v0/innerproduct_4dim_float32.hnsw differ diff --git a/python/tests/indices/v0/innerproduct_4dim_float8.hnsw b/python/tests/indices/v0/innerproduct_4dim_float8.hnsw new file mode 100644 index 00000000..17d02b63 Binary files /dev/null and b/python/tests/indices/v0/innerproduct_4dim_float8.hnsw differ diff --git a/python/tests/indices/v1/cosine_128dim_e4m3.hnsw b/python/tests/indices/v1/cosine_128dim_e4m3.hnsw new file mode 100644 index 00000000..784167c2 Binary files /dev/null and b/python/tests/indices/v1/cosine_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/cosine_128dim_float32.hnsw b/python/tests/indices/v1/cosine_128dim_float32.hnsw new file mode 100644 index 00000000..79eed221 Binary files /dev/null and b/python/tests/indices/v1/cosine_128dim_float32.hnsw differ diff --git a/python/tests/indices/v1/cosine_128dim_float8.hnsw b/python/tests/indices/v1/cosine_128dim_float8.hnsw new file mode 100644 index 00000000..52ea9290 Binary files /dev/null and b/python/tests/indices/v1/cosine_128dim_float8.hnsw differ diff --git a/python/tests/indices/v1/cosine_16dim_e4m3.hnsw b/python/tests/indices/v1/cosine_16dim_e4m3.hnsw new file mode 100644 index 00000000..9d56e1d9 Binary files /dev/null and b/python/tests/indices/v1/cosine_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/cosine_16dim_float32.hnsw b/python/tests/indices/v1/cosine_16dim_float32.hnsw new file mode 100644 index 00000000..b2b8e615 Binary files /dev/null and b/python/tests/indices/v1/cosine_16dim_float32.hnsw differ diff --git a/python/tests/indices/v1/cosine_16dim_float8.hnsw b/python/tests/indices/v1/cosine_16dim_float8.hnsw new file mode 100644 index 00000000..975d8d38 Binary files /dev/null and b/python/tests/indices/v1/cosine_16dim_float8.hnsw differ diff --git a/python/tests/indices/v1/cosine_4dim_e4m3.hnsw b/python/tests/indices/v1/cosine_4dim_e4m3.hnsw new file mode 100644 index 00000000..2bcbdb39 Binary files /dev/null and b/python/tests/indices/v1/cosine_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/cosine_4dim_float32.hnsw b/python/tests/indices/v1/cosine_4dim_float32.hnsw new file mode 100644 index 00000000..cc47c79d Binary files /dev/null and b/python/tests/indices/v1/cosine_4dim_float32.hnsw differ diff --git a/python/tests/indices/v1/cosine_4dim_float8.hnsw b/python/tests/indices/v1/cosine_4dim_float8.hnsw new file mode 100644 index 00000000..272fd98d Binary files /dev/null and b/python/tests/indices/v1/cosine_4dim_float8.hnsw differ diff --git a/python/tests/indices/v1/euclidean_128dim_e4m3.hnsw b/python/tests/indices/v1/euclidean_128dim_e4m3.hnsw new file mode 100644 index 00000000..83d8f999 Binary files /dev/null and b/python/tests/indices/v1/euclidean_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/euclidean_128dim_float32.hnsw b/python/tests/indices/v1/euclidean_128dim_float32.hnsw new file mode 100644 index 00000000..148ad87a Binary files /dev/null and b/python/tests/indices/v1/euclidean_128dim_float32.hnsw differ diff --git a/python/tests/indices/v1/euclidean_128dim_float8.hnsw b/python/tests/indices/v1/euclidean_128dim_float8.hnsw new file mode 100644 index 00000000..c2289e4f Binary files /dev/null and b/python/tests/indices/v1/euclidean_128dim_float8.hnsw differ diff --git a/python/tests/indices/v1/euclidean_16dim_e4m3.hnsw b/python/tests/indices/v1/euclidean_16dim_e4m3.hnsw new file mode 100644 index 00000000..e1a047fb Binary files /dev/null and b/python/tests/indices/v1/euclidean_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/euclidean_16dim_float32.hnsw b/python/tests/indices/v1/euclidean_16dim_float32.hnsw new file mode 100644 index 00000000..489a05e2 Binary files /dev/null and b/python/tests/indices/v1/euclidean_16dim_float32.hnsw differ diff --git a/python/tests/indices/v1/euclidean_16dim_float8.hnsw b/python/tests/indices/v1/euclidean_16dim_float8.hnsw new file mode 100644 index 00000000..4372ca1d Binary files /dev/null and b/python/tests/indices/v1/euclidean_16dim_float8.hnsw differ diff --git a/python/tests/indices/v1/euclidean_4dim_e4m3.hnsw b/python/tests/indices/v1/euclidean_4dim_e4m3.hnsw new file mode 100644 index 00000000..d98de001 Binary files /dev/null and b/python/tests/indices/v1/euclidean_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/euclidean_4dim_float32.hnsw b/python/tests/indices/v1/euclidean_4dim_float32.hnsw new file mode 100644 index 00000000..d613009f Binary files /dev/null and b/python/tests/indices/v1/euclidean_4dim_float32.hnsw differ diff --git a/python/tests/indices/v1/euclidean_4dim_float8.hnsw b/python/tests/indices/v1/euclidean_4dim_float8.hnsw new file mode 100644 index 00000000..9b90f7f6 Binary files /dev/null and b/python/tests/indices/v1/euclidean_4dim_float8.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_128dim_e4m3.hnsw b/python/tests/indices/v1/innerproduct_128dim_e4m3.hnsw new file mode 100644 index 00000000..93dd1fd2 Binary files /dev/null and b/python/tests/indices/v1/innerproduct_128dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_128dim_float32.hnsw b/python/tests/indices/v1/innerproduct_128dim_float32.hnsw new file mode 100644 index 00000000..1275ae03 Binary files /dev/null and b/python/tests/indices/v1/innerproduct_128dim_float32.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_128dim_float8.hnsw b/python/tests/indices/v1/innerproduct_128dim_float8.hnsw new file mode 100644 index 00000000..a68b9def Binary files /dev/null and b/python/tests/indices/v1/innerproduct_128dim_float8.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_16dim_e4m3.hnsw b/python/tests/indices/v1/innerproduct_16dim_e4m3.hnsw new file mode 100644 index 00000000..a057599f Binary files /dev/null and b/python/tests/indices/v1/innerproduct_16dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_16dim_float32.hnsw b/python/tests/indices/v1/innerproduct_16dim_float32.hnsw new file mode 100644 index 00000000..963d2b1a Binary files /dev/null and b/python/tests/indices/v1/innerproduct_16dim_float32.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_16dim_float8.hnsw b/python/tests/indices/v1/innerproduct_16dim_float8.hnsw new file mode 100644 index 00000000..c8c1867f Binary files /dev/null and b/python/tests/indices/v1/innerproduct_16dim_float8.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_4dim_e4m3.hnsw b/python/tests/indices/v1/innerproduct_4dim_e4m3.hnsw new file mode 100644 index 00000000..3d7644f1 Binary files /dev/null and b/python/tests/indices/v1/innerproduct_4dim_e4m3.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_4dim_float32.hnsw b/python/tests/indices/v1/innerproduct_4dim_float32.hnsw new file mode 100644 index 00000000..831b27cc Binary files /dev/null and b/python/tests/indices/v1/innerproduct_4dim_float32.hnsw differ diff --git a/python/tests/indices/v1/innerproduct_4dim_float8.hnsw b/python/tests/indices/v1/innerproduct_4dim_float8.hnsw new file mode 100644 index 00000000..fcb898d7 Binary files /dev/null and b/python/tests/indices/v1/innerproduct_4dim_float8.hnsw differ diff --git a/python/tests/test_load_indices.py b/python/tests/test_load_indices.py new file mode 100644 index 00000000..1751c90d --- /dev/null +++ b/python/tests/test_load_indices.py @@ -0,0 +1,290 @@ +#! /usr/bin/env python +# +# Copyright 2022-2023 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import os +import struct +from io import BytesIO +import numpy as np +from glob import glob + +from voyager import Index, Space, StorageDataType + +INDEX_FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "indices") + + +def detect_space_from_filename(filename: str): + if "cosine" in filename: + return Space.Cosine + elif "innerproduct" in filename: + return Space.InnerProduct + elif "euclidean" in filename: + return Space.Euclidean + else: + raise ValueError(f"Not sure which space type is used in {filename}") + + +def detect_num_dimensions_from_filename(filename: str) -> int: + return int(filename.split("_")[1].split("dim")[0]) + + +def detect_storage_datatype_from_filename(filename: str) -> int: + storage_data_type = filename.split("_")[-1].split(".")[0].lower() + if storage_data_type == "float32": + return StorageDataType.Float32 + elif storage_data_type == "float8": + return StorageDataType.Float8 + elif storage_data_type == "e4m3": + return StorageDataType.E4M3 + else: + raise ValueError(f"Not sure which storage data type is used in {filename}") + + +@pytest.mark.parametrize("load_from_stream", [False, True]) +@pytest.mark.parametrize( + "index_filename", + # Both V0 and V1 indices should be loadable with this interface: + list(glob(os.path.join(INDEX_FIXTURE_DIR, "v0", "*.hnsw"))) + + glob(os.path.join(INDEX_FIXTURE_DIR, "v1", "*.hnsw")), +) +def test_load_v0_indices(load_from_stream: bool, index_filename: str): + space = detect_space_from_filename(index_filename) + num_dimensions = detect_num_dimensions_from_filename(index_filename) + if load_from_stream: + with open(index_filename, "rb") as f: + print(f.read(8)) + f.seek(0) + index = Index.load( + f, + space=space, + num_dimensions=num_dimensions, + storage_data_type=detect_storage_datatype_from_filename(index_filename), + ) + else: + index = Index.load( + index_filename, + space=space, + num_dimensions=num_dimensions, + storage_data_type=detect_storage_datatype_from_filename(index_filename), + ) + + # All of these test indices are expected to contain exactly 0.0, 0.1, 0.2, 0.3, 0.4 + assert set(index.ids) == {0, 1, 2, 3, 4} + for _id in index.ids: + expected_vector = np.ones(num_dimensions) * (_id * 0.1) + if space == Space.Cosine and _id > 0: + # Voyager stores only normalized vectors in Cosine mode: + expected_vector = expected_vector / np.sqrt(np.sum(expected_vector**2)) + np.testing.assert_allclose(index[_id], expected_vector, atol=0.2) + + +@pytest.mark.parametrize("load_from_stream", [False, True]) +@pytest.mark.parametrize("index_filename", glob(os.path.join(INDEX_FIXTURE_DIR, "v1", "*.hnsw"))) +def test_load_v1_indices(load_from_stream: bool, index_filename: str): + space = detect_space_from_filename(index_filename) + num_dimensions = detect_num_dimensions_from_filename(index_filename) + if load_from_stream: + with open(index_filename, "rb") as f: + index = Index.load(f) + else: + index = Index.load(index_filename) + + # All of these test indices are expected to contain exactly 0.0, 0.1, 0.2, 0.3, 0.4 + assert set(index.ids) == {0, 1, 2, 3, 4} + for _id in index.ids: + expected_vector = np.ones(num_dimensions) * (_id * 0.1) + if space == Space.Cosine and _id > 0: + # Voyager stores only normalized vectors in Cosine mode: + expected_vector = expected_vector / np.sqrt(np.sum(expected_vector**2)) + np.testing.assert_allclose(index[_id], expected_vector, atol=0.2) + + +@pytest.mark.parametrize("load_from_stream", [False, True]) +@pytest.mark.parametrize("index_filename", glob(os.path.join(INDEX_FIXTURE_DIR, "v1", "*.hnsw"))) +def test_v1_indices_must_have_no_parameters_or_must_match( + load_from_stream: bool, index_filename: str +): + space = detect_space_from_filename(index_filename) + num_dimensions = detect_num_dimensions_from_filename(index_filename) + storage_data_type = detect_storage_datatype_from_filename(index_filename) + with pytest.raises(ValueError) as exception: + if load_from_stream: + with open(index_filename, "rb") as f: + Index.load( + f, + space=space, + num_dimensions=num_dimensions + 1, + storage_data_type=storage_data_type, + ) + else: + Index.load( + index_filename, + space=space, + num_dimensions=num_dimensions + 1, + storage_data_type=storage_data_type, + ) + assert "number of dimensions" in repr(exception) + assert f"({num_dimensions})" in repr(exception) + assert f"({num_dimensions + 1})" in repr(exception) + + +@pytest.mark.parametrize( + "data,should_pass", + [ + ( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20", # Storage data type + False, + ), + ( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20" # Storage data type + b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetLevel0_ + b"\x01\x00\x00\x00\x00\x00\x00\x00" # max_elements_ + b"\x01\x00\x00\x00\x00\x00\x00\x00" # cur_element_count + b"\x34\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_ + b"\x00\x00\x00\x00" # maxlevel_ + b"\x00\x00\x00\x00" # enterpoint_node_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM0_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # M_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # mult_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # ef_construction_ + + (b"\x00" * 52) # one vector + + b"\x00\x00\x00\x00", # one linklist + True, + ), + ( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20" # Storage data type + b"\x00\x00\x00\xFF\x00\x00\x00\x00" # offsetLevel0_ + b"\x01\x00\x00\x00\x00\x00\x00\x00" # max_elements_ + b"\x01\x00\x00\x00\x00\x00\x00\x00" # cur_element_count + b"\x34\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_ + b"\x00\x00\x00\x00" # maxlevel_ + b"\x00\x00\x00\x00" # enterpoint_node_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM0_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # M_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # mult_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # ef_construction_ + + (b"\x00" * 52) # one vector + + (b"\x00\x00\x00\x00"), # one linklist + False, + ), + ( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20" # Storage data type + b"\x05\x00\x00\x00\x00\x00\x00\x00" # offsetLevel0_ + b"\x02\x00\x00\x00\x00\x00\x00\x00" # max_elements_ + b"\x02\x00\x00\x00\x00\x00\x00\x00" # cur_element_count + b"\x48\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_ + b"\x05\x00\x00\x00" # maxlevel_ + b"\x05\x00\x00\x00" # enterpoint_node_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # maxM_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # maxM0_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # M_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # mult_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # ef_construction_ + + (b"\x01" * 72) # one vector + + (b"\x01\x00\x00\x00" * 20) # one linklist + + b"\x00", + False, + ), + ( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20" # Storage data type + b"\x05\x00\x00\x00\x00\x00\x00\x00" # offsetLevel0_ + b"\x02\x00\x00\x00\x00\x00\x00\x00" # max_elements_ + b"\x02\x00\x00\x00\x00\x00\x00\x00" # cur_element_count + b"\x48\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_ + b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_ + b"\x05\x00\x00\x00" # maxlevel_ + b"\x01\x00\x00\x00" # enterpoint_node_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # maxM_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # maxM0_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # M_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # mult_ + b"\x05\x00\x00\x00\x00\x00\x00\x00" # ef_construction_ + + (b"\x01" * 72) # one vector + + (b"\x01\x00\x00\x00" * 20) # one linklist + + b"\x00", + False, + ), + ], +) +def test_loading_invalid_data_cannot_crash(data: bytes, should_pass: bool): + if should_pass: + index = Index.load(BytesIO(data)) + assert len(index) == 1 + np.testing.assert_allclose(index[0], np.zeros(index.num_dimensions)) + else: + with pytest.raises(Exception): + index = Index.load(BytesIO(data)) + # We shoulnd't get here, but if we do: do we segfault? + for id in index.ids: + index.query(index[id]) + + +@pytest.mark.parametrize("seed", range(1000)) +@pytest.mark.parametrize( + "with_valid_header,offset_level_0", + [(True, 500_000), (True, None), (False, None)], +) +def test_fuzz(seed: int, with_valid_header: bool, offset_level_0: int): + """ + Send in 10,000 randomly-generated indices to ensure that the process doesn't crash + """ + np.random.seed(seed) + num_bytes = np.random.randint(1_000_000) + random_data = BytesIO((np.random.rand(num_bytes) * 255).astype(np.uint8).tobytes()) + if with_valid_header: + random_data.seek(0) + random_data.write( + b"VOYA" # Header + b"\x01\x00\x00\x00" # File version + b"\x0A\x00\x00\x00" # Number of dimensions (10) + b"\x00" # Space type + b"\x20" # Storage data type + ) + if offset_level_0: + random_data.write(struct.pack("=Q", offset_level_0)) + random_data.seek(0) + with pytest.raises(Exception): + Index.load(random_data)