Skip to content

Commit

Permalink
Add Bedrock Cohere Command R model support.
Browse files Browse the repository at this point in the history
  • Loading branch information
wmz7year committed May 2, 2024
1 parent 7252ba1 commit 9a0feaa
Show file tree
Hide file tree
Showing 15 changed files with 1,962 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.cohere;

import java.util.List;

import org.springframework.ai.bedrock.BedrockUsage;
import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest;
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;

import reactor.core.publisher.Flux;

/**
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockCohereCommandRChatClient implements ChatClient, StreamingChatClient {

private final CohereCommandRChatBedrockApi chatApi;

private final BedrockCohereCommandRChatOptions defaultOptions;

public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi) {
this(chatApi, BedrockCohereCommandRChatOptions.builder().build());
}

public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi,
BedrockCohereCommandRChatOptions options) {
Assert.notNull(chatApi, "CohereCommandRChatBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null");

this.chatApi = chatApi;
this.defaultOptions = options;
}

@Override
public ChatResponse call(Prompt prompt) {
CohereCommandRChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt));

Generation generation = new Generation(response.text());

return new ChatResponse(List.of(generation));
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(g -> {
if (g.isFinished()) {
String finishReason = g.finishReason().name();
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
return new ChatResponse(List
.of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage))));
}
return new ChatResponse(List.of(new Generation(g.text())));
});
}

CohereCommandRChatRequest createRequest(Prompt prompt) {
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

var request = CohereCommandRChatRequest.builder(promptValue)
.withSearchQueriesOnly(this.defaultOptions.getSearchQueriesOnly())
.withPreamble(this.defaultOptions.getPreamble())
.withMaxTokens(this.defaultOptions.getMaxTokens())
.withTemperature(this.defaultOptions.getTemperature())
.withTopP(this.defaultOptions.getTopP())
.withTopK(this.defaultOptions.getTopK())
.withPromptTruncation(this.defaultOptions.getPromptTruncation())
.withFrequencyPenalty(this.defaultOptions.getFrequencyPenalty())
.withPresencePenalty(this.defaultOptions.getPresencePenalty())
.withSeed(this.defaultOptions.getSeed())
.withReturnPrompt(this.defaultOptions.getReturnPrompt())
.withStopSequences(this.defaultOptions.getStopSequences())
.withRawPrompting(this.defaultOptions.getRawPrompting())
.build();

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockCohereCommandRChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereCommandRChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

return request;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.cohere;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonInclude.Include;

import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation;
import org.springframework.ai.chat.prompt.ChatOptions;

/**
* @author Wei Jiang
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class BedrockCohereCommandRChatOptions implements ChatOptions {

// @formatter:off
/**
* (optional) When enabled, it will only generate potential search queries without performing
* searches or providing a response.
*/
@JsonProperty("search_queries_only") Boolean searchQueriesOnly;
/**
* (optional) Overrides the default preamble for search query generation.
*/
@JsonProperty("preamble") String preamble;
/**
* (optional) Specify the maximum number of tokens to use in the generated response.
*/
@JsonProperty("max_tokens") Integer maxTokens;
/**
* (optional) Use a lower value to decrease randomness in the response.
*/
@JsonProperty("temperature") Float temperature;
/**
* Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable.
*/
@JsonProperty("p") Float topP;
/**
* Top K. Specify the number of token choices the model uses to generate the next token.
*/
@JsonProperty("k") Integer topK;
/**
* (optional) Dictates how the prompt is constructed.
*/
@JsonProperty("prompt_truncation") PromptTruncation promptTruncation;
/**
* (optional) Used to reduce repetitiveness of generated tokens.
*/
@JsonProperty("frequency_penalty") Float frequencyPenalty;
/**
* (optional) Used to reduce repetitiveness of generated tokens.
*/
@JsonProperty("presence_penalty") Float presencePenalty;
/**
* (optional) Specify the best effort to sample tokens deterministically.
*/
@JsonProperty("seed") Integer seed;
/**
* (optional) Specify true to return the full prompt that was sent to the model.
*/
@JsonProperty("return_prompt") Boolean returnPrompt;
/**
* (optional) A list of stop sequences.
*/
@JsonProperty("stop_sequences") List<String> stopSequences;
/**
* (optional) Specify true, to send the user’s message to the model without any preprocessing.
*/
@JsonProperty("raw_prompting") Boolean rawPrompting;
// @formatter:on

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final BedrockCohereCommandRChatOptions options = new BedrockCohereCommandRChatOptions();

public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) {
options.setSearchQueriesOnly(searchQueriesOnly);
return this;
}

public Builder withPreamble(String preamble) {
options.setPreamble(preamble);
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
options.setMaxTokens(maxTokens);
return this;
}

public Builder withTemperature(Float temperature) {
options.setTemperature(temperature);
return this;
}

public Builder withTopP(Float topP) {
options.setTopP(topP);
return this;
}

public Builder withTopK(Integer topK) {
options.setTopK(topK);
return this;
}

public Builder withPromptTruncation(PromptTruncation promptTruncation) {
options.setPromptTruncation(promptTruncation);
return this;
}

public Builder withFrequencyPenalty(Float frequencyPenalty) {
options.setFrequencyPenalty(frequencyPenalty);
return this;
}

public Builder withPresencePenalty(Float presencePenalty) {
options.setPresencePenalty(presencePenalty);
return this;
}

public Builder withSeed(Integer seed) {
options.setSeed(seed);
return this;
}

public Builder withReturnPrompt(Boolean returnPrompt) {
options.setReturnPrompt(returnPrompt);
return this;
}

public Builder withStopSequences(List<String> stopSequences) {
options.setStopSequences(stopSequences);
return this;
}

public Builder withRawPrompting(Boolean rawPrompting) {
options.setRawPrompting(rawPrompting);
return this;
}

public BedrockCohereCommandRChatOptions build() {
return this.options;
}

}

public Boolean getSearchQueriesOnly() {
return searchQueriesOnly;
}

public void setSearchQueriesOnly(Boolean searchQueriesOnly) {
this.searchQueriesOnly = searchQueriesOnly;
}

public String getPreamble() {
return preamble;
}

public void setPreamble(String preamble) {
this.preamble = preamble;
}

public Integer getMaxTokens() {
return maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

@Override
public Float getTemperature() {
return temperature;
}

public void setTemperature(Float temperature) {
this.temperature = temperature;
}

@Override
public Float getTopP() {
return topP;
}

public void setTopP(Float topP) {
this.topP = topP;
}

@Override
public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public PromptTruncation getPromptTruncation() {
return promptTruncation;
}

public void setPromptTruncation(PromptTruncation promptTruncation) {
this.promptTruncation = promptTruncation;
}

public Float getFrequencyPenalty() {
return frequencyPenalty;
}

public void setFrequencyPenalty(Float frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

public Float getPresencePenalty() {
return presencePenalty;
}

public void setPresencePenalty(Float presencePenalty) {
this.presencePenalty = presencePenalty;
}

public Integer getSeed() {
return seed;
}

public void setSeed(Integer seed) {
this.seed = seed;
}

public Boolean getReturnPrompt() {
return returnPrompt;
}

public void setReturnPrompt(Boolean returnPrompt) {
this.returnPrompt = returnPrompt;
}

public List<String> getStopSequences() {
return stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

public Boolean getRawPrompting() {
return rawPrompting;
}

public void setRawPrompting(Boolean rawPrompting) {
this.rawPrompting = rawPrompting;
}

}
Loading

0 comments on commit 9a0feaa

Please sign in to comment.