Skip to content

Commit

Permalink
Allow to programmatically define system and user messages in the AiSe…
Browse files Browse the repository at this point in the history
…rvices (#862)

This pull request is a variation and simplification of what I proposed
[here](#861). As
anticipated there instead of making the AiServices aware of the concept
of a state machine I simply added the possibility of making the system
and user messages configurable in a programmatic way. Currently this
pull request has the same limitation of the other one: lack of the
possibility of configuring messages for each user/conversation and of
support of tools, but I believe that both things could be added once we
agree on the main direction.
  • Loading branch information
mariofusco committed Apr 10, 2024
1 parent dcbc562 commit 9c69d96
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 56 deletions.
7 changes: 7 additions & 0 deletions langchain4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
<artifactId>slf4j-api</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

public class AiServiceContext {

private static final Function<Object, Optional<String>> DEFAULT_MESSAGE_PROVIDER = x -> Optional.empty();

public final Class<?> aiServiceClass;

public ChatLanguageModel chatModel;
Expand All @@ -29,6 +33,10 @@ public class AiServiceContext {

public RetrievalAugmentor retrievalAugmentor;

public Function<Object, Optional<String>> userMessagesProvider = DEFAULT_MESSAGE_PROVIDER;

public Function<Object, Optional<String>> systemMessagesProvider = DEFAULT_MESSAGE_PROVIDER;

public AiServiceContext(Class<?> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
}
Expand All @@ -40,4 +48,12 @@ public boolean hasChatMemory() {
public ChatMemory chatMemory(Object memoryId) {
return chatMemories.computeIfAbsent(memoryId, ignored -> chatMemoryProvider.get(memoryId));
}

public boolean hasUserMessagesProvider() {
return userMessagesProvider != DEFAULT_MESSAGE_PROVIDER;
}

public boolean hasSystemMessagesProvider() {
return systemMessagesProvider != DEFAULT_MESSAGE_PROVIDER;
}
}
28 changes: 21 additions & 7 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
package dev.langchain4j.service;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Function;

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
Expand All @@ -14,20 +25,13 @@
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.spi.services.AiServicesFactory;

import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
Expand Down Expand Up @@ -171,6 +175,16 @@ public static <T> AiServices<T> builder(Class<T> aiService) {
return new DefaultAiServices<>(context);
}

public AiServices<T> systemMessageProvider(Function<Object, String> systemMessageProvider) {
context.systemMessagesProvider = systemMessageProvider.andThen(Optional::ofNullable);
return this;
}

public AiServices<T> userMessageProvider(Function<Object, String> userMessageProvider) {
context.userMessagesProvider = userMessageProvider.andThen(Optional::ofNullable);
return this;
}

/**
* Configures chat model that will be used under the hood of the AI Service.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

validateParameters(method);

Optional<SystemMessage> systemMessage = prepareSystemMessage(method, args);
UserMessage userMessage = prepareUserMessage(method, args);

Object memoryId = memoryId(method, args).orElse(DEFAULT);

Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(memoryId, method, args);

if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
Expand Down Expand Up @@ -193,72 +193,65 @@ private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMes
return (T) proxyInstance;
}

private Optional<SystemMessage> prepareSystemMessage(Method method, Object[] args) {

Parameter[] parameters = method.getParameters();
Map<String, Object> variables = getPromptTemplateVariables(args, parameters);
private Optional<SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args) {
return prepareSystemMessageTemplate(memoryId, method)
.map(template -> PromptTemplate.from(template)
.apply(getPromptTemplateVariables(args, method.getParameters()))
.toSystemMessage());
}

private Optional<String> prepareSystemMessageTemplate(Object memoryId, Method method) {
dev.langchain4j.service.SystemMessage annotation = method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
if (annotation != null) {

String systemMessageTemplate = getPromptText(
if (context.hasSystemMessagesProvider()) {
throw illegalConfiguration(
"Error: the system message has been configured both via annotation and using the systemMessagesProvider. " +
"Please choose only one of the two.");
}

return Optional.of(getPromptText(
method,
"System",
annotation.fromResource(),
annotation.value(),
annotation.delimiter()
);

Prompt prompt = PromptTemplate.from(systemMessageTemplate).apply(variables);
return Optional.of(prompt.toSystemMessage());
));
}

return Optional.empty();
return context.systemMessagesProvider.apply(memoryId);
}

private static UserMessage prepareUserMessage(Method method, Object[] args) {
private UserMessage prepareUserMessage(Object memoryId, Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
Map<String, Object> variables = getPromptTemplateVariables(args, parameters);


String userName = getUserName(parameters, args);

dev.langchain4j.service.UserMessage annotation = method.getAnnotation(dev.langchain4j.service.UserMessage.class);
if (annotation != null) {
String userMessageTemplate = getPromptText(
method,
"User",
annotation.fromResource(),
annotation.value(),
annotation.delimiter()
);

if (userMessageTemplate.contains("{{it}}")) {
if (parameters.length != 1) {
throw illegalConfiguration("Error: The {{it}} placeholder is present but the method does not have exactly one parameter. " +
"Please ensure that methods using the {{it}} placeholder have exactly one parameter.");
}
return prepareUserMessageTemplate(memoryId, method)
.map(template -> prepareUserMessageFromTemplate(args, template, parameters, getPromptTemplateVariables(args, parameters), userName))
.orElse(prepareUserMessage(args, parameters, userName));
}

variables = singletonMap("it", toString(args[0]));
}
private UserMessage prepareUserMessageFromTemplate(Object[] args, String userMessageTemplate, Parameter[] parameters, Map<String, Object> variables, String userName) {
if (userMessageTemplate.contains("{{it}}")) {
String it = parameters.length == 1 ? toString(args[0]) :
findUserMessageFromAnnotation(args, parameters).orElseThrow(() -> illegalConfiguration(

This comment has been minimized.

Copy link
@daixi98

daixi98 Apr 12, 2024

Contributor

If I understand it correctly, this will replace {{it}} with the parameter annotated by @UserMessage. Is it expected?

This comment has been minimized.

Copy link
@langchain4j

langchain4j Apr 12, 2024

Owner

Could you please elaborate?

"Error: The {{it}} placeholder is present but the method does not have exactly one parameter. " +
"Please ensure that methods using the {{it}} placeholder have exactly one parameter."));
variables = singletonMap("it", it);
}

Prompt prompt = PromptTemplate.from(userMessageTemplate).apply(variables);
if (userName != null) {
return userMessage(userName, prompt.text());
} else {
return prompt.toUserMessage();
}
Prompt prompt = PromptTemplate.from(userMessageTemplate).apply(variables);
if (userName != null) {
return userMessage(userName, prompt.text());
} else {
return prompt.toUserMessage();
}
}

for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) {
String text = toString(args[i]);
if (userName != null) {
return userMessage(userName, text);
} else {
return userMessage(text);
}
}
private UserMessage prepareUserMessage(Object[] args, Parameter[] parameters, String userName) {
Optional<String> userMessage = findUserMessageFromAnnotation(args, parameters);
if (userMessage.isPresent()) {
return userName != null ? userMessage(userName, userMessage.get()) : userMessage(userMessage.get());
}

if (args == null || args.length == 0) {
Expand All @@ -277,6 +270,37 @@ private static UserMessage prepareUserMessage(Method method, Object[] args) {
throw illegalConfiguration("For methods with multiple parameters, each parameter must be annotated with @V, @UserMessage, @UserName or @MemoryId");
}

private Optional<String> findUserMessageFromAnnotation(Object[] args, Parameter[] parameters) {
for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) {
return Optional.of(toString(args[i]));
}
}
return Optional.empty();
}

private Optional<String> prepareUserMessageTemplate(Object memoryId, Method method) {
dev.langchain4j.service.UserMessage annotation = method.getAnnotation(dev.langchain4j.service.UserMessage.class);
if (annotation != null) {

if (context.hasUserMessagesProvider()) {
throw illegalConfiguration(
"Error: the user message has been configured both via annotation and using the userMessagesProvider. " +
"Please choose only one of the two.");
}

return Optional.of(getPromptText(
method,
"User",
annotation.fromResource(),
annotation.value(),
annotation.delimiter()
));
}

return context.userMessagesProvider.apply(memoryId);
}

private static String getPromptText(Method method, String type, String resource, String[] value, String delimiter) {
String messageTemplate;
if (!resource.trim().isEmpty()) {
Expand Down

0 comments on commit 9c69d96

Please sign in to comment.