From fbf112a727b668dec62c2908462dc4dd3f494f3d Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Tue, 17 Dec 2024 17:13:00 -0500
Subject: [PATCH 1/4] Binary vector format for flat and hnsw vectors
---
lucene/core/src/java/module-info.java | 5 +-
.../lucene/codecs/lucene101/package-info.java | 418 +-------
.../lucene102/BinarizedByteVectorValues.java | 84 ++
.../Lucene102BinaryFlatVectorsScorer.java | 192 ++++
...Lucene102BinaryQuantizedVectorsFormat.java | 135 +++
...Lucene102BinaryQuantizedVectorsReader.java | 424 ++++++++
...Lucene102BinaryQuantizedVectorsWriter.java | 950 ++++++++++++++++++
...ne102HnswBinaryQuantizedVectorsFormat.java | 154 +++
.../OffHeapBinarizedVectorValues.java | 383 +++++++
.../lucene/codecs/lucene102/package-info.java | 436 ++++++++
.../DefaultVectorUtilSupport.java | 27 +
.../vectorization/VectorUtilSupport.java | 13 +
.../org/apache/lucene/util/VectorUtil.java | 15 +
.../OptimizedScalarQuantizer.java | 371 +++++++
.../PanamaVectorUtilSupport.java | 117 +++
.../org.apache.lucene.codecs.KnnVectorsFormat | 2 +
...Lucene102BinaryQuantizedVectorsFormat.java | 179 ++++
...ne102HnswBinaryQuantizedVectorsFormat.java | 137 +++
.../TestOptimizedScalarQuantizer.java | 149 +++
19 files changed, 3773 insertions(+), 418 deletions(-)
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java
create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
create mode 100644 lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
create mode 100644 lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java
diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java
index 85aff5722498..ac321b12e8fa 100644
--- a/lucene/core/src/java/module-info.java
+++ b/lucene/core/src/java/module-info.java
@@ -32,6 +32,7 @@
exports org.apache.lucene.codecs.lucene95;
exports org.apache.lucene.codecs.lucene99;
exports org.apache.lucene.codecs.lucene101;
+ exports org.apache.lucene.codecs.lucene102;
exports org.apache.lucene.codecs.perfield;
exports org.apache.lucene.codecs;
exports org.apache.lucene.document;
@@ -76,7 +77,9 @@
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
- org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
+ org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
index e582f12c3185..8aa1c3b43a0c 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
@@ -15,421 +15,5 @@
* limitations under the License.
*/
-/**
- * Lucene 10.1 file format.
- *
- *
Apache Lucene - Index File Formats
- *
- *
- *
- * Introduction
- *
- *
- *
- *
This document defines the index file formats used in this version of Lucene. If you are using
- * a different version of Lucene, please consult the copy of docs/
that was distributed
- * with the version you are using.
- *
- *
This document attempts to provide a high-level definition of the Apache Lucene file formats.
- *
- *
- * Definitions
- *
- *
- *
- *
The fundamental concepts in Lucene are index, document, field and term.
- *
- *
An index contains a sequence of documents.
- *
- *
- * - A document is a sequence of fields.
- *
- A field is a named sequence of terms.
- *
- A term is a sequence of bytes.
- *
- *
- *
The same sequence of bytes in two different fields is considered a different term. Thus terms
- * are represented as a pair: the string naming the field, and the bytes within the field.
- *
- *
Inverted Indexing
- *
- *
Lucene's index stores terms and statistics about those terms in order to make term-based
- * search more efficient. Lucene's terms index falls into the family of indexes known as an
- * inverted index. This is because it can list, for a term, the documents that contain it.
- * This is the inverse of the natural relationship, in which documents list terms.
- *
- *
Types of Fields
- *
- *
In Lucene, fields may be stored, in which case their text is stored in the index
- * literally, in a non-inverted manner. Fields that are inverted are called indexed. A field
- * may be both stored and indexed.
- *
- *
The text of a field may be tokenized into terms to be indexed, or the text of a field
- * may be used literally as a term to be indexed. Most fields are tokenized, but sometimes it is
- * useful for certain identifier fields to be indexed literally.
- *
- *
See the {@link org.apache.lucene.document.Field Field} java docs for more information on
- * Fields.
- *
- *
Segments
- *
- *
Lucene indexes may be composed of multiple sub-indexes, or segments. Each segment is a
- * fully independent index, which could be searched separately. Indexes evolve by:
- *
- *
- * - Creating new segments for newly added documents.
- *
- Merging existing segments.
- *
- *
- *
Searches may involve multiple segments and/or multiple indexes, each index potentially
- * composed of a set of segments.
- *
- *
Document Numbers
- *
- *
Internally, Lucene refers to documents by an integer document number. The first
- * document added to an index is numbered zero, and each subsequent document added gets a number one
- * greater than the previous.
- *
- *
Note that a document's number may change, so caution should be taken when storing these
- * numbers outside of Lucene. In particular, numbers may change in the following situations:
- *
- *
- * -
- *
The numbers stored in each segment are unique only within the segment, and must be
- * converted before they can be used in a larger context. The standard technique is to
- * allocate each segment a range of values, based on the range of numbers used in that
- * segment. To convert a document number from a segment to an external value, the segment's
- * base document number is added. To convert an external value back to a
- * segment-specific value, the segment is identified by the range that the external value is
- * in, and the segment's base value is subtracted. For example two five document segments
- * might be combined, so that the first segment has a base value of zero, and the second of
- * five. Document three from the second segment would have an external value of eight.
- *
-
- *
When documents are deleted, gaps are created in the numbering. These are eventually
- * removed as the index evolves through merging. Deleted documents are dropped when segments
- * are merged. A freshly-merged segment thus has no gaps in its numbering.
- *
- *
- *
- *
- * Index Structure Overview
- *
- *
- *
- *
Each segment index maintains the following:
- *
- *
- * - {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment info}. This
- * contains metadata about a segment, such as the number of documents, what files it uses, and
- * information about how the segment is sorted
- *
- {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
- * contains metadata about the set of named fields used in the index.
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
- * This contains, for each document, a list of attribute-value pairs, where the attributes are
- * field names. These are used to store auxiliary information about the document, such as its
- * title, url, or an identifier to access a database. The set of stored fields are what is
- * returned for each hit when searching. This is keyed by document number.
- *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term dictionary}. A
- * dictionary containing all of the terms used in all of the indexed fields of all of the
- * documents. The dictionary also contains the number of documents which contain the term, and
- * pointers to the term's frequency and proximity data.
- *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Frequency data}. For
- * each term in the dictionary, the numbers of all the documents that contain that term, and
- * the frequency of the term in that document, unless frequencies are omitted ({@link
- * org.apache.lucene.index.IndexOptions#DOCS IndexOptions.DOCS})
- *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Proximity data}. For
- * each term in the dictionary, the positions that the term occurs in each document. Note that
- * this will not exist if all fields in all documents omit position data.
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Normalization factors}. For
- * each field in each document, a value is stored that is multiplied into the score for hits
- * on that field.
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vectors}. For each
- * field in each document, the term vector (sometimes called document vector) may be stored. A
- * term vector consists of term text and term frequency. To add Term Vectors to your index see
- * the {@link org.apache.lucene.document.Field Field} constructors
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-document values}. Like
- * stored values, these are also keyed by document number, but are generally intended to be
- * loaded into main memory for fast access. Whereas stored values are generally intended for
- * summary results from searches, per-document values are useful for things like scoring
- * factors.
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live documents}. An
- * optional file indicating which documents are live.
- *
- {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}. Optional pair
- * of files, recording dimensionally indexed fields, to enable fast numeric range filtering
- * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
- * intersection (2D, 3D).
- *
- {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The
- * vector format stores numeric vectors in a format optimized for random access and
- * computation, supporting high-dimensional nearest-neighbor search.
- *
- *
- *
Details on each of these are provided in their linked pages.
- *
- * File Naming
- *
- *
- *
- *
All files belonging to a segment have the same name with varying extensions. The extensions
- * correspond to the different file formats described below. When using the Compound File format
- * (default for small segments) these files (except for the Segment info file, the Lock file, and
- * Deleted documents file) are collapsed into a single .cfs file (see below for details)
- *
- *
Typically, all segments in an index are stored in a single directory, although this is not
- * required.
- *
- *
File names are never re-used. That is, when any file is saved to the Directory it is given a
- * never before used filename. This is achieved using a simple generations approach. For example,
- * the first segments file is segments_1, then segments_2, etc. The generation is a sequential long
- * integer represented in alpha-numeric (base 36) form.
- *
- * Summary of File Extensions
- *
- *
- *
- *
The following table summarizes the names and extensions of the files in Lucene:
- *
- *
- * lucene filenames by extension
- *
- * Name |
- * Extension |
- * Brief Description |
- *
- *
- * {@link org.apache.lucene.index.SegmentInfos Segments File} |
- * segments_N |
- * Stores information about a commit point |
- *
- *
- * Lock File |
- * write.lock |
- * The Write lock prevents multiple IndexWriters from writing to the same
- * file. |
- *
- *
- * {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment Info} |
- * .si |
- * Stores metadata about a segment |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat Compound File} |
- * .cfs, .cfe |
- * An optional "virtual" file consisting of all the other index files for
- * systems that frequently run out of file handles. |
- *
- *
- * {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields} |
- * .fnm |
- * Stores information about the fields |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Index} |
- * .fdx |
- * Contains pointers to field data |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Data} |
- * .fdt |
- * The stored fields for documents |
- *
- *
- * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Dictionary} |
- * .tim |
- * The term dictionary, stores term info |
- *
- *
- * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Index} |
- * .tip |
- * The index into the Term Dictionary |
- *
- *
- * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Frequencies} |
- * .doc |
- * Contains the list of docs which contain each term along with frequency |
- *
- *
- * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Positions} |
- * .pos |
- * Stores position information about where a term occurs in the index |
- *
- *
- * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Payloads} |
- * .pay |
- * Stores additional per-position metadata information such as character offsets and user payloads |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Norms} |
- * .nvd, .nvm |
- * Encodes length and boost factors for docs and fields |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-Document Values} |
- * .dvd, .dvm |
- * Encodes additional scoring factors or other per-document information. |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Index} |
- * .tvx |
- * Stores offset into the document data file |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Data} |
- * .tvd |
- * Contains term vector data. |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live Documents} |
- * .liv |
- * Info about what documents are live |
- *
- *
- * {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values} |
- * .kdd, .kdi, .kdm |
- * Holds indexed points |
- *
- *
- * {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values} |
- * .vec, .vem, .veq, vex |
- * Holds indexed vectors; .vec files contain the raw vector data,
- * .vem the vector metadata, .veq the quantized vector data, and .vex the
- * hnsw graph data. |
- *
- *
- *
- *
- *
- * Lock File
- *
- * The write lock, which is stored in the index directory by default, is named "write.lock". If the
- * lock directory is different from the index directory then the write lock will be named
- * "XXXX-write.lock" where XXXX is a unique prefix derived from the full path to the index
- * directory. When this file is present, a writer is currently modifying the index (adding or
- * removing documents). This lock file ensures that only one writer is modifying the index at a
- * time.
- *
- * History
- *
- * Compatibility notes are provided in this document, describing how file formats have changed
- * from prior versions:
- *
- *
- * - In version 2.1, the file format was changed to allow lock-less commits (ie, no more commit
- * lock). The change is fully backwards compatible: you can open a pre-2.1 index for searching
- * or adding/deleting of docs. When the new segments file is saved (committed), it will be
- * written in the new file format (meaning no specific "upgrade" process is needed). But note
- * that once a commit has occurred, pre-2.1 Lucene will not be able to read the index.
- *
- In version 2.3, the file format was changed to allow segments to share a single set of doc
- * store (vectors & stored fields) files. This allows for faster indexing in certain
- * cases. The change is fully backwards compatible (in the same way as the lock-less commits
- * change in 2.1).
- *
- In version 2.4, Strings are now written as true UTF-8 byte sequence, not Java's modified
- * UTF-8. See LUCENE-510 for
- * details.
- *
- In version 2.9, an optional opaque Map<String,String> CommitUserData may be passed to
- * IndexWriter's commit methods (and later retrieved), which is recorded in the segments_N
- * file. See LUCENE-1382 for
- * details. Also, diagnostics were added to each segment written recording details about why
- * it was written (due to flush, merge; which OS/JRE was used; etc.). See issue LUCENE-1654 for details.
- *
- In version 3.0, compressed fields are no longer written to the index (they can still be
- * read, but on merge the new segment will write them, uncompressed). See issue LUCENE-1960 for details.
- *
- In version 3.1, segments records the code version that created them. See LUCENE-2720 for details.
- * Additionally segments track explicitly whether or not they have term vectors. See LUCENE-2811 for details.
- *
- In version 3.2, numeric fields are written as natively to stored fields file, previously
- * they were stored in text format only.
- *
- In version 3.4, fields can omit position data while still indexing term frequencies.
- *
- In version 4.0, the format of the inverted index became extensible via the {@link
- * org.apache.lucene.codecs.Codec Codec} api. Fast per-document storage ({@code DocValues})
- * was introduced. Normalization factors need no longer be a single byte, they can be any
- * {@link org.apache.lucene.index.NumericDocValues NumericDocValues}. Terms need not be
- * unicode strings, they can be any byte sequence. Term offsets can optionally be indexed into
- * the postings lists. Payloads can be stored in the term vectors.
- *
- In version 4.1, the format of the postings list changed to use either of FOR compression or
- * variable-byte encoding, depending upon the frequency of the term. Terms appearing only once
- * were changed to inline directly into the term dictionary. Stored fields are compressed by
- * default.
- *
- In version 4.2, term vectors are compressed by default. DocValues has a new multi-valued
- * type (SortedSet), that can be used for faceting/grouping/joining on multi-valued fields.
- *
- In version 4.5, DocValues were extended to explicitly represent missing values.
- *
- In version 4.6, FieldInfos were extended to support per-field DocValues generation, to
- * allow updating NumericDocValues fields.
- *
- In version 4.8, checksum footers were added to the end of each index file for improved data
- * integrity. Specifically, the last 8 bytes of every index file contain the zlib-crc32
- * checksum of the file.
- *
- In version 4.9, DocValues has a new multi-valued numeric type (SortedNumeric) that is
- * suitable for faceting/sorting/analytics.
- *
- In version 5.4, DocValues have been improved to store more information on disk: addresses
- * for binary fields and ord indexes for multi-valued fields.
- *
- In version 6.0, Points were added, for multi-dimensional range/distance search.
- *
- In version 6.2, new Segment info format that reads/writes the index sort, to support index
- * sorting.
- *
- In version 7.0, DocValues have been improved to better support sparse doc values thanks to
- * an iterator API.
- *
- In version 8.0, postings have been enhanced to record, for each block of doc ids, the (term
- * freq, normalization factor) pairs that may trigger the maximum score of the block. This
- * information is recorded alongside skip data in order to be able to skip blocks of doc ids
- * if they may not produce high enough scores. Additionally doc values and norms has been
- * extended with jump-tables to make access O(1) instead of O(n), where n is the number of
- * elements to skip when advancing in the data.
- *
- In version 8.4, postings, positions, offsets and payload lengths have move to a more
- * performant encoding that is vectorized.
- *
- In version 8.6, index sort serialization is delegated to the sorts themselves, to allow
- * user-defined sorts to be used
- *
- In version 8.6, points fields split the index tree and leaf data into separate files, to
- * allow for different access patterns to the different data structures
- *
- In version 8.7, stored fields compression became adaptive to better handle documents with
- * smaller stored fields.
- *
- In version 9.0, vector-valued fields were added.
- *
- In version 9.1, vector-valued fields were modified to add a graph hierarchy.
- *
- In version 9.2, docs of vector-valued fields were moved from .vem to .vec and encoded by
- * IndexDISI. ordToDoc mappings was added to .vem.
- *
- In version 9.5, HNSW graph connections were changed to be delta-encoded with vints.
- * Additionally, metadata file size improvements were made by delta-encoding nodes by graph
- * layer and not writing the node ids for the zeroth layer.
- *
- In version 9.9, Vector scalar quantization support was added. Allowing the HNSW vector
- * format to utilize int8 quantized vectors for float32 vector search.
- *
- In version 9.12, skip data was refactored to have only two levels: every 128 docs and every
- * 4,06 docs, and to be inlined in postings lists. This resulted in a speedup for queries that
- * need skipping, especially conjunctions.
- *
- In version 10.1, block encoding changed to be optimized for int[] storage instead of
- * long[].
- *
- *
- *
- *
- * Limitations
- *
- *
- *
- *
Lucene uses a Java int
to refer to document numbers, and the index file format
- * uses an Int32
on-disk to store document numbers. This is a limitation of both the
- * index file format and the current implementation. Eventually these should be replaced with either
- * UInt64
values, or better yet, {@link org.apache.lucene.store.DataOutput#writeVInt
- * VInt} values which have no limit.
- */
+/** Lucene 10.1 file format. */
package org.apache.lucene.codecs.lucene101;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java
new file mode 100644
index 000000000000..62eff4d72ac4
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/BinarizedByteVectorValues.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
+
+import java.io.IOException;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Binarized byte vector values */
+abstract class BinarizedByteVectorValues extends ByteVectorValues {
+
+ /**
+ * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
+ * distances, the corrective terms are, in order
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the dot-product of the non-centered vector with the centroid
+ *
- the sum of quantized components
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the l2norm of the centered vector
+ *
- the sum of quantized components
+ *
+ *
+ * @param vectorOrd the vector ordinal
+ * @return the corrective terms
+ * @throws IOException if an I/O error occurs
+ */
+ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd)
+ throws IOException;
+
+ /**
+ * @return the quantizer used to quantize the vectors
+ */
+ public abstract OptimizedScalarQuantizer getQuantizer();
+
+ public abstract float[] getCentroid() throws IOException;
+
+ int discretizedDimensions() {
+ return discretize(dimension(), 64);
+ }
+
+ /**
+ * Return a {@link VectorScorer} for the given query vector.
+ *
+ * @param query the query vector
+ * @return a {@link VectorScorer} instance or null
+ */
+ public abstract VectorScorer scorer(float[] query) throws IOException;
+
+ @Override
+ public abstract BinarizedByteVectorValues copy() throws IOException;
+
+ float getCentroidDP() throws IOException {
+ // this only gets executed on-merge
+ float[] centroid = getCentroid();
+ return VectorUtil.dotProduct(centroid, centroid);
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java
new file mode 100644
index 000000000000..a25fc0807563
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryFlatVectorsScorer.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult;
+
+/** Vector scorer over binarized vector values */
+public class Lucene102BinaryFlatVectorsScorer implements FlatVectorsScorer {
+ private final FlatVectorsScorer nonQuantizedDelegate;
+ private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
+
+ public Lucene102BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
+ this.nonQuantizedDelegate = nonQuantizedDelegate;
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
+ throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues) {
+ throw new UnsupportedOperationException(
+ "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format");
+ }
+ return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
+ throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+ OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
+ float[] centroid = binarizedVectors.getCentroid();
+ // We make a copy as the quantization process mutates the input
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ if (similarityFunction == COSINE) {
+ VectorUtil.l2normalize(copy);
+ }
+ target = copy;
+ byte[] initial = new byte[target.length];
+ byte[] quantized = new byte[QUERY_BITS * binarizedVectors.discretizedDimensions() / 8];
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections =
+ quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
+ transposeHalfByte(initial, quantized);
+ BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections);
+ return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
+ }
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
+ throws IOException {
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
+ BinarizedByteVectorValues targetVectors) {
+ return new BinarizedRandomVectorScorerSupplier(
+ scoringVectors, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
+ }
+
+ /** Vector scorer supplier over binarized vector values */
+ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
+ private final Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues
+ queryVectors;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ BinarizedRandomVectorScorerSupplier(
+ Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction) {
+ this.queryVectors = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ byte[] vector = queryVectors.vectorValue(ord);
+ QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord);
+ BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms);
+ return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new BinarizedRandomVectorScorerSupplier(
+ queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ }
+ }
+
+ /** A binarized query representing its quantized form along with factors */
+ public record BinaryQueryVector(
+ byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {}
+
+ /** Vector scorer over binarized vector values */
+ public static class BinarizedRandomVectorScorer
+ extends RandomVectorScorer.AbstractRandomVectorScorer {
+ private final BinaryQueryVector queryVector;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ public BinarizedRandomVectorScorer(
+ BinaryQueryVector queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction) {
+ super(targetVectors);
+ this.queryVector = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public float score(int targetOrd) throws IOException {
+ byte[] quantizedQuery = queryVector.vector();
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ float qcDist = VectorUtil.int4BitDotProduct(quantizedQuery, binaryCode);
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections =
+ queryVector.quantizationResult();
+ OptimizedScalarQuantizer.QuantizationResult indexCorrections =
+ targetVectors.getCorrectiveTerms(targetOrd);
+ float x1 = indexCorrections.quantizedComponentSum();
+ float ax = indexCorrections.lowerInterval();
+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+ float lx = indexCorrections.upperInterval() - ax;
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ float score =
+ ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score =
+ queryCorrections.additionalCorrection()
+ + indexCorrections.additionalCorrection()
+ - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score +=
+ queryCorrections.additionalCorrection()
+ + indexCorrections.additionalCorrection()
+ - targetVectors.getCentroidDP();
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..ae48a220b235
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 Codec for
+ * encoding/decoding binary quantized vectors The binary quantization format used here is a
+ * per-vector optimized scalar quantization. Also see {@link
+ * org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are:
+ *
+ *
+ * - Estimating the distance between two vectors using their centroid normalized distance. This
+ * requires some additional corrective factors, but allows for centroid normalization to
+ * occur.
+ *
- Optimized scalar quantization to single bit level of centroid normalized vectors.
+ *
- Asymmetric quantization of vectors, where query vectors are quantized to half-byte (4 bits)
+ * precision (normalized to the centroid) and then compared directly against the single bit
+ * quantized vectors in the index.
+ *
- Transforming the half-byte quantized query vectors in such a way that the comparison with
+ * single bit vectors can be done with bit arithmetic.
+ *
+ *
+ * The format is stored in two files:
+ *
+ * .veb (vector data) file
+ *
+ * Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's
+ * corrective factors. At the end of the file, additional information is stored for vector ordinal
+ * to centroid ordinal mapping and sparse vector information.
+ *
+ *
+ * - For each vector:
+ *
+ * - [byte] the binary quantized values, each byte holds 8 bits.
+ *
- [float] the optimized quantiles and an additional similarity dependent
+ * corrective factor.
+ *
- short the sum of the quantized components
+ *
+ * - After the vectors, sparse vector information keeping track of monotonic blocks.
+ *
+ *
+ * .vemb (vector metadata) file
+ *
+ * Stores the metadata for the vectors. This includes the number of vectors, the number of
+ * dimensions, and file offset information.
+ *
+ *
+ * - int the field number
+ *
- int the vector encoding ordinal
+ *
- int the vector similarity ordinal
+ *
- vint the vector dimensions
+ *
- vlong the offset to the vector data in the .veb file
+ *
- vlong the length of the vector data in the .veb file
+ *
- vint the number of vectors
+ *
- [float] the centroid
+ *
- float the centroid square magnitude
+ *
- The sparse vector information, if required, mapping vector ordinal to doc ID
+ *
+ */
+public class Lucene102BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
+
+ public static final byte QUERY_BITS = 4;
+ public static final byte INDEX_BITS = 1;
+
+ public static final String BINARIZED_VECTOR_COMPONENT = "BVEC";
+ public static final String NAME = "Lucene102BinaryQuantizedVectorsFormat";
+
+ static final int VERSION_START = 0;
+ static final int VERSION_CURRENT = VERSION_START;
+ static final String META_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatMeta";
+ static final String VECTOR_DATA_CODEC_NAME = "Lucene102BinaryQuantizedVectorsFormatData";
+ static final String META_EXTENSION = "vemb";
+ static final String VECTOR_DATA_EXTENSION = "veb";
+ static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
+
+ private static final FlatVectorsFormat rawVectorFormat =
+ new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ private static final Lucene102BinaryFlatVectorsScorer scorer =
+ new Lucene102BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ /** Creates a new instance with the default number of vectors per cluster. */
+ public Lucene102BinaryQuantizedVectorsFormat() {
+ super(NAME);
+ }
+
+ @Override
+ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene102BinaryQuantizedVectorsWriter(
+ scorer, rawVectorFormat.fieldsWriter(state), state);
+ }
+
+ @Override
+ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene102BinaryQuantizedVectorsReader(
+ state, rawVectorFormat.fieldsReader(state), scorer);
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return 1024;
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene102BinaryQuantizedVectorsFormat(name="
+ + NAME
+ + ", flatVectorScorer="
+ + scorer
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java
new file mode 100644
index 000000000000..a698ea9a7541
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java
@@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.CorruptIndexException;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.ChecksumIndexInput;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.ReadAdvice;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Reader for binary quantized vectors in the Lucene 10.2 format. */
+class Lucene102BinaryQuantizedVectorsReader extends FlatVectorsReader {
+
+ private static final long SHALLOW_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(Lucene102BinaryQuantizedVectorsReader.class);
+
+ private final Map fields = new HashMap<>();
+ private final IndexInput quantizedVectorData;
+ private final FlatVectorsReader rawVectorsReader;
+ private final Lucene102BinaryFlatVectorsScorer vectorScorer;
+
+ Lucene102BinaryQuantizedVectorsReader(
+ SegmentReadState state,
+ FlatVectorsReader rawVectorsReader,
+ Lucene102BinaryFlatVectorsScorer vectorsScorer)
+ throws IOException {
+ super(vectorsScorer);
+ this.vectorScorer = vectorsScorer;
+ this.rawVectorsReader = rawVectorsReader;
+ int versionMeta = -1;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene102BinaryQuantizedVectorsFormat.META_EXTENSION);
+ boolean success = false;
+ try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
+ Throwable priorE = null;
+ try {
+ versionMeta =
+ CodecUtil.checkIndexHeader(
+ meta,
+ Lucene102BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_START,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ readFields(meta, state.fieldInfos);
+ } catch (Throwable exception) {
+ priorE = exception;
+ } finally {
+ CodecUtil.checkFooter(meta, priorE);
+ }
+ quantizedVectorData =
+ openDataInput(
+ state,
+ versionMeta,
+ Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
+ Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ // Quantized vectors are accessed randomly from their node ID stored in the HNSW
+ // graph.
+ state.context.withReadAdvice(ReadAdvice.RANDOM));
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
+ for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
+ FieldInfo info = infos.fieldInfo(fieldNumber);
+ if (info == null) {
+ throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
+ }
+ FieldEntry fieldEntry = readField(meta, info);
+ validateFieldEntry(info, fieldEntry);
+ fields.put(info.name, fieldEntry);
+ }
+ }
+
+ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ int dimension = info.getVectorDimension();
+ if (dimension != fieldEntry.dimension) {
+ throw new IllegalStateException(
+ "Inconsistent vector dimension for field=\""
+ + info.name
+ + "\"; "
+ + dimension
+ + " != "
+ + fieldEntry.dimension);
+ }
+
+ int binaryDims = discretize(dimension, 64) / 8;
+ long numQuantizedVectorBytes =
+ Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size);
+ if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
+ throw new IllegalStateException(
+ "Binarized vector data length "
+ + fieldEntry.vectorDataLength
+ + " not matching size = "
+ + fieldEntry.size
+ + " * (binaryBytes="
+ + binaryDims
+ + " + 14"
+ + ") = "
+ + numQuantizedVectorBytes);
+ }
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ return vectorScorer.getRandomVectorScorer(
+ fi.similarityFunction,
+ OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData),
+ target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
+ return rawVectorsReader.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ rawVectorsReader.checkIntegrity();
+ CodecUtil.checksumEntireFile(quantizedVectorData);
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fi.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
+ }
+ OffHeapBinarizedVectorValues bvv =
+ OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData);
+ return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return rawVectorsReader.getByteVectorValues(field);
+ }
+
+ @Override
+ public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
+ throws IOException {
+ rawVectorsReader.search(field, target, knnCollector, acceptDocs);
+ }
+
+ @Override
+ public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
+ throws IOException {
+ if (knnCollector.k() == 0) return;
+ final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
+ if (scorer == null) return;
+ OrdinalTranslatedKnnCollector collector =
+ new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
+ Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
+ for (int i = 0; i < scorer.maxOrd(); i++) {
+ if (acceptedOrds == null || acceptedOrds.get(i)) {
+ collector.collect(i, scorer.score(i));
+ collector.incVisitedCount(1);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(quantizedVectorData, rawVectorsReader);
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size +=
+ RamUsageEstimator.sizeOfMap(
+ fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
+ size += rawVectorsReader.ramBytesUsed();
+ return size;
+ }
+
+ public float[] getCentroid(String field) {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry != null) {
+ return fieldEntry.centroid;
+ }
+ return null;
+ }
+
+ private static IndexInput openDataInput(
+ SegmentReadState state,
+ int versionMeta,
+ String fileExtension,
+ String codecName,
+ IOContext context)
+ throws IOException {
+ String fileName =
+ IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
+ IndexInput in = state.directory.openInput(fileName, context);
+ boolean success = false;
+ try {
+ int versionVectorData =
+ CodecUtil.checkIndexHeader(
+ in,
+ codecName,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_START,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ if (versionMeta != versionVectorData) {
+ throw new CorruptIndexException(
+ "Format versions mismatch: meta="
+ + versionMeta
+ + ", "
+ + codecName
+ + "="
+ + versionVectorData,
+ in);
+ }
+ CodecUtil.retrieveChecksum(in);
+ success = true;
+ return in;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(in);
+ }
+ }
+ }
+
+ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
+ VectorEncoding vectorEncoding = readVectorEncoding(input);
+ VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
+ if (similarityFunction != info.getVectorSimilarityFunction()) {
+ throw new IllegalStateException(
+ "Inconsistent vector similarity function for field=\""
+ + info.name
+ + "\"; "
+ + similarityFunction
+ + " != "
+ + info.getVectorSimilarityFunction());
+ }
+ return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
+ }
+
+ private record FieldEntry(
+ VectorSimilarityFunction similarityFunction,
+ VectorEncoding vectorEncoding,
+ int dimension,
+ int descritizedDimension,
+ long vectorDataOffset,
+ long vectorDataLength,
+ int size,
+ float[] centroid,
+ float centroidDP,
+ OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) {
+
+ static FieldEntry create(
+ IndexInput input,
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ int dimension = input.readVInt();
+ long vectorDataOffset = input.readVLong();
+ long vectorDataLength = input.readVLong();
+ int size = input.readVInt();
+ final float[] centroid;
+ float centroidDP = 0;
+ if (size > 0) {
+ centroid = new float[dimension];
+ input.readFloats(centroid, 0, dimension);
+ centroidDP = Float.intBitsToFloat(input.readInt());
+ } else {
+ centroid = null;
+ }
+ OrdToDocDISIReaderConfiguration conf =
+ OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
+ return new FieldEntry(
+ similarityFunction,
+ vectorEncoding,
+ dimension,
+ discretize(dimension, 64),
+ vectorDataOffset,
+ vectorDataLength,
+ size,
+ centroid,
+ centroidDP,
+ conf);
+ }
+ }
+
+ /** Binarized vector values holding row and quantized vector values */
+ protected static final class BinarizedVectorValues extends FloatVectorValues {
+ private final FloatVectorValues rawVectorValues;
+ private final BinarizedByteVectorValues quantizedVectorValues;
+
+ BinarizedVectorValues(
+ FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
+ this.rawVectorValues = rawVectorValues;
+ this.quantizedVectorValues = quantizedVectorValues;
+ }
+
+ @Override
+ public int dimension() {
+ return rawVectorValues.dimension();
+ }
+
+ @Override
+ public int size() {
+ return rawVectorValues.size();
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return rawVectorValues.vectorValue(ord);
+ }
+
+ @Override
+ public BinarizedVectorValues copy() throws IOException {
+ return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return rawVectorValues.getAcceptOrds(acceptDocs);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return rawVectorValues.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return rawVectorValues.iterator();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) throws IOException {
+ return quantizedVectorValues.scorer(query);
+ }
+
+ BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
+ return quantizedVectorValues;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java
new file mode 100644
index 000000000000..e15a8dab31b4
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsWriter.java
@@ -0,0 +1,950 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT;
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.index.DocsWithFieldSet;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.hppc.FloatArrayList;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/** Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */
+public class Lucene102BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
+ private static final long SHALLOW_RAM_BYTES_USED =
+ shallowSizeOfInstance(Lucene102BinaryQuantizedVectorsWriter.class);
+
+ private final SegmentWriteState segmentWriteState;
+ private final List fields = new ArrayList<>();
+ private final IndexOutput meta, binarizedVectorData;
+ private final FlatVectorsWriter rawVectorDelegate;
+ private final Lucene102BinaryFlatVectorsScorer vectorsScorer;
+ private boolean finished;
+
+ /**
+ * Sole constructor
+ *
+ * @param vectorsScorer the scorer to use for scoring vectors
+ */
+ protected Lucene102BinaryQuantizedVectorsWriter(
+ Lucene102BinaryFlatVectorsScorer vectorsScorer,
+ FlatVectorsWriter rawVectorDelegate,
+ SegmentWriteState state)
+ throws IOException {
+ super(vectorsScorer);
+ this.vectorsScorer = vectorsScorer;
+ this.segmentWriteState = state;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene102BinaryQuantizedVectorsFormat.META_EXTENSION);
+
+ String binarizedVectorDataFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
+ this.rawVectorDelegate = rawVectorDelegate;
+ boolean success = false;
+ try {
+ meta = state.directory.createOutput(metaFileName, state.context);
+ binarizedVectorData =
+ state.directory.createOutput(binarizedVectorDataFileName, state.context);
+
+ CodecUtil.writeIndexHeader(
+ meta,
+ Lucene102BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ CodecUtil.writeIndexHeader(
+ binarizedVectorData,
+ Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ Lucene102BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ @Override
+ public FlatFieldVectorsWriter> addField(FieldInfo fieldInfo) throws IOException {
+ FlatFieldVectorsWriter> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ @SuppressWarnings("unchecked")
+ FieldWriter fieldWriter =
+ new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate);
+ fields.add(fieldWriter);
+ return fieldWriter;
+ }
+ return rawVectorDelegate;
+ }
+
+ @Override
+ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
+ rawVectorDelegate.flush(maxDoc, sortMap);
+ for (FieldWriter field : fields) {
+ // after raw vectors are written, normalize vectors for clustering and quantization
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ field.normalizeVectors();
+ }
+ final float[] clusterCenter;
+ int vectorCount = field.flatFieldVectorsWriter.getVectors().size();
+ clusterCenter = new float[field.dimensionSums.length];
+ if (vectorCount > 0) {
+ for (int i = 0; i < field.dimensionSums.length; i++) {
+ clusterCenter[i] = field.dimensionSums[i] / vectorCount;
+ }
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ VectorUtil.l2normalize(clusterCenter);
+ }
+ }
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ OptimizedScalarQuantizer quantizer =
+ new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction());
+ if (sortMap == null) {
+ writeField(field, clusterCenter, maxDoc, quantizer);
+ } else {
+ writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer);
+ }
+ field.finish();
+ }
+ }
+
+ private void writeField(
+ FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer)
+ throws IOException {
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeBinarizedVectors(fieldData, clusterCenter, quantizer);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
+
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ vectorDataLength,
+ clusterCenter,
+ centroidDp,
+ fieldData.getDocsWithFieldSet());
+ }
+
+ private void writeBinarizedVectors(
+ FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64);
+ byte[] quantizationScratch = new byte[discreteDims];
+ byte[] vector = new byte[discreteDims / 8];
+ for (int i = 0; i < fieldData.getVectors().size(); i++) {
+ float[] v = fieldData.getVectors().get(i);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ scalarQuantizer.scalarQuantize(v, quantizationScratch, INDEX_BITS, clusterCenter);
+ packAsBinary(quantizationScratch, vector);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0
+ && corrections.quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeSortingField(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int maxDoc,
+ Sorter.DocMap sortMap,
+ OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ final int[] ordMap =
+ new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
+
+ DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
+ mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
+
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
+ long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+
+ float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ quantizedVectorLength,
+ clusterCenter,
+ centroidDp,
+ newDocsWithField);
+ }
+
+ private void writeSortedBinarizedVectors(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int[] ordMap,
+ OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ int discreteDims = discretize(fieldData.fieldInfo.getVectorDimension(), 64);
+ byte[] quantizationScratch = new byte[discreteDims];
+ byte[] vector = new byte[discreteDims / 8];
+ for (int ordinal : ordMap) {
+ float[] v = fieldData.getVectors().get(ordinal);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ scalarQuantizer.scalarQuantize(v, quantizationScratch, INDEX_BITS, clusterCenter);
+ packAsBinary(quantizationScratch, vector);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0
+ && corrections.quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeMeta(
+ FieldInfo field,
+ int maxDoc,
+ long vectorDataOffset,
+ long vectorDataLength,
+ float[] clusterCenter,
+ float centroidDp,
+ DocsWithFieldSet docsWithField)
+ throws IOException {
+ meta.writeInt(field.number);
+ meta.writeInt(field.getVectorEncoding().ordinal());
+ meta.writeInt(field.getVectorSimilarityFunction().ordinal());
+ meta.writeVInt(field.getVectorDimension());
+ meta.writeVLong(vectorDataOffset);
+ meta.writeVLong(vectorDataLength);
+ int count = docsWithField.cardinality();
+ meta.writeVInt(count);
+ if (count > 0) {
+ final ByteBuffer buffer =
+ ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES)
+ .order(ByteOrder.LITTLE_ENDIAN);
+ buffer.asFloatBuffer().put(clusterCenter);
+ meta.writeBytes(buffer.array(), buffer.array().length);
+ meta.writeInt(Float.floatToIntBits(centroidDp));
+ }
+ OrdToDocDISIReaderConfiguration.writeStoredMeta(
+ DIRECT_MONOTONIC_BLOCK_SHIFT, meta, binarizedVectorData, count, maxDoc, docsWithField);
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ throw new IllegalStateException("already finished");
+ }
+ finished = true;
+ rawVectorDelegate.finish();
+ if (meta != null) {
+ // write end of fields marker
+ meta.writeInt(-1);
+ CodecUtil.writeFooter(meta);
+ }
+ if (binarizedVectorData != null) {
+ CodecUtil.writeFooter(binarizedVectorData);
+ }
+ }
+
+ @Override
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ FloatVectorValues floatVectorValues =
+ MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ BinarizedFloatVectorValues binarizedVectorValues =
+ new BinarizedFloatVectorValues(
+ floatVectorValues,
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()),
+ centroid);
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ DocsWithFieldSet docsWithField =
+ writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ centroidDp,
+ docsWithField);
+ } else {
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ }
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorAndQueryData(
+ IndexOutput binarizedVectorData,
+ IndexOutput binarizedQueryData,
+ FloatVectorValues floatVectorValues,
+ float[] centroid,
+ OptimizedScalarQuantizer binaryQuantizer)
+ throws IOException {
+ int discretizedDimension = discretize(floatVectorValues.dimension(), 64);
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()];
+ byte[] toIndex = new byte[discretizedDimension / 8];
+ byte[] toQuery = new byte[(discretizedDimension / 8) * QUERY_BITS];
+ KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write index vector
+ OptimizedScalarQuantizer.QuantizationResult[] r =
+ binaryQuantizer.multiScalarQuantize(
+ floatVectorValues.vectorValue(iterator.index()),
+ quantizationScratch,
+ new byte[] {INDEX_BITS, QUERY_BITS},
+ centroid);
+ // pack and store document bit vector
+ packAsBinary(quantizationScratch[0], toIndex);
+ binarizedVectorData.writeBytes(toIndex, toIndex.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection()));
+ assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) r[0].quantizedComponentSum());
+ docsWithField.add(docV);
+
+ // pack and store the 4bit query vector
+ transposeHalfByte(quantizationScratch[1], toQuery);
+ binarizedQueryData.writeBytes(toQuery, toQuery.length);
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval()));
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval()));
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection()));
+ assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff;
+ binarizedQueryData.writeShort((short) r[1].quantizedComponentSum());
+ }
+ return docsWithField;
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorData(
+ IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write vector
+ byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index());
+ output.writeBytes(binaryValue, binaryValue.length);
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ binarizedByteVectorValues.getCorrectiveTerms(iterator.index());
+ output.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0
+ && corrections.quantizedComponentSum() <= 0xffff;
+ output.writeShort((short) corrections.quantizedComponentSum());
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
+ @Override
+ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float cDotC;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC);
+ }
+ return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
+ }
+
+ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ SegmentWriteState segmentWriteState,
+ FieldInfo fieldInfo,
+ MergeState mergeState,
+ float[] centroid,
+ float cDotC)
+ throws IOException {
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ final IndexOutput tempQuantizedVectorData =
+ segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(), "temp", segmentWriteState.context);
+ final IndexOutput tempScoreQuantizedVectorData =
+ segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(), "score_temp", segmentWriteState.context);
+ IndexInput binarizedDataInput = null;
+ IndexInput binarizedScoreDataInput = null;
+ boolean success = false;
+ OptimizedScalarQuantizer quantizer =
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
+ try {
+ FloatVectorValues floatVectorValues =
+ MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ DocsWithFieldSet docsWithField =
+ writeBinarizedVectorAndQueryData(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ floatVectorValues,
+ centroid,
+ quantizer);
+ CodecUtil.writeFooter(tempQuantizedVectorData);
+ IOUtils.close(tempQuantizedVectorData);
+ binarizedDataInput =
+ segmentWriteState.directory.openInput(
+ tempQuantizedVectorData.getName(), segmentWriteState.context);
+ binarizedVectorData.copyBytes(
+ binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength());
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ CodecUtil.retrieveChecksum(binarizedDataInput);
+ CodecUtil.writeFooter(tempScoreQuantizedVectorData);
+ IOUtils.close(tempScoreQuantizedVectorData);
+ binarizedScoreDataInput =
+ segmentWriteState.directory.openInput(
+ tempScoreQuantizedVectorData.getName(), segmentWriteState.context);
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ cDotC,
+ docsWithField);
+ success = true;
+ final IndexInput finalBinarizedDataInput = binarizedDataInput;
+ final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput;
+ OffHeapBinarizedVectorValues vectorValues =
+ new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ centroid,
+ cDotC,
+ quantizer,
+ fieldInfo.getVectorSimilarityFunction(),
+ vectorsScorer,
+ finalBinarizedDataInput);
+ RandomVectorScorerSupplier scorerSupplier =
+ vectorsScorer.getRandomVectorScorerSupplier(
+ fieldInfo.getVectorSimilarityFunction(),
+ new OffHeapBinarizedQueryVectorValues(
+ finalBinarizedScoreDataInput,
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality()),
+ vectorValues);
+ return new BinarizedCloseableRandomVectorScorerSupplier(
+ scorerSupplier,
+ vectorValues,
+ () -> {
+ IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName());
+ });
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ binarizedDataInput,
+ binarizedScoreDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName());
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(meta, binarizedVectorData, rawVectorDelegate);
+ }
+
+ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+ vectorsReader = candidateReader.getFieldReader(fieldName);
+ }
+ if (vectorsReader instanceof Lucene102BinaryQuantizedVectorsReader reader) {
+ return reader.getCentroid(fieldName);
+ }
+ return null;
+ }
+
+ static int mergeAndRecalculateCentroids(
+ MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException {
+ boolean recalculate = false;
+ int totalVectorCount = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null
+ || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) {
+ continue;
+ }
+ float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name);
+ int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size();
+ if (vectorCount == 0) {
+ continue;
+ }
+ totalVectorCount += vectorCount;
+ // If there aren't centroids, or previously clustered with more than one cluster
+ // or if there are deleted docs, we must recalculate the centroid
+ if (centroid == null || mergeState.liveDocs[i] != null) {
+ recalculate = true;
+ break;
+ }
+ for (int j = 0; j < centroid.length; j++) {
+ mergedCentroid[j] += centroid[j] * vectorCount;
+ }
+ }
+ if (recalculate) {
+ return calculateCentroid(mergeState, fieldInfo, mergedCentroid);
+ } else {
+ for (int j = 0; j < mergedCentroid.length; j++) {
+ mergedCentroid[j] = mergedCentroid[j] / totalVectorCount;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(mergedCentroid);
+ }
+ return totalVectorCount;
+ }
+ }
+
+ static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid)
+ throws IOException {
+ assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+ // clear out the centroid
+ Arrays.fill(centroid, 0);
+ int count = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null) continue;
+ FloatVectorValues vectorValues =
+ mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) {
+ continue;
+ }
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc();
+ doc != DocIdSetIterator.NO_MORE_DOCS;
+ doc = iterator.nextDoc()) {
+ ++count;
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ }
+ if (count == 0) {
+ return count;
+ }
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= count;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(centroid);
+ }
+ return count;
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = SHALLOW_RAM_BYTES_USED;
+ for (FieldWriter field : fields) {
+ // the field tracks the delegate field usage
+ total += field.ramBytesUsed();
+ }
+ return total;
+ }
+
+ static class FieldWriter extends FlatFieldVectorsWriter {
+ private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
+ private final FieldInfo fieldInfo;
+ private boolean finished;
+ private final FlatFieldVectorsWriter flatFieldVectorsWriter;
+ private final float[] dimensionSums;
+ private final FloatArrayList magnitudes = new FloatArrayList();
+
+ FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) {
+ this.fieldInfo = fieldInfo;
+ this.flatFieldVectorsWriter = flatFieldVectorsWriter;
+ this.dimensionSums = new float[fieldInfo.getVectorDimension()];
+ }
+
+ @Override
+ public List getVectors() {
+ return flatFieldVectorsWriter.getVectors();
+ }
+
+ public void normalizeVectors() {
+ for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) {
+ float[] vector = flatFieldVectorsWriter.getVectors().get(i);
+ float magnitude = magnitudes.get(i);
+ for (int j = 0; j < vector.length; j++) {
+ vector[j] /= magnitude;
+ }
+ }
+ }
+
+ @Override
+ public DocsWithFieldSet getDocsWithFieldSet() {
+ return flatFieldVectorsWriter.getDocsWithFieldSet();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ return;
+ }
+ assert flatFieldVectorsWriter.isFinished();
+ finished = true;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return finished && flatFieldVectorsWriter.isFinished();
+ }
+
+ @Override
+ public void addValue(int docID, float[] vectorValue) throws IOException {
+ flatFieldVectorsWriter.addValue(docID, vectorValue);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
+ float divisor = (float) Math.sqrt(dp);
+ magnitudes.add(divisor);
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += (vectorValue[i] / divisor);
+ }
+ } else {
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += vectorValue[i];
+ }
+ }
+ }
+
+ @Override
+ public float[] copyValue(float[] vectorValue) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size += flatFieldVectorsWriter.ramBytesUsed();
+ size += magnitudes.ramBytesUsed();
+ return size;
+ }
+ }
+
+ // When accessing vectorValue method, targerOrd here means a row ordinal.
+ static class OffHeapBinarizedQueryVectorValues {
+ private final IndexInput slice;
+ private final int dimension;
+ private final int size;
+ protected final byte[] binaryValue;
+ protected final ByteBuffer byteBuffer;
+ private final int byteSize;
+ protected final float[] correctiveValues;
+ private int lastOrd = -1;
+ private int quantizedComponentSum;
+
+ OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) {
+ this.slice = data;
+ this.dimension = dimension;
+ this.size = size;
+ // 4x the quantized binary dimensions
+ int binaryDimensions = (discretize(dimension, 64) / 8) * QUERY_BITS;
+ this.byteBuffer = ByteBuffer.allocate(binaryDimensions);
+ this.binaryValue = byteBuffer.array();
+ // + 1 for the quantized sum
+ this.correctiveValues = new float[3];
+ this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES;
+ }
+
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd)
+ throws IOException {
+ if (lastOrd == targetOrd) {
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+ vectorValue(targetOrd);
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public int dimension() {
+ return dimension;
+ }
+
+ public OffHeapBinarizedQueryVectorValues copy() throws IOException {
+ return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size);
+ }
+
+ public IndexInput getSlice() {
+ return slice;
+ }
+
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(binaryValue, 0, binaryValue.length);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+ }
+
+ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
+ private OptimizedScalarQuantizer.QuantizationResult corrections;
+ private final byte[] binarized;
+ private final byte[] initQuantized;
+ private final float[] centroid;
+ private final FloatVectorValues values;
+ private final OptimizedScalarQuantizer quantizer;
+
+ private int lastOrd = -1;
+
+ BinarizedFloatVectorValues(
+ FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) {
+ this.values = delegate;
+ this.quantizer = quantizer;
+ this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
+ this.initQuantized = new byte[delegate.dimension()];
+ this.centroid = centroid;
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
+ if (ord != lastOrd) {
+ throw new IllegalStateException(
+ "attempt to retrieve corrective terms for different ord "
+ + ord
+ + " than the quantization was done for: "
+ + lastOrd);
+ }
+ return corrections;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ if (ord != lastOrd) {
+ binarize(ord);
+ lastOrd = ord;
+ }
+ return binarized;
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] getCentroid() throws IOException {
+ return centroid;
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() throws IOException {
+ return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid);
+ }
+
+ private void binarize(int ord) throws IOException {
+ corrections =
+ quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
+ packAsBinary(initQuantized, binarized);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+ }
+
+ static class BinarizedCloseableRandomVectorScorerSupplier
+ implements CloseableRandomVectorScorerSupplier {
+ private final RandomVectorScorerSupplier supplier;
+ private final KnnVectorValues vectorValues;
+ private final Closeable onClose;
+
+ BinarizedCloseableRandomVectorScorerSupplier(
+ RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) {
+ this.supplier = supplier;
+ this.onClose = onClose;
+ this.vectorValues = vectorValues;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ return supplier.scorer(ord);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return supplier.copy();
+ }
+
+ @Override
+ public void close() throws IOException {
+ onClose.close();
+ }
+
+ @Override
+ public int totalVectorCount() {
+ return vectorValues.size();
+ }
+ }
+
+ static final class NormalizedFloatVectorValues extends FloatVectorValues {
+ private final FloatVectorValues values;
+ private final float[] normalizedVector;
+
+ NormalizedFloatVectorValues(FloatVectorValues values) {
+ this.values = values;
+ this.normalizedVector = new float[values.dimension()];
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
+ VectorUtil.l2normalize(normalizedVector);
+ return normalizedVector;
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public NormalizedFloatVectorValues copy() throws IOException {
+ return new NormalizedFloatVectorValues(values.copy());
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..55fe8807356b
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.util.hnsw.HnswGraph;
+
+/**
+ * A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary
+ * quantized using {@link Lucene102BinaryQuantizedVectorsFormat} before being stored in the graph.
+ */
+public class Lucene102HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
+
+ public static final String NAME = "Lucene102HnswBinaryQuantizedVectorsFormat";
+
+ /**
+ * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
+ * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
+ */
+ private final int maxConn;
+
+ /**
+ * The number of candidate neighbors to track while searching the graph for each newly inserted
+ * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
+ * for details.
+ */
+ private final int beamWidth;
+
+ /** The format for storing, reading, merging vectors on disk */
+ private static final FlatVectorsFormat flatVectorsFormat =
+ new Lucene102BinaryQuantizedVectorsFormat();
+
+ private final int numMergeWorkers;
+ private final TaskExecutor mergeExec;
+
+ /** Constructs a format using default graph construction parameters */
+ public Lucene102HnswBinaryQuantizedVectorsFormat() {
+ this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ */
+ public Lucene102HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
+ this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters and scalar quantization.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
+ * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
+ * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
+ * generated by this format to do the merge
+ */
+ public Lucene102HnswBinaryQuantizedVectorsFormat(
+ int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
+ super(NAME);
+ if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
+ throw new IllegalArgumentException(
+ "maxConn must be positive and less than or equal to "
+ + MAXIMUM_MAX_CONN
+ + "; maxConn="
+ + maxConn);
+ }
+ if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
+ throw new IllegalArgumentException(
+ "beamWidth must be positive and less than or equal to "
+ + MAXIMUM_BEAM_WIDTH
+ + "; beamWidth="
+ + beamWidth);
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ if (numMergeWorkers == 1 && mergeExec != null) {
+ throw new IllegalArgumentException(
+ "No executor service is needed as we'll use single thread to merge");
+ }
+ this.numMergeWorkers = numMergeWorkers;
+ if (mergeExec != null) {
+ this.mergeExec = new TaskExecutor(mergeExec);
+ } else {
+ this.mergeExec = null;
+ }
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene99HnswVectorsWriter(
+ state,
+ maxConn,
+ beamWidth,
+ flatVectorsFormat.fieldsWriter(state),
+ numMergeWorkers,
+ mergeExec);
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return 1024;
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java
new file mode 100644
index 000000000000..13eef8b263b5
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/OffHeapBinarizedVectorValues.java
@@ -0,0 +1,383 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.packed.DirectMonotonicReader;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+/**
+ * Binarized vector values loaded from off-heap
+ *
+ * @lucene.internal
+ */
+public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues {
+
+ final int dimension;
+ final int size;
+ final int numBytes;
+ final VectorSimilarityFunction similarityFunction;
+ final FlatVectorsScorer vectorsScorer;
+
+ final IndexInput slice;
+ final byte[] binaryValue;
+ final ByteBuffer byteBuffer;
+ final int byteSize;
+ private int lastOrd = -1;
+ final float[] correctiveValues;
+ int quantizedComponentSum;
+ final OptimizedScalarQuantizer binaryQuantizer;
+ final float[] centroid;
+ final float centroidDp;
+ private final int discretizedDimensions;
+
+ OffHeapBinarizedVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer quantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ this.dimension = dimension;
+ this.size = size;
+ this.similarityFunction = similarityFunction;
+ this.vectorsScorer = vectorsScorer;
+ this.slice = slice;
+ this.centroid = centroid;
+ this.centroidDp = centroidDp;
+ this.numBytes = discretize(dimension, 64) / 8;
+ this.correctiveValues = new float[3];
+ this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES;
+ this.byteBuffer = ByteBuffer.allocate(numBytes);
+ this.binaryValue = byteBuffer.array();
+ this.binaryQuantizer = quantizer;
+ this.discretizedDimensions = discretize(dimension, 64);
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+
+ @Override
+ public int discretizedDimensions() {
+ return discretizedDimensions;
+ }
+
+ @Override
+ public float getCentroidDP() {
+ return centroidDp;
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd)
+ throws IOException {
+ if (lastOrd == targetOrd) {
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+ slice.seek(((long) targetOrd * byteSize) + numBytes);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ return binaryQuantizer;
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public int getVectorByteLength() {
+ return numBytes;
+ }
+
+ static OffHeapBinarizedVectorValues load(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ OptimizedScalarQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ float[] centroid,
+ float centroidDp,
+ long quantizedVectorDataOffset,
+ long quantizedVectorDataLength,
+ IndexInput vectorData)
+ throws IOException {
+ if (configuration.isEmpty()) {
+ return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
+ }
+ assert centroid != null;
+ IndexInput bytesSlice =
+ vectorData.slice(
+ "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
+ if (configuration.isDense()) {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ } else {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ vectorData,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ }
+ }
+
+ /** Dense off-heap binarized vector values */
+ static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ DenseOffHeapVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() throws IOException {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return acceptDocs;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ DenseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+ }
+
+ /** Sparse off-heap binarized vector values */
+ private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ private final DirectMonotonicReader ordToDoc;
+ private final IndexedDISI disi;
+ // dataIn was used to init a new IndexedDIS for #randomAccess()
+ private final IndexInput dataIn;
+ private final OrdToDocDISIReaderConfiguration configuration;
+
+ SparseOffHeapVectorValues(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer binaryQuantizer,
+ IndexInput dataIn,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice)
+ throws IOException {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ this.configuration = configuration;
+ this.dataIn = dataIn;
+ this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
+ this.disi = configuration.getIndexedDISI(dataIn);
+ }
+
+ @Override
+ public SparseOffHeapVectorValues copy() throws IOException {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ dataIn,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return (int) ordToDoc.get(ord);
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ if (acceptDocs == null) {
+ return null;
+ }
+ return new Bits() {
+ @Override
+ public boolean get(int index) {
+ return acceptDocs.get(ordToDoc(index));
+ }
+
+ @Override
+ public int length() {
+ return size;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return IndexedDISI.asDocIndexIterator(disi);
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ SparseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+ }
+
+ private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ EmptyOffHeapVectorValues(
+ int dimension,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer) {
+ super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return null;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
new file mode 100644
index 000000000000..92d9de567a20
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Lucene 10.2 file format.
+ *
+ * Apache Lucene - Index File Formats
+ *
+ *
+ *
+ * Introduction
+ *
+ *
+ *
+ *
This document defines the index file formats used in this version of Lucene. If you are using
+ * a different version of Lucene, please consult the copy of docs/
that was distributed
+ * with the version you are using.
+ *
+ *
This document attempts to provide a high-level definition of the Apache Lucene file formats.
+ *
+ *
+ * Definitions
+ *
+ *
+ *
+ *
The fundamental concepts in Lucene are index, document, field and term.
+ *
+ *
An index contains a sequence of documents.
+ *
+ *
+ * - A document is a sequence of fields.
+ *
- A field is a named sequence of terms.
+ *
- A term is a sequence of bytes.
+ *
+ *
+ *
The same sequence of bytes in two different fields is considered a different term. Thus terms
+ * are represented as a pair: the string naming the field, and the bytes within the field.
+ *
+ *
Inverted Indexing
+ *
+ *
Lucene's index stores terms and statistics about those terms in order to make term-based
+ * search more efficient. Lucene's terms index falls into the family of indexes known as an
+ * inverted index. This is because it can list, for a term, the documents that contain it.
+ * This is the inverse of the natural relationship, in which documents list terms.
+ *
+ *
Types of Fields
+ *
+ *
In Lucene, fields may be stored, in which case their text is stored in the index
+ * literally, in a non-inverted manner. Fields that are inverted are called indexed. A field
+ * may be both stored and indexed.
+ *
+ *
The text of a field may be tokenized into terms to be indexed, or the text of a field
+ * may be used literally as a term to be indexed. Most fields are tokenized, but sometimes it is
+ * useful for certain identifier fields to be indexed literally.
+ *
+ *
See the {@link org.apache.lucene.document.Field Field} java docs for more information on
+ * Fields.
+ *
+ *
Segments
+ *
+ *
Lucene indexes may be composed of multiple sub-indexes, or segments. Each segment is a
+ * fully independent index, which could be searched separately. Indexes evolve by:
+ *
+ *
+ * - Creating new segments for newly added documents.
+ *
- Merging existing segments.
+ *
+ *
+ *
Searches may involve multiple segments and/or multiple indexes, each index potentially
+ * composed of a set of segments.
+ *
+ *
Document Numbers
+ *
+ *
Internally, Lucene refers to documents by an integer document number. The first
+ * document added to an index is numbered zero, and each subsequent document added gets a number one
+ * greater than the previous.
+ *
+ *
Note that a document's number may change, so caution should be taken when storing these
+ * numbers outside of Lucene. In particular, numbers may change in the following situations:
+ *
+ *
+ * -
+ *
The numbers stored in each segment are unique only within the segment, and must be
+ * converted before they can be used in a larger context. The standard technique is to
+ * allocate each segment a range of values, based on the range of numbers used in that
+ * segment. To convert a document number from a segment to an external value, the segment's
+ * base document number is added. To convert an external value back to a
+ * segment-specific value, the segment is identified by the range that the external value is
+ * in, and the segment's base value is subtracted. For example two five document segments
+ * might be combined, so that the first segment has a base value of zero, and the second of
+ * five. Document three from the second segment would have an external value of eight.
+ *
-
+ *
When documents are deleted, gaps are created in the numbering. These are eventually
+ * removed as the index evolves through merging. Deleted documents are dropped when segments
+ * are merged. A freshly-merged segment thus has no gaps in its numbering.
+ *
+ *
+ *
+ *
+ * Index Structure Overview
+ *
+ *
+ *
+ *
Each segment index maintains the following:
+ *
+ *
+ * - {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment info}. This
+ * contains metadata about a segment, such as the number of documents, what files it uses, and
+ * information about how the segment is sorted
+ *
- {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
+ * contains metadata about the set of named fields used in the index.
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
+ * This contains, for each document, a list of attribute-value pairs, where the attributes are
+ * field names. These are used to store auxiliary information about the document, such as its
+ * title, url, or an identifier to access a database. The set of stored fields are what is
+ * returned for each hit when searching. This is keyed by document number.
+ *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term dictionary}. A
+ * dictionary containing all of the terms used in all of the indexed fields of all of the
+ * documents. The dictionary also contains the number of documents which contain the term, and
+ * pointers to the term's frequency and proximity data.
+ *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Frequency data}. For
+ * each term in the dictionary, the numbers of all the documents that contain that term, and
+ * the frequency of the term in that document, unless frequencies are omitted ({@link
+ * org.apache.lucene.index.IndexOptions#DOCS IndexOptions.DOCS})
+ *
- {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Proximity data}. For
+ * each term in the dictionary, the positions that the term occurs in each document. Note that
+ * this will not exist if all fields in all documents omit position data.
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Normalization factors}. For
+ * each field in each document, a value is stored that is multiplied into the score for hits
+ * on that field.
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vectors}. For each
+ * field in each document, the term vector (sometimes called document vector) may be stored. A
+ * term vector consists of term text and term frequency. To add Term Vectors to your index see
+ * the {@link org.apache.lucene.document.Field Field} constructors
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-document values}. Like
+ * stored values, these are also keyed by document number, but are generally intended to be
+ * loaded into main memory for fast access. Whereas stored values are generally intended for
+ * summary results from searches, per-document values are useful for things like scoring
+ * factors.
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live documents}. An
+ * optional file indicating which documents are live.
+ *
- {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values}. Optional pair
+ * of files, recording dimensionally indexed fields, to enable fast numeric range filtering
+ * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
+ * intersection (2D, 3D).
+ *
- {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The
+ * vector format stores numeric vectors in a format optimized for random access and
+ * computation, supporting high-dimensional nearest-neighbor search.
+ *
+ *
+ *
Details on each of these are provided in their linked pages.
+ *
+ * File Naming
+ *
+ *
+ *
+ *
All files belonging to a segment have the same name with varying extensions. The extensions
+ * correspond to the different file formats described below. When using the Compound File format
+ * (default for small segments) these files (except for the Segment info file, the Lock file, and
+ * Deleted documents file) are collapsed into a single .cfs file (see below for details)
+ *
+ *
Typically, all segments in an index are stored in a single directory, although this is not
+ * required.
+ *
+ *
File names are never re-used. That is, when any file is saved to the Directory it is given a
+ * never before used filename. This is achieved using a simple generations approach. For example,
+ * the first segments file is segments_1, then segments_2, etc. The generation is a sequential long
+ * integer represented in alpha-numeric (base 36) form.
+ *
+ * Summary of File Extensions
+ *
+ *
+ *
+ *
The following table summarizes the names and extensions of the files in Lucene:
+ *
+ *
+ * lucene filenames by extension
+ *
+ * Name |
+ * Extension |
+ * Brief Description |
+ *
+ *
+ * {@link org.apache.lucene.index.SegmentInfos Segments File} |
+ * segments_N |
+ * Stores information about a commit point |
+ *
+ *
+ * Lock File |
+ * write.lock |
+ * The Write lock prevents multiple IndexWriters from writing to the same
+ * file. |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene99.Lucene99SegmentInfoFormat Segment Info} |
+ * .si |
+ * Stores metadata about a segment |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat Compound File} |
+ * .cfs, .cfe |
+ * An optional "virtual" file consisting of all the other index files for
+ * systems that frequently run out of file handles. |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields} |
+ * .fnm |
+ * Stores information about the fields |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Index} |
+ * .fdx |
+ * Contains pointers to field data |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Field Data} |
+ * .fdt |
+ * The stored fields for documents |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Dictionary} |
+ * .tim |
+ * The term dictionary, stores term info |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Term Index} |
+ * .tip |
+ * The index into the Term Dictionary |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Frequencies} |
+ * .doc |
+ * Contains the list of docs which contain each term along with frequency |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Positions} |
+ * .pos |
+ * Stores position information about where a term occurs in the index |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat Payloads} |
+ * .pay |
+ * Stores additional per-position metadata information such as character offsets and user payloads |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90NormsFormat Norms} |
+ * .nvd, .nvm |
+ * Encodes length and boost factors for docs and fields |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat Per-Document Values} |
+ * .dvd, .dvm |
+ * Encodes additional scoring factors or other per-document information. |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Index} |
+ * .tvx |
+ * Stores offset into the document data file |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat Term Vector Data} |
+ * .tvd |
+ * Contains term vector data. |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat Live Documents} |
+ * .liv |
+ * Info about what documents are live |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene90.Lucene90PointsFormat Point values} |
+ * .kdd, .kdi, .kdm |
+ * Holds indexed points |
+ *
+ *
+ * {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values} |
+ * .vec, .vem, .veq, vex |
+ * Holds indexed vectors; .vec files contain the raw vector data,
+ * .vem the vector metadata, .veq the quantized vector data, and .vex the
+ * hnsw graph data. |
+ *
+ *
+ *
+ *
+ *
+ * Lock File
+ *
+ * The write lock, which is stored in the index directory by default, is named "write.lock". If the
+ * lock directory is different from the index directory then the write lock will be named
+ * "XXXX-write.lock" where XXXX is a unique prefix derived from the full path to the index
+ * directory. When this file is present, a writer is currently modifying the index (adding or
+ * removing documents). This lock file ensures that only one writer is modifying the index at a
+ * time.
+ *
+ * History
+ *
+ * Compatibility notes are provided in this document, describing how file formats have changed
+ * from prior versions:
+ *
+ *
+ * - In version 2.1, the file format was changed to allow lock-less commits (ie, no more commit
+ * lock). The change is fully backwards compatible: you can open a pre-2.1 index for searching
+ * or adding/deleting of docs. When the new segments file is saved (committed), it will be
+ * written in the new file format (meaning no specific "upgrade" process is needed). But note
+ * that once a commit has occurred, pre-2.1 Lucene will not be able to read the index.
+ *
- In version 2.3, the file format was changed to allow segments to share a single set of doc
+ * store (vectors & stored fields) files. This allows for faster indexing in certain
+ * cases. The change is fully backwards compatible (in the same way as the lock-less commits
+ * change in 2.1).
+ *
- In version 2.4, Strings are now written as true UTF-8 byte sequence, not Java's modified
+ * UTF-8. See LUCENE-510 for
+ * details.
+ *
- In version 2.9, an optional opaque Map<String,String> CommitUserData may be passed to
+ * IndexWriter's commit methods (and later retrieved), which is recorded in the segments_N
+ * file. See LUCENE-1382 for
+ * details. Also, diagnostics were added to each segment written recording details about why
+ * it was written (due to flush, merge; which OS/JRE was used; etc.). See issue LUCENE-1654 for details.
+ *
- In version 3.0, compressed fields are no longer written to the index (they can still be
+ * read, but on merge the new segment will write them, uncompressed). See issue LUCENE-1960 for details.
+ *
- In version 3.1, segments records the code version that created them. See LUCENE-2720 for details.
+ * Additionally segments track explicitly whether or not they have term vectors. See LUCENE-2811 for details.
+ *
- In version 3.2, numeric fields are written as natively to stored fields file, previously
+ * they were stored in text format only.
+ *
- In version 3.4, fields can omit position data while still indexing term frequencies.
+ *
- In version 4.0, the format of the inverted index became extensible via the {@link
+ * org.apache.lucene.codecs.Codec Codec} api. Fast per-document storage ({@code DocValues})
+ * was introduced. Normalization factors need no longer be a single byte, they can be any
+ * {@link org.apache.lucene.index.NumericDocValues NumericDocValues}. Terms need not be
+ * unicode strings, they can be any byte sequence. Term offsets can optionally be indexed into
+ * the postings lists. Payloads can be stored in the term vectors.
+ *
- In version 4.1, the format of the postings list changed to use either of FOR compression or
+ * variable-byte encoding, depending upon the frequency of the term. Terms appearing only once
+ * were changed to inline directly into the term dictionary. Stored fields are compressed by
+ * default.
+ *
- In version 4.2, term vectors are compressed by default. DocValues has a new multi-valued
+ * type (SortedSet), that can be used for faceting/grouping/joining on multi-valued fields.
+ *
- In version 4.5, DocValues were extended to explicitly represent missing values.
+ *
- In version 4.6, FieldInfos were extended to support per-field DocValues generation, to
+ * allow updating NumericDocValues fields.
+ *
- In version 4.8, checksum footers were added to the end of each index file for improved data
+ * integrity. Specifically, the last 8 bytes of every index file contain the zlib-crc32
+ * checksum of the file.
+ *
- In version 4.9, DocValues has a new multi-valued numeric type (SortedNumeric) that is
+ * suitable for faceting/sorting/analytics.
+ *
- In version 5.4, DocValues have been improved to store more information on disk: addresses
+ * for binary fields and ord indexes for multi-valued fields.
+ *
- In version 6.0, Points were added, for multi-dimensional range/distance search.
+ *
- In version 6.2, new Segment info format that reads/writes the index sort, to support index
+ * sorting.
+ *
- In version 7.0, DocValues have been improved to better support sparse doc values thanks to
+ * an iterator API.
+ *
- In version 8.0, postings have been enhanced to record, for each block of doc ids, the (term
+ * freq, normalization factor) pairs that may trigger the maximum score of the block. This
+ * information is recorded alongside skip data in order to be able to skip blocks of doc ids
+ * if they may not produce high enough scores. Additionally doc values and norms has been
+ * extended with jump-tables to make access O(1) instead of O(n), where n is the number of
+ * elements to skip when advancing in the data.
+ *
- In version 8.4, postings, positions, offsets and payload lengths have move to a more
+ * performant encoding that is vectorized.
+ *
- In version 8.6, index sort serialization is delegated to the sorts themselves, to allow
+ * user-defined sorts to be used
+ *
- In version 8.6, points fields split the index tree and leaf data into separate files, to
+ * allow for different access patterns to the different data structures
+ *
- In version 8.7, stored fields compression became adaptive to better handle documents with
+ * smaller stored fields.
+ *
- In version 9.0, vector-valued fields were added.
+ *
- In version 9.1, vector-valued fields were modified to add a graph hierarchy.
+ *
- In version 9.2, docs of vector-valued fields were moved from .vem to .vec and encoded by
+ * IndexDISI. ordToDoc mappings was added to .vem.
+ *
- In version 9.5, HNSW graph connections were changed to be delta-encoded with vints.
+ * Additionally, metadata file size improvements were made by delta-encoding nodes by graph
+ * layer and not writing the node ids for the zeroth layer.
+ *
- In version 9.9, Vector scalar quantization support was added. Allowing the HNSW vector
+ * format to utilize int8 quantized vectors for float32 vector search.
+ *
- In version 9.12, skip data was refactored to have only two levels: every 128 docs and every
+ * 4,06 docs, and to be inlined in postings lists. This resulted in a speedup for queries that
+ * need skipping, especially conjunctions.
+ *
- In version 10.1, block encoding changed to be optimized for int[] storage instead of
+ * long[].
+ *
- In version 10.2, new vector asymmetric binary quantization that for HNSW and flat formats.
+ *
+ *
+ *
+ *
+ * Limitations
+ *
+ *
+ *
+ *
Lucene uses a Java int
to refer to document numbers, and the index file format
+ * uses an Int32
on-disk to store document numbers. This is a limitation of both the
+ * index file format and the current implementation. Eventually these should be replaced with either
+ * UInt64
values, or better yet, {@link org.apache.lucene.store.DataOutput#writeVInt
+ * VInt} values which have no limit.
+ */
+package org.apache.lucene.codecs.lucene102;
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
index 184403cf48b7..d118b58f321f 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
@@ -17,6 +17,7 @@
package org.apache.lucene.internal.vectorization;
+import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;
@@ -207,4 +208,30 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) {
}
return to;
}
+
+ @Override
+ public long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized) {
+ return int4BitDotProductImpl(int4Quantized, binaryQuantized);
+ }
+
+ public static long int4BitDotProductImpl(byte[] q, byte[] d) {
+ assert q.length == d.length * 4;
+ long ret = 0;
+ int size = d.length;
+ for (int i = 0; i < 4; i++) {
+ int r = 0;
+ long subRet = 0;
+ for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
+ subRet +=
+ Integer.bitCount(
+ (int) BitUtil.VH_NATIVE_INT.get(q, i * size + r)
+ & (int) BitUtil.VH_NATIVE_INT.get(d, r));
+ }
+ for (; r < d.length; r++) {
+ subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF);
+ }
+ ret += subRet << i;
+ }
+ return ret;
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
index fb94b0e31736..82947d003434 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
@@ -52,4 +52,17 @@ public interface VectorUtilSupport {
* to} is returned.
*/
int findNextGEQ(int[] buffer, int target, int from, int to);
+
+ /**
+ * Compute the dot product between a quantized int4 vector and a binary quantized vector. It is
+ * assumed that the int4 quantized bits are packed in the byte array in the same way as the {@link
+ * org.apache.lucene.util.quantization.OptimizedScalarQuantizer#transposeHalfByte(byte[], byte[])}
+ * and that the binary bits are packed the same way as {@link
+ * org.apache.lucene.util.quantization.OptimizedScalarQuantizer#packAsBinary(byte[], byte[])}.
+ *
+ * @param int4Quantized half byte packed int4 quantized vector
+ * @param binaryQuantized byte packed binary quantized vector
+ * @return the dot product
+ */
+ long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized);
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index 250c65448703..75998c2f7fbb 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -213,6 +213,21 @@ public static int int4DotProductPacked(byte[] unpacked, byte[] packed) {
return IMPL.int4DotProduct(unpacked, false, packed, true);
}
+ /**
+ * Dot product computed over int4 (values between [0,15]) bytes and a binary vector.
+ *
+ * @param q the int4 query vector
+ * @param d the binary document vector
+ * @return the dot product
+ */
+ public static long int4BitDotProduct(byte[] q, byte[] d) {
+ if (q.length != d.length * 4) {
+ throw new IllegalArgumentException(
+ "vector dimensions incompatible: " + q.length + "!= " + 4 + " x " + d.length);
+ }
+ return IMPL.int4BitDotProduct(q, d);
+ }
+
/**
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
new file mode 100644
index 000000000000..b93f2df1ea79
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.quantization;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.VectorUtil;
+
+/**
+ * OptimizedScalarQuantizer is a scalar quantizer that optimizes the quantization intervals for a
+ * given vector. This is done by optimizing the quantiles of the vector centered on a provided
+ * centroid. The optimization is done by minimizing the quantization loss via coordinate descent.
+ *
+ * @lucene.experimental
+ */
+public class OptimizedScalarQuantizer {
+ // The initial interval is set to the minimum MSE grid for each number of bits
+ // these starting points are derived from the optimal MSE grid for a uniform distribution
+ static final float[][] MINIMUM_MSE_GRID =
+ new float[][] {
+ {-0.798f, 0.798f},
+ {-1.493f, 1.493f},
+ {-2.051f, 2.051f},
+ {-2.514f, 2.514f},
+ {-2.916f, 2.916f},
+ {-3.278f, 3.278f},
+ {-3.611f, 3.611f},
+ {-3.922f, 3.922f}
+ };
+ private static final float DEFAULT_LAMBDA = 0.1f;
+ private static final int DEFAULT_ITERS = 5;
+ private final VectorSimilarityFunction similarityFunction;
+ private final float lambda;
+ private final int iters;
+
+ /**
+ * Create a new scalar quantizer with the given similarity function, lambda, and number of
+ * iterations.
+ *
+ * @param similarityFunction similarity function to use
+ * @param lambda lambda value to use
+ * @param iters number of iterations to use
+ */
+ public OptimizedScalarQuantizer(
+ VectorSimilarityFunction similarityFunction, float lambda, int iters) {
+ this.similarityFunction = similarityFunction;
+ this.lambda = lambda;
+ this.iters = iters;
+ }
+
+ /**
+ * Create a new scalar quantizer with the default lambda and number of iterations.
+ *
+ * @param similarityFunction similarity function to use
+ */
+ public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
+ this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
+ }
+
+ /**
+ * Quantization result containing the lower and upper interval bounds, the additional correction
+ *
+ * @param lowerInterval the lower interval bound
+ * @param upperInterval the upper interval bound
+ * @param additionalCorrection the additional correction
+ * @param quantizedComponentSum the sum of the quantized components
+ */
+ public record QuantizationResult(
+ float lowerInterval,
+ float upperInterval,
+ float additionalCorrection,
+ int quantizedComponentSum) {}
+
+ /**
+ * Quantize the vector to the multiple bit levels.
+ *
+ * @param vector raw vector
+ * @param destinations array of destinations to store the quantized vector
+ * @param bits array of bits to quantize the vector
+ * @param centroid centroid to center the vector
+ * @return array of quantization results
+ */
+ public QuantizationResult[] multiScalarQuantize(
+ float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert bits.length == destinations.length;
+ float[] intervalScratch = new float[2];
+ double vecMean = 0;
+ double vecVar = 0;
+ float norm2 = 0;
+ float centroidDot = 0;
+ float min = Float.MAX_VALUE;
+ float max = -Float.MAX_VALUE;
+ for (int i = 0; i < vector.length; ++i) {
+ if (similarityFunction != EUCLIDEAN) {
+ centroidDot += vector[i] * centroid[i];
+ }
+ vector[i] = vector[i] - centroid[i];
+ min = Math.min(min, vector[i]);
+ max = Math.max(max, vector[i]);
+ norm2 += (vector[i] * vector[i]);
+ double delta = vector[i] - vecMean;
+ vecMean += delta / (i + 1);
+ vecVar += delta * (vector[i] - vecMean);
+ }
+ vecVar /= vector.length;
+ double vecStd = Math.sqrt(vecVar);
+ QuantizationResult[] results = new QuantizationResult[bits.length];
+ for (int i = 0; i < bits.length; ++i) {
+ assert bits[i] > 0 && bits[i] <= 8;
+ int points = (1 << bits[i]);
+ // Linearly scale the interval to the standard deviation of the vector, ensuring we are within
+ // the min/max bounds
+ intervalScratch[0] =
+ (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max);
+ intervalScratch[1] =
+ (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max);
+ optimizeIntervals(intervalScratch, vector, norm2, points);
+ float nSteps = ((1 << bits[i]) - 1);
+ float a = intervalScratch[0];
+ float b = intervalScratch[1];
+ float step = (b - a) / nSteps;
+ int sumQuery = 0;
+ // Now we have the optimized intervals, quantize the vector
+ for (int h = 0; h < vector.length; h++) {
+ float xi = (float) clamp(vector[h], a, b);
+ int assignment = Math.round((xi - a) / step);
+ sumQuery += assignment;
+ destinations[i][h] = (byte) assignment;
+ }
+ results[i] =
+ new QuantizationResult(
+ intervalScratch[0],
+ intervalScratch[1],
+ similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
+ sumQuery);
+ }
+ return results;
+ }
+
+ /**
+ * Quantize the vector to the given bit level.
+ *
+ * @param vector raw vector
+ * @param destination destination to store the quantized vector
+ * @param bits number of bits to quantize the vector
+ * @param centroid centroid to center the vector
+ * @return quantization result
+ */
+ public QuantizationResult scalarQuantize(
+ float[] vector, byte[] destination, byte bits, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert vector.length <= destination.length;
+ assert bits > 0 && bits <= 8;
+ float[] intervalScratch = new float[2];
+ int points = 1 << bits;
+ double vecMean = 0;
+ double vecVar = 0;
+ float norm2 = 0;
+ float centroidDot = 0;
+ float min = Float.MAX_VALUE;
+ float max = -Float.MAX_VALUE;
+ for (int i = 0; i < vector.length; ++i) {
+ if (similarityFunction != EUCLIDEAN) {
+ centroidDot += vector[i] * centroid[i];
+ }
+ vector[i] = vector[i] - centroid[i];
+ min = Math.min(min, vector[i]);
+ max = Math.max(max, vector[i]);
+ norm2 += (vector[i] * vector[i]);
+ double delta = vector[i] - vecMean;
+ vecMean += delta / (i + 1);
+ vecVar += delta * (vector[i] - vecMean);
+ }
+ vecVar /= vector.length;
+ double vecStd = Math.sqrt(vecVar);
+ // Linearly scale the interval to the standard deviation of the vector, ensuring we are within
+ // the min/max bounds
+ intervalScratch[0] =
+ (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max);
+ intervalScratch[1] =
+ (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max);
+ optimizeIntervals(intervalScratch, vector, norm2, points);
+ float nSteps = ((1 << bits) - 1);
+ // Now we have the optimized intervals, quantize the vector
+ float a = intervalScratch[0];
+ float b = intervalScratch[1];
+ float step = (b - a) / nSteps;
+ int sumQuery = 0;
+ for (int h = 0; h < vector.length; h++) {
+ float xi = (float) clamp(vector[h], a, b);
+ int assignment = Math.round((xi - a) / step);
+ sumQuery += assignment;
+ destination[h] = (byte) assignment;
+ }
+ return new QuantizationResult(
+ intervalScratch[0],
+ intervalScratch[1],
+ similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
+ sumQuery);
+ }
+
+ /**
+ * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a
+ * dequantized vector with the raw vector.
+ *
+ * @param vector raw vector
+ * @param interval interval to quantize the vector
+ * @param points number of quantization points
+ * @param norm2 squared norm of the vector
+ * @return the loss
+ */
+ private double loss(float[] vector, float[] interval, int points, float norm2) {
+ double a = interval[0];
+ double b = interval[1];
+ double step = ((b - a) / (points - 1.0F));
+ double stepInv = 1.0 / step;
+ double xe = 0.0;
+ double e = 0.0;
+ for (double xi : vector) {
+ // this is quantizing and then dequantizing the vector
+ double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv));
+ // how much does the de-quantized value differ from the original value
+ xe += xi * (xi - xiq);
+ e += (xi - xiq) * (xi - xiq);
+ }
+ return (1.0 - lambda) * xe * xe / norm2 + lambda * e;
+ }
+
+ /**
+ * Optimize the quantization interval for the given vector. This is done via a coordinate descent
+ * trying to minimize the quantization loss. Note, the loss is not always guaranteed to decrease,
+ * so we have a maximum number of iterations and will exit early if the loss increases.
+ *
+ * @param initInterval initial interval, the optimized interval will be stored here
+ * @param vector raw vector
+ * @param norm2 squared norm of the vector
+ * @param points number of quantization points
+ */
+ private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
+ double initialLoss = loss(vector, initInterval, points, norm2);
+ final float scale = (1.0f - lambda) / norm2;
+ if (Float.isFinite(scale) == false) {
+ return;
+ }
+ for (int i = 0; i < iters; ++i) {
+ float a = initInterval[0];
+ float b = initInterval[1];
+ float stepInv = (points - 1.0f) / (b - a);
+ // calculate the grid points for coordinate descent
+ double daa = 0;
+ double dab = 0;
+ double dbb = 0;
+ double dax = 0;
+ double dbx = 0;
+ for (float xi : vector) {
+ float k = Math.round((clamp(xi, a, b) - a) * stepInv);
+ float s = k / (points - 1);
+ daa += (1.0 - s) * (1.0 - s);
+ dab += (1.0 - s) * s;
+ dbb += s * s;
+ dax += xi * (1.0 - s);
+ dbx += xi * s;
+ }
+ double m0 = scale * dax * dax + lambda * daa;
+ double m1 = scale * dax * dbx + lambda * dab;
+ double m2 = scale * dbx * dbx + lambda * dbb;
+ // its possible that the determinant is 0, in which case we can't update the interval
+ double det = m0 * m2 - m1 * m1;
+ if (det == 0) {
+ return;
+ }
+ float aOpt = (float) ((m2 * dax - m1 * dbx) / det);
+ float bOpt = (float) ((m0 * dbx - m1 * dax) / det);
+ // If there is no change in the interval, we can stop
+ if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
+ return;
+ }
+ double newLoss = loss(vector, new float[] {aOpt, bOpt}, points, norm2);
+ // If the new loss is worse, don't update the interval and exit
+ // This optimization, unlike kMeans, does not always converge to better loss
+ // So exit if we are getting worse
+ if (newLoss > initialLoss) {
+ return;
+ }
+ // Update the interval and go again
+ initInterval[0] = aOpt;
+ initInterval[1] = bOpt;
+ initialLoss = newLoss;
+ }
+ }
+
+ public static int discretize(int value, int bucket) {
+ return ((value + (bucket - 1)) / bucket) * bucket;
+ }
+
+ /**
+ * Transpose the query vector into a byte array allowing for efficient bitwise operations with the
+ * index bit vectors. The idea here is to organize the query vector bits such that the first bit
+ * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
+ * third, and fourth bits are in the second, third, and fourth set of dimensions bits,
+ * respectively. This allows for direct bitwise comparisons with the stored index vectors through
+ * summing the bitwise results with the relative required bit shifts.
+ *
+ * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
+ * @param quantQueryByte the byte array to store the transposed query vector
+ */
+ public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) {
+ for (int i = 0; i < q.length; ) {
+ assert q[i] >= 0 && q[i] <= 15;
+ int lowerByte = 0;
+ int lowerMiddleByte = 0;
+ int upperMiddleByte = 0;
+ int upperByte = 0;
+ for (int j = 7; j >= 0 && i < q.length; j--) {
+ lowerByte |= (q[i] & 1) << j;
+ lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+ upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+ upperByte |= ((q[i] >> 3) & 1) << j;
+ i++;
+ }
+ int index = ((i + 7) / 8) - 1;
+ quantQueryByte[index] = (byte) lowerByte;
+ quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+ quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+ quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+ }
+ }
+
+ /**
+ * Pack the vector as a binary array.
+ *
+ * @param vector the vector to pack
+ * @param packed the packed vector
+ */
+ public static void packAsBinary(byte[] vector, byte[] packed) {
+ for (int i = 0; i < vector.length; ) {
+ byte result = 0;
+ for (int j = 7; j >= 0 && i < vector.length; j--) {
+ assert vector[i] == 0 || vector[i] == 1;
+ result |= (byte) ((vector[i] & 1) << j);
+ ++i;
+ }
+ int index = ((i + 7) / 8) - 1;
+ assert index < packed.length;
+ packed[index] = result;
+ }
+ }
+
+ private static double clamp(double x, double a, double b) {
+ return Math.min(Math.max(x, a), b);
+ }
+}
diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
index 9273f7c5a813..37671b63e9f6 100644
--- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
+++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
@@ -29,6 +29,7 @@
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
@@ -58,6 +59,8 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
PanamaVectorConstants.PRERERRED_INT_SPECIES;
private static final VectorSpecies BYTE_SPECIES;
private static final VectorSpecies SHORT_SPECIES;
+ private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+ private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256;
static final int VECTOR_BITSIZE;
@@ -786,4 +789,118 @@ public int findNextGEQ(int[] buffer, int target, int from, int to) {
}
return to;
}
+
+ @Override
+ public long int4BitDotProduct(byte[] q, byte[] d) {
+ assert q.length == d.length * 4;
+ // 128 / 8 == 16
+ if (d.length >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
+ if (VECTOR_BITSIZE >= 256) {
+ return int4BitDotProduct256(q, d);
+ } else if (VECTOR_BITSIZE == 128) {
+ return int4BitDotProduct128(q, d);
+ }
+ }
+ return DefaultVectorUtilSupport.int4BitDotProductImpl(q, d);
+ }
+
+ static long int4BitDotProduct256(byte[] q, byte[] d) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+
+ if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+ int limit = ByteVector.SPECIES_256.loopBound(d.length);
+ var sum0 = LongVector.zero(LongVector.SPECIES_256);
+ var sum1 = LongVector.zero(LongVector.SPECIES_256);
+ var sum2 = LongVector.zero(LongVector.SPECIES_256);
+ var sum3 = LongVector.zero(LongVector.SPECIES_256);
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs();
+ var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs();
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+
+ if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+ var sum0 = LongVector.zero(LongVector.SPECIES_128);
+ var sum1 = LongVector.zero(LongVector.SPECIES_128);
+ var sum2 = LongVector.zero(LongVector.SPECIES_128);
+ var sum3 = LongVector.zero(LongVector.SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(d.length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs();
+ var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs();
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+ // tail as bytes
+ for (; i < d.length; i++) {
+ subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF);
+ subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF);
+ subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF);
+ subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
+ public static long int4BitDotProduct128(byte[] q, byte[] d) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+
+ var sum0 = IntVector.zero(IntVector.SPECIES_128);
+ var sum1 = IntVector.zero(IntVector.SPECIES_128);
+ var sum2 = IntVector.zero(IntVector.SPECIES_128);
+ var sum3 = IntVector.zero(IntVector.SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(d.length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts();
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts();
+ sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ // tail as bytes
+ for (; i < d.length; i++) {
+ int dValue = d[i];
+ subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+ subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF);
+ subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF);
+ subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
}
diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
index cb5fee62aeec..0558fc8fef05 100644
--- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
+++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
@@ -16,3 +16,5 @@
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..668b7eee822d
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static java.lang.String.format;
+import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.io.IOException;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene101.Lucene101Codec;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+public class TestLucene102BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene101Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene102BinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testSearch() throws Exception {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, iwc)) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ }
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ final int k = random().nextInt(5, 50);
+ float[] queryVector = randomVector(dims);
+ Query q = new KnnFloatVectorQuery(fieldName, queryVector, k);
+ TopDocs collectedDocs = searcher.search(q, k);
+ assertEquals(k, collectedDocs.totalHits.value());
+ assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation());
+ }
+ }
+ }
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene102BinaryQuantizedVectorsFormat();
+ }
+ };
+ String expectedPattern =
+ "Lucene102BinaryQuantizedVectorsFormat("
+ + "name=Lucene102BinaryQuantizedVectorsFormat, "
+ + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))";
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ @Override
+ public void testRandomWithUpdatesAndGraph() {
+ // graph not supported
+ }
+
+ @Override
+ public void testSearchWithVisitedLimit() {
+ // visited limit is not respected, as it is brute force search
+ }
+
+ public void testQuantizedVectorsWriteAndRead() throws IOException {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ if (i % 101 == 0) {
+ w.commit();
+ }
+ }
+ w.commit();
+ w.forceMerge(1);
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
+ assertEquals(vectorValues.size(), numVectors);
+ BinarizedByteVectorValues qvectorValues =
+ ((Lucene102BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues)
+ .getQuantizedVectorValues();
+ float[] centroid = qvectorValues.getCentroid();
+ assertEquals(centroid.length, dims);
+
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+ byte[] quantizedVector = new byte[dims];
+ byte[] expectedVector = new byte[discretize(dims, 64) / 8];
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ vectorValues =
+ new Lucene102BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);
+ }
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ OptimizedScalarQuantizer.QuantizationResult corrections =
+ quantizer.scalarQuantize(
+ vectorValues.vectorValue(docIndexIterator.index()),
+ quantizedVector,
+ INDEX_BITS,
+ centroid);
+ packAsBinary(quantizedVector, expectedVector);
+ assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index()));
+ assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index()));
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..b23804b1840e
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene102;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.util.Arrays;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene101.Lucene101Codec;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.util.SameThreadExecutorService;
+
+public class TestLucene102HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene101Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene102HnswBinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene102HnswBinaryQuantizedVectorsFormat(10, 20, 1, null);
+ }
+ };
+ String expectedPattern =
+ "Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ + " flatVectorFormat=Lucene102BinaryQuantizedVectorsFormat(name=Lucene102BinaryQuantizedVectorsFormat,"
+ + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))";
+
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ public void testSingleVectorCase() throws Exception {
+ float[] vector = randomVector(random().nextInt(12, 500));
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
+ w.addDocument(doc);
+ w.commit();
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues("f");
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+ assert (vectorValues.size() == 1);
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
+ }
+ float[] randomVector = randomVector(vector.length);
+ float trueScore = similarityFunction.compare(vector, randomVector);
+ TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE);
+ assertEquals(1, td.totalHits.value());
+ assertTrue(td.scoreDocs[0].score >= 0);
+ // When it's the only vector in a segment, the score should be very close to the true
+ // score
+ assertEquals(trueScore, td.scoreDocs[0].score, 0.01f);
+ }
+ }
+ }
+ }
+
+ public void testLimits() {
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene102HnswBinaryQuantizedVectorsFormat(-1, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene102HnswBinaryQuantizedVectorsFormat(0, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, 0));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, -1));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene102HnswBinaryQuantizedVectorsFormat(512 + 1, 20));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene102HnswBinaryQuantizedVectorsFormat(20, 3201));
+ expectThrows(
+ IllegalArgumentException.class,
+ () ->
+ new Lucene102HnswBinaryQuantizedVectorsFormat(
+ 20, 100, 1, new SameThreadExecutorService()));
+ }
+
+ // Ensures that all expected vector similarity functions are translatable in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+ assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java
new file mode 100644
index 000000000000..63ac99d48d28
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.quantization;
+
+import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat;
+import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
+import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+
+public class TestOptimizedScalarQuantizer extends LuceneTestCase {
+ static final byte[] ALL_BITS = new byte[] {1, 2, 3, 4, 5, 6, 7, 8};
+
+ public void testAbusiveEdgeCases() {
+ // large zero array
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ continue;
+ }
+ float[] vector = new float[4096];
+ float[] centroid = new float[4096];
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
+ OptimizedScalarQuantizer.QuantizationResult[] results =
+ osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (byte[] destination : destinations) {
+ assertArrayEquals(new byte[4096], destination);
+ }
+ byte[] destination = new byte[4096];
+ for (byte bit : ALL_BITS) {
+ OptimizedScalarQuantizer.QuantizationResult result =
+ osq.scalarQuantize(vector, destination, bit, centroid);
+ assertValidResults(result);
+ assertArrayEquals(new byte[4096], destination);
+ }
+ }
+
+ // single value array
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ float[] vector = new float[] {randomFloat()};
+ float[] centroid = new float[] {randomFloat()};
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
+ OptimizedScalarQuantizer.QuantizationResult[] results =
+ osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (int i = 0; i < ALL_BITS.length; i++) {
+ assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
+ }
+ for (byte bit : ALL_BITS) {
+ vector = new float[] {randomFloat()};
+ centroid = new float[] {randomFloat()};
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+ byte[] destination = new byte[1];
+ OptimizedScalarQuantizer.QuantizationResult result =
+ osq.scalarQuantize(vector, destination, bit, centroid);
+ assertValidResults(result);
+ assertValidQuantizedRange(destination, bit);
+ }
+ }
+ }
+
+ public void testMathematicalConsistency() {
+ int dims = randomIntBetween(1, 4096);
+ float[] vector = new float[dims];
+ for (int i = 0; i < dims; ++i) {
+ vector[i] = randomFloat();
+ }
+ float[] centroid = new float[dims];
+ for (int i = 0; i < dims; ++i) {
+ centroid[i] = randomFloat();
+ }
+ float[] copy = new float[dims];
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ // copy the vector to avoid modifying it
+ System.arraycopy(vector, 0, copy, 0, dims);
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(copy);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
+ OptimizedScalarQuantizer.QuantizationResult[] results =
+ osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (int i = 0; i < ALL_BITS.length; i++) {
+ assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
+ }
+ for (byte bit : ALL_BITS) {
+ byte[] destination = new byte[dims];
+ System.arraycopy(vector, 0, copy, 0, dims);
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(copy);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer.QuantizationResult result =
+ osq.scalarQuantize(copy, destination, bit, centroid);
+ assertValidResults(result);
+ assertValidQuantizedRange(destination, bit);
+ }
+ }
+ }
+
+ static void assertValidQuantizedRange(byte[] quantized, byte bits) {
+ for (byte b : quantized) {
+ if (bits < 8) {
+ assertTrue(b >= 0);
+ }
+ assertTrue(b < 1 << bits);
+ }
+ }
+
+ static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) {
+ for (OptimizedScalarQuantizer.QuantizationResult result : results) {
+ assertTrue(Float.isFinite(result.lowerInterval()));
+ assertTrue(Float.isFinite(result.upperInterval()));
+ assertTrue(result.lowerInterval() <= result.upperInterval());
+ assertTrue(Float.isFinite(result.additionalCorrection()));
+ assertTrue(result.quantizedComponentSum() >= 0);
+ }
+ }
+}
From 62cd45ba541c9b9b0569d877ca67f994dee5406b Mon Sep 17 00:00:00 2001
From: ChrisHegarty
Date: Wed, 18 Dec 2024 15:32:21 +0000
Subject: [PATCH 2/4] test default and panama impls return the same result
---
.../vectorization/TestVectorUtilSupport.java | 28 +++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
index 7064955cb5f3..6443de752bfa 100644
--- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
+++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
@@ -20,6 +20,7 @@
import java.util.Arrays;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
+import java.util.function.ToLongFunction;
import java.util.stream.IntStream;
public class TestVectorUtilSupport extends BaseVectorizationTestCase {
@@ -133,6 +134,27 @@ public void testInt4DotProductBoundaries() {
PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true));
}
+ public void testInt4BitDotProduct() {
+ var binaryQuantized = new byte[size];
+ var int4Quantized = new byte[size * 4];
+ random().nextBytes(binaryQuantized);
+ random().nextBytes(int4Quantized);
+ assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized));
+ }
+
+ public void testInt4BitDotProductBoundaries() {
+ var binaryQuantized = new byte[size];
+ var int4Quantized = new byte[size * 4];
+
+ Arrays.fill(binaryQuantized, Byte.MAX_VALUE);
+ Arrays.fill(int4Quantized, Byte.MAX_VALUE);
+ assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized));
+
+ Arrays.fill(binaryQuantized, Byte.MIN_VALUE);
+ Arrays.fill(int4Quantized, Byte.MIN_VALUE);
+ assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized));
+ }
+
static byte[] pack(byte[] unpacked) {
int len = (unpacked.length + 1) / 2;
var packed = new byte[len];
@@ -154,4 +176,10 @@ private void assertIntReturningProviders(ToIntFunction func)
func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()),
func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport()));
}
+
+ private void assertLongReturningProviders(ToLongFunction func) {
+ assertEquals(
+ func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()),
+ func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport()));
+ }
}
From db10587ba16790ad120a3feda180581985b6f417 Mon Sep 17 00:00:00 2001
From: ChrisHegarty
Date: Wed, 18 Dec 2024 15:43:11 +0000
Subject: [PATCH 3/4] add tests for int4BitDotProdut
---
.../BaseVectorizationTestCase.java | 12 +-
.../apache/lucene/util/TestVectorUtil.java | 129 ++++++++++++++++++
2 files changed, 139 insertions(+), 2 deletions(-)
diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
index 34a0e5230022..ea6c183c7e7a 100644
--- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
@@ -21,8 +21,8 @@
public abstract class BaseVectorizationTestCase extends LuceneTestCase {
- protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider();
- protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true);
+ protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider();
+ protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider();
@BeforeClass
public static void beforeClass() throws Exception {
@@ -30,4 +30,12 @@ public static void beforeClass() throws Exception {
"Test only works when JDK's vector incubator module is enabled.",
PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass());
}
+
+ public static VectorizationProvider defaultProvider() {
+ return new DefaultVectorizationProvider();
+ }
+
+ public static VectorizationProvider maybePanamaProvider() {
+ return VectorizationProvider.lookup(true);
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index 6e449a550028..30a2f6811c6e 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -16,8 +16,13 @@
*/
package org.apache.lucene.util;
+import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween;
+
+import java.util.Arrays;
import java.util.Random;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase;
+import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -384,4 +389,128 @@ private static int slowFindNextGEQ(int[] buffer, int length, int target, int fro
}
return length;
}
+
+ public void testInt4BitDotProductInvariants() {
+ int iterations = atLeast(10);
+ for (int i = 0; i < iterations; i++) {
+ int size = randomIntBetween(random(), 1, 10);
+ var d = new byte[size];
+ var q = new byte[size * 4 - 1];
+ expectThrows(IllegalArgumentException.class, () -> VectorUtil.int4BitDotProduct(q, d));
+ }
+ }
+
+ static final VectorizationProvider defaultedProvider =
+ BaseVectorizationTestCase.defaultProvider();
+ static final VectorizationProvider defOrPanamaProvider =
+ BaseVectorizationTestCase.maybePanamaProvider();
+
+ public void testBasicInt4BitDotProduct() {
+ testBasicInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
+ testBasicInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
+ testBasicInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
+ }
+
+ interface Int4BitDotProduct {
+ long apply(byte[] q, byte[] d);
+ }
+
+ void testBasicInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
+ assertEquals(15L, Int4BitDotProductFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1}));
+ assertEquals(
+ 30L, Int4BitDotProductFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2}));
+
+ var d = new byte[] {1, 2, 3};
+ var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
+ assert scalarInt4BitDotProduct(q, d) == 60L; // 4 + 8 + 16 + 32
+ assertEquals(60L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4};
+ q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4};
+ assert scalarInt4BitDotProduct(q, d) == 75L; // 5 + 10 + 20 + 40
+ assertEquals(75L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5};
+ q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
+ assert scalarInt4BitDotProduct(q, d) == 105L; // 7 + 14 + 28 + 56
+ assertEquals(105L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6};
+ q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
+ assert scalarInt4BitDotProduct(q, d) == 135L; // 9 + 18 + 36 + 72
+ assertEquals(135L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7
+ };
+ assert scalarInt4BitDotProduct(q, d) == 180L; // 12 + 24 + 48 + 96
+ assertEquals(180L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
+ 7, 8
+ };
+ assert scalarInt4BitDotProduct(q, d) == 195L; // 13 + 26 + 52 + 104
+ assertEquals(195L, Int4BitDotProductFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3,
+ 4, 5, 6, 7, 8, 9
+ };
+ assert scalarInt4BitDotProduct(q, d) == 225L; // 15 + 30 + 60 + 120
+ assertEquals(225L, Int4BitDotProductFunc.apply(q, d));
+ }
+
+ public void testInt4BitDotProduct() {
+ testInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
+ testInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
+ testInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
+ }
+
+ void testInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
+ int iterations = atLeast(50);
+ for (int i = 0; i < iterations; i++) {
+ int size = random().nextInt(5000);
+ var d = new byte[size];
+ var q = new byte[size * 4];
+ random().nextBytes(d);
+ random().nextBytes(q);
+ assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
+
+ Arrays.fill(d, Byte.MAX_VALUE);
+ Arrays.fill(q, Byte.MAX_VALUE);
+ assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
+
+ Arrays.fill(d, Byte.MIN_VALUE);
+ Arrays.fill(q, Byte.MIN_VALUE);
+ assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
+ }
+ }
+
+ static int scalarInt4BitDotProduct(byte[] q, byte[] d) {
+ int res = 0;
+ for (int i = 0; i < 4; i++) {
+ res += (popcount(q, i * d.length, d, d.length) << i);
+ }
+ return res;
+ }
+
+ public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
+ int res = 0;
+ for (int j = 0; j < length; j++) {
+ int value = (a[aOffset + j] & b[j]) & 0xFF;
+ for (int k = 0; k < Byte.SIZE; k++) {
+ if ((value & (1 << k)) != 0) {
+ ++res;
+ }
+ }
+ }
+ return res;
+ }
}
From b76d616a90131563964cd575a118e846e639659d Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Thu, 16 Jan 2025 13:48:24 -0500
Subject: [PATCH 4/4] fixing tests, addressing pr comments
---
...Lucene102BinaryQuantizedVectorsFormat.java | 13 ++++---
.../lucene/codecs/lucene102/package-info.java | 7 ++--
.../OptimizedScalarQuantizer.java | 34 +++++++++++++++++--
...Lucene102BinaryQuantizedVectorsFormat.java | 14 ++++----
...ne102HnswBinaryQuantizedVectorsFormat.java | 14 ++++----
5 files changed, 55 insertions(+), 27 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
index ae48a220b235..6850e2d09d6a 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsFormat.java
@@ -26,9 +26,12 @@
import org.apache.lucene.index.SegmentWriteState;
/**
- * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 Codec for
- * encoding/decoding binary quantized vectors The binary quantization format used here is a
- * per-vector optimized scalar quantization. Also see {@link
+ * The binary quantization format used here is a per-vector optimized scalar quantization. These
+ * ideas are evolutions of LVQ proposed in Similarity
+ * search in the blink of an eye with compressed indices by Cecilia Aguerrebere et al. and the
+ * previous work on globally optimized scalar
+ *
+ * The format is stored in two files: Also see {@link
* org.apache.lucene.util.quantization.OptimizedScalarQuantizer}. Some of key features are:
*
*
@@ -43,8 +46,6 @@
* single bit vectors can be done with bit arithmetic.
*
*
- * The format is stored in two files:
- *
* .veb (vector data) file
*
* Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's
@@ -130,6 +131,8 @@ public String toString() {
+ NAME
+ ", flatVectorScorer="
+ scorer
+ + ", rawVectorFormat="
+ + rawVectorFormat
+ ")";
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
index 92d9de567a20..8f6fcb2ef5bc 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/package-info.java
@@ -311,10 +311,11 @@
*
*
* {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values} |
- * .vec, .vem, .veq, vex |
+ * .vec, .vem, .veq, .vex, .veb, .vemb |
* Holds indexed vectors; .vec files contain the raw vector data,
- * .vem the vector metadata, .veq the quantized vector data, and .vex the
- * hnsw graph data. |
+ * .vem
the vector metadata, .veq
the quantized vector data, .vex
the
+ * hnsw graph data, .veb
the binary quantized vector data, and
+ * .vemb
the binary quantized vector metadata
*
*
*
diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
index b93f2df1ea79..7ae74e743993 100644
--- a/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
+++ b/lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizer.java
@@ -23,9 +23,19 @@
import org.apache.lucene.util.VectorUtil;
/**
- * OptimizedScalarQuantizer is a scalar quantizer that optimizes the quantization intervals for a
- * given vector. This is done by optimizing the quantiles of the vector centered on a provided
- * centroid. The optimization is done by minimizing the quantization loss via coordinate descent.
+ * This is a scalar quantizer that optimizes the quantization intervals for a given vector. This is
+ * done by optimizing the quantiles of the vector centered on a provided centroid. The optimization
+ * is done by minimizing the quantization loss via coordinate descent.
+ *
+ * Local vector quantization parameters was originally proposed with LVQ in Similarity search in the blink of an eye with compressed
+ * indices This technique builds on LVQ, but instead of taking the min/max values, a grid search
+ * over the centered vector is done to find the optimal quantization intervals, taking into account
+ * anisotropic loss.
+ *
+ *
Anisotropic loss is first discussed in depth by Accelerating Large-Scale Inference with Anisotropic
+ * Vector Quantization by Ruiqi Guo, et al.
*
* @lucene.experimental
*/
@@ -43,10 +53,20 @@ public class OptimizedScalarQuantizer {
{-3.611f, 3.611f},
{-3.922f, 3.922f}
};
+ // the default lambda value
private static final float DEFAULT_LAMBDA = 0.1f;
+ // the default optimization iterations allowed
private static final int DEFAULT_ITERS = 5;
private final VectorSimilarityFunction similarityFunction;
+ // This determines how much emphasis we place on quantization errors perpendicular to the
+ // embedding
+ // as opposed to parallel to it.
+ // The smaller the value the more we will allow the overall error to increase if it allows us to
+ // reduce error parallel to the vector.
+ // Parallel errors are important for nearest neighbor queries because the closest document vectors
+ // tend to be parallel to the query
private final float lambda;
+ // the number of iterations to optimize the quantization intervals
private final int iters;
/**
@@ -320,6 +340,14 @@ public static int discretize(int value, int bucket) {
* respectively. This allows for direct bitwise comparisons with the stored index vectors through
* summing the bitwise results with the relative required bit shifts.
*
+ *
This bit decomposition for fast bitwise SIMD operations was first proposed in:
+ *
+ *
+ * Gao, Jianyang, and Cheng Long. "RaBitQ: Quantizing High-
+ * Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search."
+ * Proceedings of the ACM on Management of Data 2, no. 3 (2024): 1-27.
+ *
+ *
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
* @param quantQueryByte the byte array to store the transposed query vector
*/
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
index 668b7eee822d..216e645b3010 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102BinaryQuantizedVectorsFormat.java
@@ -29,7 +29,6 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
-import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.DirectoryReader;
@@ -47,18 +46,16 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
public class TestLucene102BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+ private static final KnnVectorsFormat FORMAT = new Lucene102BinaryQuantizedVectorsFormat();
+
@Override
protected Codec getCodec() {
- return new Lucene101Codec() {
- @Override
- public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene102BinaryQuantizedVectorsFormat();
- }
- };
+ return TestUtil.alwaysKnnVectorsFormat(FORMAT);
}
public void testSearch() throws Exception {
@@ -103,7 +100,8 @@ public KnnVectorsFormat knnVectorsFormat() {
String expectedPattern =
"Lucene102BinaryQuantizedVectorsFormat("
+ "name=Lucene102BinaryQuantizedVectorsFormat, "
- + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))";
+ + "flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()), "
+ + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer =
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
index b23804b1840e..010c99eab41e 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java
@@ -26,7 +26,6 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
-import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
@@ -40,18 +39,16 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.SameThreadExecutorService;
public class TestLucene102HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+ private static final KnnVectorsFormat FORMAT = new Lucene102HnswBinaryQuantizedVectorsFormat();
+
@Override
protected Codec getCodec() {
- return new Lucene101Codec() {
- @Override
- public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene102HnswBinaryQuantizedVectorsFormat();
- }
- };
+ return TestUtil.alwaysKnnVectorsFormat(FORMAT);
}
public void testToString() {
@@ -65,7 +62,8 @@ public KnnVectorsFormat knnVectorsFormat() {
String expectedPattern =
"Lucene102HnswBinaryQuantizedVectorsFormat(name=Lucene102HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ " flatVectorFormat=Lucene102BinaryQuantizedVectorsFormat(name=Lucene102BinaryQuantizedVectorsFormat,"
- + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))";
+ + " flatVectorScorer=Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()),"
+ + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer =