Skip to content

Commit

Permalink
fix bug: AiMessage text content is not copied when toolCalls are pres…
Browse files Browse the repository at this point in the history
  • Loading branch information
hrhrng committed May 8, 2024
1 parent e36ef57 commit 4267df9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.output.TokenUsage;

import java.util.Collection;
import java.util.Collections;
import java.util.List;

import static dev.ai4j.openai4j.chat.ContentType.IMAGE_URL;
Expand All @@ -25,6 +26,7 @@
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.lang.String.format;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

public class InternalOpenAiHelper {
Expand Down Expand Up @@ -204,14 +206,17 @@ private static dev.ai4j.openai4j.chat.Parameters toOpenAiParameters(ToolParamete

public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
AssistantMessage assistantMessage = response.choices().get(0).message();
String text = assistantMessage.content();

List<ToolCall> toolCalls = assistantMessage.toolCalls();
if (!isNullOrEmpty(toolCalls)) {
List<ToolExecutionRequest> toolExecutionRequests = toolCalls.stream()
.filter(toolCall -> toolCall.type() == FUNCTION)
.map(InternalOpenAiHelper::toToolExecutionRequest)
.collect(toList());
return aiMessage(toolExecutionRequests);
return isNullOrEmpty(text) ?
aiMessage(toolExecutionRequests) :
new AiMessage(text, toolExecutionRequests);
}

FunctionCall functionCall = assistantMessage.functionCall();
Expand All @@ -220,10 +225,12 @@ public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
.name(functionCall.name())
.arguments(functionCall.arguments())
.build();
return aiMessage(toolExecutionRequest);
return isNullOrEmpty(text) ?
aiMessage(toolExecutionRequest) :
new AiMessage(text, singletonList(toolExecutionRequest));
}

return aiMessage(assistantMessage.content());
return aiMessage(text);
}

private static ToolExecutionRequest toToolExecutionRequest(ToolCall toolCall) {
Expand Down
Expand Up @@ -47,7 +47,6 @@ void should_return_ai_message_with_toolExecutionRequests_when_function_is_presen
ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content("unexpected text")
.functionCall(FunctionCall.builder()
.name(functionName)
.arguments(functionArguments)
Expand All @@ -60,7 +59,6 @@ void should_return_ai_message_with_toolExecutionRequests_when_function_is_presen
AiMessage aiMessage = aiMessageFrom(response);

// then
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
Expand All @@ -79,7 +77,6 @@ void should_return_ai_message_with_toolExecutionRequests_when_tool_calls_are_pre
ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content("unexpected text")
.toolCalls(ToolCall.builder()
.type(FUNCTION)
.function(FunctionCall.builder()
Expand All @@ -95,7 +92,41 @@ void should_return_ai_message_with_toolExecutionRequests_when_tool_calls_are_pre
AiMessage aiMessage = aiMessageFrom(response);

// then
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
.arguments(functionArguments)
.build()
);
}

@Test
void should_return_ai_message_with_toolExecutionRequests_and_text_when_tool_calls_and_content_are_both_present() {

// given
String functionName = "current_time";
String functionArguments = "{}";

ChatCompletionResponse response = ChatCompletionResponse.builder()
.choices(singletonList(ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.content("Hello")
.toolCalls(ToolCall.builder()
.type(FUNCTION)
.function(FunctionCall.builder()
.name(functionName)
.arguments(functionArguments)
.build())
.build())
.build())
.build()))
.build();

// when
AiMessage aiMessage = aiMessageFrom(response);

// then
assertThat(aiMessage.text()).isEqualTo("Hello");
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
Expand Down

0 comments on commit 4267df9

Please sign in to comment.