diff --git a/cpp/src/TypedIndex.h b/cpp/src/TypedIndex.h index c0f5debc..63066a40 100644 --- a/cpp/src/TypedIndex.h +++ b/cpp/src/TypedIndex.h @@ -53,29 +53,6 @@ template <> const std::string storageDataTypeName() { return "Float8"; } template <> const std::string storageDataTypeName() { return "Float32"; } template <> const std::string storageDataTypeName() { return "E4M3"; } -template -dist_t ensureNotNegative(dist_t distance, hnswlib::labeltype label) { - if constexpr (std::is_same_v) { - // Allow for a very slight negative distance if using E4M3 - if (distance < 0 && distance >= -0.14) { - return 0; - } - } - - if (distance < 0) { - if (distance >= -0.00001) { - return 0; - } - - throw std::runtime_error( - "Potential candidate (with label '" + std::to_string(label) + - "') had negative distance " + std::to_string(distance) + - ". This may indicate a corrupted index file."); - } - - return distance; -} - /** * A C++ wrapper class for a typed HNSW index. * @@ -402,7 +379,7 @@ class TypedIndex : public Index { floatToDataType(&inputArray[startIndex], &convertedArray[startIndex], actualDimensions); - size_t id = ids.size() ? ids.at(row) : (currentLabel + row); + size_t id = ids.size() ? ids.at(row) : (currentLabel.fetch_add(1)); try { algorithmImpl->addPoint(convertedArray.data() + startIndex, id); } catch (IndexFullError &e) { @@ -438,7 +415,7 @@ class TypedIndex : public Index { normalizeVector( &inputArray[startIndex], &normalizedArray[startIndex], actualDimensions); - size_t id = ids.size() ? ids.at(row) : (currentLabel + row); + size_t id = ids.size() ? ids.at(row) : (currentLabel.fetch_add(1)); try { algorithmImpl->addPoint(normalizedArray.data() + startIndex, id); @@ -629,8 +606,7 @@ class TypedIndex : public Index { dist_t distance = result_tuple.first; hnswlib::labeltype label = result_tuple.second; - distancePointer[row * k + i] = - ensureNotNegative(distance, label); + distancePointer[row * k + i] = distance; labelPointer[row * k + i] = label; result.pop(); } @@ -704,8 +680,7 @@ class TypedIndex : public Index { for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); - distancePointer[i] = ensureNotNegative( - result_tuple.first, result_tuple.second); + distancePointer[i] = result_tuple.first; labelPointer[i] = result_tuple.second; result.pop(); }