diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 7648cd1196a..b29e5e21e95 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -35,6 +35,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "te EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}" @@ -107,6 +108,10 @@ Global {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.Build.0 = Release|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.Build.0 = Release|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.Build.0 = Debug|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -131,6 +136,7 @@ Global {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 2bd9470ffa7..2925a43e16f 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -19,7 +19,6 @@ namespace AutoGen.OpenAI; /// - /// - /// - -/// - /// - where T is /// - where TMessage1 is and TMessage2 is /// @@ -27,6 +26,11 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa { private bool strictMode = false; + /// + /// Create a new instance of . + /// + /// If true, will throw an + /// When the message type is not supported. If false, it will ignore the unsupported message type. public OpenAIChatRequestMessageConnector(bool strictMode = false) { this.strictMode = strictMode; @@ -36,8 +40,7 @@ public OpenAIChatRequestMessageConnector(bool strictMode = false) public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { - var chatMessages = ProcessIncomingMessages(agent, context.Messages) - .Select(m => new MessageEnvelope(m)); + var chatMessages = ProcessIncomingMessages(agent, context.Messages); var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); @@ -49,8 +52,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var chatMessages = ProcessIncomingMessages(agent, context.Messages) - .Select(m => new MessageEnvelope(m)); + var chatMessages = ProcessIncomingMessages(agent, context.Messages); var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); string? currentToolName = null; await foreach (var reply in streamingReply) @@ -73,7 +75,14 @@ await foreach (var reply in streamingReply) } else { - yield return reply; + if (this.strictMode) + { + throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}"); + } + else + { + yield return reply; + } } } } @@ -82,16 +91,10 @@ public IMessage PostProcessMessage(IMessage message) { return message switch { - TextMessage => message, - ImageMessage => message, - MultiModalMessage => message, - ToolCallMessage => message, - ToolCallResultMessage => message, - Message => message, - AggregateMessage => message, - IMessage m => PostProcessMessage(m), - IMessage m => PostProcessMessage(m), - _ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage, AggregateMessage"), + IMessage m => PostProcessChatResponseMessage(m.Content, m.From), + IMessage m => PostProcessChatCompletions(m), + _ when strictMode is false => message, + _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"), }; } @@ -120,12 +123,7 @@ public IMessage PostProcessMessage(IMessage message) } } - private IMessage PostProcessMessage(IMessage message) - { - return PostProcessMessage(message.Content, message.From); - } - - private IMessage PostProcessMessage(IMessage message) + private IMessage PostProcessChatCompletions(IMessage message) { // throw exception if prompt filter results is not null if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) @@ -133,12 +131,12 @@ private IMessage PostProcessMessage(IMessage message) throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); } - return PostProcessMessage(message.Content.Choices[0].Message, message.From); + return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From); } - private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, string? from) + private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from) { - if (chatResponseMessage.Content is string content) + if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content)) { return new TextMessage(Role.Assistant, content, from); } @@ -162,112 +160,41 @@ private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, str throw new InvalidOperationException("Invalid ChatResponseMessage"); } - public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) { - return messages.SelectMany(m => + return messages.SelectMany(m => { - if (m.From == null) + if (m is IMessage crm) { - return ProcessIncomingMessagesWithEmptyFrom(m); - } - else if (m.From == agent.Name) - { - return ProcessIncomingMessagesForSelf(m); + return [crm]; } else { - return ProcessIncomingMessagesForOther(m); + var chatRequestMessages = m switch + { + TextMessage textMessage => ProcessTextMessage(agent, textMessage), + ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage), + MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage), + ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage), + Message msg => ProcessMessage(agent, msg), + _ when strictMode is false => [], + _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"), + }; + + if (chatRequestMessages.Any()) + { + return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From)); + } + else + { + return [m]; + } } }); } - private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesForSelf(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesForSelf(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesForSelf(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesForSelf(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForSelf(toolCallResultMessage), - Message msg => ProcessIncomingMessagesForSelf(msg), - IMessage crm => ProcessIncomingMessagesForSelf(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesForSelf(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesWithEmptyFrom(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesWithEmptyFrom(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesWithEmptyFrom(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallResultMessage), - Message msg => ProcessIncomingMessagesWithEmptyFrom(msg), - IMessage crm => ProcessIncomingMessagesWithEmptyFrom(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesWithEmptyFrom(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesForOther(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesForOther(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesForOther(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesForOther(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesForOther(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForOther(toolCallResultMessage), - Message msg => ProcessIncomingMessagesForOther(msg), - IMessage crm => ProcessIncomingMessagesForOther(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesForOther(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(TextMessage message) - { - if (message.Role == Role.System) - { - return new[] { new ChatRequestSystemMessage(message.Content) }; - } - else - { - return new[] { new ChatRequestAssistantMessage(message.Content) }; - } - } - - private IEnumerable ProcessIncomingMessagesForSelf(ImageMessage _) - { - return [new ChatRequestAssistantMessage("// Image Message is not supported")]; - } - - private IEnumerable ProcessIncomingMessagesForSelf(MultiModalMessage _) - { - return [new ChatRequestAssistantMessage("// MultiModal Message is not supported")]; - } - - private IEnumerable ProcessIncomingMessagesForSelf(ToolCallMessage message) - { - var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); - var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty); - foreach (var tc in toolCall) - { - chatRequestMessage.ToolCalls.Add(tc); - } - - return new[] { chatRequestMessage }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(ToolCallResultMessage message) - { - return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - } - private IEnumerable ProcessIncomingMessagesForSelf(Message message) { if (message.Role == Role.System) @@ -303,151 +230,145 @@ private IEnumerable ProcessIncomingMessagesForSelf(Message m } } - private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) - { - return new[] { message.Content }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(AggregateMessage aggregateMessage) + private IEnumerable ProcessIncomingMessagesForOther(Message message) { - var toolCallMessage1 = aggregateMessage.Message1; - var toolCallResultMessage = aggregateMessage.Message2; - - var assistantMessage = new ChatRequestAssistantMessage(string.Empty); - var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); - foreach (var tc in toolCalls) + if (message.Role == Role.System) { - assistantMessage.ToolCalls.Add(tc); + return [new ChatRequestSystemMessage(message.Content) { Name = message.From }]; } + else if (message.Content is string content && content is { Length: > 0 }) + { + if (message.FunctionName is not null) + { + return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; + } - var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - - // return assistantMessage and tool call result messages - var messages = new List { assistantMessage }; - messages.AddRange(toolCallResults); - - return messages; + return [new ChatRequestUserMessage(message.Content) { Name = message.From }]; + } + else if (message.FunctionName is string _) + { + return [new ChatRequestUserMessage("// Message type is not supported") { Name = message.From }]; + } + else + { + throw new InvalidOperationException("Invalid Message as message from other."); + } } - private IEnumerable ProcessIncomingMessagesForOther(TextMessage message) + private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message) { if (message.Role == Role.System) { - return new[] { new ChatRequestSystemMessage(message.Content) }; + return [new ChatRequestSystemMessage(message.Content) { Name = message.From }]; + } + + if (agent.Name == message.From) + { + return [new ChatRequestAssistantMessage(message.Content) { Name = agent.Name }]; } else { - return new[] { new ChatRequestUserMessage(message.Content) }; + return message.From switch + { + null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)], + null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage(message.Content)], + null => throw new InvalidOperationException("Invalid Role"), + _ => [new ChatRequestUserMessage(message.Content) { Name = message.From }] + }; } } - private IEnumerable ProcessIncomingMessagesForOther(ImageMessage message) + private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message) { - return new[] { new ChatRequestUserMessage([ - new ChatMessageImageContentItem(new Uri(message.Url ?? message.BuildDataUri())), - ])}; + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent"); + } + + var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message); + return [new ChatRequestUserMessage([imageContentItem]) { Name = message.From }]; } - private IEnumerable ProcessIncomingMessagesForOther(MultiModalMessage message) + private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message) { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent"); + } + IEnumerable items = message.Content.Select(ci => ci switch { TextMessage text => new ChatMessageTextContentItem(text.Content), - ImageMessage image => new ChatMessageImageContentItem(new Uri(image.Url ?? image.BuildDataUri())), + ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image), _ => throw new NotImplementedException(), }); - return new[] { new ChatRequestUserMessage(items) }; + return [new ChatRequestUserMessage(items) { Name = message.From }]; } - private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage msg) + private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message) { - throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); + return message.Data is null + ? new ChatMessageImageContentItem(new Uri(message.Url)) + : new ChatMessageImageContentItem(message.Data, message.Data.MediaType); } - private IEnumerable ProcessIncomingMessagesForOther(ToolCallResultMessage message) + private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message) { - return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - } - - private IEnumerable ProcessIncomingMessagesForOther(Message message) - { - if (message.Role == Role.System) + if (message.From is not null && message.From != agent.Name) { - return new[] { new ChatRequestSystemMessage(message.Content) }; + throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); } - else if (message.Content is string content && content is { Length: > 0 }) - { - if (message.FunctionName is not null) - { - return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; - } - return new[] { new ChatRequestUserMessage(message.Content) }; - } - else if (message.FunctionName is string _) - { - return new[] - { - new ChatRequestUserMessage("// Message type is not supported"), - }; - } - else + var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From }; + foreach (var tc in toolCall) { - throw new InvalidOperationException("Invalid Message as message from other."); + chatRequestMessage.ToolCalls.Add(tc); } - } - private IEnumerable ProcessIncomingMessagesForOther(IMessage message) - { - return new[] { message.Content }; - } - - private IEnumerable ProcessIncomingMessagesForOther(AggregateMessage aggregateMessage) - { - // convert as user message - var resultMessage = aggregateMessage.Message2; - - return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + return [chatRequestMessage]; } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(TextMessage message) + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message) { - return ProcessIncomingMessagesForOther(message); + return message.ToolCalls + .Where(tc => tc.Result is not null) + .Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ImageMessage message) + private IEnumerable ProcessMessage(IAgent agent, Message message) { - return ProcessIncomingMessagesForOther(message); - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(MultiModalMessage message) - { - return ProcessIncomingMessagesForOther(message); - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallMessage message) - { - return ProcessIncomingMessagesForSelf(message); + if (message.From is not null && message.From != agent.Name) + { + return ProcessIncomingMessagesForOther(message); + } + else + { + return ProcessIncomingMessagesForSelf(message); + } } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallResultMessage message) + private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage) { - return ProcessIncomingMessagesForOther(message); - } + if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name) + { + // convert as user message + var resultMessage = aggregateMessage.Message2; - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(Message message) - { - return ProcessIncomingMessagesForOther(message); - } + return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result) { Name = aggregateMessage.From }); + } + else + { + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) - { - return new[] { message.Content }; - } + var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1); + var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage); - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(AggregateMessage aggregateMessage) - { - return ProcessIncomingMessagesForOther(aggregateMessage); + return assistantMessage.Concat(toolCallResults); + } } } diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt similarity index 94% rename from dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt rename to dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index 2cb58f4d88c..d17de56e129 100644 --- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -3,6 +3,7 @@ "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )", "ConvertedMessages": [ { + "Name": null, "Role": "system", "Content": "You are a helpful AI assistant" } @@ -14,6 +15,7 @@ { "Role": "user", "Content": "Hello", + "Name": "user", "MultiModaItem": null } ] @@ -24,6 +26,7 @@ { "Role": "assistant", "Content": "How can I help you?", + "Name": "assistant", "TooCall": [], "FunctionCallName": null, "FunctionCallArguments": null @@ -34,6 +37,7 @@ "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )", "ConvertedMessages": [ { + "Name": null, "Role": "system", "Content": "You are a helpful AI assistant" } @@ -45,6 +49,7 @@ { "Role": "user", "Content": "Hello", + "Name": "user", "MultiModaItem": null } ] @@ -55,6 +60,7 @@ { "Role": "assistant", "Content": "How can I help you?", + "Name": null, "TooCall": [], "FunctionCallName": null, "FunctionCallArguments": null @@ -67,6 +73,7 @@ { "Role": "user", "Content": "result", + "Name": "user", "MultiModaItem": null } ] @@ -77,6 +84,7 @@ { "Role": "assistant", "Content": null, + "Name": null, "TooCall": [], "FunctionCallName": "functionName", "FunctionCallArguments": "functionArguments" @@ -89,6 +97,7 @@ { "Role": "user", "Content": null, + "Name": "user", "MultiModaItem": [ { "Type": "Image", @@ -107,6 +116,7 @@ { "Role": "user", "Content": null, + "Name": "user", "MultiModaItem": [ { "Type": "Text", @@ -129,6 +139,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", @@ -173,6 +184,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", @@ -198,6 +210,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj new file mode 100644 index 00000000000..044975354b8 --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj @@ -0,0 +1,32 @@ + + + + $(TestTargetFramework) + false + True + + + + + + + + + + + + + + + + + + + + $([System.String]::Copy('%(FileName)').Split('.')[0]) + $(ProjectExt.Replace('proj', '')) + %(ParentFile)%(ParentExtension) + + + + diff --git a/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs new file mode 100644 index 00000000000..d66bf001ed5 --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GlobalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs similarity index 100% rename from dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs rename to dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs new file mode 100644 index 00000000000..a8c1d3f7860 --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs @@ -0,0 +1,612 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIMessageTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using AutoGen.OpenAI; +using Azure.AI.OpenAI; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Tests; + +public class OpenAIMessageTests +{ + private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { + WriteIndented = true, + IgnoreReadOnlyProperties = false, + }; + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("ApprovalTests")] + public void BasicMessageTest() + { + IMessage[] messages = [ + new TextMessage(Role.System, "You are a helpful AI assistant"), + new TextMessage(Role.User, "Hello", "user"), + new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.System, "You are a helpful AI assistant"), + new Message(Role.User, "Hello", "user"), + new Message(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.Function, "result", "user"), + new Message(Role.Assistant, null, "assistant") + { + FunctionName = "functionName", + FunctionArguments = "functionArguments", + }, + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"), + new ToolCallMessage("test", "test", "assistant"), + new ToolCallResultMessage("result", "test", "test", "user"), + new ToolCallResultMessage( + [ + new ToolCall("result", "test", "test"), + new ToolCall("result", "test", "test"), + ], "user"), + new ToolCallMessage( + [ + new ToolCall("test", "test"), + new ToolCall("test", "test"), + ], "assistant"), + new AggregateMessage( + message1: new ToolCallMessage("test", "test", "assistant"), + message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), + ]; + var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant"); + + var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); + VerifyOAIMessages(oaiMessages); + } + + [Fact] + public async Task ItProcessUserTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("Hello"); + chatRequestMessage.Name.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new TextMessage(Role.User, "Hello", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItShortcutChatRequestMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = new ChatRequestUserMessage("hello"); + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItShortcutMessageWhenStrictModelIsFalseAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItThrowExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1"); + } + + [Fact] + public async Task ItProcessAssistantTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("How can I help you?"); + chatRequestMessage.Name.Should().Be("assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // assistant message + IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessSystemTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("You are a helpful AI assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // system message + IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessImageMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("user"); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(1); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant"); + Func action = async () => await agent.GenerateReplyAsync([imageMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage"); + } + + [Fact] + public async Task ItProcessMultiModalMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("user"); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(2); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new MultiModalMessage( + Role.User, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var multiModalMessage = new MultiModalMessage( + Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "assistant"), + new ImageMessage(Role.User, "https://example.com/image.png", "assistant"), + ], "assistant"); + + Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage"); + } + + [Fact] + public async Task ItProcessToolCallMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("assistant"); + chatRequestMessage.ToolCalls.Count().Should().Be(1); + chatRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallMessage("test", "test", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(strictMode: true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]); + await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage"); + } + + [Fact] + public async Task ItProcessToolCallResultMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallResultMessage("result", "test", "test", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.Name.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(1); + toolCallRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "assistant"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItConvertChatResponseMessageToTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = CreateInstance(ChatRole.Assistant, "hello"); + var chatRequestMessage = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetContent().Should().Be("hello"); + message.GetRole().Should().Be(Role.Assistant); + } + + [Fact] + public async Task ItConvertChatResponseMessageToToolCallMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // tool call message + var toolCallMessage = CreateInstance(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary()); + var chatRequestMessage = MessageEnvelope.Create(toolCallMessage); + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetToolCalls()!.Count().Should().Be(1); + message.GetToolCalls()!.First().FunctionName.Should().Be("test"); + message.GetToolCalls()!.First().FunctionArguments.Should().Be("test"); + } + + [Fact] + public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = "hello"; + var messageToSend = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([messageToSend]); + message.Should().BeOfType>(); + } + + [Fact] + public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = new ChatRequestUserMessage("hello"); + var messageToSend = MessageEnvelope.Create(textMessage); + Func action = async () => await agent.GenerateReplyAsync([messageToSend]); + + await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1"); + } + + [Fact] + public void ToOpenAIChatRequestMessageShortCircuitTest() + { + var agent = new EchoAgent("assistant"); + var middleware = new OpenAIChatRequestMessageConnector(); + ChatRequestMessage[] messages = + [ + new ChatRequestUserMessage("Hello"), + new ChatRequestAssistantMessage("How can I help you?"), + new ChatRequestSystemMessage("You are a helpful AI assistant"), + new ChatRequestFunctionMessage("result", "functionName"), + new ChatRequestToolMessage("test", "test"), + ]; + + foreach (var oaiMessage in messages) + { + IMessage message = new MessageEnvelope(oaiMessage); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); + oaiMessages.Count().Should().Be(1); + //oaiMessages.First().Should().BeOfType>(); + if (oaiMessages.First() is IMessage chatRequestMessage) + { + chatRequestMessage.Content.Should().Be(oaiMessage); + } + else + { + // fail the test + Assert.True(false); + } + } + } + private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) + { + var jsonObjects = messages.Select(pair => + { + var (originalMessage, ms) = pair; + var objs = new List(); + foreach (var m in ms) + { + object? obj = null; + var chatRequestMessage = (m as IMessage)?.Content; + if (chatRequestMessage is ChatRequestUserMessage userMessage) + { + obj = new + { + Role = userMessage.Role.ToString(), + Content = userMessage.Content, + Name = userMessage.Name, + MultiModaItem = userMessage.MultimodalContentItems?.Select(item => + { + return item switch + { + ChatMessageImageContentItem imageContentItem => new + { + Type = "Image", + ImageUrl = GetImageUrlFromContent(imageContentItem), + } as object, + ChatMessageTextContentItem textContentItem => new + { + Type = "Text", + Text = textContentItem.Text, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + }; + } + + if (chatRequestMessage is ChatRequestAssistantMessage assistantMessage) + { + obj = new + { + Role = assistantMessage.Role.ToString(), + Content = assistantMessage.Content, + Name = assistantMessage.Name, + TooCall = assistantMessage.ToolCalls.Select(tc => + { + return tc switch + { + ChatCompletionsFunctionToolCall functionToolCall => new + { + Type = "Function", + Name = functionToolCall.Name, + Arguments = functionToolCall.Arguments, + Id = functionToolCall.Id, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + FunctionCallName = assistantMessage.FunctionCall?.Name, + FunctionCallArguments = assistantMessage.FunctionCall?.Arguments, + }; + } + + if (chatRequestMessage is ChatRequestSystemMessage systemMessage) + { + obj = new + { + Name = systemMessage.Name, + Role = systemMessage.Role.ToString(), + Content = systemMessage.Content, + }; + } + + if (chatRequestMessage is ChatRequestFunctionMessage functionMessage) + { + obj = new + { + Role = functionMessage.Role.ToString(), + Content = functionMessage.Content, + Name = functionMessage.Name, + }; + } + + if (chatRequestMessage is ChatRequestToolMessage toolCallMessage) + { + obj = new + { + Role = toolCallMessage.Role.ToString(), + Content = toolCallMessage.Content, + ToolCallId = toolCallMessage.ToolCallId, + }; + } + + objs.Add(obj ?? throw new System.NotImplementedException()); + } + + return new + { + OriginalMessage = originalMessage.ToString(), + ConvertedMessages = objs, + }; + }); + + var json = JsonSerializer.Serialize(jsonObjects, this.jsonSerializerOptions); + Approvals.Verify(json); + } + + private object? GetImageUrlFromContent(ChatMessageImageContentItem content) + { + return content.GetType().GetProperty("ImageUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.GetValue(content); + } + + private static T CreateInstance(params object[] args) + { + var type = typeof(T); + var instance = type.Assembly.CreateInstance( + type.FullName!, false, + BindingFlags.Instance | BindingFlags.NonPublic, + null, args, null, null); + return (T)instance!; + } +} diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs index 28a7b91bad5..9cead5ad251 100644 --- a/dotnet/test/AutoGen.Tests/EchoAgent.cs +++ b/dotnet/test/AutoGen.Tests/EchoAgent.cs @@ -3,12 +3,13 @@ using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace AutoGen.Tests { - internal class EchoAgent : IAgent + public class EchoAgent : IStreamingAgent { public EchoAgent(string name) { @@ -27,5 +28,14 @@ public EchoAgent(string name) return Task.FromResult(lastMessage); } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var message in messages) + { + message.From = this.Name; + yield return message; + } + } } } diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs deleted file mode 100644 index 6e9cd28c4cb..00000000000 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// OpenAIMessageTests.cs - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using ApprovalTests; -using ApprovalTests.Namers; -using ApprovalTests.Reporters; -using AutoGen.OpenAI; -using Azure.AI.OpenAI; -using FluentAssertions; -using Xunit; - -namespace AutoGen.Tests; - -public class OpenAIMessageTests -{ - private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions - { - WriteIndented = true, - IgnoreReadOnlyProperties = false, - }; - - [Fact] - [UseReporter(typeof(DiffReporter))] - [UseApprovalSubdirectory("ApprovalTests")] - public void BasicMessageTest() - { - IMessage[] messages = [ - new TextMessage(Role.System, "You are a helpful AI assistant"), - new TextMessage(Role.User, "Hello", "user"), - new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"), - new Message(Role.System, "You are a helpful AI assistant"), - new Message(Role.User, "Hello", "user"), - new Message(Role.Assistant, "How can I help you?", from: "assistant"), - new Message(Role.Function, "result", "user"), - new Message(Role.Assistant, null, "assistant") - { - FunctionName = "functionName", - FunctionArguments = "functionArguments", - }, - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - new MultiModalMessage(Role.Assistant, - [ - new TextMessage(Role.User, "Hello", "user"), - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - ], "user"), - new ToolCallMessage("test", "test", "assistant"), - new ToolCallResultMessage("result", "test", "test", "user"), - new ToolCallResultMessage( - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], "user"), - new ToolCallMessage( - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], "assistant"), - new AggregateMessage( - message1: new ToolCallMessage("test", "test", "assistant"), - message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), - ]; - var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector(); - var agent = new EchoAgent("assistant"); - - var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); - VerifyOAIMessages(oaiMessages); - } - - [Fact] - public void ToOpenAIChatRequestMessageTest() - { - var agent = new EchoAgent("assistant"); - var middleware = new OpenAIChatRequestMessageConnector(); - - // user message - IMessage message = new TextMessage(Role.User, "Hello", "user"); - var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("Hello"); - - // user message test 2 - // even if Role is assistant, it should be converted to user message because it is from the user - message = new TextMessage(Role.Assistant, "Hello", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("Hello"); - - // user message with multimodal content - // image - message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().BeNullOrEmpty(); - userMessage.MultimodalContentItems.Count().Should().Be(1); - userMessage.MultimodalContentItems.First().Should().BeOfType(); - - // text and image - message = new MultiModalMessage( - Role.User, - [ - new TextMessage(Role.User, "Hello", "user"), - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - ], "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().BeNullOrEmpty(); - userMessage.MultimodalContentItems.Count().Should().Be(2); - userMessage.MultimodalContentItems.First().Should().BeOfType(); - - // assistant text message - message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().Be("How can I help you?"); - - // assistant text message with single tool call - message = new ToolCallMessage("test", "test", "assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(1); - assistantMessage.ToolCalls.First().Should().BeOfType(); - - // user should not suppose to send tool call message - message = new ToolCallMessage("test", "test", "user"); - Func action = () => middleware.ProcessIncomingMessages(agent, [message]).First(); - action.Should().Throw().WithMessage("ToolCallMessage is not supported when message.From is not the same with agent"); - - // assistant text message with multiple tool calls - message = new ToolCallMessage( - toolCalls: - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], "assistant"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(2); - - // tool call result message - message = new ToolCallResultMessage("result", "test", "test", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); - toolCallMessage.Content.Should().Be("result"); - - // tool call result message with multiple tool calls - message = new ToolCallResultMessage( - toolCalls: - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); - toolCallMessage.Content.Should().Be("test"); - oaiMessages.Last().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); - toolCallMessage.Content.Should().Be("test"); - - // aggregate message test - // aggregate message with tool call and tool call result will be returned by GPT agent if the tool call is automatically invoked inside agent - message = new AggregateMessage( - message1: new ToolCallMessage("test", "test", "assistant"), - message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(1); - - oaiMessages.Last().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); - toolCallMessage.Content.Should().Be("result"); - - // aggregate message test 2 - // if the aggregate message is from user, it should be converted to user message - message = new AggregateMessage( - message1: new ToolCallMessage("test", "test", "user"), - message2: new ToolCallResultMessage("result", "test", "test", "user"), "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("result"); - - // aggregate message test 3 - // if the aggregate message is from user and contains multiple tool call results, it should be converted to user message - message = new AggregateMessage( - message1: new ToolCallMessage( - toolCalls: - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], from: "user"), - message2: new ToolCallResultMessage( - toolCalls: - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], from: "user"), "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - oaiMessages.Last().Should().BeOfType(); - - // system message - message = new TextMessage(Role.System, "You are a helpful AI assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - } - - [Fact] - public void ToOpenAIChatRequestMessageShortCircuitTest() - { - var agent = new EchoAgent("assistant"); - var middleware = new OpenAIChatRequestMessageConnector(); - ChatRequestMessage[] messages = - [ - new ChatRequestUserMessage("Hello"), - new ChatRequestAssistantMessage("How can I help you?"), - new ChatRequestSystemMessage("You are a helpful AI assistant"), - new ChatRequestFunctionMessage("result", "functionName"), - new ChatRequestToolMessage("test", "test"), - ]; - - foreach (var oaiMessage in messages) - { - IMessage message = new MessageEnvelope(oaiMessage); - var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().Be(oaiMessage); - } - } - private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) - { - var jsonObjects = messages.Select(pair => - { - var (originalMessage, ms) = pair; - var objs = new List(); - foreach (var m in ms) - { - object? obj = null; - if (m is ChatRequestUserMessage userMessage) - { - obj = new - { - Role = userMessage.Role.ToString(), - Content = userMessage.Content, - MultiModaItem = userMessage.MultimodalContentItems?.Select(item => - { - return item switch - { - ChatMessageImageContentItem imageContentItem => new - { - Type = "Image", - ImageUrl = GetImageUrlFromContent(imageContentItem), - } as object, - ChatMessageTextContentItem textContentItem => new - { - Type = "Text", - Text = textContentItem.Text, - } as object, - _ => throw new System.NotImplementedException(), - }; - }), - }; - } - - if (m is ChatRequestAssistantMessage assistantMessage) - { - obj = new - { - Role = assistantMessage.Role.ToString(), - Content = assistantMessage.Content, - TooCall = assistantMessage.ToolCalls.Select(tc => - { - return tc switch - { - ChatCompletionsFunctionToolCall functionToolCall => new - { - Type = "Function", - Name = functionToolCall.Name, - Arguments = functionToolCall.Arguments, - Id = functionToolCall.Id, - } as object, - _ => throw new System.NotImplementedException(), - }; - }), - FunctionCallName = assistantMessage.FunctionCall?.Name, - FunctionCallArguments = assistantMessage.FunctionCall?.Arguments, - }; - } - - if (m is ChatRequestSystemMessage systemMessage) - { - obj = new - { - Role = systemMessage.Role.ToString(), - Content = systemMessage.Content, - }; - } - - if (m is ChatRequestFunctionMessage functionMessage) - { - obj = new - { - Role = functionMessage.Role.ToString(), - Content = functionMessage.Content, - Name = functionMessage.Name, - }; - } - - if (m is ChatRequestToolMessage toolCallMessage) - { - obj = new - { - Role = toolCallMessage.Role.ToString(), - Content = toolCallMessage.Content, - ToolCallId = toolCallMessage.ToolCallId, - }; - } - - objs.Add(obj ?? throw new System.NotImplementedException()); - } - - return new - { - OriginalMessage = originalMessage.ToString(), - ConvertedMessages = objs, - }; - }); - - var json = JsonSerializer.Serialize(jsonObjects, this.jsonSerializerOptions); - Approvals.Verify(json); - } - - private object? GetImageUrlFromContent(ChatMessageImageContentItem content) - { - return content.GetType().GetProperty("ImageUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.GetValue(content); - } -}