Skip to content

Commit

Permalink
Deprecate COSINE VectorSimilarity function
Browse files Browse the repository at this point in the history
  • Loading branch information
Pulkitg64 committed Apr 15, 2024
1 parent 0345fca commit d500c89
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 367 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
* <ul>
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
* </ul>
* </ul>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.Version;

public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestBase {
Expand All @@ -98,8 +99,14 @@ public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestB
private static final int KNN_VECTOR_MIN_SUPPORTED_VERSION = LUCENE_9_0_0.major;
private static final String KNN_VECTOR_FIELD = "knn_field";
private static final FieldType KNN_VECTOR_FIELD_TYPE =
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE);
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.DOT_PRODUCT);
private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f};
private static final float[] NORMALIZED_KNN_VECTOR = new float[KNN_VECTOR.length];

static {
System.arraycopy(KNN_VECTOR, 0, NORMALIZED_KNN_VECTOR, 0, KNN_VECTOR.length);
VectorUtil.l2normalize(NORMALIZED_KNN_VECTOR);
}

/**
* A parameter constructor for {@link com.carrotsearch.randomizedtesting.RandomizedRunner}. See
Expand Down Expand Up @@ -235,6 +242,7 @@ static void addDoc(IndexWriter writer, int id) throws IOException {
doc.add(new Field("content6", "here is more content with aaa aaa aaa", customType4));

float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id};
VectorUtil.l2normalize(vector);
doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE));

// TODO:
Expand Down Expand Up @@ -479,6 +487,7 @@ public static void searchIndex(
assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension());
for (int doc = values.nextDoc(); doc != NO_MORE_DOCS; doc = values.nextDoc()) {
float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt};
VectorUtil.l2normalize(expectedVector);
assertArrayEquals(
"vectors do not match for doc=" + cnt, expectedVector, values.vectorValue(), 0);
cnt++;
Expand All @@ -488,7 +497,7 @@ public static void searchIndex(
assertEquals(DOCS_COUNT, cnt);

// test KNN search
ScoreDoc[] scoreDocs = assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
ScoreDoc[] scoreDocs = assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
for (int i = 0; i < scoreDocs.length; i++) {
int id = Integer.parseInt(storedFields.document(scoreDocs[i].doc).get("id"));
int expectedId = i < DELETED_ID ? i : i + 1;
Expand Down Expand Up @@ -559,14 +568,12 @@ public void changeIndexWithAdds(Random random, Directory dir, Version nameVersio

if (nameVersion.major >= KNN_VECTOR_MIN_SUPPORTED_VERSION) {
// make sure KNN search sees all hits (graph may not be used if k is big)
assertKNNSearch(searcher, KNN_VECTOR, 1000, 44, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, 44, "0");
// make sure KNN search using HNSW graph sees newly added docs
assertKNNSearch(
searcher,
new float[] {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * 44},
10,
10,
"44");
float[] normalizedQueryVector =
new float[] {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * 44};
VectorUtil.l2normalize(normalizedQueryVector);
assertKNNSearch(searcher, normalizedQueryVector, 10, 10, "44");
}
reader.close();

Expand All @@ -593,12 +600,10 @@ public void changeIndexWithAdds(Random random, Directory dir, Version nameVersio
// make sure KNN search sees all hits
assertKNNSearch(searcher, KNN_VECTOR, 1000, 44, "0");
// make sure KNN search using HNSW graph sees newly added docs
assertKNNSearch(
searcher,
new float[] {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * 44},
10,
10,
"44");
float[] normalizedQueryVector =
new float[] {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * 44};
VectorUtil.l2normalize(normalizedQueryVector);
assertKNNSearch(searcher, normalizedQueryVector, 10, 10, "44");
}
reader.close();
}
Expand All @@ -615,9 +620,9 @@ public void changeIndexNoAdds(Random random, Directory dir, Version nameVersion)

if (nameVersion.major >= KNN_VECTOR_MIN_SUPPORTED_VERSION) {
// make sure KNN search sees all hits
assertKNNSearch(searcher, KNN_VECTOR, 1000, 34, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, 34, "0");
// make sure KNN search using HNSW graph retrieves correct results
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
}
reader.close();

Expand All @@ -639,9 +644,9 @@ public void changeIndexNoAdds(Random random, Directory dir, Version nameVersion)
// make sure searching sees right # hits for KNN search
if (nameVersion.major >= KNN_VECTOR_MIN_SUPPORTED_VERSION) {
// make sure KNN search sees all hits
assertKNNSearch(searcher, KNN_VECTOR, 1000, 34, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, 34, "0");
// make sure KNN search using HNSW graph retrieves correct results
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
}
reader.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.Version;

public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTestBase {
Expand All @@ -50,8 +51,14 @@ public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTe
private static final String KNN_VECTOR_FIELD = "knn_field";
private static final int DOC_COUNT = 30;
private static final FieldType KNN_VECTOR_FIELD_TYPE =
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE);
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.DOT_PRODUCT);
private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f};
private static final float[] NORMALIZED_KNN_VECTOR = new float[KNN_VECTOR.length];

static {
System.arraycopy(KNN_VECTOR, 0, NORMALIZED_KNN_VECTOR, 0, KNN_VECTOR.length);
VectorUtil.l2normalize(NORMALIZED_KNN_VECTOR);
}

public TestInt8HnswBackwardsCompatibility(Version version, String pattern) {
super(version, pattern);
Expand Down Expand Up @@ -104,8 +111,8 @@ public void testInt8HnswIndexAndSearch() throws Exception {
writer.commit();
try (IndexReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT + 10, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, DOC_COUNT + 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
}
}
// This will confirm the docs are really sorted
Expand All @@ -127,14 +134,15 @@ protected void createIndex(Directory dir) throws IOException {
}
try (DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
}
}

private static Document knnDocument(int id) {
Document doc = new Document();
float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id};
VectorUtil.l2normalize(vector);
doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE));
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
return doc;
Expand All @@ -143,8 +151,8 @@ private static Document knnDocument(int id) {
public void testReadOldIndices() throws Exception {
try (DirectoryReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, NORMALIZED_KNN_VECTOR, 10, 10, "0");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@
* <ul>
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
* <li>3: MAXIMUM_INNER_PRODUCT similarity. ({@link
* <li>2: MAXIMUM_INNER_PRODUCT similarity. ({@link
* VectorSimilarityFunction#MAXIMUM_INNER_PRODUCT})
* </ul>
* </ul>
Expand Down Expand Up @@ -302,7 +301,6 @@ private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) th
List.of(
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.COSINE,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);

static VectorSimilarityFunction distOrdToFunc(byte i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
List.of(
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.COSINE,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);

public static VectorSimilarityFunction readSimilarityFunction(DataInput input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
Expand Down Expand Up @@ -321,14 +320,8 @@ private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (float[] v : fieldData.floatVectors) {
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}

for (float[] v : fieldData.floatVectors) {
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
if (compressedVector != null) {
Expand Down Expand Up @@ -393,14 +386,8 @@ private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) th
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal);
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
if (compressedVector != null) {
Expand Down Expand Up @@ -703,7 +690,6 @@ static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
private final byte bits;
private final boolean compress;
private final InfoStream infoStream;
private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY;
private float maxQuantile = Float.NEGATIVE_INFINITY;
private boolean finished;
Expand All @@ -721,7 +707,6 @@ static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.fieldInfo = fieldInfo;
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>();
this.infoStream = infoStream;
this.docsWithField = new DocsWithFieldSet();
Expand All @@ -736,7 +721,7 @@ void finish() throws IOException {
finished = true;
return;
}
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors);
ScalarQuantizer quantizer =
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
Expand Down Expand Up @@ -796,14 +781,10 @@ public float[] copyValue(float[] vectorValue) {

static class FloatVectorWrapper extends FloatVectorValues {
private final List<float[]> vectorList;
private final float[] copy;
private final boolean normalize;
protected int curDoc = -1;

FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
FloatVectorWrapper(List<float[]> vectorList) {
this.vectorList = vectorList;
this.copy = new float[vectorList.get(0).length];
this.normalize = normalize;
}

@Override
Expand All @@ -821,11 +802,6 @@ public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= vectorList.size()) {
throw new IOException("Current doc not set or too many iterations");
}
if (normalize) {
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
return copy;
}
return vectorList.get(curDoc);
}

Expand Down Expand Up @@ -977,7 +953,6 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
private final float[] normalizedVector;
private float offsetValue = 0f;

private final VectorSimilarityFunction vectorSimilarityFunction;
Expand All @@ -990,11 +965,6 @@ public QuantizedFloatVectorValues(
this.quantizer = quantizer;
this.quantizedVector = new byte[values.dimension()];
this.vectorSimilarityFunction = vectorSimilarityFunction;
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
this.normalizedVector = new float[values.dimension()];
} else {
this.normalizedVector = null;
}
}

@Override
Expand Down Expand Up @@ -1041,15 +1011,8 @@ public int advance(int target) throws IOException {
}

private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
offsetValue =
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
} else {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.lucene.index;

import static org.apache.lucene.util.VectorUtil.cosine;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.dotProductScore;
import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore;
Expand Down Expand Up @@ -61,24 +60,6 @@ public float compare(byte[] v1, byte[] v2) {
}
},

/**
* Cosine similarity. NOTE: the preferred way to perform cosine similarity is to normalize all
* vectors to unit length, and instead use {@link VectorSimilarityFunction#DOT_PRODUCT}. You
* should only use this function if you need to preserve the original vectors and cannot normalize
* them in advance. The similarity score is normalised to assure it is positive.
*/
COSINE {
@Override
public float compare(float[] v1, float[] v2) {
return Math.max((1 + cosine(v1, v2)) / 2, 0);
}

@Override
public float compare(byte[] v1, byte[] v2) {
return (1 + cosine(v1, v2)) / 2;
}
},

/**
* Maximum inner product. This is like {@link VectorSimilarityFunction#DOT_PRODUCT}, but does not
* require normalization of the inputs. Should be used when the embedding vectors store useful
Expand Down
Loading

0 comments on commit d500c89

Please sign in to comment.