diff --git a/cpp/Enums.h b/cpp/Enums.h index 735d6dd6..95d11b74 100644 --- a/cpp/Enums.h +++ b/cpp/Enums.h @@ -22,3 +22,29 @@ enum class StorageDataType : unsigned char { // allowing representation of values from 2e-9 to 448. E4M3 = 3 << 4, }; + +inline const std::string toString(StorageDataType sdt) { + switch (sdt) { + case StorageDataType::Float8: + return "Float8"; + case StorageDataType::Float32: + return "Float32"; + case StorageDataType::E4M3: + return "E4M3"; + default: + return "Unknown storage data type (value " + std::to_string((int)sdt) + ")"; + } +} + +inline const std::string toString(SpaceType space) { + switch (space) { + case SpaceType::Euclidean: + return "Euclidean"; + case SpaceType::Cosine: + return "Cosine"; + case SpaceType::InnerProduct: + return "InnerProduct"; + default: + return "Unknown space type (value " + std::to_string((int)space) + ")"; + } +} \ No newline at end of file diff --git a/cpp/TypedIndex.h b/cpp/TypedIndex.h index 719bacd7..510d7288 100644 --- a/cpp/TypedIndex.h +++ b/cpp/TypedIndex.h @@ -29,6 +29,7 @@ #include #include "E4M3.h" +#include "Enums.h" #include "Index.h" #include "Metadata.h" #include "array_utils.h" diff --git a/java/com_spotify_voyager_jni_Index.cpp b/java/com_spotify_voyager_jni_Index.cpp index 029ca6f0..dcbae44d 100644 --- a/java/com_spotify_voyager_jni_Index.cpp +++ b/java/com_spotify_voyager_jni_Index.cpp @@ -21,6 +21,7 @@ #include "com_spotify_voyager_jni_Index.h" #include "JavaInputStream.h" #include "JavaOutputStream.h" +#include #include #include @@ -709,16 +710,23 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromFileWithParameters( if (metadata) { if (metadata->getStorageDataType() != toStorageDataType(env, storageDataType)) { - throw std::domain_error("Provided storage data type does not match " - "the data type used in this file."); + throw std::domain_error( + "Provided storage data type (" + + toString(toStorageDataType(env, storageDataType)) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { - throw std::domain_error("Provided space type does not match " - "the space type used in this file."); + throw std::domain_error( + "Provided space type (" + toString(toSpaceType(env, spaceType)) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != numDimensions) { - throw std::domain_error("Provided number of dimensions does not match " - "the number of dimensions used in this file."); + throw std::domain_error( + "Provided number of dimensions (" + std::to_string(numDimensions) + + ") does not match the number of dimensions used in this file (" + + std::to_string(metadata->getNumDimensions()) + ")."); } } @@ -760,16 +768,23 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStreamWithParameters( if (metadata) { if (metadata->getStorageDataType() != toStorageDataType(env, storageDataType)) { - throw std::domain_error("Provided storage data type does not match " - "the data type used in this file."); + throw std::domain_error( + "Provided storage data type (" + + toString(toStorageDataType(env, storageDataType)) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { - throw std::domain_error("Provided space type does not match " - "the space type used in this file."); + throw std::domain_error( + "Provided space type (" + toString(toSpaceType(env, spaceType)) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != numDimensions) { - throw std::domain_error("Provided number of dimensions does not match " - "the number of dimensions used in this file."); + throw std::domain_error( + "Provided number of dimensions (" + std::to_string(numDimensions) + + ") does not match the number of dimensions used in this file (" + + std::to_string(metadata->getNumDimensions()) + ")."); } } diff --git a/python/bindings.cpp b/python/bindings.cpp index 17d89c4c..fb5d68ad 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -911,17 +911,24 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). if (metadata) { if (metadata->getStorageDataType() != storageDataType) { - throw std::domain_error("Provided storage data type does not match " - "the data type used in this file."); + throw std::domain_error( + "Provided storage data type (" + toString(storageDataType) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != space) { - throw std::domain_error("Provided space type does not match " - "the space type used in this file."); + throw std::domain_error( + "Provided space type (" + toString(space) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != num_dimensions) { throw std::domain_error( - "Provided number of dimensions does not match " - "the number of dimensions used in this file."); + "Provided number of dimensions (" + + std::to_string(num_dimensions) + + ") does not match the number of dimensions used in this file " + "(" + + std::to_string(metadata->getNumDimensions()) + ")."); } } @@ -973,17 +980,24 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.). if (metadata) { if (metadata->getStorageDataType() != storageDataType) { - throw std::domain_error("Provided storage data type does not match " - "the data type used in this file."); + throw std::domain_error( + "Provided storage data type (" + toString(storageDataType) + + ") does not match the data type used in this file (" + + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != space) { - throw std::domain_error("Provided space type does not match " - "the space type used in this file."); + throw std::domain_error( + "Provided space type (" + toString(space) + + ") does not match the space type used in this file (" + + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != num_dimensions) { throw std::domain_error( - "Provided number of dimensions does not match " - "the number of dimensions used in this file."); + "Provided number of dimensions (" + + std::to_string(num_dimensions) + + ") does not match the number of dimensions used in this file " + "(" + + std::to_string(metadata->getNumDimensions()) + ")."); } } diff --git a/python/tests/test_load_indices.py b/python/tests/test_load_indices.py index e9516a70..3c17701e 100644 --- a/python/tests/test_load_indices.py +++ b/python/tests/test_load_indices.py @@ -112,6 +112,35 @@ def test_load_v1_indices(load_from_stream: bool, index_filename: str): np.testing.assert_allclose(index[_id], expected_vector, atol=0.2) +@pytest.mark.parametrize("load_from_stream", [False, True]) +@pytest.mark.parametrize("index_filename", glob(os.path.join(INDEX_FIXTURE_DIR, "v1", "*.hnsw"))) +def test_v1_indices_must_have_no_parameters_or_must_match( + load_from_stream: bool, index_filename: str +): + space = detect_space_from_filename(index_filename) + num_dimensions = detect_num_dimensions_from_filename(index_filename) + storage_data_type = detect_storage_datatype_from_filename(index_filename) + with pytest.raises(ValueError) as exception: + if load_from_stream: + with open(index_filename, "rb") as f: + Index.load( + f, + space=space, + num_dimensions=num_dimensions + 1, + storage_data_type=storage_data_type, + ) + else: + Index.load( + index_filename, + space=space, + num_dimensions=num_dimensions + 1, + storage_data_type=storage_data_type, + ) + assert "number of dimensions" in repr(exception) + assert f"({num_dimensions})" in repr(exception) + assert f"({num_dimensions + 1})" in repr(exception) + + @pytest.mark.parametrize( "data,should_pass", [