diff --git a/README.md b/README.md
index b98a551780..5fbb317734 100644
--- a/README.md
+++ b/README.md
@@ -179,7 +179,7 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main
| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-open-ai) | | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/hugging-face) | | ✅ | | ✅ | | | | |
-| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | ✅ | ✅ | | |
+| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | ✅ | ✅ | ✅ | | |
| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ |
| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | |
| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistral-ai) | | ✅ | ✅ | ✅ | | | ✅ |
diff --git a/langchain4j-bedrock/pom.xml b/langchain4j-bedrock/pom.xml
index 2d341f2f16..c318e54621 100644
--- a/langchain4j-bedrock/pom.xml
+++ b/langchain4j-bedrock/pom.xml
@@ -55,6 +55,14 @@
test
+
+ dev.langchain4j
+ langchain4j-core
+ tests
+ test-jar
+ test
+
+
diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
new file mode 100644
index 0000000000..465925eca3
--- /dev/null
+++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
@@ -0,0 +1,33 @@
+package dev.langchain4j.model.bedrock;
+
+import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
+import lombok.Builder;
+import lombok.Getter;
+import lombok.experimental.SuperBuilder;
+
+@Getter
+@SuperBuilder
+public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
+ @Builder.Default
+ private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue();
+
+ @Override
+ protected String getModelId() {
+ return model;
+ }
+
+ @Getter
+ /**
+ * Bedrock Anthropic model ids
+ */
+ public enum Types {
+ AnthropicClaudeV2("anthropic.claude-v2"),
+ AnthropicClaudeV2_1("anthropic.claude-v2:1");
+
+ private final String value;
+
+ Types(String modelID) {
+ this.value = modelID;
+ }
+ }
+}
diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java
index 71a552fbb7..77fba2d57d 100644
--- a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java
+++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java
@@ -13,6 +13,7 @@
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
@@ -30,47 +31,14 @@
*/
@Getter
@SuperBuilder
-public abstract class AbstractBedrockChatModel implements ChatLanguageModel {
- private static final String HUMAN_PROMPT = "Human:";
- private static final String ASSISTANT_PROMPT = "Assistant:";
-
- @Builder.Default
- private final String humanPrompt = HUMAN_PROMPT;
- @Builder.Default
- private final String assistantPrompt = ASSISTANT_PROMPT;
- @Builder.Default
- private final Integer maxRetries = 5;
- @Builder.Default
- private final Region region = Region.US_EAST_1;
- @Builder.Default
- private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
- @Builder.Default
- private final int maxTokens = 300;
- @Builder.Default
- private final float temperature = 1;
- @Builder.Default
- private final float topP = 0.999f;
- @Builder.Default
- private final String[] stopSequences = new String[]{};
+public abstract class AbstractBedrockChatModel extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
@Getter(lazy = true)
private final BedrockRuntimeClient client = initClient();
@Override
public Response generate(List messages) {
- final String context = messages.stream()
- .filter(message -> message.type() == ChatMessageType.SYSTEM)
- .map(ChatMessage::text)
- .collect(joining("\n"));
-
- final String userMessages = messages.stream()
- .filter(message -> message.type() != ChatMessageType.SYSTEM)
- .map(this::chatMessageToString)
- .collect(joining("\n"));
-
- final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
- final Map requestParameters = getRequestParameters(prompt);
- final String body = Json.toJson(requestParameters);
+ final String body = convertMessagesToAwsBody(messages);
InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
final String response = invokeModelResponse.body().asUtf8String();
@@ -81,26 +49,6 @@ public Response generate(List messages) {
result.getFinishReason());
}
- /**
- * Convert chat message to string
- *
- * @param message chat message
- * @return string
- */
- protected String chatMessageToString(ChatMessage message) {
- switch (message.type()) {
- case SYSTEM:
- return message.text();
- case USER:
- return humanPrompt + " " + message.text();
- case AI:
- return assistantPrompt + " " + message.text();
- case TOOL_EXECUTION_RESULT:
- throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
- }
-
- throw new IllegalArgumentException("Unknown message type: " + message.type());
- }
/**
* Get request parameters
@@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) {
*/
protected abstract Map getRequestParameters(final String prompt);
- /**
- * Get model id
- *
- * @return model id
- */
- protected abstract String getModelId();
-
/**
* Get response class type
diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
new file mode 100644
index 0000000000..f71f8412b1
--- /dev/null
+++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
@@ -0,0 +1,87 @@
+package dev.langchain4j.model.bedrock.internal;
+
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.internal.Json;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import lombok.Getter;
+import lombok.experimental.SuperBuilder;
+import software.amazon.awssdk.core.SdkBytes;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
+import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
+import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Bedrock Streaming chat model
+ */
+@Getter
+@SuperBuilder
+public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
+ @Getter
+ private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();
+
+ class StreamingResponse {
+ public String completion;
+ }
+
+ @Override
+ public void generate(String userMessage, StreamingResponseHandler handler) {
+ List messages = new ArrayList<>();
+ messages.add(new UserMessage(userMessage));
+ generate(messages, handler);
+ }
+
+ @Override
+ public void generate(List messages, StreamingResponseHandler handler) {
+ InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
+ .body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
+ .modelId(getModelId())
+ .contentType("application/json")
+ .accept("application/json")
+ .build();
+
+ StringBuffer finalCompletion = new StringBuffer();
+
+ InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
+ .onChunk(chunk -> {
+ StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
+ finalCompletion.append(sr.completion);
+ handler.onNext(sr.completion);
+ })
+ .build();
+
+ InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
+ .onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
+ .onComplete(() -> {
+ handler.onComplete(Response.from(new AiMessage(finalCompletion.toString())));
+ })
+ .onError(handler::onError)
+ .build();
+ asyncClient.invokeModelWithResponseStream(request, h).join();
+
+ }
+
+ /**
+ * Initialize async bedrock client
+ *
+ * @return async bedrock client
+ */
+ private BedrockRuntimeAsyncClient initAsyncClient() {
+ BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
+ .region(region)
+ .credentialsProvider(credentialsProvider)
+ .build();
+ return client;
+ }
+
+
+
+}
diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
new file mode 100644
index 0000000000..681f4781f1
--- /dev/null
+++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
@@ -0,0 +1,112 @@
+package dev.langchain4j.model.bedrock.internal;
+
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ChatMessageType;
+import dev.langchain4j.internal.Json;
+import lombok.Builder;
+import lombok.Getter;
+import lombok.experimental.SuperBuilder;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static java.util.stream.Collectors.joining;
+
+@Getter
+@SuperBuilder
+public abstract class AbstractSharedBedrockChatModel {
+ // Claude requires you to enclose the prompt as follows:
+ // String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:";
+ protected static final String HUMAN_PROMPT = "Human:";
+ protected static final String ASSISTANT_PROMPT = "Assistant:";
+ protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";
+
+ @Builder.Default
+ protected final String humanPrompt = HUMAN_PROMPT;
+ @Builder.Default
+ protected final String assistantPrompt = ASSISTANT_PROMPT;
+ @Builder.Default
+ protected final Integer maxRetries = 5;
+ @Builder.Default
+ protected final Region region = Region.US_EAST_1;
+ @Builder.Default
+ protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
+ @Builder.Default
+ protected final int maxTokens = 300;
+ @Builder.Default
+ protected final double temperature = 1;
+ @Builder.Default
+ protected final float topP = 0.999f;
+ @Builder.Default
+ protected final String[] stopSequences = new String[]{};
+ @Builder.Default
+ protected final int topK = 250;
+ @Builder.Default
+ protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;
+
+
+ /**
+ * Convert chat message to string
+ *
+ * @param message chat message
+ * @return string
+ */
+ protected String chatMessageToString(ChatMessage message) {
+ switch (message.type()) {
+ case SYSTEM:
+ return message.text();
+ case USER:
+ return humanPrompt + " " + message.text();
+ case AI:
+ return assistantPrompt + " " + message.text();
+ case TOOL_EXECUTION_RESULT:
+ throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
+ }
+
+ throw new IllegalArgumentException("Unknown message type: " + message.type());
+ }
+
+ protected String convertMessagesToAwsBody(List messages) {
+ final String context = messages.stream()
+ .filter(message -> message.type() == ChatMessageType.SYSTEM)
+ .map(ChatMessage::text)
+ .collect(joining("\n"));
+
+ final String userMessages = messages.stream()
+ .filter(message -> message.type() != ChatMessageType.SYSTEM)
+ .map(this::chatMessageToString)
+ .collect(joining("\n"));
+
+ final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
+ final Map requestParameters = getRequestParameters(prompt);
+ final String body = Json.toJson(requestParameters);
+ return body;
+ }
+
+ protected Map getRequestParameters(String prompt) {
+ final Map parameters = new HashMap<>(7);
+
+ parameters.put("prompt", prompt);
+ parameters.put("max_tokens_to_sample", getMaxTokens());
+ parameters.put("temperature", getTemperature());
+ parameters.put("top_k", topK);
+ parameters.put("top_p", getTopP());
+ parameters.put("stop_sequences", getStopSequences());
+ parameters.put("anthropic_version", anthropicVersion);
+
+ return parameters;
+ }
+
+ /**
+ * Get model id
+ *
+ * @return model id
+ */
+ protected abstract String getModelId();
+
+}
diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java
new file mode 100644
index 0000000000..d758ef45bb
--- /dev/null
+++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java
@@ -0,0 +1,39 @@
+package dev.langchain4j.model.bedrock;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.chat.TestStreamingResponseHandler;
+import dev.langchain4j.model.output.Response;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import software.amazon.awssdk.regions.Region;
+
+import static dev.langchain4j.data.message.UserMessage.userMessage;
+import static java.util.Collections.singletonList;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class BedrockStreamingChatModelIT {
+ @Test
+ @Disabled("To run this test, you must have provide your own access key, secret, region")
+ void testBedrockAnthropicStreamingChatModel() {
+ //given
+ BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel
+ .builder()
+ .temperature(0.5)
+ .maxTokens(300)
+ .region(Region.US_EAST_1)
+ .maxRetries(1)
+ .build();
+ UserMessage userMessage = userMessage("What's the capital of Poland?");
+
+ //when
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ bedrockChatModel.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
+
+ //then
+ assertThat(response.content().text()).contains("Warsaw");
+ }
+
+
+}