Skip to content

Commit

Permalink
feat(langchain4j-milvus): MilvusEmbeddingStore supports configure req…
Browse files Browse the repository at this point in the history
…uired index parameters

Fix: [langchain4j#860](langchain4j#860)

Added support for configure common custom index parameters

BREAKING CHANGE: The constructor of MilvusEmbeddingStore now require a parameter of type IndexParam.
  • Loading branch information
Glarme committed Apr 12, 2024
1 parent 71c9ef3 commit d7f1b05
Show file tree
Hide file tree
Showing 15 changed files with 750 additions and 2 deletions.
@@ -1,5 +1,6 @@
package dev.langchain4j.store.embedding.milvus;

import dev.langchain4j.store.embedding.milvus.parameter.IndexParam;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.FlushResponse;
Expand Down Expand Up @@ -81,12 +82,14 @@ static void dropCollection(MilvusServiceClient milvusClient, String collectionNa
static void createIndex(MilvusServiceClient milvusClient,
String collectionName,
IndexType indexType,
IndexParam indexParam,
MetricType metricType) {

CreateIndexParam request = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(VECTOR_FIELD_NAME)
.withIndexType(indexType)
.withExtraParam(indexParam.toExtraParam())
.withMetricType(metricType)
.build();

Expand Down
Expand Up @@ -2,6 +2,7 @@

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.*;
import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest;
import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds;
Expand All @@ -17,10 +18,11 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.milvus.parameter.IndexParam;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.param.ConnectParam;
Expand Down Expand Up @@ -59,6 +61,7 @@ public MilvusEmbeddingStore(
String collectionName,
Integer dimension,
IndexType indexType,
IndexParam indexParam,
MetricType metricType,
String uri,
String token,
Expand Down Expand Up @@ -87,8 +90,17 @@ public MilvusEmbeddingStore(
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);

if (!hasCollection(milvusClient, this.collectionName)) {
indexType = getOrDefault(indexType, FLAT);
if (indexParam == null) {
if (IndexParam.isIndexParamNullable(indexType)) {
indexParam = IndexParam.EMPTY_INSTANCE;
}
}
ensureNotNull(indexParam, "IndexParam is required for indexType " + indexType);
ensureTrue(indexParam.support(indexType), String.format("IndexParam %s does not support IndexType %s", indexParam.getClass(), indexType));
// validate IndexParam before creating the collection to prevent exceptions caused by invalid indices
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
createIndex(milvusClient, this.collectionName, indexType, indexParam, this.metricType);
}

loadCollectionInMemory(milvusClient, collectionName);
Expand Down Expand Up @@ -185,6 +197,7 @@ public static class Builder {
private String collectionName;
private Integer dimension;
private IndexType indexType;
private IndexParam indexParam;
private MetricType metricType;
private String uri;
private String token;
Expand Down Expand Up @@ -245,6 +258,17 @@ public Builder indexType(IndexType indexType) {
return this;
}

/**
* This parameter is required except for indexType {@link IndexType#FLAT} and {@link IndexType#BIN_FLAT}.
*
* @param indexParam The parameters of the index.
* @return builder
*/
public Builder indexParam(IndexParam indexParam) {
this.indexParam = indexParam;
return this;
}

/**
* @param metricType The type of the metric used for similarity search.
* Default value: COSINE.
Expand Down Expand Up @@ -332,6 +356,7 @@ public MilvusEmbeddingStore build() {
collectionName,
dimension,
indexType,
indexParam,
metricType,
uri,
token,
Expand Down
@@ -0,0 +1,25 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* for more information, see <a href="https://milvus.io/docs/index.md#BIN_FLAT">Index#BIN_FLAT</a>
*/
public class BinFlatIndexParam extends IndexParam{
public BinFlatIndexParam() {
super(IndexType.BIN_FLAT);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public BinFlatIndexParam build() {
return new BinFlatIndexParam();
}
}
}
@@ -0,0 +1,48 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;

/**
* for more information, see <a href="https://milvus.io/docs/index.md#BIN_IVF_FLAT">Index#BIN_IVF_FLAT</a>
*/
public class BinIvfFlatIndexParam extends IndexParam {

/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;

public BinIvfFlatIndexParam(Integer nlist) {
super(IndexType.BIN_IVF_FLAT);
this.nlist = nlist;
ensureBetween(nlist, 1, 65536, "nlist must be in range [1,65536]");
}

public static Builder builder() {
return new Builder();
}

public Integer getNlist() {
return nlist;
}


public static final class Builder {
private Integer nlist;

public Builder() {
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}


public BinIvfFlatIndexParam build() {
return new BinIvfFlatIndexParam(nlist);
}
}
}
@@ -0,0 +1,25 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* for more information, see <a href="https://milvus.io/docs/disk_index.md#Index-and-search-settings">Disk Index#DISKANN</a>
*/
public class DiskannIndexParam extends IndexParam {
public DiskannIndexParam() {
super(IndexType.DISKANN);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public DiskannIndexParam build() {
return new DiskannIndexParam();
}
}
}
@@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* empty placeholder class
* for more information, see <a href="https://milvus.io/docs/index.md#FLAT">Index#FLAT</a>
*/
public class FlatIndexParam extends IndexParam {
public FlatIndexParam() {
super(IndexType.FLAT);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public FlatIndexParam build() {
return new FlatIndexParam();
}
}
}
@@ -0,0 +1,42 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;

/**
* for more information, see <a href="https://milvus.io/docs/index-with-gpu.md#Prepare-index-parameters">GPU Index</a>
* parameter same as <a href="https://milvus.io/docs/index.md#IVF_FLAT">Index#IVF_FLAT</a>
*/
public class GpuIvfFlatIndexParam extends IndexParam {
/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;

public GpuIvfFlatIndexParam(Integer nlist) {
super(IndexType.GPU_IVF_FLAT);
ensureBetween(nlist, 1, 65536, "nlist must be in range [1,65536]");
this.nlist = nlist;
}

public Integer getNlist() {
return nlist;
}

public static final class Builder {
private Integer nlist;

public Builder() {
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}

public GpuIvfFlatIndexParam build() {
return new GpuIvfFlatIndexParam(nlist);
}
}
}
@@ -0,0 +1,86 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* for more information, see <a href="https://milvus.io/docs/index-with-gpu.md#Prepare-index-parameters">GPU Index</a>
* parameter same as <a href="https://milvus.io/docs/index.md#IVF_PQ">Index#IVF_PQ</a>
*/
public class GpuIvfPqIndexParam extends IndexParam {

/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;
/**
* Number of factors of product quantization, Range: dim mod m == 0
*/
private final Integer m;
/**
* [Optional] Number of bits in which each low-dimensional vector is stored. Range: [1, 16], Default: 8
*/
private final Integer nbits;

public GpuIvfPqIndexParam(Integer nlist, Integer m) {
this(nlist, m, 8);
}

public GpuIvfPqIndexParam(Integer nlist, Integer m, Integer nbits) {
super(IndexType.GPU_IVF_PQ);
ensureBetween(nlist, 1, 65536, "nlist must be between in range [1,65536]");
ensureNotNull(m, "m must not be null, value range is dim mod m == 0");
if (nbits != null) {
ensureBetween(nbits, 1, 16, "nbits must be in rnage [1,16]");
}
this.nlist = nlist;
this.m = m;
this.nbits = nbits;
}

public static Builder builder() {
return new Builder();
}

public Integer getNlist() {
return nlist;
}

public Integer getM() {
return m;
}

public Integer getNbits() {
return nbits;
}

public static final class Builder {
private Integer nbits;
private Integer m;
private Integer nlist = 8;

public Builder() {
}

public Builder nbits(Integer nbits) {
this.nbits = nbits;
return this;
}

public Builder m(Integer m) {
this.m = m;
return this;
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}

public GpuIvfPqIndexParam build() {
return new GpuIvfPqIndexParam(nlist, m, nbits);
}
}
}

0 comments on commit d7f1b05

Please sign in to comment.