Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#435: Add metadata support (read/write) to pinecone embedded store #955

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -34,6 +34,7 @@ void should_add_embedding_with_segment_with_metadata() {
}

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to add this function to be able to use Awaitility with pinecone client to wait for records to be retrievable.
Pinecone is eventually consistent, so records are/may not be available just after write

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! What do you think about changing existing awaitUntilPersisted() instead of introducing new overloaded method? I guess the implementation can also be moved to the EmbeddingStoreIT so that all implementations can benefit from using awaitility?


List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
assertThat(relevant).hasSize(1);
Expand Down
Expand Up @@ -42,6 +42,7 @@ void should_add_embedding() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand All @@ -68,6 +69,7 @@ void should_add_embedding_with_id() {
embeddingStore().add(id, embedding);

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -95,6 +97,7 @@ void should_add_embedding_with_segment() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -125,6 +128,7 @@ void should_add_multiple_embeddings() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -171,6 +175,7 @@ void should_add_multiple_embeddings_with_segments() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -210,6 +215,7 @@ void should_find_with_min_score() {
embeddingStore().add(secondId, secondEmbedding);

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -282,6 +288,7 @@ void should_return_correct_score() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

Embedding referenceEmbedding = embeddingModel().embed("hi").content();

Expand All @@ -304,4 +311,8 @@ void should_return_correct_score() {
protected void awaitUntilPersisted() {
// not waiting by default
}

protected void awaitUntilPersisted(Embedding firstEmbedding, int expectedSize) {
// not waiting by default
}
}
9 changes: 8 additions & 1 deletion langchain4j-pinecone/pom.xml
Expand Up @@ -83,6 +83,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>


<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
Expand All @@ -91,4 +98,4 @@

</dependencies>

</project>
</project>
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
Expand All @@ -13,10 +14,10 @@
import io.pinecone.PineconeConnection;
import io.pinecone.PineconeConnectionConfig;
import io.pinecone.proto.*;
import io.pinecone.proto.Vector;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Collections.emptyList;
Expand All @@ -41,13 +42,13 @@ public class PineconeEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* Creates an instance of PineconeEmbeddingStore.
*
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
*/
public PineconeEmbeddingStore(String apiKey,
String environment,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that environment and projectId are not used any more. Does Pinecone resolve them from api key now?
I think we should mark there params as @Deprecated in the builder then, WDYT?

Expand Down Expand Up @@ -136,12 +137,16 @@ private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<T
vectorBuilder.setMetadata(Struct.newBuilder()
.putFields(metadataTextKey, Value.newBuilder()
.setStringValue(textSegments.get(i).text())
.build()));
.build())
.putAllFields(textSegments.get(i).metadata().asMap().entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> Value.newBuilder().setStringValue(e.getValue()).build()))));
}

upsertRequestBuilder.addVectors(vectorBuilder.build());
}

//noinspection ResultOfMethodCallIgnored
connection.getBlockingStub().upsert(upsertRequestBuilder.build());
}

Expand Down Expand Up @@ -184,22 +189,47 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
return matches;
}


private EmbeddingMatch<TextSegment> toEmbeddingMatch(Vector vector, Embedding referenceEmbedding) {
Value textSegmentValue = vector.getMetadata()
Struct metadataStruct = vector.getMetadata();

Value textSegmentValue = metadataStruct
.getFieldsMap()
.get(metadataTextKey);

boolean filterOutMetadataTextKey = true;
Map<String, String> metadataMap = structToMap(metadataStruct, filterOutMetadataTextKey);
Metadata metadata = Metadata.from(metadataMap);

Embedding embedding = Embedding.from(vector.getValuesList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);

return new EmbeddingMatch<>(
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
vector.getId(),
embedding,
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue(), metadata)
);
}

private Map<String, String> structToMap(Struct struct, boolean filterOutMetadataTextKey) {
Map<String, String> result = new HashMap<>();
Map<String, Value> fields = struct.getFieldsMap();

for (Map.Entry<String, Value> entry : fields.entrySet()) {
if (filterOutMetadataTextKey && isMetadataTextKey(entry.getKey())) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when retreiving I am filtering out metadataTextKey from metadata, otherwise tests were failing.
I assumed metadataTextKey is only a technical thing to let us store original content in metadata, but it is not something that should be exposed

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We store original embedded text under text_segment (or whatever is defined in metadataTextKey property) key in Pinecone metadata, and it should be returned back inside TextSegment.text().

The problem can happen in such case:
TextSegment textSegment = TextSegment.from("hello", new Metadata().put("text_segment", "bye"))
Since metadata key matches key of the text, text will be overriden and no metadata when retrieveing it back:
TextSegment { text = "bye" metadata = {} }

One option is to prepend all metadata keys with metadata_ prefix and then removing this prefix when retreiving back, so this TextSegment.from("hello", new Metadata().put("text_segment", "bye")) will become:

text_segment -> "hello"
metadata_text_segment -> "bye"

in Pinecone's metadata.

WDYT?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just making sure your not going to have a key collision and throwing an exception if it happens?
My reasoning being is some people are going to have multiple systems feeding into their store, and they will have to go modify other systems to support this paradigm.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfotex makes sense!

continue;
}
result.put(entry.getKey(), entry.getValue().getStringValue());
}

return result;
}

private boolean isMetadataTextKey(String key) {
return metadataTextKey.equals(key);
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -231,7 +261,7 @@ public Builder environment(String environment) {

/**
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

*/
public Builder projectId(String projectId) {
this.projectId = projectId;
Expand All @@ -257,6 +287,7 @@ public Builder nameSpace(String nameSpace) {
/**
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
*/
@SuppressWarnings("unused")
public Builder metadataTextKey(String metadataTextKey) {
this.metadataTextKey = metadataTextKey;
return this;
Expand Down
@@ -1,16 +1,20 @@
package dev.langchain4j.store.embedding.pinecone;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.time.Duration;

import static dev.langchain4j.internal.Utils.randomUUID;
import static org.awaitility.Awaitility.await;

@EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+")
class PineconeEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
class PineconeEmbeddingStoreIT extends EmbeddingStoreIT {

EmbeddingStore<TextSegment> embeddingStore = PineconeEmbeddingStore.builder()
.apiKey(System.getenv("PINECONE_API_KEY"))
Expand All @@ -31,4 +35,11 @@ protected EmbeddingStore<TextSegment> embeddingStore() {
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}

@Override
protected void awaitUntilPersisted(Embedding embedding, int expectedSize) {
await()
.timeout(Duration.ofSeconds(15))
.until(() -> embeddingStore.findRelevant(embedding, expectedSize).size() == expectedSize);
}
}