Skip to content

Commit

Permalink
Merge pull request #8 from showpune/main
Browse files Browse the repository at this point in the history
Support Azure OpenAI and AI Search
  • Loading branch information
langchain4j committed May 2, 2024
2 parents 40cf8c2 + 489aa41 commit 2b95082
Show file tree
Hide file tree
Showing 16 changed files with 850 additions and 0 deletions.
75 changes: 75 additions & 0 deletions langchain4j-azure-ai-search-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring</artifactId>
<version>0.31.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<artifactId>langchain4j-azure-ai-search-spring-boot-starter</artifactId>
<name>LangChain4j Spring Boot starter for Azure AI Search</name>
<packaging>jar</packaging>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-ai-search</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure-processor</artifactId>
<optional>true</optional>
</dependency>

<!-- should be listed before spring-boot-configuration-processor -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<!-- needed to generate automatic metadata about available config properties -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package dev.langchain4j.azure.aisearch.spring;

import com.azure.search.documents.indexes.models.SearchIndex;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchEmbeddingStore;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.lang.Nullable;

import static dev.langchain4j.azure.aisearch.spring.Properties.PREFIX;

@AutoConfiguration
@EnableConfigurationProperties(Properties.class)
public class AutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".content-retriever.api-key")
public AzureAiSearchContentRetriever azureAiSearchContentRetriever(Properties properties, @Nullable EmbeddingModel embeddingModel, @Nullable SearchIndex index) {
Properties.NestedProperties nestedProperties = properties.getContentRetriever();
return AzureAiSearchContentRetriever.builder()
.endpoint(nestedProperties.getEndpoint())
.apiKey(nestedProperties.getApiKey())
.createOrUpdateIndex(nestedProperties.getCreateOrUpdateIndex())
.embeddingModel(embeddingModel)
.dimensions(nestedProperties.getDimensions() == null ? 0 : nestedProperties.getDimensions())
.index(index)
.maxResults(nestedProperties.getMaxResults())
.minScore(nestedProperties.getMinScore() == null ? 0.0 : nestedProperties.getMinScore())
.queryType(nestedProperties.getQueryType())
.build();
}

@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.api-key")
public AzureAiSearchEmbeddingStore azureAiSearchEmbeddingStore(Properties properties, @Nullable EmbeddingModel embeddingModel, @Nullable SearchIndex index) {
Properties.NestedProperties nestedProperties = properties.getEmbeddingStore();
return AzureAiSearchEmbeddingStore.builder()
.endpoint(nestedProperties.getEndpoint())
.apiKey(nestedProperties.getApiKey())
.createOrUpdateIndex(nestedProperties.getCreateOrUpdateIndex())
.dimensions(nestedProperties.getDimensions())
.index(index)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dev.langchain4j.azure.aisearch.spring;

import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchQueryType;
import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

@Getter
@Setter
@ConfigurationProperties(prefix = Properties.PREFIX)
public class Properties {

static final String PREFIX = "langchain4j.azure.ai-search";

@NestedConfigurationProperty
NestedProperties contentRetriever;

@NestedConfigurationProperty
NestedProperties embeddingStore;

@Getter
@Setter
public static class NestedProperties {
String endpoint;
String apiKey;
Integer dimensions;
Boolean createOrUpdateIndex;
String indexName;
Integer maxResults = 3;
Double minScore;
AzureAiSearchQueryType queryType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
dev.langchain4j.azure.aisearch.spring.AutoConfig
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dev.langchain4j.azure.aisearch.spring.AutoConfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package dev.langchain4j.azure.aisearch.spring;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.SearchIndex;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchContentRetriever;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchQueryType;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchEmbeddingStore;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.List;

import static dev.langchain4j.store.embedding.azure.search.AbstractAzureAiSearchEmbeddingStore.INDEX_NAME;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;

class AutoConfigIT {

private static final Logger log = LoggerFactory.getLogger(AutoConfigIT.class);

private static final String AZURE_SEARCH_KEY = System.getenv("AZURE_SEARCH_KEY");
private static final String AZURE_SEARCH_ENDPOINT = System.getenv("AZURE_SEARCH_ENDPOINT");

private final EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
private final int dimensions = embeddingModel.embed("test").content().dimension();

private final SearchIndexClient searchIndexClient = new SearchIndexClientBuilder()
.endpoint(System.getenv("AZURE_SEARCH_ENDPOINT"))
.credential(new AzureKeyCredential(System.getenv("AZURE_SEARCH_KEY")))
.buildClient();
private final SearchIndex index = new SearchIndex(INDEX_NAME);

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AutoConfig.class));

@Test
void should_provide_ai_search_retriever() {

searchIndexClient.deleteIndex(INDEX_NAME);

contextRunner
.withPropertyValues(
Properties.PREFIX + ".content-retriever.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".content-retriever.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".content-retriever.dimensions=" + dimensions,
Properties.PREFIX + ".content-retriever.create-or-update-index=" + "true",
Properties.PREFIX + ".content-retriever.query-type=" + "VECTOR"
).withBean(EmbeddingModel.class, () -> embeddingModel)
.run(context -> {
ContentRetriever contentRetriever = context.getBean(ContentRetriever.class);
assertThat(contentRetriever).isInstanceOf(AzureAiSearchContentRetriever.class);
AzureAiSearchContentRetriever azureAiSearchContentRetriever = (AzureAiSearchContentRetriever) contentRetriever;

String content1 = "This book is about politics";
String content2 = "Cats sleeps a lot.";
String content3 = "Sandwiches taste good.";
String content4 = "The house is open";
List<String> contents = asList(content1, content2, content3, content4);

for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
azureAiSearchContentRetriever.add(embedding, textSegment);
}

awaitUntilPersisted();
});

String content = "house";
Query query = Query.from(content);

contextRunner
.withPropertyValues(
Properties.PREFIX + ".content-retriever.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".content-retriever.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".content-retriever.create-or-update-index=" + "false",
Properties.PREFIX + ".content-retriever.max-results=" + "3",
Properties.PREFIX + ".content-retriever.min-score=" + "0.6",
Properties.PREFIX + ".content-retriever.query-type=" + AzureAiSearchQueryType.VECTOR
).withBean(SearchIndex.class, () -> index)
.withBean(EmbeddingModel.class, () -> embeddingModel)
.run(context -> {
ContentRetriever contentRetriever = context.getBean(ContentRetriever.class);
assertThat(contentRetriever).isInstanceOf(AzureAiSearchContentRetriever.class);
AzureAiSearchContentRetriever contentRetrieverWithVector = (AzureAiSearchContentRetriever) contentRetriever;
log.info("Testing Vector Search");
List<Content> relevant = contentRetrieverWithVector.retrieve(query);
assertThat(relevant).hasSizeGreaterThan(0);
assertThat(relevant.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant.get(0).textSegment().text());
});

contextRunner
.withPropertyValues(
Properties.PREFIX + ".content-retriever.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".content-retriever.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".content-retriever.create-or-update-index=" + "false",
Properties.PREFIX + ".content-retriever.query-type=" + AzureAiSearchQueryType.FULL_TEXT
)
.run(context -> {
ContentRetriever contentRetriever = context.getBean(ContentRetriever.class);
assertThat(contentRetriever).isInstanceOf(AzureAiSearchContentRetriever.class);
AzureAiSearchContentRetriever contentRetrieverWithFullText = (AzureAiSearchContentRetriever) contentRetriever;
log.info("Testing Full Text Search");
// This uses the same storage as the vector search, so we don't need to add the content again
List<Content> relevant2 = contentRetrieverWithFullText.retrieve(query);
assertThat(relevant2).hasSizeGreaterThan(0);
assertThat(relevant2.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant2.get(0).textSegment().text());
});

contextRunner
.withPropertyValues(
Properties.PREFIX + ".content-retriever.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".content-retriever.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".content-retriever.create-or-update-index=" + "false",
Properties.PREFIX + ".content-retriever.query-type=" + AzureAiSearchQueryType.HYBRID
).withBean(SearchIndex.class, () -> index)
.withBean(EmbeddingModel.class, () -> embeddingModel)
.run(context -> {
ContentRetriever contentRetriever = context.getBean(ContentRetriever.class);
assertThat(contentRetriever).isInstanceOf(AzureAiSearchContentRetriever.class);
AzureAiSearchContentRetriever contentRetrieverWithHybrid = (AzureAiSearchContentRetriever) contentRetriever;
log.info("Testing Hybrid Search");
List<Content> relevant3 = contentRetrieverWithHybrid.retrieve(query);
assertThat(relevant3).hasSizeGreaterThan(0);
assertThat(relevant3.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant3.get(0).textSegment().text());
});

contextRunner
.withPropertyValues(
Properties.PREFIX + ".content-retriever.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".content-retriever.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".content-retriever.create-or-update-index=" + "false",
Properties.PREFIX + ".content-retriever.max-results=" + "3",
Properties.PREFIX + ".content-retriever.min-score=" + "0.4",
Properties.PREFIX + ".content-retriever.query-type=" + AzureAiSearchQueryType.HYBRID_WITH_RERANKING
).withBean(SearchIndex.class, () -> index)
.withBean(EmbeddingModel.class, () -> embeddingModel)
.run(context -> {
ContentRetriever contentRetriever = context.getBean(ContentRetriever.class);
assertThat(contentRetriever).isInstanceOf(AzureAiSearchContentRetriever.class);
AzureAiSearchContentRetriever contentRetrieverWithHybridAndReranking = (AzureAiSearchContentRetriever) contentRetriever;
log.info("Testing Hybrid Search with Reranking");
List<Content> relevant4 = contentRetrieverWithHybridAndReranking.retrieve(query);
assertThat(relevant4).hasSizeGreaterThan(0);
assertThat(relevant4.get(0).textSegment().text()).isEqualTo("The house is open");
log.info("#1 relevant item: {}", relevant4.get(0).textSegment().text());
});
}

@Test
void should_provide_ai_search_embedding_store() {

searchIndexClient.deleteIndex(INDEX_NAME);

contextRunner
.withPropertyValues(
Properties.PREFIX + ".embedding-store.api-key=" + AZURE_SEARCH_KEY,
Properties.PREFIX + ".embedding-store.endpoint=" + AZURE_SEARCH_ENDPOINT,
Properties.PREFIX + ".embedding-store.dimensions=" + 384,
Properties.PREFIX + ".embedding-store.create-or-update-index=" + "true"
).withBean(EmbeddingModel.class, () -> embeddingModel)
.run(context -> {
EmbeddingStore<TextSegment> embeddingStore = context.getBean(EmbeddingStore.class);
assertThat(embeddingStore).isInstanceOf(AzureAiSearchEmbeddingStore.class);
assertThat(context.getBean(AzureAiSearchEmbeddingStore.class)).isSameAs(embeddingStore);

String content1 = "banana";
String content2 = "computer";
String content3 = "apple";
String content4 = "pizza";
String content5 = "strawberry";
String content6 = "chess";
List<String> contents = asList(content1, content2, content3, content4, content5, content6);

for (String content : contents) {
TextSegment textSegment = TextSegment.from(content);
Embedding embedding = embeddingModel.embed(content).content();
embeddingStore.add(embedding, textSegment);
}
Embedding relevantEmbedding = embeddingModel.embed("fruit").content();
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(relevantEmbedding, 3);
assertThat(relevant).hasSize(3);
assertThat(relevant.get(0).embedding()).isNotNull();
assertThat(relevant.get(0).embedded().text()).isIn(content1, content3, content5);
});
}

protected void awaitUntilPersisted() {
try {
Thread.sleep(1_000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

0 comments on commit 2b95082

Please sign in to comment.