From fbf112a727b668dec62c2908462dc4dd3f494f3d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:13:00 -0500 Subject: [PATCH 1/4] Binary vector format for flat and hnsw vectors --- lucene/core/src/java/module-info.java | 5 +- .../lucene/codecs/lucene101/package-info.java | 418 +------- .../lucene102/BinarizedByteVectorValues.java | 84 ++ .../Lucene102BinaryFlatVectorsScorer.java | 192 ++++ ...Lucene102BinaryQuantizedVectorsFormat.java | 135 +++ ...Lucene102BinaryQuantizedVectorsReader.java | 424 ++++++++ ...Lucene102BinaryQuantizedVectorsWriter.java | 950 ++++++++++++++++++ ...ne102HnswBinaryQuantizedVectorsFormat.java | 154 +++ .../OffHeapBinarizedVectorValues.java | 383 +++++++ .../lucene/codecs/lucene102/package-info.java | 436 ++++++++ .../DefaultVectorUtilSupport.java | 27 + .../vectorization/VectorUtilSupport.java | 13 + .../org/apache/lucene/util/VectorUtil.java | 15 + .../OptimizedScalarQuantizer.java | 371 +++++++ .../PanamaVectorUtilSupport.java | 117 +++ .../org.apache.lucene.codecs.KnnVectorsFormat | 2 + ...Lucene102BinaryQuantizedVectorsFormat.java | 179 ++++ ...ne102HnswBinaryQuantizedVectorsFormat.java | 137 +++ .../TestOptimizedScalarQuantizer.java | 149 +++ 19 files changed, 3773 insertions(+), 418 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 85aff5722498..ac321b12e8fa 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -32,6 +32,7 @@ exports org.apache.lucene.codecs.lucene95; exports org.apache.lucene.codecs.lucene99; exports org.apache.lucene.codecs.lucene101; + exports org.apache.lucene.codecs.lucene102; exports org.apache.lucene.codecs.perfield; exports org.apache.lucene.codecs; exports org.apache.lucene.document; @@ -76,7 +77,9 @@ provides org.apache.lucene.codecs.KnnVectorsFormat with org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat, org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java index e582f12c3185..8aa1c3b43a0c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java @@ -15,421 +15,5 @@ * limitations under the License. */ -/** - * Lucene 10.1 file format. - * - *

Apache Lucene - Index File Formats

- * - *
- * - * - * - *
- * - *

Introduction

- * - *
- * - *

This document defines the index file formats used in this version of Lucene. If you are using - * a different version of Lucene, please consult the copy of docs/ that was distributed - * with the version you are using. - * - *

This document attempts to provide a high-level definition of the Apache Lucene file formats. - *

- * - *

Definitions

- * - *
- * - *

The fundamental concepts in Lucene are index, document, field and term. - * - *

An index contains a sequence of documents. - * - *

- * - *

The same sequence of bytes in two different fields is considered a different term. Thus terms - * are represented as a pair: the string naming the field, and the bytes within the field. - * - *

Inverted Indexing

- * - *

Lucene's index stores terms and statistics about those terms in order to make term-based - * search more efficient. Lucene's terms index falls into the family of indexes known as an - * inverted index. This is because it can list, for a term, the documents that contain it. - * This is the inverse of the natural relationship, in which documents list terms. - * - *

Types of Fields

- * - *

In Lucene, fields may be stored, in which case their text is stored in the index - * literally, in a non-inverted manner. Fields that are inverted are called indexed. A field - * may be both stored and indexed. - * - *

The text of a field may be tokenized into terms to be indexed, or the text of a field - * may be used literally as a term to be indexed. Most fields are tokenized, but sometimes it is - * useful for certain identifier fields to be indexed literally. - * - *

See the {@link org.apache.lucene.document.Field Field} java docs for more information on - * Fields. - * - *

Segments

- * - *

Lucene indexes may be composed of multiple sub-indexes, or segments. Each segment is a - * fully independent index, which could be searched separately. Indexes evolve by: - * - *

    - *
  1. Creating new segments for newly added documents. - *
  2. Merging existing segments. - *
- * - *

Searches may involve multiple segments and/or multiple indexes, each index potentially - * composed of a set of segments. - * - *

Document Numbers

- * - *

Internally, Lucene refers to documents by an integer document number. The first - * document added to an index is numbered zero, and each subsequent document added gets a number one - * greater than the previous. - * - *

Note that a document's number may change, so caution should be taken when storing these - * numbers outside of Lucene. In particular, numbers may change in the following situations: - * - *

- * - *
- * - *

Index Structure Overview

- * - *
- * - *

Each segment index maintains the following: - * - *

- * - *

Details on each of these are provided in their linked pages.

- * - *

File Naming

- * - *
- * - *

All files belonging to a segment have the same name with varying extensions. The extensions - * correspond to the different file formats described below. When using the Compound File format - * (default for small segments) these files (except for the Segment info file, the Lock file, and - * Deleted documents file) are collapsed into a single .cfs file (see below for details) - * - *

Typically, all segments in an index are stored in a single directory, although this is not - * required. - * - *

File names are never re-used. That is, when any file is saved to the Directory it is given a - * never before used filename. This is achieved using a simple generations approach. For example, - * the first segments file is segments_1, then segments_2, etc. The generation is a sequential long - * integer represented in alpha-numeric (base 36) form.

- * - *

Summary of File Extensions

- * - *
- * - *

The following table summarizes the names and extensions of the files in Lucene: - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - *
lucene filenames by extension
NameExtensionBrief Description
{@link org.apache.lucene.index.SegmentInfos Segments File}segments_NStores information about a commit point
Lock Filewrite.lockThe Write lock prevents multiple IndexWriters from writing to the same - * file.
{@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment Info}.siStores metadata about a segment
{@link org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat Compound File}.cfs, .cfeAn optional "virtual" file consisting of all the other index files for - * systems that frequently run out of file handles.
{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}.fnmStores information about the fields
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Index}.fdxContains pointers to field data
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Data}.fdtThe stored fields for documents
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Dictionary}.timThe term dictionary, stores term info
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Index}.tipThe index into the Term Dictionary
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Frequencies}.docContains the list of docs which contain each term along with frequency
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Positions}.posStores position information about where a term occurs in the index
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Payloads}.payStores additional per-position metadata information such as character offsets and user payloads
{@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Norms}.nvd, .nvmEncodes length and boost factors for docs and fields
{@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-Document Values}.dvd, .dvmEncodes additional scoring factors or other per-document information.
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Index}.tvxStores offset into the document data file
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Data}.tvdContains term vector data.
{@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live Documents}.livInfo about what documents are live
{@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}.kdd, .kdi, .kdmHolds indexed points
{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}.vec, .vem, .veq, vexHolds indexed vectors; .vec files contain the raw vector data, - * .vem the vector metadata, .veq the quantized vector data, and .vex the - * hnsw graph data.
- * - *

- * - *

Lock File

- * - * The write lock, which is stored in the index directory by default, is named "write.lock". If the - * lock directory is different from the index directory then the write lock will be named - * "XXXX-write.lock" where XXXX is a unique prefix derived from the full path to the index - * directory. When this file is present, a writer is currently modifying the index (adding or - * removing documents). This lock file ensures that only one writer is modifying the index at a - * time. - * - *

History

- * - *

Compatibility notes are provided in this document, describing how file formats have changed - * from prior versions: - * - *

- * - * - * - *

Limitations

- * - *
- * - *

Lucene uses a Java int to refer to document numbers, and the index file format - * uses an Int32 on-disk to store document numbers. This is a limitation of both the - * index file format and the current implementation. Eventually these should be replaced with either - * UInt64 values, or better yet, {@link org.apache.lucene.store.DataOutput#writeVInt - * VInt} values which have no limit.

- */ +/** Lucene 10.1 file format. */ package org.apache.lucene.codecs.lucene101; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java new file mode 100644 index 000000000000..62eff4d72ac4 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; + +import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Binarized byte vector values */ +abstract class BinarizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + * + * + * For euclidean: + * + * + * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) + throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + int discretizedDimensions() { + return discretize(dimension(), 64); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java new file mode 100644 index 000000000000..a25fc0807563 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult; + +/** Vector scorer over binarized vector values */ +public class Lucene102BinaryFlatVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public Lucene102BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + byte[] initial = new byte[target.length]; + byte[] quantized = new byte[QUERY_BITS * binarizedVectors.discretizedDimensions() / 8]; + OptimizedScalarQuantizer.QuantizationResult queryCorrections = + quantizer.scalarQuantize(target, initial, (byte) 4, centroid); + transposeHalfByte(initial, quantized); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors) { + return new BinarizedRandomVectorScorerSupplier( + scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues + queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord); + BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier( + queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + public record BinaryQueryVector( + byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {} + + /** Vector scorer over binarized vector values */ + public static class BinarizedRandomVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + public BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = VectorUtil.int4BitDotProduct(quantizedQuery, binaryCode); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = + queryVector.quantizationResult(); + OptimizedScalarQuantizer.QuantizationResult indexCorrections = + targetVectors.getCorrectiveTerms(targetOrd); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = indexCorrections.upperInterval() - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = + ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += + queryCorrections.additionalCorrection() + + indexCorrections.additionalCorrection() + - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..ae48a220b235 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 Codec for + * encoding/decoding binary quantized vectors The binary quantization format used here is a + * per-vector optimized scalar quantization. Also see {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: + * + * + * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ */ +public class Lucene102BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final byte QUERY_BITS = 4; + public static final byte INDEX_BITS = 1; + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "Lucene102BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private static final Lucene102BinaryFlatVectorsScorer scorer = + new Lucene102BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + /** Creates a new instance with the default number of vectors per cluster. */ + public Lucene102BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene102BinaryQuantizedVectorsWriter( + scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene102BinaryQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene102BinaryQuantizedVectorsFormat(name=" + + NAME + + ", flatVectorScorer=" + + scorer + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java new file mode 100644 index 000000000000..a698ea9a7541 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Reader for binary quantized vectors in the Lucene 10.2 format. */ +class Lucene102BinaryQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene102BinaryQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final Lucene102BinaryFlatVectorsScorer vectorScorer; + + Lucene102BinaryQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene102BinaryFlatVectorsScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene102BinaryQuantizedVectorsFormat.META_EXTENSION); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene102BinaryQuantizedVectorsFormat.META_CODEC_NAME, + Lucene102BinaryQuantizedVectorsFormat.VERSION_START, + Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = + openDataInput( + state, + versionMeta, + Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withReadAdvice(ReadAdvice.RANDOM)); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + int binaryDims = discretize(dimension, 64) / 8; + long numQuantizedVectorBytes = + Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 14" + + ") = " + + numQuantizedVectorBytes); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + OffHeapBinarizedVectorValues bvv = + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene102BinaryQuantizedVectorsFormat.VERSION_START, + Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + + static FieldEntry create( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = + OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf); + } + } + + /** Binarized vector values holding row and quantized vector values */ + protected static final class BinarizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; + + BinarizedVectorValues( + FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java new file mode 100644 index 000000000000..e15a8dab31b4 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java @@ -0,0 +1,950 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ +public class Lucene102BinaryQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene102BinaryQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final Lucene102BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected Lucene102BinaryQuantizedVectorsWriter( + Lucene102BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state) + throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene102BinaryQuantizedVectorsFormat.META_EXTENSION); + + String binarizedVectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = + state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene102BinaryQuantizedVectorsFormat.META_CODEC_NAME, + Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + binarizedVectorData, + Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField( + FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet()); + } + + private void writeBinarizedVectors( + FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, quantizationScratch, INDEX_BITS, clusterCenter); + packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 + && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + clusterCenter, + centroidDp, + newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = + scalarQuantizer.scalarQuantize(v, quantizationScratch, INDEX_BITS, clusterCenter); + packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 + && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, binarizedVectorData, count, maxDoc, docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + BinarizedFloatVectorValues binarizedVectorValues = + new BinarizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = + writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + OptimizedScalarQuantizer binaryQuantizer) + throws IOException { + int discretizedDimension = discretize(floatVectorValues.dimension(), 64); + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()]; + byte[] toIndex = new byte[discretizedDimension / 8]; + byte[] toQuery = new byte[(discretizedDimension / 8) * QUERY_BITS]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult[] r = + binaryQuantizer.multiScalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + new byte[] {INDEX_BITS, QUERY_BITS}, + centroid); + // pack and store document bit vector + packAsBinary(quantizationScratch[0], toIndex); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); + assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); + docsWithField.add(docV); + + // pack and store the 4bit query vector + transposeHalfByte(quantizationScratch[1], toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); + assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; + binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData( + IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = + binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 + && corrections.quantizedComponentSum() <= 0xffff; + output.writeShort((short) corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC) + throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + final IndexOutput tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), "temp", segmentWriteState.context); + final IndexOutput tempScoreQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), "score_temp", segmentWriteState.context); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + boolean success = false; + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + FloatVectorValues floatVectorValues = + MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = + writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = + segmentWriteState.directory.openInput( + tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes( + binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = + segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), segmentWriteState.context); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = + new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput); + RandomVectorScorerSupplier scorerSupplier = + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality()), + vectorValues); + return new BinarizedCloseableRandomVectorScorerSupplier( + scorerSupplier, + vectorValues, + () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName()); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName()); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof Lucene102BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids( + MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null + || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + // When accessing vectorValue method, targerOrd here means a row ordinal. + static class OffHeapBinarizedQueryVectorValues { + private final IndexInput slice; + private final int dimension; + private final int size; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + private final int byteSize; + protected final float[] correctiveValues; + private int lastOrd = -1; + private int quantizedComponentSum; + + OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + this.slice = data; + this.dimension = dimension; + this.size = size; + // 4x the quantized binary dimensions + int binaryDimensions = (discretize(dimension, 64) / 8) * QUERY_BITS; + this.byteBuffer = ByteBuffer.allocate(binaryDimensions); + this.binaryValue = byteBuffer.array(); + // + 1 for the quantized sum + this.correctiveValues = new float[3]; + this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + vectorValue(targetOrd); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + public int size() { + return size; + } + + public int dimension() { + return dimension; + } + + public OffHeapBinarizedQueryVectorValues copy() throws IOException { + return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size); + } + + public IndexInput getSlice() { + return slice; + } + + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(binaryValue, 0, binaryValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + } + + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private final float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues( + FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + this.centroid = centroid; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = + quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); + packAsBinary(initQuantized, binarized); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class BinarizedCloseableRandomVectorScorerSupplier + implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + BinarizedCloseableRandomVectorScorerSupplier( + RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..55fe8807356b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary + * quantized using {@link Lucene102BinaryQuantizedVectorsFormat} before being stored in the graph. + */ +public class Lucene102HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "Lucene102HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = + new Lucene102BinaryQuantizedVectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene102HnswBinaryQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene102HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene102HnswBinaryQuantizedVectorsFormat( + int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java new file mode 100644 index 000000000000..13eef8b263b5 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +/** + * Binarized vector values loaded from off-heap + * + * @lucene.internal + */ +public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { + + final int dimension; + final int size; + final int numBytes; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] binaryValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer binaryQuantizer; + final float[] centroid; + final float centroidDp; + private final int discretizedDimensions; + + OffHeapBinarizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.numBytes = discretize(dimension, 64) / 8; + this.correctiveValues = new float[3]; + this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES; + this.byteBuffer = ByteBuffer.allocate(numBytes); + this.binaryValue = byteBuffer.array(); + this.binaryQuantizer = quantizer; + this.discretizedDimensions = discretize(dimension, 64); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public int discretizedDimensions() { + return discretizedDimensions; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) + throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return binaryQuantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + static OffHeapBinarizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap binarized vector values */ + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java new file mode 100644 index 000000000000..92d9de567a20 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Lucene 10.2 file format. + * + *

Apache Lucene - Index File Formats

+ * + * + * + *

Introduction

+ * + *
+ * + *

This document defines the index file formats used in this version of Lucene. If you are using + * a different version of Lucene, please consult the copy of docs/ that was distributed + * with the version you are using. + * + *

This document attempts to provide a high-level definition of the Apache Lucene file formats. + *

+ * + *

Definitions

+ * + *
+ * + *

The fundamental concepts in Lucene are index, document, field and term. + * + *

An index contains a sequence of documents. + * + *

    + *
  • A document is a sequence of fields. + *
  • A field is a named sequence of terms. + *
  • A term is a sequence of bytes. + *
+ * + *

The same sequence of bytes in two different fields is considered a different term. Thus terms + * are represented as a pair: the string naming the field, and the bytes within the field. + * + *

Inverted Indexing

+ * + *

Lucene's index stores terms and statistics about those terms in order to make term-based + * search more efficient. Lucene's terms index falls into the family of indexes known as an + * inverted index. This is because it can list, for a term, the documents that contain it. + * This is the inverse of the natural relationship, in which documents list terms. + * + *

Types of Fields

+ * + *

In Lucene, fields may be stored, in which case their text is stored in the index + * literally, in a non-inverted manner. Fields that are inverted are called indexed. A field + * may be both stored and indexed. + * + *

The text of a field may be tokenized into terms to be indexed, or the text of a field + * may be used literally as a term to be indexed. Most fields are tokenized, but sometimes it is + * useful for certain identifier fields to be indexed literally. + * + *

See the {@link org.apache.lucene.document.Field Field} java docs for more information on + * Fields. + * + *

Segments

+ * + *

Lucene indexes may be composed of multiple sub-indexes, or segments. Each segment is a + * fully independent index, which could be searched separately. Indexes evolve by: + * + *

    + *
  1. Creating new segments for newly added documents. + *
  2. Merging existing segments. + *
+ * + *

Searches may involve multiple segments and/or multiple indexes, each index potentially + * composed of a set of segments. + * + *

Document Numbers

+ * + *

Internally, Lucene refers to documents by an integer document number. The first + * document added to an index is numbered zero, and each subsequent document added gets a number one + * greater than the previous. + * + *

Note that a document's number may change, so caution should be taken when storing these + * numbers outside of Lucene. In particular, numbers may change in the following situations: + * + *

    + *
  • + *

    The numbers stored in each segment are unique only within the segment, and must be + * converted before they can be used in a larger context. The standard technique is to + * allocate each segment a range of values, based on the range of numbers used in that + * segment. To convert a document number from a segment to an external value, the segment's + * base document number is added. To convert an external value back to a + * segment-specific value, the segment is identified by the range that the external value is + * in, and the segment's base value is subtracted. For example two five document segments + * might be combined, so that the first segment has a base value of zero, and the second of + * five. Document three from the second segment would have an external value of eight. + *

  • + *

    When documents are deleted, gaps are created in the numbering. These are eventually + * removed as the index evolves through merging. Deleted documents are dropped when segments + * are merged. A freshly-merged segment thus has no gaps in its numbering. + *

+ * + *
+ * + *

Index Structure Overview

+ * + *
+ * + *

Each segment index maintains the following: + * + *

    + *
  • {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment info}. This + * contains metadata about a segment, such as the number of documents, what files it uses, and + * information about how the segment is sorted + *
  • {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This + * contains metadata about the set of named fields used in the index. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}. + * This contains, for each document, a list of attribute-value pairs, where the attributes are + * field names. These are used to store auxiliary information about the document, such as its + * title, url, or an identifier to access a database. The set of stored fields are what is + * returned for each hit when searching. This is keyed by document number. + *
  • {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term dictionary}. A + * dictionary containing all of the terms used in all of the indexed fields of all of the + * documents. The dictionary also contains the number of documents which contain the term, and + * pointers to the term's frequency and proximity data. + *
  • {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Frequency data}. For + * each term in the dictionary, the numbers of all the documents that contain that term, and + * the frequency of the term in that document, unless frequencies are omitted ({@link + * org.apache.lucene.index.IndexOptions#DOCS IndexOptions.DOCS}) + *
  • {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Proximity data}. For + * each term in the dictionary, the positions that the term occurs in each document. Note that + * this will not exist if all fields in all documents omit position data. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Normalization factors}. For + * each field in each document, a value is stored that is multiplied into the score for hits + * on that field. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vectors}. For each + * field in each document, the term vector (sometimes called document vector) may be stored. A + * term vector consists of term text and term frequency. To add Term Vectors to your index see + * the {@link org.apache.lucene.document.Field Field} constructors + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-document values}. Like + * stored values, these are also keyed by document number, but are generally intended to be + * loaded into main memory for fast access. Whereas stored values are generally intended for + * summary results from searches, per-document values are useful for things like scoring + * factors. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live documents}. An + * optional file indicating which documents are live. + *
  • {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}. Optional pair + * of files, recording dimensionally indexed fields, to enable fast numeric range filtering + * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape + * intersection (2D, 3D). + *
  • {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The + * vector format stores numeric vectors in a format optimized for random access and + * computation, supporting high-dimensional nearest-neighbor search. + *
+ * + *

Details on each of these are provided in their linked pages.

+ * + *

File Naming

+ * + *
+ * + *

All files belonging to a segment have the same name with varying extensions. The extensions + * correspond to the different file formats described below. When using the Compound File format + * (default for small segments) these files (except for the Segment info file, the Lock file, and + * Deleted documents file) are collapsed into a single .cfs file (see below for details) + * + *

Typically, all segments in an index are stored in a single directory, although this is not + * required. + * + *

File names are never re-used. That is, when any file is saved to the Directory it is given a + * never before used filename. This is achieved using a simple generations approach. For example, + * the first segments file is segments_1, then segments_2, etc. The generation is a sequential long + * integer represented in alpha-numeric (base 36) form.

+ * + *

Summary of File Extensions

+ * + *
+ * + *

The following table summarizes the names and extensions of the files in Lucene: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
lucene filenames by extension
NameExtensionBrief Description
{@link org.apache.lucene.index.SegmentInfos Segments File}segments_NStores information about a commit point
Lock Filewrite.lockThe Write lock prevents multiple IndexWriters from writing to the same + * file.
{@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment Info}.siStores metadata about a segment
{@link org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat Compound File}.cfs, .cfeAn optional "virtual" file consisting of all the other index files for + * systems that frequently run out of file handles.
{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}.fnmStores information about the fields
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Index}.fdxContains pointers to field data
{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Data}.fdtThe stored fields for documents
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Dictionary}.timThe term dictionary, stores term info
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Index}.tipThe index into the Term Dictionary
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Frequencies}.docContains the list of docs which contain each term along with frequency
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Positions}.posStores position information about where a term occurs in the index
{@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Payloads}.payStores additional per-position metadata information such as character offsets and user payloads
{@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Norms}.nvd, .nvmEncodes length and boost factors for docs and fields
{@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-Document Values}.dvd, .dvmEncodes additional scoring factors or other per-document information.
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Index}.tvxStores offset into the document data file
{@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Data}.tvdContains term vector data.
{@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live Documents}.livInfo about what documents are live
{@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}.kdd, .kdi, .kdmHolds indexed points
{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}.vec, .vem, .veq, vexHolds indexed vectors; .vec files contain the raw vector data, + * .vem the vector metadata, .veq the quantized vector data, and .vex the + * hnsw graph data.
+ * + *

+ * + *

Lock File

+ * + * The write lock, which is stored in the index directory by default, is named "write.lock". If the + * lock directory is different from the index directory then the write lock will be named + * "XXXX-write.lock" where XXXX is a unique prefix derived from the full path to the index + * directory. When this file is present, a writer is currently modifying the index (adding or + * removing documents). This lock file ensures that only one writer is modifying the index at a + * time. + * + *

History

+ * + *

Compatibility notes are provided in this document, describing how file formats have changed + * from prior versions: + * + *

    + *
  • In version 2.1, the file format was changed to allow lock-less commits (ie, no more commit + * lock). The change is fully backwards compatible: you can open a pre-2.1 index for searching + * or adding/deleting of docs. When the new segments file is saved (committed), it will be + * written in the new file format (meaning no specific "upgrade" process is needed). But note + * that once a commit has occurred, pre-2.1 Lucene will not be able to read the index. + *
  • In version 2.3, the file format was changed to allow segments to share a single set of doc + * store (vectors & stored fields) files. This allows for faster indexing in certain + * cases. The change is fully backwards compatible (in the same way as the lock-less commits + * change in 2.1). + *
  • In version 2.4, Strings are now written as true UTF-8 byte sequence, not Java's modified + * UTF-8. See LUCENE-510 for + * details. + *
  • In version 2.9, an optional opaque Map<String,String> CommitUserData may be passed to + * IndexWriter's commit methods (and later retrieved), which is recorded in the segments_N + * file. See LUCENE-1382 for + * details. Also, diagnostics were added to each segment written recording details about why + * it was written (due to flush, merge; which OS/JRE was used; etc.). See issue LUCENE-1654 for details. + *
  • In version 3.0, compressed fields are no longer written to the index (they can still be + * read, but on merge the new segment will write them, uncompressed). See issue LUCENE-1960 for details. + *
  • In version 3.1, segments records the code version that created them. See LUCENE-2720 for details. + * Additionally segments track explicitly whether or not they have term vectors. See LUCENE-2811 for details. + *
  • In version 3.2, numeric fields are written as natively to stored fields file, previously + * they were stored in text format only. + *
  • In version 3.4, fields can omit position data while still indexing term frequencies. + *
  • In version 4.0, the format of the inverted index became extensible via the {@link + * org.apache.lucene.codecs.Codec Codec} api. Fast per-document storage ({@code DocValues}) + * was introduced. Normalization factors need no longer be a single byte, they can be any + * {@link org.apache.lucene.index.NumericDocValues NumericDocValues}. Terms need not be + * unicode strings, they can be any byte sequence. Term offsets can optionally be indexed into + * the postings lists. Payloads can be stored in the term vectors. + *
  • In version 4.1, the format of the postings list changed to use either of FOR compression or + * variable-byte encoding, depending upon the frequency of the term. Terms appearing only once + * were changed to inline directly into the term dictionary. Stored fields are compressed by + * default. + *
  • In version 4.2, term vectors are compressed by default. DocValues has a new multi-valued + * type (SortedSet), that can be used for faceting/grouping/joining on multi-valued fields. + *
  • In version 4.5, DocValues were extended to explicitly represent missing values. + *
  • In version 4.6, FieldInfos were extended to support per-field DocValues generation, to + * allow updating NumericDocValues fields. + *
  • In version 4.8, checksum footers were added to the end of each index file for improved data + * integrity. Specifically, the last 8 bytes of every index file contain the zlib-crc32 + * checksum of the file. + *
  • In version 4.9, DocValues has a new multi-valued numeric type (SortedNumeric) that is + * suitable for faceting/sorting/analytics. + *
  • In version 5.4, DocValues have been improved to store more information on disk: addresses + * for binary fields and ord indexes for multi-valued fields. + *
  • In version 6.0, Points were added, for multi-dimensional range/distance search. + *
  • In version 6.2, new Segment info format that reads/writes the index sort, to support index + * sorting. + *
  • In version 7.0, DocValues have been improved to better support sparse doc values thanks to + * an iterator API. + *
  • In version 8.0, postings have been enhanced to record, for each block of doc ids, the (term + * freq, normalization factor) pairs that may trigger the maximum score of the block. This + * information is recorded alongside skip data in order to be able to skip blocks of doc ids + * if they may not produce high enough scores. Additionally doc values and norms has been + * extended with jump-tables to make access O(1) instead of O(n), where n is the number of + * elements to skip when advancing in the data. + *
  • In version 8.4, postings, positions, offsets and payload lengths have move to a more + * performant encoding that is vectorized. + *
  • In version 8.6, index sort serialization is delegated to the sorts themselves, to allow + * user-defined sorts to be used + *
  • In version 8.6, points fields split the index tree and leaf data into separate files, to + * allow for different access patterns to the different data structures + *
  • In version 8.7, stored fields compression became adaptive to better handle documents with + * smaller stored fields. + *
  • In version 9.0, vector-valued fields were added. + *
  • In version 9.1, vector-valued fields were modified to add a graph hierarchy. + *
  • In version 9.2, docs of vector-valued fields were moved from .vem to .vec and encoded by + * IndexDISI. ordToDoc mappings was added to .vem. + *
  • In version 9.5, HNSW graph connections were changed to be delta-encoded with vints. + * Additionally, metadata file size improvements were made by delta-encoding nodes by graph + * layer and not writing the node ids for the zeroth layer. + *
  • In version 9.9, Vector scalar quantization support was added. Allowing the HNSW vector + * format to utilize int8 quantized vectors for float32 vector search. + *
  • In version 9.12, skip data was refactored to have only two levels: every 128 docs and every + * 4,06 docs, and to be inlined in postings lists. This resulted in a speedup for queries that + * need skipping, especially conjunctions. + *
  • In version 10.1, block encoding changed to be optimized for int[] storage instead of + * long[]. + *
  • In version 10.2, new vector asymmetric binary quantization that for HNSW and flat formats. + *
+ * + * + * + *

Limitations

+ * + *
+ * + *

Lucene uses a Java int to refer to document numbers, and the index file format + * uses an Int32 on-disk to store document numbers. This is a limitation of both the + * index file format and the current implementation. Eventually these should be replaced with either + * UInt64 values, or better yet, {@link org.apache.lucene.store.DataOutput#writeVInt + * VInt} values which have no limit.

+ */ +package org.apache.lucene.codecs.lucene102; diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 184403cf48b7..d118b58f321f 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -17,6 +17,7 @@ package org.apache.lucene.internal.vectorization; +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; @@ -207,4 +208,30 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) { } return to; } + + @Override + public long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized) { + return int4BitDotProductImpl(int4Quantized, binaryQuantized); + } + + public static long int4BitDotProductImpl(byte[] q, byte[] d) { + assert q.length == d.length * 4; + long ret = 0; + int size = d.length; + for (int i = 0; i < 4; i++) { + int r = 0; + long subRet = 0; + for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + subRet += + Integer.bitCount( + (int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) + & (int) BitUtil.VH_NATIVE_INT.get(d, r)); + } + for (; r < d.length; r++) { + subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF); + } + ret += subRet << i; + } + return ret; + } } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index fb94b0e31736..82947d003434 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -52,4 +52,17 @@ public interface VectorUtilSupport { * to} is returned. */ int findNextGEQ(int[] buffer, int target, int from, int to); + + /** + * Compute the dot product between a quantized int4 vector and a binary quantized vector. It is + * assumed that the int4 quantized bits are packed in the byte array in the same way as the {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer#transposeHalfByte(byte[], byte[])} + * and that the binary bits are packed the same way as {@link + * org.apache.lucene.util.quantization.OptimizedScalarQuantizer#packAsBinary(byte[], byte[])}. + * + * @param int4Quantized half byte packed int4 quantized vector + * @param binaryQuantized byte packed binary quantized vector + * @return the dot product + */ + long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized); } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 250c65448703..75998c2f7fbb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -213,6 +213,21 @@ public static int int4DotProductPacked(byte[] unpacked, byte[] packed) { return IMPL.int4DotProduct(unpacked, false, packed, true); } + /** + * Dot product computed over int4 (values between [0,15]) bytes and a binary vector. + * + * @param q the int4 query vector + * @param d the binary document vector + * @return the dot product + */ + public static long int4BitDotProduct(byte[] q, byte[] d) { + if (q.length != d.length * 4) { + throw new IllegalArgumentException( + "vector dimensions incompatible: " + q.length + "!= " + 4 + " x " + d.length); + } + return IMPL.int4BitDotProduct(q, d); + } + /** * For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. * On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java new file mode 100644 index 000000000000..b93f2df1ea79 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.quantization; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; + +/** + * OptimizedScalarQuantizer is a scalar quantizer that optimizes the quantization intervals for a + * given vector. This is done by optimizing the quantiles of the vector centered on a provided + * centroid. The optimization is done by minimizing the quantization loss via coordinate descent. + * + * @lucene.experimental + */ +public class OptimizedScalarQuantizer { + // The initial interval is set to the minimum MSE grid for each number of bits + // these starting points are derived from the optimal MSE grid for a uniform distribution + static final float[][] MINIMUM_MSE_GRID = + new float[][] { + {-0.798f, 0.798f}, + {-1.493f, 1.493f}, + {-2.051f, 2.051f}, + {-2.514f, 2.514f}, + {-2.916f, 2.916f}, + {-3.278f, 3.278f}, + {-3.611f, 3.611f}, + {-3.922f, 3.922f} + }; + private static final float DEFAULT_LAMBDA = 0.1f; + private static final int DEFAULT_ITERS = 5; + private final VectorSimilarityFunction similarityFunction; + private final float lambda; + private final int iters; + + /** + * Create a new scalar quantizer with the given similarity function, lambda, and number of + * iterations. + * + * @param similarityFunction similarity function to use + * @param lambda lambda value to use + * @param iters number of iterations to use + */ + public OptimizedScalarQuantizer( + VectorSimilarityFunction similarityFunction, float lambda, int iters) { + this.similarityFunction = similarityFunction; + this.lambda = lambda; + this.iters = iters; + } + + /** + * Create a new scalar quantizer with the default lambda and number of iterations. + * + * @param similarityFunction similarity function to use + */ + public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { + this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS); + } + + /** + * Quantization result containing the lower and upper interval bounds, the additional correction + * + * @param lowerInterval the lower interval bound + * @param upperInterval the upper interval bound + * @param additionalCorrection the additional correction + * @param quantizedComponentSum the sum of the quantized components + */ + public record QuantizationResult( + float lowerInterval, + float upperInterval, + float additionalCorrection, + int quantizedComponentSum) {} + + /** + * Quantize the vector to the multiple bit levels. + * + * @param vector raw vector + * @param destinations array of destinations to store the quantized vector + * @param bits array of bits to quantize the vector + * @param centroid centroid to center the vector + * @return array of quantization results + */ + public QuantizationResult[] multiScalarQuantize( + float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert bits.length == destinations.length; + float[] intervalScratch = new float[2]; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + QuantizationResult[] results = new QuantizationResult[bits.length]; + for (int i = 0; i < bits.length; ++i) { + assert bits[i] > 0 && bits[i] <= 8; + int points = (1 << bits[i]); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within + // the min/max bounds + intervalScratch[0] = + (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = + (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits[i]) - 1); + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + // Now we have the optimized intervals, quantize the vector + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destinations[i][h] = (byte) assignment; + } + results[i] = + new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery); + } + return results; + } + + /** + * Quantize the vector to the given bit level. + * + * @param vector raw vector + * @param destination destination to store the quantized vector + * @param bits number of bits to quantize the vector + * @param centroid centroid to center the vector + * @return quantization result + */ + public QuantizationResult scalarQuantize( + float[] vector, byte[] destination, byte bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert vector.length <= destination.length; + assert bits > 0 && bits <= 8; + float[] intervalScratch = new float[2]; + int points = 1 << bits; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within + // the min/max bounds + intervalScratch[0] = + (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = + (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits) - 1); + // Now we have the optimized intervals, quantize the vector + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destination[h] = (byte) assignment; + } + return new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery); + } + + /** + * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a + * dequantized vector with the raw vector. + * + * @param vector raw vector + * @param interval interval to quantize the vector + * @param points number of quantization points + * @param norm2 squared norm of the vector + * @return the loss + */ + private double loss(float[] vector, float[] interval, int points, float norm2) { + double a = interval[0]; + double b = interval[1]; + double step = ((b - a) / (points - 1.0F)); + double stepInv = 1.0 / step; + double xe = 0.0; + double e = 0.0; + for (double xi : vector) { + // this is quantizing and then dequantizing the vector + double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv)); + // how much does the de-quantized value differ from the original value + xe += xi * (xi - xiq); + e += (xi - xiq) * (xi - xiq); + } + return (1.0 - lambda) * xe * xe / norm2 + lambda * e; + } + + /** + * Optimize the quantization interval for the given vector. This is done via a coordinate descent + * trying to minimize the quantization loss. Note, the loss is not always guaranteed to decrease, + * so we have a maximum number of iterations and will exit early if the loss increases. + * + * @param initInterval initial interval, the optimized interval will be stored here + * @param vector raw vector + * @param norm2 squared norm of the vector + * @param points number of quantization points + */ + private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) { + double initialLoss = loss(vector, initInterval, points, norm2); + final float scale = (1.0f - lambda) / norm2; + if (Float.isFinite(scale) == false) { + return; + } + for (int i = 0; i < iters; ++i) { + float a = initInterval[0]; + float b = initInterval[1]; + float stepInv = (points - 1.0f) / (b - a); + // calculate the grid points for coordinate descent + double daa = 0; + double dab = 0; + double dbb = 0; + double dax = 0; + double dbx = 0; + for (float xi : vector) { + float k = Math.round((clamp(xi, a, b) - a) * stepInv); + float s = k / (points - 1); + daa += (1.0 - s) * (1.0 - s); + dab += (1.0 - s) * s; + dbb += s * s; + dax += xi * (1.0 - s); + dbx += xi * s; + } + double m0 = scale * dax * dax + lambda * daa; + double m1 = scale * dax * dbx + lambda * dab; + double m2 = scale * dbx * dbx + lambda * dbb; + // its possible that the determinant is 0, in which case we can't update the interval + double det = m0 * m2 - m1 * m1; + if (det == 0) { + return; + } + float aOpt = (float) ((m2 * dax - m1 * dbx) / det); + float bOpt = (float) ((m0 * dbx - m1 * dax) / det); + // If there is no change in the interval, we can stop + if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { + return; + } + double newLoss = loss(vector, new float[] {aOpt, bOpt}, points, norm2); + // If the new loss is worse, don't update the interval and exit + // This optimization, unlike kMeans, does not always converge to better loss + // So exit if we are getting worse + if (newLoss > initialLoss) { + return; + } + // Update the interval and go again + initInterval[0] = aOpt; + initInterval[1] = bOpt; + initialLoss = newLoss; + } + } + + public static int discretize(int value, int bucket) { + return ((value + (bucket - 1)) / bucket) * bucket; + } + + /** + * Transpose the query vector into a byte array allowing for efficient bitwise operations with the + * index bit vectors. The idea here is to organize the query vector bits such that the first bit + * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second, + * third, and fourth bits are in the second, third, and fourth set of dimensions bits, + * respectively. This allows for direct bitwise comparisons with the stored index vectors through + * summing the bitwise results with the relative required bit shifts. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param quantQueryByte the byte array to store the transposed query vector + */ + public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) { + for (int i = 0; i < q.length; ) { + assert q[i] >= 0 && q[i] <= 15; + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; j >= 0 && i < q.length; j--) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + i++; + } + int index = ((i + 7) / 8) - 1; + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + } + + /** + * Pack the vector as a binary array. + * + * @param vector the vector to pack + * @param packed the packed vector + */ + public static void packAsBinary(byte[] vector, byte[] packed) { + for (int i = 0; i < vector.length; ) { + byte result = 0; + for (int j = 7; j >= 0 && i < vector.length; j--) { + assert vector[i] == 0 || vector[i] == 1; + result |= (byte) ((vector[i] & 1) << j); + ++i; + } + int index = ((i + 7) / 8) - 1; + assert index < packed.length; + packed[index] = result; + } + } + + private static double clamp(double x, double a, double b) { + return Math.min(Math.max(x, a), b); + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 9273f7c5a813..37671b63e9f6 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -29,6 +29,7 @@ import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.Vector; import jdk.incubator.vector.VectorMask; @@ -58,6 +59,8 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { PanamaVectorConstants.PRERERRED_INT_SPECIES; private static final VectorSpecies BYTE_SPECIES; private static final VectorSpecies SHORT_SPECIES; + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; static final int VECTOR_BITSIZE; @@ -786,4 +789,118 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) { } return to; } + + @Override + public long int4BitDotProduct(byte[] q, byte[] d) { + assert q.length == d.length * 4; + // 128 / 8 == 16 + if (d.length >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { + if (VECTOR_BITSIZE >= 256) { + return int4BitDotProduct256(q, d); + } else if (VECTOR_BITSIZE == 128) { + return int4BitDotProduct128(q, d); + } + } + return DefaultVectorUtilSupport.int4BitDotProductImpl(q, d); + } + + static long int4BitDotProduct256(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(d.length); + var sum0 = LongVector.zero(LongVector.SPECIES_256); + var sum1 = LongVector.zero(LongVector.SPECIES_256); + var sum2 = LongVector.zero(LongVector.SPECIES_256); + var sum3 = LongVector.zero(LongVector.SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LongVector.SPECIES_128); + var sum1 = LongVector.zero(LongVector.SPECIES_128); + var sum2 = LongVector.zero(LongVector.SPECIES_128); + var sum3 = LongVector.zero(LongVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + for (; i < d.length; i++) { + subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF); + subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + public static long int4BitDotProduct128(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + var sum0 = IntVector.zero(IntVector.SPECIES_128); + var sum1 = IntVector.zero(IntVector.SPECIES_128); + var sum2 = IntVector.zero(IntVector.SPECIES_128); + var sum3 = IntVector.zero(IntVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + for (; i < d.length; i++) { + int dValue = d[i]; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } } diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index cb5fee62aeec..0558fc8fef05 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -16,3 +16,5 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat +org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..668b7eee822d --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static java.lang.String.format; +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +public class TestLucene102BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene101Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene102BinaryQuantizedVectorsFormat(); + } + }; + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene102BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = + "Lucene102BinaryQuantizedVectorsFormat(" + + "name=Lucene102BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = + ((Lucene102BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] quantizedVector = new byte[dims]; + byte[] expectedVector = new byte[discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = + new Lucene102BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = + quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + quantizedVector, + INDEX_BITS, + centroid); + packAsBinary(quantizedVector, expectedVector); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..b23804b1840e --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene102; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.SameThreadExecutorService; + +public class TestLucene102HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene101Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene102HnswBinaryQuantizedVectorsFormat(); + } + }; + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene102HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene102BinaryQuantizedVectorsFormat(name=Lucene102BinaryQuantizedVectorsFormat," + + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true + // score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows( + IllegalArgumentException.class, + () -> new Lucene102HnswBinaryQuantizedVectorsFormat(-1, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene102HnswBinaryQuantizedVectorsFormat(0, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, -1)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene102HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> + new Lucene102HnswBinaryQuantizedVectorsFormat( + 20, 100, 1, new SameThreadExecutorService())); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java new file mode 100644 index 000000000000..63ac99d48d28 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.quantization; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; + +public class TestOptimizedScalarQuantizer extends LuceneTestCase { + static final byte[] ALL_BITS = new byte[] {1, 2, 3, 4, 5, 6, 7, 8}; + + public void testAbusiveEdgeCases() { + // large zero array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + continue; + } + float[] vector = new float[4096]; + float[] centroid = new float[4096]; + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096]; + OptimizedScalarQuantizer.QuantizationResult[] results = + osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (byte[] destination : destinations) { + assertArrayEquals(new byte[4096], destination); + } + byte[] destination = new byte[4096]; + for (byte bit : ALL_BITS) { + OptimizedScalarQuantizer.QuantizationResult result = + osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertArrayEquals(new byte[4096], destination); + } + } + + // single value array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + float[] vector = new float[] {randomFloat()}; + float[] centroid = new float[] {randomFloat()}; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1]; + OptimizedScalarQuantizer.QuantizationResult[] results = + osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + vector = new float[] {randomFloat()}; + centroid = new float[] {randomFloat()}; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + byte[] destination = new byte[1]; + OptimizedScalarQuantizer.QuantizationResult result = + osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + } + + public void testMathematicalConsistency() { + int dims = randomIntBetween(1, 4096); + float[] vector = new float[dims]; + for (int i = 0; i < dims; ++i) { + vector[i] = randomFloat(); + } + float[] centroid = new float[dims]; + for (int i = 0; i < dims; ++i) { + centroid[i] = randomFloat(); + } + float[] copy = new float[dims]; + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + // copy the vector to avoid modifying it + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims]; + OptimizedScalarQuantizer.QuantizationResult[] results = + osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + byte[] destination = new byte[dims]; + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer.QuantizationResult result = + osq.scalarQuantize(copy, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + } + + static void assertValidQuantizedRange(byte[] quantized, byte bits) { + for (byte b : quantized) { + if (bits < 8) { + assertTrue(b >= 0); + } + assertTrue(b < 1 << bits); + } + } + + static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) { + for (OptimizedScalarQuantizer.QuantizationResult result : results) { + assertTrue(Float.isFinite(result.lowerInterval())); + assertTrue(Float.isFinite(result.upperInterval())); + assertTrue(result.lowerInterval() <= result.upperInterval()); + assertTrue(Float.isFinite(result.additionalCorrection())); + assertTrue(result.quantizedComponentSum() >= 0); + } + } +} From 62cd45ba541c9b9b0569d877ca67f994dee5406b Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Wed, 18 Dec 2024 15:32:21 +0000 Subject: [PATCH 2/4] test default and panama impls return the same result --- .../vectorization/TestVectorUtilSupport.java | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 7064955cb5f3..6443de752bfa 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.function.ToDoubleFunction; import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; import java.util.stream.IntStream; public class TestVectorUtilSupport extends BaseVectorizationTestCase { @@ -133,6 +134,27 @@ public void testInt4DotProductBoundaries() { PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); } + public void testInt4BitDotProduct() { + var binaryQuantized = new byte[size]; + var int4Quantized = new byte[size * 4]; + random().nextBytes(binaryQuantized); + random().nextBytes(int4Quantized); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + } + + public void testInt4BitDotProductBoundaries() { + var binaryQuantized = new byte[size]; + var int4Quantized = new byte[size * 4]; + + Arrays.fill(binaryQuantized, Byte.MAX_VALUE); + Arrays.fill(int4Quantized, Byte.MAX_VALUE); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + + Arrays.fill(binaryQuantized, Byte.MIN_VALUE); + Arrays.fill(int4Quantized, Byte.MIN_VALUE); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + } + static byte[] pack(byte[] unpacked) { int len = (unpacked.length + 1) / 2; var packed = new byte[len]; @@ -154,4 +176,10 @@ private void assertIntReturningProviders(ToIntFunction func) func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport())); } + + private void assertLongReturningProviders(ToLongFunction func) { + assertEquals( + func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()), + func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport())); + } } From db10587ba16790ad120a3feda180581985b6f417 Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Wed, 18 Dec 2024 15:43:11 +0000 Subject: [PATCH 3/4] add tests for int4BitDotProdut --- .../BaseVectorizationTestCase.java | 12 +- .../apache/lucene/util/TestVectorUtil.java | 129 ++++++++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java index 34a0e5230022..ea6c183c7e7a 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java @@ -21,8 +21,8 @@ public abstract class BaseVectorizationTestCase extends LuceneTestCase { - protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider(); - protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true); + protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider(); + protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider(); @BeforeClass public static void beforeClass() throws Exception { @@ -30,4 +30,12 @@ public static void beforeClass() throws Exception { "Test only works when JDK's vector incubator module is enabled.", PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass()); } + + public static VectorizationProvider defaultProvider() { + return new DefaultVectorizationProvider(); + } + + public static VectorizationProvider maybePanamaProvider() { + return VectorizationProvider.lookup(true); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index 6e449a550028..30a2f6811c6e 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -16,8 +16,13 @@ */ package org.apache.lucene.util; +import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween; + +import java.util.Arrays; import java.util.Random; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase; +import org.apache.lucene.internal.vectorization.VectorizationProvider; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -384,4 +389,128 @@ private static int slowFindNextGEQ(int[] buffer, int length, int target, int fro } return length; } + + public void testInt4BitDotProductInvariants() { + int iterations = atLeast(10); + for (int i = 0; i < iterations; i++) { + int size = randomIntBetween(random(), 1, 10); + var d = new byte[size]; + var q = new byte[size * 4 - 1]; + expectThrows(IllegalArgumentException.class, () -> VectorUtil.int4BitDotProduct(q, d)); + } + } + + static final VectorizationProvider defaultedProvider = + BaseVectorizationTestCase.defaultProvider(); + static final VectorizationProvider defOrPanamaProvider = + BaseVectorizationTestCase.maybePanamaProvider(); + + public void testBasicInt4BitDotProduct() { + testBasicInt4BitDotProductImpl(VectorUtil::int4BitDotProduct); + testBasicInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct); + testBasicInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct); + } + + interface Int4BitDotProduct { + long apply(byte[] q, byte[] d); + } + + void testBasicInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) { + assertEquals(15L, Int4BitDotProductFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1})); + assertEquals( + 30L, Int4BitDotProductFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2})); + + var d = new byte[] {1, 2, 3}; + var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}; + assert scalarInt4BitDotProduct(q, d) == 60L; // 4 + 8 + 16 + 32 + assertEquals(60L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4}; + q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}; + assert scalarInt4BitDotProduct(q, d) == 75L; // 5 + 10 + 20 + 40 + assertEquals(75L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5}; + q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; + assert scalarInt4BitDotProduct(q, d) == 105L; // 7 + 14 + 28 + 56 + assertEquals(105L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6}; + q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}; + assert scalarInt4BitDotProduct(q, d) == 135L; // 9 + 18 + 36 + 72 + assertEquals(135L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 + }; + assert scalarInt4BitDotProduct(q, d) == 180L; // 12 + 24 + 48 + 96 + assertEquals(180L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, + 7, 8 + }; + assert scalarInt4BitDotProduct(q, d) == 195L; // 13 + 26 + 52 + 104 + assertEquals(195L, Int4BitDotProductFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, + 4, 5, 6, 7, 8, 9 + }; + assert scalarInt4BitDotProduct(q, d) == 225L; // 15 + 30 + 60 + 120 + assertEquals(225L, Int4BitDotProductFunc.apply(q, d)); + } + + public void testInt4BitDotProduct() { + testInt4BitDotProductImpl(VectorUtil::int4BitDotProduct); + testInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct); + testInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct); + } + + void testInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) { + int iterations = atLeast(50); + for (int i = 0; i < iterations; i++) { + int size = random().nextInt(5000); + var d = new byte[size]; + var q = new byte[size * 4]; + random().nextBytes(d); + random().nextBytes(q); + assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d)); + + Arrays.fill(d, Byte.MAX_VALUE); + Arrays.fill(q, Byte.MAX_VALUE); + assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d)); + + Arrays.fill(d, Byte.MIN_VALUE); + Arrays.fill(q, Byte.MIN_VALUE); + assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d)); + } + } + + static int scalarInt4BitDotProduct(byte[] q, byte[] d) { + int res = 0; + for (int i = 0; i < 4; i++) { + res += (popcount(q, i * d.length, d, d.length) << i); + } + return res; + } + + public static int popcount(byte[] a, int aOffset, byte[] b, int length) { + int res = 0; + for (int j = 0; j < length; j++) { + int value = (a[aOffset + j] & b[j]) & 0xFF; + for (int k = 0; k < Byte.SIZE; k++) { + if ((value & (1 << k)) != 0) { + ++res; + } + } + } + return res; + } } From b76d616a90131563964cd575a118e846e639659d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:48:24 -0500 Subject: [PATCH 4/4] fixing tests, addressing pr comments --- ...Lucene102BinaryQuantizedVectorsFormat.java | 13 ++++--- .../lucene/codecs/lucene102/package-info.java | 7 ++-- .../OptimizedScalarQuantizer.java | 34 +++++++++++++++++-- ...Lucene102BinaryQuantizedVectorsFormat.java | 14 ++++---- ...ne102HnswBinaryQuantizedVectorsFormat.java | 14 ++++---- 5 files changed, 55 insertions(+), 27 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java index ae48a220b235..6850e2d09d6a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java @@ -26,9 +26,12 @@ import org.apache.lucene.index.SegmentWriteState; /** - * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 Codec for - * encoding/decoding binary quantized vectors The binary quantization format used here is a - * per-vector optimized scalar quantization. Also see {@link + * The binary quantization format used here is a per-vector optimized scalar quantization. These + * ideas are evolutions of LVQ proposed in Similarity + * search in the blink of an eye with compressed indices by Cecilia Aguerrebere et al. and the + * previous work on globally optimized scalar + * + *

The format is stored in two files: Also see {@link * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are: * *

    @@ -43,8 +46,6 @@ * single bit vectors can be done with bit arithmetic. *
* - * The format is stored in two files: - * *

.veb (vector data) file

* *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's @@ -130,6 +131,8 @@ public String toString() { + NAME + ", flatVectorScorer=" + scorer + + ", rawVectorFormat=" + + rawVectorFormat + ")"; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java index 92d9de567a20..8f6fcb2ef5bc 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java @@ -311,10 +311,11 @@ * * * {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values} - * .vec, .vem, .veq, vex + * .vec, .vem, .veq, .vex, .veb, .vemb * Holds indexed vectors; .vec files contain the raw vector data, - * .vem the vector metadata, .veq the quantized vector data, and .vex the - * hnsw graph data. + * .vem the vector metadata, .veq the quantized vector data, .vex the + * hnsw graph data, .vebthe binary quantized vector data, and + * .vembthe binary quantized vector metadata * * * diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java index b93f2df1ea79..7ae74e743993 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java @@ -23,9 +23,19 @@ import org.apache.lucene.util.VectorUtil; /** - * OptimizedScalarQuantizer is a scalar quantizer that optimizes the quantization intervals for a - * given vector. This is done by optimizing the quantiles of the vector centered on a provided - * centroid. The optimization is done by minimizing the quantization loss via coordinate descent. + * This is a scalar quantizer that optimizes the quantization intervals for a given vector. This is + * done by optimizing the quantiles of the vector centered on a provided centroid. The optimization + * is done by minimizing the quantization loss via coordinate descent. + * + *

Local vector quantization parameters was originally proposed with LVQ in Similarity search in the blink of an eye with compressed + * indices This technique builds on LVQ, but instead of taking the min/max values, a grid search + * over the centered vector is done to find the optimal quantization intervals, taking into account + * anisotropic loss. + * + *

Anisotropic loss is first discussed in depth by Accelerating Large-Scale Inference with Anisotropic + * Vector Quantization by Ruiqi Guo, et al. * * @lucene.experimental */ @@ -43,10 +53,20 @@ public class OptimizedScalarQuantizer { {-3.611f, 3.611f}, {-3.922f, 3.922f} }; + // the default lambda value private static final float DEFAULT_LAMBDA = 0.1f; + // the default optimization iterations allowed private static final int DEFAULT_ITERS = 5; private final VectorSimilarityFunction similarityFunction; + // This determines how much emphasis we place on quantization errors perpendicular to the + // embedding + // as opposed to parallel to it. + // The smaller the value the more we will allow the overall error to increase if it allows us to + // reduce error parallel to the vector. + // Parallel errors are important for nearest neighbor queries because the closest document vectors + // tend to be parallel to the query private final float lambda; + // the number of iterations to optimize the quantization intervals private final int iters; /** @@ -320,6 +340,14 @@ public static int discretize(int value, int bucket) { * respectively. This allows for direct bitwise comparisons with the stored index vectors through * summing the bitwise results with the relative required bit shifts. * + *

This bit decomposition for fast bitwise SIMD operations was first proposed in: + * + *

+   *   Gao, Jianyang, and Cheng Long. "RaBitQ: Quantizing High-
+   *   Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search."
+   *   Proceedings of the ACM on Management of Data 2, no. 3 (2024): 1-27.
+   *   
+ * * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 * @param quantQueryByte the byte array to store the transposed query vector */ diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java index 668b7eee822d..216e645b3010 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java @@ -29,7 +29,6 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -47,18 +46,16 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; public class TestLucene102BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + private static final KnnVectorsFormat FORMAT = new Lucene102BinaryQuantizedVectorsFormat(); + @Override protected Codec getCodec() { - return new Lucene101Codec() { - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene102BinaryQuantizedVectorsFormat(); - } - }; + return TestUtil.alwaysKnnVectorsFormat(FORMAT); } public void testSearch() throws Exception { @@ -103,7 +100,8 @@ public KnnVectorsFormat knnVectorsFormat() { String expectedPattern = "Lucene102BinaryQuantizedVectorsFormat(" + "name=Lucene102BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()), " + + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))"; var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java index b23804b1840e..010c99eab41e 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java @@ -26,7 +26,6 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; @@ -40,18 +39,16 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.SameThreadExecutorService; public class TestLucene102HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + private static final KnnVectorsFormat FORMAT = new Lucene102HnswBinaryQuantizedVectorsFormat(); + @Override protected Codec getCodec() { - return new Lucene101Codec() { - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene102HnswBinaryQuantizedVectorsFormat(); - } - }; + return TestUtil.alwaysKnnVectorsFormat(FORMAT); } public void testToString() { @@ -65,7 +62,8 @@ public KnnVectorsFormat knnVectorsFormat() { String expectedPattern = "Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + " flatVectorFormat=Lucene102BinaryQuantizedVectorsFormat(name=Lucene102BinaryQuantizedVectorsFormat," - + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())," + + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))"; var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); var memSegScorer =