From e4439a8656c2f76c9be8a9e890959f5b6f527122 Mon Sep 17 00:00:00 2001 From: Michal Kozminski Date: Tue, 13 Feb 2024 17:38:47 +0100 Subject: [PATCH] Freature #1005 - Add streaming API for Bedrock Anthropics --- README.md | 2 +- langchain4j-bedrock/pom.xml | 8 ++ .../BedrockAnthropicStreamingChatModel.java | 33 ++++++ .../internal/AbstractBedrockChatModel.java | 65 +--------- .../AbstractBedrockStreamingChatModel.java | 87 ++++++++++++++ .../AbstractSharedBedrockChatModel.java | 112 ++++++++++++++++++ .../bedrock/BedrockStreamingChatModelIT.java | 39 ++++++ 7 files changed, 283 insertions(+), 63 deletions(-) create mode 100644 langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java create mode 100644 langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java create mode 100644 langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java create mode 100644 langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java diff --git a/README.md b/README.md index b98a551780..5fbb317734 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main | [OpenAI](https://docs.langchain4j.dev/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-open-ai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | | [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/hugging-face) | | ✅ | | ✅ | | | | | -| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | ✅ | ✅ | | | +| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | ✅ | ✅ | ✅ | | | | [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ | | [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | | | [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistral-ai) | | ✅ | ✅ | ✅ | | | ✅ | diff --git a/langchain4j-bedrock/pom.xml b/langchain4j-bedrock/pom.xml index 2d341f2f16..c318e54621 100644 --- a/langchain4j-bedrock/pom.xml +++ b/langchain4j-bedrock/pom.xml @@ -55,6 +55,14 @@ test + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java new file mode 100644 index 0000000000..465925eca3 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java @@ -0,0 +1,33 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; + +@Getter +@SuperBuilder +public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { + @Builder.Default + private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue(); + + @Override + protected String getModelId() { + return model; + } + + @Getter + /** + * Bedrock Anthropic model ids + */ + public enum Types { + AnthropicClaudeV2("anthropic.claude-v2"), + AnthropicClaudeV2_1("anthropic.claude-v2:1"); + + private final String value; + + Types(String modelID) { + this.value = modelID; + } + } +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java index 71a552fbb7..77fba2d57d 100644 --- a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java @@ -13,6 +13,7 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -30,47 +31,14 @@ */ @Getter @SuperBuilder -public abstract class AbstractBedrockChatModel implements ChatLanguageModel { - private static final String HUMAN_PROMPT = "Human:"; - private static final String ASSISTANT_PROMPT = "Assistant:"; - - @Builder.Default - private final String humanPrompt = HUMAN_PROMPT; - @Builder.Default - private final String assistantPrompt = ASSISTANT_PROMPT; - @Builder.Default - private final Integer maxRetries = 5; - @Builder.Default - private final Region region = Region.US_EAST_1; - @Builder.Default - private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); - @Builder.Default - private final int maxTokens = 300; - @Builder.Default - private final float temperature = 1; - @Builder.Default - private final float topP = 0.999f; - @Builder.Default - private final String[] stopSequences = new String[]{}; +public abstract class AbstractBedrockChatModel extends AbstractSharedBedrockChatModel implements ChatLanguageModel { @Getter(lazy = true) private final BedrockRuntimeClient client = initClient(); @Override public Response generate(List messages) { - final String context = messages.stream() - .filter(message -> message.type() == ChatMessageType.SYSTEM) - .map(ChatMessage::text) - .collect(joining("\n")); - - final String userMessages = messages.stream() - .filter(message -> message.type() != ChatMessageType.SYSTEM) - .map(this::chatMessageToString) - .collect(joining("\n")); - - final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); - final Map requestParameters = getRequestParameters(prompt); - final String body = Json.toJson(requestParameters); + final String body = convertMessagesToAwsBody(messages); InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries); final String response = invokeModelResponse.body().asUtf8String(); @@ -81,26 +49,6 @@ public Response generate(List messages) { result.getFinishReason()); } - /** - * Convert chat message to string - * - * @param message chat message - * @return string - */ - protected String chatMessageToString(ChatMessage message) { - switch (message.type()) { - case SYSTEM: - return message.text(); - case USER: - return humanPrompt + " " + message.text(); - case AI: - return assistantPrompt + " " + message.text(); - case TOOL_EXECUTION_RESULT: - throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); - } - - throw new IllegalArgumentException("Unknown message type: " + message.type()); - } /** * Get request parameters @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) { */ protected abstract Map getRequestParameters(final String prompt); - /** - * Get model id - * - * @return model id - */ - protected abstract String getModelId(); - /** * Get response class type diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java new file mode 100644 index 0000000000..f71f8412b1 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java @@ -0,0 +1,87 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.internal.Json; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Bedrock Streaming chat model + */ +@Getter +@SuperBuilder +public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { + @Getter + private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); + + class StreamingResponse { + public String completion; + } + + @Override + public void generate(String userMessage, StreamingResponseHandler handler) { + List messages = new ArrayList<>(); + messages.add(new UserMessage(userMessage)); + generate(messages, handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() + .body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) + .modelId(getModelId()) + .contentType("application/json") + .accept("application/json") + .build(); + + StringBuffer finalCompletion = new StringBuffer(); + + InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() + .onChunk(chunk -> { + StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); + finalCompletion.append(sr.completion); + handler.onNext(sr.completion); + }) + .build(); + + InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() + .onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) + .onComplete(() -> { + handler.onComplete(Response.from(new AiMessage(finalCompletion.toString()))); + }) + .onError(handler::onError) + .build(); + asyncClient.invokeModelWithResponseStream(request, h).join(); + + } + + /** + * Initialize async bedrock client + * + * @return async bedrock client + */ + private BedrockRuntimeAsyncClient initAsyncClient() { + BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .build(); + return client; + } + + + +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java new file mode 100644 index 0000000000..681f4781f1 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java @@ -0,0 +1,112 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.internal.Json; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.stream.Collectors.joining; + +@Getter +@SuperBuilder +public abstract class AbstractSharedBedrockChatModel { + // Claude requires you to enclose the prompt as follows: + // String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:"; + protected static final String HUMAN_PROMPT = "Human:"; + protected static final String ASSISTANT_PROMPT = "Assistant:"; + protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; + + @Builder.Default + protected final String humanPrompt = HUMAN_PROMPT; + @Builder.Default + protected final String assistantPrompt = ASSISTANT_PROMPT; + @Builder.Default + protected final Integer maxRetries = 5; + @Builder.Default + protected final Region region = Region.US_EAST_1; + @Builder.Default + protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); + @Builder.Default + protected final int maxTokens = 300; + @Builder.Default + protected final double temperature = 1; + @Builder.Default + protected final float topP = 0.999f; + @Builder.Default + protected final String[] stopSequences = new String[]{}; + @Builder.Default + protected final int topK = 250; + @Builder.Default + protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; + + + /** + * Convert chat message to string + * + * @param message chat message + * @return string + */ + protected String chatMessageToString(ChatMessage message) { + switch (message.type()) { + case SYSTEM: + return message.text(); + case USER: + return humanPrompt + " " + message.text(); + case AI: + return assistantPrompt + " " + message.text(); + case TOOL_EXECUTION_RESULT: + throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); + } + + throw new IllegalArgumentException("Unknown message type: " + message.type()); + } + + protected String convertMessagesToAwsBody(List messages) { + final String context = messages.stream() + .filter(message -> message.type() == ChatMessageType.SYSTEM) + .map(ChatMessage::text) + .collect(joining("\n")); + + final String userMessages = messages.stream() + .filter(message -> message.type() != ChatMessageType.SYSTEM) + .map(this::chatMessageToString) + .collect(joining("\n")); + + final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); + final Map requestParameters = getRequestParameters(prompt); + final String body = Json.toJson(requestParameters); + return body; + } + + protected Map getRequestParameters(String prompt) { + final Map parameters = new HashMap<>(7); + + parameters.put("prompt", prompt); + parameters.put("max_tokens_to_sample", getMaxTokens()); + parameters.put("temperature", getTemperature()); + parameters.put("top_k", topK); + parameters.put("top_p", getTopP()); + parameters.put("stop_sequences", getStopSequences()); + parameters.put("anthropic_version", anthropicVersion); + + return parameters; + } + + /** + * Get model id + * + * @return model id + */ + protected abstract String getModelId(); + +} diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java new file mode 100644 index 0000000000..d758ef45bb --- /dev/null +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java @@ -0,0 +1,39 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; + +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +public class BedrockStreamingChatModelIT { + @Test + @Disabled("To run this test, you must have provide your own access key, secret, region") + void testBedrockAnthropicStreamingChatModel() { + //given + BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel + .builder() + .temperature(0.5) + .maxTokens(300) + .region(Region.US_EAST_1) + .maxRetries(1) + .build(); + UserMessage userMessage = userMessage("What's the capital of Poland?"); + + //when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + bedrockChatModel.generate(singletonList(userMessage), handler); + Response response = handler.get(); + + //then + assertThat(response.content().text()).contains("Warsaw"); + } + + +}