Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net - DESIGN REVIEW: Agent Chat Filters #5901

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
629b0b5
Checkpoint
crickman Apr 17, 2024
79431fc
Typos
crickman Apr 17, 2024
0d8b4c2
Cleanup
crickman Apr 17, 2024
ba6616a
Comment
crickman Apr 17, 2024
b628c07
Merge branch 'main' into feature_agent_filters
crickman Apr 17, 2024
86c6fe4
Merge branch 'main' into feature_agent_filters
crickman Apr 17, 2024
a5cc2a7
Merge branch 'main' into feature_agent_filters
crickman Apr 17, 2024
fc48d03
Merge branch 'feature_agent_filters' of https://github.com/microsoft/…
crickman Apr 17, 2024
b2a43d3
Refactor merge
crickman Apr 17, 2024
274b61f
Tuned
crickman Apr 18, 2024
3fdfbbe
Merge from main
crickman Apr 22, 2024
5fbd41b
Example
crickman Apr 22, 2024
7cf233d
Warning
crickman Apr 22, 2024
57e9329
Merge branch 'main' into feature_agent_filters
crickman Apr 23, 2024
98d0f38
Merge branch 'main' into feature_agent_filters
crickman Apr 23, 2024
783aa33
Merge branch 'feature_agent_filters' of https://github.com/microsoft/…
crickman Apr 23, 2024
9bf0ad7
Merge from main
crickman Apr 24, 2024
eb55834
Coverage
crickman Apr 24, 2024
7c99c5a
Branch coverage
crickman Apr 24, 2024
30697bf
Merge branch 'main' into feature_agent_filters
crickman Apr 24, 2024
e41e19b
Merge branch 'main' into feature_agent_filters
crickman Apr 25, 2024
59608b2
Merge branch 'main' into feature_agent_filters
crickman Apr 25, 2024
8a96c30
Merge branch 'main' into feature_agent_filters
crickman Apr 26, 2024
31290e9
Merge branch 'main' into feature_agent_filters
crickman Apr 30, 2024
f03cfe2
Fix merge
crickman Apr 30, 2024
6b783d4
Merge branch 'main' into feature_agent_filters
crickman Apr 30, 2024
e90656c
Merge branch 'main' into feature_agent_filters
crickman May 1, 2024
bf8dc66
Merge branch 'main' into feature_agent_filters
crickman May 1, 2024
f66594e
Merge branch 'main' into feature_agent_filters
crickman May 1, 2024
be782e0
Merge branch 'main' into feature_agent_filters
crickman May 1, 2024
4505326
Merge branch 'main' into feature_agent_filters
crickman May 2, 2024
9a9b743
Merge from main
crickman May 2, 2024
18165bf
Fix assignment
crickman May 2, 2024
4fdaeb2
Merge branch 'main' into feature_agent_filters
crickman May 6, 2024
e0fe9ae
Merge branch 'main' into feature_agent_filters
crickman May 7, 2024
e3841b7
Merge branch 'main' into feature_agent_filters
crickman May 8, 2024
5f161d4
Merge branch 'main' into feature_agent_filters
crickman May 30, 2024
ef3e791
Merge branch 'main' into feature_agent_filters
crickman Jun 12, 2024
75927d8
Merge branch 'main' into feature_agent_filters
crickman Jun 12, 2024
55cee9d
Merge branch 'main' into feature_agent_filters
crickman Jun 14, 2024
e7e3a01
Merge branch 'main' into feature_agent_filters
crickman Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 107 additions & 0 deletions dotnet/samples/GettingStartedWithAgents/Step6_Filters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.Chat;
using Microsoft.SemanticKernel.Agents.Filters;
using Microsoft.SemanticKernel.ChatCompletion;

namespace GettingStarted;

/// <summary>
/// Demonstrate usage of <see cref="AgentChat"/> with <see cref="IAgentChatFilter"/>.
/// </summary>
public class Step6_Filters(ITestOutputHelper output) : BaseTest(output)
{
private const string ReviewerName = "ArtDirector";
private const string ReviewerInstructions =
"""
You are an art director who has opinions about copywriting born of a love for David Ogilvy.
The goal is to determine is the given copy is acceptable to print.
If so, state that it is approved.
If not, provide insight on how to refine suggested copy without example.
""";

private const string CopyWriterName = "Writer";
private const string CopyWriterInstructions =
"""
You are a copywriter with ten years of experience and are known for brevity and a dry humor.
You're laser focused on the goal at hand. Don't waste time with chit chat.
The goal is to refine and decide on the single best copy as an expert in the field.
Consider suggestions when refining an idea.
""";

[Fact]
public async Task RunAsync()
{
// Define the agents
ChatCompletionAgent agentReviewer =
new()
{
Instructions = ReviewerInstructions,
Name = ReviewerName,
Kernel = this.CreateKernelWithChatCompletion(),
};

ChatCompletionAgent agentWriter =
new()
{
Instructions = CopyWriterInstructions,
Name = CopyWriterName,
Kernel = this.CreateKernelWithChatCompletion(),
};

// Create a chat for agent interaction.
AgentGroupChat chat =
new(agentWriter, agentReviewer)
{
ExecutionSettings =
new()
{
// Here a TerminationStrategy subclass is used that will terminate when
// an assistant message contains the term "approve".
TerminationStrategy =
new ApprovalTerminationStrategy()
{
// Only the art-director may approve.
Agents = [agentReviewer],
}
},
Filters =
{
new ExampleChatFilter(base.Output)
}
};

// Invoke chat and display messages.
string input = "concept: maps made out of egg cartons.";
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input));
this.WriteLine($"# {AuthorRole.User}: '{input}'");

await foreach (var content in chat.InvokeAsync())
{
this.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
}

this.WriteLine($"# IS COMPLETE: {chat.IsComplete}");
}

private sealed class ApprovalTerminationStrategy : TerminationStrategy
{
// Terminate when the final message contains the term "approve"
protected override Task<bool> ShouldAgentTerminateAsync(Agent agent, IReadOnlyList<ChatMessageContent> history, CancellationToken cancellationToken)
=> Task.FromResult(history[history.Count - 1].Content?.Contains("approve", StringComparison.OrdinalIgnoreCase) ?? false);
}

private sealed class ExampleChatFilter(ITestOutputHelper output) : IAgentChatFilter
{
public void OnAgentInvoked(AgentChatFilterInvokedContext context)
{
output.WriteLine($"$ FILTER - {context.Agent.Name}: {nameof(OnAgentInvoked)} #{context.Message.Content?.Length ?? 0}");
}

public void OnAgentInvoking(AgentChatFilterInvokingContext context)
{
output.WriteLine($"$ FILTER - {context.Agent.Name}: {nameof(OnAgentInvoking)}");
}
}
}
56 changes: 53 additions & 3 deletions dotnet/src/Agents/Abstractions/AgentChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.Agents.Extensions;
using Microsoft.SemanticKernel.Agents.Filters;
using Microsoft.SemanticKernel.Agents.Internal;
using Microsoft.SemanticKernel.ChatCompletion;

Expand All @@ -26,6 +27,7 @@ public abstract class AgentChat
private readonly Dictionary<Agent, string> _channelMap; // Map agent to its channel-hash: one entry per agent.

private int _isActive;
private List<IAgentChatFilter>? _filters;
private ILogger? _logger;

/// <summary>
Expand All @@ -44,6 +46,11 @@ public abstract class AgentChat
/// </summary>
protected ILogger Logger => this._logger ??= this.LoggerFactory.CreateLogger(this.GetType());

/// <summary>
/// %%%
/// </summary>
public IList<IAgentChatFilter> Filters => this._filters ??= [];

/// <summary>
/// Exposes the internal history to subclasses.
/// </summary>
Expand Down Expand Up @@ -209,6 +216,9 @@ public void AddChatMessages(IReadOnlyList<ChatMessageContent> messages)

try
{
// %%%
this.OnAgentInvokingFilter(agent, this.History);

// Get or create the required channel and block until channel is synchronized.
// Will throw exception when propagating a processing failure.
AgentChannel channel = await GetOrCreateChannelAsync().ConfigureAwait(false);
Expand All @@ -217,11 +227,17 @@ public void AddChatMessages(IReadOnlyList<ChatMessageContent> messages)
List<ChatMessageContent> messages = [];
await foreach (ChatMessageContent message in channel.InvokeAsync(agent, cancellationToken).ConfigureAwait(false))
{
// Invoke filter
AgentChatFilterInvokedContext? context = this.OnAgentInvokedFilter(agent, this.History, message);

// Capture potential message replacement
ChatMessageContent effectiveMessage = context?.Message ?? message;

this.Logger.LogTrace("[{MethodName}] Agent message {AgentType}: {Message}", nameof(InvokeAgentAsync), agent.GetType(), message);

// Add to primary history
this.History.Add(message);
messages.Add(message);
this.History.Add(effectiveMessage);
messages.Add(effectiveMessage);

// Don't expose function-call and function-result messages to caller.
if (message.Items.All(i => i is FunctionCallContent || i is FunctionResultContent))
Expand All @@ -230,7 +246,7 @@ await foreach (ChatMessageContent message in channel.InvokeAsync(agent, cancella
}

// Yield message to caller
yield return message;
yield return effectiveMessage;
}

// Broadcast message to other channels (in parallel)
Expand Down Expand Up @@ -328,6 +344,40 @@ private string GetAgentHash(Agent agent)
return channel;
}

private AgentChatFilterInvokingContext? OnAgentInvokingFilter(Agent agent, IReadOnlyList<ChatMessageContent> history)
{
AgentChatFilterInvokingContext? context = null;

if (this._filters is { Count: > 0 })
{
context = new(agent, history);

for (int i = 0; i < this._filters.Count; i++)
{
this._filters[i].OnAgentInvoking(context);
}
}

return context;
}

private AgentChatFilterInvokedContext? OnAgentInvokedFilter(Agent agent, IReadOnlyList<ChatMessageContent> history, ChatMessageContent message)
{
AgentChatFilterInvokedContext? context = null;

if (this._filters is { Count: > 0 })
{
context = new(agent, history, message);

for (int i = 0; i < this._filters.Count; i++)
{
this._filters[i].OnAgentInvoked(context);
}
}

return context;
}

/// <summary>
/// Initializes a new instance of the <see cref="AgentChat"/> class.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;

namespace Microsoft.SemanticKernel.Agents.Filters;

/// <summary>
/// Base class with data related to <see cref="IAgentChatFilter"/>.
/// </summary>
public abstract class AgentChatFilterContext
{
/// <summary>
/// Gets the <see cref="Agent"/> with which this filter is associated.
/// </summary>
public Agent Agent { get; }

/// <summary>
/// Gets the message history with which this filter is associated.
/// </summary>
public IReadOnlyList<ChatMessageContent> History { get; }

/// <summary>
/// Initializes a new instance of the <see cref="AgentChatFilterContext"/> class.
/// </summary>
/// <param name="agent"></param>
/// <param name="history"></param>
internal AgentChatFilterContext(Agent agent, IReadOnlyList<ChatMessageContent> history)
{
this.Agent = agent;
this.History = history;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;

namespace Microsoft.SemanticKernel.Agents.Filters;

/// <summary>
/// Context associated with to <see cref="IAgentChatFilter.OnAgentInvoked"/>.
/// </summary>
public sealed class AgentChatFilterInvokedContext : AgentChatFilterContext
{
/// <summary>
/// %%%
/// </summary>
public ChatMessageContent Message { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="AgentChatFilterInvokedContext"/> class.
/// </summary>
/// <param name="agent"></param>
/// <param name="history"></param>
/// <param name="message"></param>
internal AgentChatFilterInvokedContext(Agent agent, IReadOnlyList<ChatMessageContent> history, ChatMessageContent message)
: base(agent, history)
{
this.Message = message;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;

namespace Microsoft.SemanticKernel.Agents.Filters;

/// <summary>
/// Context associated with to <see cref="IAgentChatFilter.OnAgentInvoking"/>.
/// </summary>
public sealed class AgentChatFilterInvokingContext : AgentChatFilterContext
{
/// <summary>
/// Initializes a new instance of the <see cref="AgentChatFilterInvokingContext"/> class.
/// </summary>
/// <param name="agent"></param>
/// <param name="history"></param>
internal AgentChatFilterInvokingContext(Agent agent, IReadOnlyList<ChatMessageContent> history)
: base(agent, history)
{ }
}
20 changes: 20 additions & 0 deletions dotnet/src/Agents/Abstractions/Filters/IAgentChatFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.SemanticKernel.Agents.Filters;

/// <summary>
/// Interface for filtering actions during agent chat.
/// </summary>
public interface IAgentChatFilter
{
/// <summary>
/// Method which is executed before invoking agent.
/// </summary>
/// <param name="context">Data related to agent before invoking.</param>
void OnAgentInvoking(AgentChatFilterInvokingContext context);

/// <summary>
/// Method which is executed after invoking agent as each response is processed.
/// </summary>
/// <param name="context">Data related to agent after invoking.</param>
void OnAgentInvoked(AgentChatFilterInvokedContext context);
}
32 changes: 32 additions & 0 deletions dotnet/src/Agents/UnitTests/AgentChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.Filters;
using Microsoft.SemanticKernel.ChatCompletion;
using Moq;
using Xunit;

namespace SemanticKernel.Agents.UnitTests;
Expand Down Expand Up @@ -109,6 +111,36 @@ async Task SynchronizedInvokeAsync()
}
}

/// <summary>
/// Verify behavior of <see cref="AgentChat"/> usage of <see cref="IAgentChatFilter"/>.
/// </summary>
[Fact]
public async Task VerifyAgentChatFiltersAsync()
{
// Create a filter
Mock<IAgentChatFilter> mockFilter = new();

// Create chat
TestChat chat = new()
{
Filters =
{
mockFilter.Object
}
};

// Verify initial state
mockFilter.Verify(f => f.OnAgentInvoking(It.IsAny<AgentChatFilterInvokingContext>()), Times.Never);
mockFilter.Verify(f => f.OnAgentInvoked(It.IsAny<AgentChatFilterInvokedContext>()), Times.Never);

// Invoke with input & verify (agent joins chat)
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, "hi"));
await chat.InvokeAsync().ToArrayAsync();
Assert.Equal(1, chat.Agent.InvokeCount);
mockFilter.Verify(f => f.OnAgentInvoking(It.IsAny<AgentChatFilterInvokingContext>()), Times.Once);
mockFilter.Verify(f => f.OnAgentInvoked(It.IsAny<AgentChatFilterInvokedContext>()), Times.Once);
}

private async Task VerifyHistoryAsync(int expectedCount, IAsyncEnumerable<ChatMessageContent> history)
{
if (expectedCount == 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.Filters;
using Microsoft.SemanticKernel.ChatCompletion;
using Moq;
using Xunit;

namespace SemanticKernel.Agents.UnitTests.OpenAI;

/// <summary>
/// Unit testing of <see cref="AgentChatFilterInvokedContext"/>.
/// </summary>
public class AgentChatFilterInvokedContextTests
{
/// <summary>
/// Verify initial state.
/// </summary>
[Fact]
public void VerifyAgentChatFilterInvokedContextState()
{
Mock<Agent> agent = new();
ChatHistory history = [];
ChatMessageContent message = new(AuthorRole.User, "hi");

AgentChatFilterInvokedContext context = new(agent.Object, history, message);

Assert.NotNull(context.Agent);
Assert.NotNull(context.History);
Assert.NotNull(context.Message);
}
}