Skip to content

Commit

Permalink
Ensure negative scores aren not returned from scalar quantization sco…
Browse files Browse the repository at this point in the history
…rer (#13356)

Depending on how we quantize and then scale, we can edge down below 0 for dotproduct scores.

This is exceptionally rare, I have only seen it in extreme circumstances in tests (with random data and low dimensionality).
  • Loading branch information
benwtrent committed May 13, 2024
1 parent 3a4e4e3 commit fd98698
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 4 deletions.
4 changes: 4 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ Bug Fixes

* GITHUB#13206: Subtract deleted file size from the cache size of NRTCachingDirectory. (Jean-François Boeuf)

* GITHUB#12966: Aggregation facets no longer assume that aggregation values are positive. (Stefan Vodita)

* GITHUB#13356: Ensure negative scores are not returned from scalar quantization scorer. (Ben Trent)

Build
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ static RandomVectorScorer fromVectorSimilarity(
case COSINE:
case DOT_PRODUCT:
return dotProductFactory(
targetBytes, offsetCorrection, sim, constMultiplier, values, f -> (1 + f) / 2);
targetBytes, offsetCorrection, constMultiplier, values, f -> Math.max((1 + f) / 2, 0f));
case MAXIMUM_INNER_PRODUCT:
return dotProductFactory(
targetBytes,
offsetCorrection,
sim,
constMultiplier,
values,
VectorUtil::scaleMaxInnerProductScore);
Expand All @@ -122,7 +121,6 @@ static RandomVectorScorer fromVectorSimilarity(
private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(
byte[] targetBytes,
float offsetCorrection,
VectorSimilarityFunction sim,
float constMultiplier,
RandomAccessQuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
Expand Down Expand Up @@ -187,6 +185,8 @@ public float score(int vectorOrdinal) throws IOException {
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
Expand Down Expand Up @@ -224,6 +224,8 @@ public float score(int vectorOrdinal) throws IOException {
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
Expand Down Expand Up @@ -255,6 +257,8 @@ public float score(int vectorOrdinal) throws IOException {
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ public DotProduct(float constMultiplier, ByteVectorComparator comparator) {
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return (1 + adjustedDistance) / 2;
return Math.max((1 + adjustedDistance) / 2, 0);
}
}

Expand All @@ -111,6 +113,8 @@ public MaximumInnerProduct(float constMultiplier, ByteVectorComparator comparato
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return scaleMaxInnerProductScore(adjustedDistance);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.SameThreadExecutorService;
Expand Down Expand Up @@ -78,6 +81,41 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
};
}

public void testQuantizationScoringEdgeCase() throws Exception {
float[][] vectors = new float[][] {{0.6f, 0.8f}, {0.8f, 0.6f}, {-0.6f, -0.8f}};
try (Directory dir = newDirectory();
IndexWriter w =
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
16, 100, 1, (byte) 7, false, 0.9f, null);
}
}))) {
for (float[] vector : vectors) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
}
w.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE);
r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null);
TopDocs topDocs = topKnnCollector.topDocs();
assertEquals(3, topDocs.totalHits.value);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
assertTrue(scoreDoc.score >= 0f);
}
}
}
}

public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

package org.apache.lucene.codecs.lucene99;

import static org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues.compressBytes;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -32,9 +37,14 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {

Expand All @@ -54,6 +64,95 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
};
}

public void testNonZeroScores() throws IOException {
for (int bits : new int[] {4, 7}) {
for (boolean compress : new boolean[] {true, false}) {
vectorNonZeroScoringTest(bits, compress);
}
}
}

private void vectorNonZeroScoringTest(int bits, boolean compress) throws IOException {
try (Directory dir = newDirectory()) {
// keep vecs `0` so dot product is `0`
byte[] vec1 = new byte[32];
byte[] vec2 = new byte[32];
if (compress && bits == 4) {
byte[] vec1Compressed = new byte[16];
byte[] vec2Compressed = new byte[16];
compressBytes(vec1, vec1Compressed);
compressBytes(vec2, vec2Compressed);
vec1 = vec1Compressed;
vec2 = vec2Compressed;
}
String fileName = getTestName() + "-32";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
// large negative offset to override any query score correction and
// ensure negative values that need to be snapped to `0`
var negativeOffset = floatToByteArray(-50f);
byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
out.writeBytes(bytes, 0, bytes.length);
}
ScalarQuantizer scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) bits);
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
Lucene99ScalarQuantizedVectorScorer scorer =
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
RandomAccessQuantizedByteVectorValues values =
new RandomAccessQuantizedByteVectorValues() {
@Override
public int dimension() {
return 32;
}

@Override
public int getVectorByteLength() {
return compress && bits == 4 ? 16 : 32;
}

@Override
public int size() {
return 2;
}

@Override
public byte[] vectorValue(int ord) {
return new byte[32];
}

@Override
public float getScoreCorrectionConstant(int ord) {
return -50;
}

@Override
public RandomAccessQuantizedByteVectorValues copy() throws IOException {
return this;
}

@Override
public IndexInput getSlice() {
return in;
}

@Override
public ScalarQuantizer getScalarQuantizer() {
return scalarQuantizer;
}
};
float[] queryVector = new float[32];
for (int i = 0; i < 32; i++) {
queryVector[i] = i * 0.1f;
}
for (VectorSimilarityFunction function : VectorSimilarityFunction.values()) {
RandomVectorScorer randomScorer =
scorer.getRandomVectorScorer(function, values, queryVector);
assertTrue(randomScorer.score(0) >= 0f);
assertTrue(randomScorer.score(1) >= 0f);
}
}
}
}

public void testScoringCompressedInt4() throws Exception {
vectorScoringTest(4, true);
}
Expand Down Expand Up @@ -153,4 +252,17 @@ private static void indexVectors(
writer.forceMerge(1);
}
}

private static byte[] floatToByteArray(float value) {
return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array();
}

private static byte[] concat(byte[]... arrays) throws IOException {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
for (var ba : arrays) {
baos.write(ba);
}
return baos.toByteArray();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@

public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {

public void testNonZeroScores() {
byte[][] quantized = new byte[2][32];
for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
float multiplier = random().nextFloat();
if (random().nextBoolean()) {
multiplier = -multiplier;
}
for (byte bits : new byte[] {4, 7}) {
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, multiplier, bits);
float negativeOffsetA = -(random().nextFloat() * (random().nextInt(10) + 1));
float negativeOffsetB = -(random().nextFloat() * (random().nextInt(10) + 1));
float score =
quantizedSimilarity.score(quantized[0], negativeOffsetA, quantized[1], negativeOffsetB);
assertTrue(score >= 0);
}
}
}

public void testToEuclidean() throws IOException {
int dims = 128;
int numVecs = 100;
Expand Down

0 comments on commit fd98698

Please sign in to comment.