Skip to content

Commit

Permalink
.Net: Fix 5796 function calling enum params (#5998)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Fixes #5796

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Type for enum wasn't correctly set in JsonSchemaMapper, it didn't matter
for OpenAI but gemini was throwing exception if type isn't specified.
Fixed that with `string` type.
Added new unit tests for Gemini and OpenAI. Both passed.

@RogerBarreto 
@SergeyMenshykh

DataHelper and BertOnyx was updated automatically by formatter.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Roger Barreto <[email protected]>
  • Loading branch information
Krzysztof318 and RogerBarreto committed May 13, 2024
1 parent f53c98e commit 8a8cd95
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 32 deletions.
2 changes: 1 addition & 1 deletion dotnet/samples/GettingStarted/Step7_Observability.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void MyInvokedHandler(object? sender, FunctionInvokedEventArgs e)
{
if (e.Result.Metadata is not null && e.Result.Metadata.ContainsKey("Usage"))
{
Console.WriteLine($"Token usage: {e.Result.Metadata?["Usage"]?.AsJson()}");
Console.WriteLine("Token usage: {0}", e.Result.Metadata?["Usage"]?.AsJson());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void ItCanCreateValidGeminiFunctionManualForPlugin()
// Assert
Assert.NotNull(result);
Assert.Equal(
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"type":"string","enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
JsonSerializer.Serialize(result.Parameters)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public void ItCanCreateValidOpenAIFunctionManualForPlugin()
// Assert
Assert.NotNull(result);
Assert.Equal(
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
"""{"type":"object","required":["parameter1","parameter2","parameter3"],"properties":{"parameter1":{"type":"string","description":"String parameter"},"parameter2":{"type":"string","enum":["Value1","Value2"],"description":"Enum parameter"},"parameter3":{"type":"string","format":"date-time","description":"DateTime parameter"}}}""",
result.Parameters.ToString()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI;
namespace SemanticKernel.IntegrationTests.Connectors.Google;

public sealed class EmbeddingGenerationTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI.Gemini;
namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini;

public sealed class GeminiChatCompletionTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
using System.ComponentModel;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Time.Testing;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using xRetry;
using Xunit;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI.Gemini;
namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini;

public sealed class GeminiFunctionCallingTests(ITestOutputHelper output) : TestsBase(output)
{
Expand Down Expand Up @@ -291,6 +292,64 @@ public async Task ChatStreamingAutoInvokeTwoPluginsShouldGetDateAndReturnTasksBy
Assert.Contains("5", content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")]
[InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")]
public async Task ChatGenerationAutoInvokeShouldCallFunctionWithEnumParameterAndReturnResponseAsync(ServiceType serviceType)
{
// Arrange
var kernel = new Kernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));
var sut = this.GetChatService(serviceType);
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019");
var executionSettings = new GeminiPromptExecutionSettings()
{
MaxTokens = 2000,
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions,
};

// Act
var response = await sut.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);

// Assert
this.Output.WriteLine(response.Content);
Assert.Contains("19.04.2024", response.Content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")]
[InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")]
public async Task ChatStreamingAutoInvokeShouldCallFunctionWithEnumParameterAndReturnResponseAsync(ServiceType serviceType)
{
// Arrange
var kernel = new Kernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));
var sut = this.GetChatService(serviceType);
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019");
var executionSettings = new GeminiPromptExecutionSettings()
{
MaxTokens = 2000,
ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions,
};

// Act
var responses = await sut.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)
.ToListAsync();

// Assert
string content = string.Concat(responses.Select(c => c.Content));
this.Output.WriteLine(content);
Assert.Contains("19.04.2024", content, StringComparison.OrdinalIgnoreCase);
}

public sealed class CustomerPlugin
{
[KernelFunction(nameof(GetCustomers))]
Expand Down Expand Up @@ -343,6 +402,37 @@ public DateTime GetDate()
}
}

public sealed class TimePlugin
{
private readonly TimeProvider _timeProvider;

public TimePlugin(TimeProvider timeProvider)
{
this._timeProvider = timeProvider;
}

[KernelFunction]
[Description("Get the date of the last day matching the supplied week day name in English. Example: Che giorno era 'Martedi' scorso -> dateMatchingLastDayName 'Tuesday' => Tuesday, 16 May, 2023")]
public string DateMatchingLastDayName(
[Description("The day name to match")] DayOfWeek input,
IFormatProvider? formatProvider = null)
{
DateTimeOffset dateTime = this._timeProvider.GetUtcNow();

// Walk backwards from the previous day for up to a week to find the matching day
for (int i = 1; i <= 7; ++i)
{
dateTime = dateTime.AddDays(-1);
if (dateTime.DayOfWeek == input)
{
break;
}
}

return dateTime.ToString("D", formatProvider);
}
}

public sealed class MathPlugin
{
[KernelFunction(nameof(Sum))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using Microsoft.SemanticKernel.Embeddings;
using Xunit.Abstractions;

namespace SemanticKernel.IntegrationTests.Connectors.GoogleVertexAI;
namespace SemanticKernel.IntegrationTests.Connectors.Google;

public abstract class TestsBase(ITestOutputHelper output)
{
Expand Down
53 changes: 53 additions & 0 deletions dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Time.Testing;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -112,6 +113,27 @@ public async Task CanAutoInvokeKernelFunctionsWithPrimitiveTypeParametersAsync()
Assert.Contains("10", result.GetValue<string>(), StringComparison.InvariantCulture);
}

[Fact(Skip = "OpenAI is throttling requests. Switch this test to use Azure OpenAI.")]
public async Task CanAutoInvokeKernelFunctionsWithEnumTypeParametersAsync()
{
// Arrange
Kernel kernel = this.InitializeKernel();
var timeProvider = new FakeTimeProvider();
timeProvider.SetUtcNow(new DateTimeOffset(new DateTime(2024, 4, 24))); // Wednesday
var timePlugin = new TimePlugin(timeProvider);
kernel.ImportPluginFromObject(timePlugin, nameof(TimePlugin));

// Act
OpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result = await kernel.InvokePromptAsync(
"When was last friday? Show the date in format DD.MM.YYYY for example: 15.07.2019",
new(settings));

// Assert
Assert.NotNull(result);
Assert.Contains("19.04.2024", result.GetValue<string>(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task CanAutoInvokeKernelFunctionFromPromptAsync()
{
Expand Down Expand Up @@ -550,4 +572,35 @@ private sealed class FakeFunctionFilter : IFunctionInvocationFilter
}

#endregion

public sealed class TimePlugin
{
private readonly TimeProvider _timeProvider;

public TimePlugin(TimeProvider timeProvider)
{
this._timeProvider = timeProvider;
}

[KernelFunction]
[Description("Get the date of the last day matching the supplied week day name in English. Example: Che giorno era 'Martedi' scorso -> dateMatchingLastDayName 'Tuesday' => Tuesday, 16 May, 2023")]
public string DateMatchingLastDayName(
[Description("The day name to match")] DayOfWeek input,
IFormatProvider? formatProvider = null)
{
DateTimeOffset dateTime = this._timeProvider.GetUtcNow();

// Walk backwards from the previous day for up to a week to find the matching day
for (int i = 1; i <= 7; ++i)
{
dateTime = dateTime.AddDays(-1);
if (dateTime.DayOfWeek == input)
{
break;
}
}

return dateTime.ToString("D", formatProvider);
}
}
}
1 change: 1 addition & 0 deletions dotnet/src/IntegrationTests/IntegrationTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" />
<PackageReference Include="Microsoft.Extensions.Http" />
<PackageReference Include="Microsoft.Extensions.Http.Resilience" />
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="System.Linq.Async" />
<PackageReference Include="xRetry" />
Expand Down
44 changes: 19 additions & 25 deletions dotnet/src/InternalUtilities/src/Schema/JsonSchemaMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
string? title = null,
string? description = null,
bool isNullableReferenceType = false,
bool isNullableOfTElement = false,
JsonConverter? customConverter = null,
bool hasDefaultValue = false,
JsonNode? defaultValue = null,
Expand All @@ -186,7 +187,7 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
JsonConverter effectiveConverter = customConverter ?? typeInfo.Converter;
JsonNumberHandling? effectiveNumberHandling = customNumberHandling ?? typeInfo.NumberHandling;
bool emitsTypeDiscriminator = derivedTypeDiscriminator?.Value is not null;
bool isCacheable = !emitsTypeDiscriminator && description is null && !hasDefaultValue;
bool isCacheable = !emitsTypeDiscriminator && description is null && !hasDefaultValue && !isNullableOfTElement;

if (!IsBuiltInConverter(effectiveConverter))
{
Expand Down Expand Up @@ -220,7 +221,8 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
defaultValue: defaultValue,
customNumberHandling: customNumberHandling,
customConverter: customConverter,
parentNullableOfT: type);
parentNullableOfT: type,
isNullableOfTElement: true);
}

if (isCacheable && typeInfo.Kind != JsonTypeInfoKind.None)
Expand Down Expand Up @@ -319,23 +321,15 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals
}
else if (type.IsEnum)
{
if (TryGetStringEnumConverterValues(typeInfo, effectiveConverter, out JsonArray? values))
if (TryGetStringEnumConverterValues(typeInfo, effectiveConverter, out enumValues))
{
if (values is null)
{
// enum declared with the flags attribute -- do not surface enum values in the JSON schema.
schemaType = JsonSchemaType.String;
}
else
schemaType = JsonSchemaType.String;

if (enumValues != null && isNullableOfTElement)
{
if (parentNullableOfT is not null)
{
// We're generating the schema for a nullable
// enum type. Append null to the "enum" array.
values.Add(null);
}

enumValues = values;
// We're generating the schema for a nullable
// enum type. Append null to the "enum" array.
enumValues.Add(null);
}
}
else
Expand Down Expand Up @@ -417,15 +411,15 @@ public static string ToJsonString(this JsonNode? node, bool writeIndented = fals

state.Push(property.Name);
JsonObject propertySchema = MapJsonSchemaCore(
propertyTypeInfo,
ref state,
typeInfo: propertyTypeInfo,
state: ref state,
title: null,
propertyDescription,
isPropertyNullableReferenceType,
property.CustomConverter,
propertyHasDefaultValue,
propertyDefaultValue,
propertyNumberHandling);
description: propertyDescription,
isNullableReferenceType: isPropertyNullableReferenceType,
customConverter: property.CustomConverter,
hasDefaultValue: propertyHasDefaultValue,
defaultValue: propertyDefaultValue,
customNumberHandling: propertyNumberHandling);

state.Pop();

Expand Down

0 comments on commit 8a8cd95

Please sign in to comment.