Skip to content

Commit

Permalink
Add feedback to error messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
psobot committed Oct 3, 2023
1 parent 53cddf7 commit 306f157
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 24 deletions.
26 changes: 26 additions & 0 deletions cpp/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,29 @@ enum class StorageDataType : unsigned char {
// allowing representation of values from 2e-9 to 448.
E4M3 = 3 << 4,
};

inline const std::string toString(StorageDataType sdt) {
switch (sdt) {
case StorageDataType::Float8:
return "Float8";
case StorageDataType::Float32:
return "Float32";
case StorageDataType::E4M3:
return "E4M3";
default:
return "Unknown storage data type (value " + std::to_string((int)sdt) + ")";
}
}

inline const std::string toString(SpaceType space) {
switch (space) {
case SpaceType::Euclidean:
return "Euclidean";
case SpaceType::Cosine:
return "Cosine";
case SpaceType::InnerProduct:
return "InnerProduct";
default:
return "Unknown space type (value " + std::to_string((int)space) + ")";
}
}
1 change: 1 addition & 0 deletions cpp/TypedIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <ratio>

#include "E4M3.h"
#include "Enums.h"
#include "Index.h"
#include "Metadata.h"
#include "array_utils.h"
Expand Down
39 changes: 27 additions & 12 deletions java/com_spotify_voyager_jni_Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "com_spotify_voyager_jni_Index.h"
#include "JavaInputStream.h"
#include "JavaOutputStream.h"
#include <Enums.h>
#include <Index.h>
#include <TypedIndex.h>

Expand Down Expand Up @@ -709,16 +710,23 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromFileWithParameters(
if (metadata) {
if (metadata->getStorageDataType() !=
toStorageDataType(env, storageDataType)) {
throw std::domain_error("Provided storage data type does not match "
"the data type used in this file.");
throw std::domain_error(
"Provided storage data type (" +
toString(toStorageDataType(env, storageDataType)) +
") does not match the data type used in this file (" +
toString(metadata->getStorageDataType()) + ").");
}
if (metadata->getSpaceType() != toSpaceType(env, spaceType)) {
throw std::domain_error("Provided space type does not match "
"the space type used in this file.");
throw std::domain_error(
"Provided space type (" + toString(toSpaceType(env, spaceType)) +
") does not match the space type used in this file (" +
toString(metadata->getSpaceType()) + ").");
}
if (metadata->getNumDimensions() != numDimensions) {
throw std::domain_error("Provided number of dimensions does not match "
"the number of dimensions used in this file.");
throw std::domain_error(
"Provided number of dimensions (" + std::to_string(numDimensions) +
") does not match the number of dimensions used in this file (" +
std::to_string(metadata->getNumDimensions()) + ").");
}
}

Expand Down Expand Up @@ -760,16 +768,23 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStreamWithParameters(
if (metadata) {
if (metadata->getStorageDataType() !=
toStorageDataType(env, storageDataType)) {
throw std::domain_error("Provided storage data type does not match "
"the data type used in this file.");
throw std::domain_error(
"Provided storage data type (" +
toString(toStorageDataType(env, storageDataType)) +
") does not match the data type used in this file (" +
toString(metadata->getStorageDataType()) + ").");
}
if (metadata->getSpaceType() != toSpaceType(env, spaceType)) {
throw std::domain_error("Provided space type does not match "
"the space type used in this file.");
throw std::domain_error(
"Provided space type (" + toString(toSpaceType(env, spaceType)) +
") does not match the space type used in this file (" +
toString(metadata->getSpaceType()) + ").");
}
if (metadata->getNumDimensions() != numDimensions) {
throw std::domain_error("Provided number of dimensions does not match "
"the number of dimensions used in this file.");
throw std::domain_error(
"Provided number of dimensions (" + std::to_string(numDimensions) +
") does not match the number of dimensions used in this file (" +
std::to_string(metadata->getNumDimensions()) + ").");
}
}

Expand Down
38 changes: 26 additions & 12 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,17 +911,24 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.).

if (metadata) {
if (metadata->getStorageDataType() != storageDataType) {
throw std::domain_error("Provided storage data type does not match "
"the data type used in this file.");
throw std::domain_error(
"Provided storage data type (" + toString(storageDataType) +
") does not match the data type used in this file (" +
toString(metadata->getStorageDataType()) + ").");
}
if (metadata->getSpaceType() != space) {
throw std::domain_error("Provided space type does not match "
"the space type used in this file.");
throw std::domain_error(
"Provided space type (" + toString(space) +
") does not match the space type used in this file (" +
toString(metadata->getSpaceType()) + ").");
}
if (metadata->getNumDimensions() != num_dimensions) {
throw std::domain_error(
"Provided number of dimensions does not match "
"the number of dimensions used in this file.");
"Provided number of dimensions (" +
std::to_string(num_dimensions) +
") does not match the number of dimensions used in this file "
"(" +
std::to_string(metadata->getNumDimensions()) + ").");
}
}

Expand Down Expand Up @@ -973,17 +980,24 @@ binary data (i.e.: ``open(..., \"rb\")`` or ``io.BinaryIO``, etc.).

if (metadata) {
if (metadata->getStorageDataType() != storageDataType) {
throw std::domain_error("Provided storage data type does not match "
"the data type used in this file.");
throw std::domain_error(
"Provided storage data type (" + toString(storageDataType) +
") does not match the data type used in this file (" +
toString(metadata->getStorageDataType()) + ").");
}
if (metadata->getSpaceType() != space) {
throw std::domain_error("Provided space type does not match "
"the space type used in this file.");
throw std::domain_error(
"Provided space type (" + toString(space) +
") does not match the space type used in this file (" +
toString(metadata->getSpaceType()) + ").");
}
if (metadata->getNumDimensions() != num_dimensions) {
throw std::domain_error(
"Provided number of dimensions does not match "
"the number of dimensions used in this file.");
"Provided number of dimensions (" +
std::to_string(num_dimensions) +
") does not match the number of dimensions used in this file "
"(" +
std::to_string(metadata->getNumDimensions()) + ").");
}
}

Expand Down
29 changes: 29 additions & 0 deletions python/tests/test_load_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ def test_load_v1_indices(load_from_stream: bool, index_filename: str):
np.testing.assert_allclose(index[_id], expected_vector, atol=0.2)


@pytest.mark.parametrize("load_from_stream", [False, True])
@pytest.mark.parametrize("index_filename", glob(os.path.join(INDEX_FIXTURE_DIR, "v1", "*.hnsw")))
def test_v1_indices_must_have_no_parameters_or_must_match(
load_from_stream: bool, index_filename: str
):
space = detect_space_from_filename(index_filename)
num_dimensions = detect_num_dimensions_from_filename(index_filename)
storage_data_type = detect_storage_datatype_from_filename(index_filename)
with pytest.raises(ValueError) as exception:
if load_from_stream:
with open(index_filename, "rb") as f:
Index.load(
f,
space=space,
num_dimensions=num_dimensions + 1,
storage_data_type=storage_data_type,
)
else:
Index.load(
index_filename,
space=space,
num_dimensions=num_dimensions + 1,
storage_data_type=storage_data_type,
)
assert "number of dimensions" in repr(exception)
assert f"({num_dimensions})" in repr(exception)
assert f"({num_dimensions + 1})" in repr(exception)


@pytest.mark.parametrize(
"data,should_pass",
[
Expand Down

0 comments on commit 306f157

Please sign in to comment.