Skip to content

Commit

Permalink
Freature langchain4j#1005 - Add streaming API for Bedrock Anthropics
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkozminski committed Apr 23, 2024
1 parent f492df6 commit e4439a8
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 63 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -179,7 +179,7 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main
| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/open-ai) |||||| ||
| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-open-ai) | ||||| ||
| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/hugging-face) | || || | | | |
| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | || ||| | |
| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | || ||| | |
| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | ||| || ||
| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) ||| ||| | |
| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistral-ai) | |||| | ||
Expand Down
8 changes: 8 additions & 0 deletions langchain4j-bedrock/pom.xml
Expand Up @@ -55,6 +55,14 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
@@ -0,0 +1,33 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;

@Getter
@SuperBuilder
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
@Builder.Default
private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue();

@Override
protected String getModelId() {
return model;
}

@Getter
/**
* Bedrock Anthropic model ids
*/
public enum Types {
AnthropicClaudeV2("anthropic.claude-v2"),
AnthropicClaudeV2_1("anthropic.claude-v2:1");

private final String value;

Types(String modelID) {
this.value = modelID;
}
}
}
Expand Up @@ -13,6 +13,7 @@
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
Expand All @@ -30,47 +31,14 @@
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> implements ChatLanguageModel {
private static final String HUMAN_PROMPT = "Human:";
private static final String ASSISTANT_PROMPT = "Assistant:";

@Builder.Default
private final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
private final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
private final Integer maxRetries = 5;
@Builder.Default
private final Region region = Region.US_EAST_1;
@Builder.Default
private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
private final int maxTokens = 300;
@Builder.Default
private final float temperature = 1;
@Builder.Default
private final float topP = 0.999f;
@Builder.Default
private final String[] stopSequences = new String[]{};
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
@Getter(lazy = true)
private final BedrockRuntimeClient client = initClient();

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
final String body = convertMessagesToAwsBody(messages);

InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
final String response = invokeModelResponse.body().asUtf8String();
Expand All @@ -81,26 +49,6 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
result.getFinishReason());
}

/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

/**
* Get request parameters
Expand All @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) {
*/
protected abstract Map<String, Object> getRequestParameters(final String prompt);

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();


/**
* Get response class type
Expand Down
@@ -0,0 +1,87 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

/**
* Bedrock Streaming chat model
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
@Getter
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();

class StreamingResponse {
public String completion;
}

@Override
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage(userMessage));
generate(messages, handler);
}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
.modelId(getModelId())
.contentType("application/json")
.accept("application/json")
.build();

StringBuffer finalCompletion = new StringBuffer();

InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
.onChunk(chunk -> {
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
finalCompletion.append(sr.completion);
handler.onNext(sr.completion);
})
.build();

InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
.onComplete(() -> {
handler.onComplete(Response.from(new AiMessage(finalCompletion.toString())));
})
.onError(handler::onError)
.build();
asyncClient.invokeModelWithResponseStream(request, h).join();

}

/**
* Initialize async bedrock client
*
* @return async bedrock client
*/
private BedrockRuntimeAsyncClient initAsyncClient() {
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(region)
.credentialsProvider(credentialsProvider)
.build();
return client;
}



}
@@ -0,0 +1,112 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.stream.Collectors.joining;

@Getter
@SuperBuilder
public abstract class AbstractSharedBedrockChatModel {
// Claude requires you to enclose the prompt as follows:
// String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:";
protected static final String HUMAN_PROMPT = "Human:";
protected static final String ASSISTANT_PROMPT = "Assistant:";
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";

@Builder.Default
protected final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
protected final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
protected final Integer maxRetries = 5;
@Builder.Default
protected final Region region = Region.US_EAST_1;
@Builder.Default
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
protected final int maxTokens = 300;
@Builder.Default
protected final double temperature = 1;
@Builder.Default
protected final float topP = 0.999f;
@Builder.Default
protected final String[] stopSequences = new String[]{};
@Builder.Default
protected final int topK = 250;
@Builder.Default
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;


/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
return body;
}

protected Map<String, Object> getRequestParameters(String prompt) {
final Map<String, Object> parameters = new HashMap<>(7);

parameters.put("prompt", prompt);
parameters.put("max_tokens_to_sample", getMaxTokens());
parameters.put("temperature", getTemperature());
parameters.put("top_k", topK);
parameters.put("top_p", getTopP());
parameters.put("stop_sequences", getStopSequences());
parameters.put("anthropic_version", anthropicVersion);

return parameters;
}

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();

}
@@ -0,0 +1,39 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.regions.Region;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;

public class BedrockStreamingChatModelIT {
@Test
@Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockAnthropicStreamingChatModel() {
//given
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel
.builder()
.temperature(0.5)
.maxTokens(300)
.region(Region.US_EAST_1)
.maxRetries(1)
.build();
UserMessage userMessage = userMessage("What's the capital of Poland?");

//when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
bedrockChatModel.generate(singletonList(userMessage), handler);
Response<AiMessage> response = handler.get();

//then
assertThat(response.content().text()).contains("Warsaw");
}


}

0 comments on commit e4439a8

Please sign in to comment.