Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Fix 5796 function calling enum params #5998

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dotnet/samples/GettingStarted/Step7_Observability.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void MyInvokedHandler(object? sender, FunctionInvokedEventArgs e)
{
if (e.Result.Metadata is not null && e.Result.Metadata.ContainsKey("Usage"))
{
Console.WriteLine($"Token usage: {e.Result.Metadata?["Usage"]?.AsJson()}");
Console.WriteLine("Token usage: {0}", e.Result.Metadata?["Usage"]?.AsJson());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void ItCanCreateValidGeminiFunctionManualForPlugin()
// Assert
Assert.NotNull(result);
Assert.Equal(
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"type":"string","enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
JsonSerializer.Serialize(result.Parameters)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public void ItCanCreateValidOpenAIFunctionManualForPlugin()
// Assert
Assert.NotNull(result);
Assert.Equal(
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"type":"string","enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
result.Parameters.ToString()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI;
namespace SemanticKernel.IntegrationTests.Connectors.Google;

public sealed class EmbeddingGenerationTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI.Gemini;
namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini;

public sealed class GeminiChatCompletionTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
using System.ComponentModel;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Time.Testing;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using xRetry;
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI.Gemini;
namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini;

public sealed class GeminiFunctionCallingTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down Expand Up @@ -291,6 +292,64 @@ public async Task ChatStreamingAutoInvokeTwoPluginsShouldGetDateAndReturnTasksBy
Assert.Contains("5", content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")]
[InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")]
public async Task ChatGenerationAutoInvokeShouldCallFunctionWithEnumParameterAndReturnResponseAsync(ServiceType serviceType)
{
// Arrange
var kernel = new Kernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));
var sut = this.GetChatService(serviceType);
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019");
var executionSettings = new GeminiPromptExecutionSettings()
{
MaxTokens = 2000,
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions,
};

// Act
var response = await sut.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);

// Assert
this.Output.WriteLine(response.Content);
Assert.Contains("19.04.2024", response.Content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")]
[InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")]
public async Task ChatStreamingAutoInvokeShouldCallFunctionWithEnumParameterAndReturnResponseAsync(ServiceType serviceType)
{
// Arrange
var kernel = new Kernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));
var sut = this.GetChatService(serviceType);
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019");
var executionSettings = new GeminiPromptExecutionSettings()
{
MaxTokens = 2000,
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions,
};

// Act
var responses = await sut.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)
.ToListAsync();

// Assert
string content = string.Concat(responses.Select(c => c.Content));
this.Output.WriteLine(content);
Assert.Contains("19.04.2024", content, StringComparison.OrdinalIgnoreCase);
}

public sealed class CustomerPlugin
{
[KernelFunction(nameof(GetCustomers))]
Expand Down Expand Up @@ -343,6 +402,37 @@ public DateTime GetDate()
}
}

public sealed class TimePlugin
{
private readonly TimeProvider _timeProvider;

public TimePlugin(TimeProvider timeProvider)
{
this._timeProvider = timeProvider;
}

[KernelFunction]
[Description("Get the date of the last day matching the supplied week day name in English. Example: Che giorno era 'Martedi' scorso -> dateMatchingLastDayName 'Tuesday' => Tuesday, 16 May, 2023")]
public string DateMatchingLastDayName(
[Description("The day name to match")] DayOfWeek input,
IFormatProvider? formatProvider = null)
{
DateTimeOffset dateTime = this._timeProvider.GetUtcNow();

// Walk backwards from the previous day for up to a week to find the matching day
for (int i = 1; i <= 7; ++i)
{
dateTime = dateTime.AddDays(-1);
if (dateTime.DayOfWeek == input)
{
break;
}
}

return dateTime.ToString("D", formatProvider);
}
}

public sealed class MathPlugin
{
[KernelFunction(nameof(Sum))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using Microsoft.SemanticKernel.Embeddings;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI;
namespace SemanticKernel.IntegrationTests.Connectors.Google;

public abstract class TestsBase(ITestOutputHelper output)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Time.Testing;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -112,6 +113,27 @@ public async Task CanAutoInvokeKernelFunctionsWithPrimitiveTypeParametersAsync()
Assert.Contains("10", result.GetValue<string>(), StringComparison.InvariantCulture);
}

[Fact(Skip = "OpenAI is throttling requests. Switch this test to use Azure OpenAI.")]
public async Task CanAutoInvokeKernelFunctionsWithEnumTypeParametersAsync()
{
// Arrange
Kernel kernel = this.InitializeKernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));

// Act
OpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result = await kernel.InvokePromptAsync(
"When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019",
new(settings));

// Assert
Assert.NotNull(result);
Assert.Contains("19.04.2024", result.GetValue<string>(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task CanAutoInvokeKernelFunctionFromPromptAsync()
{
Expand Down Expand Up @@ -550,4 +572,35 @@ private sealed class FakeFunctionFilter : IFunctionInvocationFilter
}

#endregion

public sealed class TimePlugin
{
private readonly TimeProvider _timeProvider;

public TimePlugin(TimeProvider timeProvider)
{
this._timeProvider = timeProvider;
}

[KernelFunction]
[Description("Get the date of the last day matching the supplied week day name in English. Example: Che giorno era 'Martedi' scorso -> dateMatchingLastDayName 'Tuesday' => Tuesday, 16 May, 2023")]
public string DateMatchingLastDayName(
[Description("The day name to match")] DayOfWeek input,
IFormatProvider? formatProvider = null)
{
DateTimeOffset dateTime = this._timeProvider.GetUtcNow();

// Walk backwards from the previous day for up to a week to find the matching day
for (int i = 1; i <= 7; ++i)
{
dateTime = dateTime.AddDays(-1);
if (dateTime.DayOfWeek == input)
{
break;
}
}

return dateTime.ToString("D", formatProvider);
}
}
}
1 change: 1 addition & 0 deletions dotnet/src/IntegrationTests/IntegrationTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" />
<PackageReference Include="Microsoft.Extensions.Http" />
<PackageReference Include="Microsoft.Extensions.Http.Resilience" />
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="System.Linq.Async" />
<PackageReference Include="xRetry" />
Expand Down
44 changes: 19 additions & 25 deletions dotnet/src/InternalUtilities/src/Schema/JsonSchemaMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
string? title = null,
string? description = null,
bool isNullableReferenceType = false,
bool isNullableOfTElement = false,
JsonConverter? customConverter = null,
bool hasDefaultValue = false,
JsonNode? defaultValue = null,
Expand All @@ -186,7 +187,7 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
JsonConverter effectiveConverter = customConverter ?? typeInfo.Converter;
JsonNumberHandling? effectiveNumberHandling = customNumberHandling ?? typeInfo.NumberHandling;
bool emitsTypeDiscriminator = derivedTypeDiscriminator?.Value is not null;
bool isCacheable = !emitsTypeDiscriminator && description is null && !hasDefaultValue;
bool isCacheable = !emitsTypeDiscriminator && description is null && !hasDefaultValue && !isNullableOfTElement;

if (!IsBuiltInConverter(effectiveConverter))
{
Expand Down Expand Up @@ -220,7 +221,8 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
defaultValue: defaultValue,
customNumberHandling: customNumberHandling,
customConverter: customConverter,
parentNullableOfT: type);
parentNullableOfT: type,
isNullableOfTElement: true);
}

if (isCacheable && typeInfo.Kind != JsonTypeInfoKind.None)
Expand Down Expand Up @@ -319,23 +321,15 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
}
else if (type.IsEnum)
{
if (TryGetStringEnumConverterValues(typeInfo, effectiveConverter, out JsonArray? values))
if (TryGetStringEnumConverterValues(typeInfo, effectiveConverter, out enumValues))
{
if (values is null)
{
// enum declared with the flags attribute -- do not surface enum values in the JSON schema.
schemaType = JsonSchemaType.String;
}
else
schemaType = JsonSchemaType.String;

if (enumValues != null && isNullableOfTElement)
{
if (parentNullableOfT is not null)
{
// We're generating the schema for a nullable
// enum type. Append null to the "enum" array.
values.Add(null);
}

enumValues = values;
// We're generating the schema for a nullable
// enum type. Append null to the "enum" array.
enumValues.Add(null);
}
}
else
Expand Down Expand Up @@ -417,15 +411,15 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals

state.Push(property.Name);
JsonObject propertySchema = MapJsonSchemaCore(
propertyTypeInfo,
ref state,
typeInfo: propertyTypeInfo,
state: ref state,
title: null,
propertyDescription,
isPropertyNullableReferenceType,
property.CustomConverter,
propertyHasDefaultValue,
propertyDefaultValue,
propertyNumberHandling);
description: propertyDescription,
isNullableReferenceType: isPropertyNullableReferenceType,
customConverter: property.CustomConverter,
hasDefaultValue: propertyHasDefaultValue,
defaultValue: propertyDefaultValue,
customNumberHandling: propertyNumberHandling);

state.Pop();

Expand Down