Skip to content

Commit

Permalink
Add missing CancellationToken parameters (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcominerva committed Jun 7, 2023
2 parents 54738de + 9717d96 commit 8898b48
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 44 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ If necessary, it is possibile to provide a custom Cache by implementing the [ICh
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
public Task RemoveAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.Remove(conversationId);
return Task.CompletedTask;
Expand Down
6 changes: 3 additions & 3 deletions samples/ChatGptConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
public Task RemoveAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.Remove(conversationId);
return Task.CompletedTask;
Expand Down
44 changes: 22 additions & 22 deletions src/ChatGptNet/ChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public ChatGptClient(HttpClient httpClient, IChatGptCache cache, ChatGptOptions
this.options = options;
}

public async Task<Guid> SetupAsync(Guid conversationId, string message)
public async Task<Guid> SetupAsync(Guid conversationId, string message, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(message);

Expand All @@ -52,7 +52,7 @@ public async Task<Guid> SetupAsync(Guid conversationId, string message)
}
};

await cache.SetAsync(conversationId, messages, options.MessageExpiration);
await cache.SetAsync(conversationId, messages, options.MessageExpiration, cancellationToken);
return conversationId;
}

Expand All @@ -66,7 +66,7 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
conversationId = Guid.NewGuid();
}

var messages = await CreateMessageListAsync(conversationId, message);
var messages = await CreateMessageListAsync(conversationId, message, cancellationToken);
var request = CreateRequest(messages, false, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
Expand All @@ -79,7 +79,7 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
if (response.IsSuccessful)
{
// Adds the response message to the conversation cache.
await UpdateHistoryAsync(conversationId, messages, response.Choices.First().Message);
await UpdateHistoryAsync(conversationId, messages, response.Choices.First().Message, cancellationToken);
}
else if (options.ThrowExceptionOnError)
{
Expand All @@ -99,7 +99,7 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
conversationId = Guid.NewGuid();
}

var messages = await CreateMessageListAsync(conversationId, message);
var messages = await CreateMessageListAsync(conversationId, message, cancellationToken);
var request = CreateRequest(messages, true, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
Expand Down Expand Up @@ -159,7 +159,7 @@ await UpdateHistoryAsync(conversationId, messages, new()
{
Role = ChatGptRoles.Assistant,
Content = contentBuilder.ToString()
});
}, cancellationToken);
}
else
{
Expand All @@ -176,32 +176,32 @@ await UpdateHistoryAsync(conversationId, messages, new()
}
}

public async Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId)
public async Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
var messages = await cache.GetAsync(conversationId) ?? Enumerable.Empty<ChatGptMessage>();
var messages = await cache.GetAsync(conversationId, cancellationToken) ?? Enumerable.Empty<ChatGptMessage>();
return messages;
}

public async Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false)
public async Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false, CancellationToken cancellationToken = default)
{
if (!preserveSetup)
{
// We don't want to preserve setup message, so just deletes all the cache history.
await cache.RemoveAsync(conversationId);
await cache.RemoveAsync(conversationId, cancellationToken);
}
else
{
var messages = await cache.GetAsync(conversationId);
var messages = await cache.GetAsync(conversationId, cancellationToken);
if (messages is not null)
{
// Removes all the messages, except system ones.
messages.RemoveAll(m => m.Role != ChatGptRoles.System);
await cache.SetAsync(conversationId, messages, options.MessageExpiration);
await cache.SetAsync(conversationId, messages, options.MessageExpiration, cancellationToken);
}
}
}

public async Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true)
public async Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(messages);

Expand All @@ -215,25 +215,25 @@ public async Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<C
{
// If messages must replace history, just use the current list, discarding all the previously cached content.
// If messages.Count() > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
await UpdateCacheAsync(conversationId, messages);
await UpdateCacheAsync(conversationId, messages, cancellationToken);
}
else
{
// Retrieves the current history and adds new messages.
var conversationHistory = await cache.GetAsync(conversationId) ?? new List<ChatGptMessage>();
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken) ?? new List<ChatGptMessage>();
conversationHistory.AddRange(messages);

// If messages total length > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
await UpdateCacheAsync(conversationId, conversationHistory);
await UpdateCacheAsync(conversationId, conversationHistory, cancellationToken);
}

return conversationId;
}

private async Task<List<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message)
private async Task<List<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message, CancellationToken cancellationToken = default)
{
// Checks whether a list of messages for the given conversationId already exists.
var conversationHistory = await cache.GetAsync(conversationId);
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken);
List<ChatGptMessage> messages = conversationHistory is not null ? new(conversationHistory) : new();

messages.Add(new()
Expand All @@ -259,13 +259,13 @@ private ChatGptRequest CreateRequest(IEnumerable<ChatGptMessage> messages, bool
User = options.User,
};

private async Task UpdateHistoryAsync(Guid conversationId, IList<ChatGptMessage> messages, ChatGptMessage message)
private async Task UpdateHistoryAsync(Guid conversationId, IList<ChatGptMessage> messages, ChatGptMessage message, CancellationToken cancellationToken = default)
{
messages.Add(message);
await UpdateCacheAsync(conversationId, messages);
await UpdateCacheAsync(conversationId, messages, cancellationToken);
}

private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages)
private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, CancellationToken cancellationToken = default)
{
// If the maximum number of messages has been reached, deletes the oldest ones.
// Note: system message does not count for message limit.
Expand All @@ -285,7 +285,7 @@ private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMess
messages = conversation.ToList();
}

await cache.SetAsync(conversationId, messages, options.MessageExpiration);
await cache.SetAsync(conversationId, messages, options.MessageExpiration, cancellationToken);
}

private static void EnsureErrorIsSet(ChatGptResponse response, HttpResponseMessage httpResponse)
Expand Down
6 changes: 3 additions & 3 deletions src/ChatGptNet/ChatGptMemoryCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ public ChatGptMemoryCache(IMemoryCache cache)
this.cache = cache;
}

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
cache.Set(conversationId, messages, expiration);
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
var messages = cache.Get<List<ChatGptMessage>?>(conversationId);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
public Task RemoveAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
cache.Remove(conversationId);
return Task.CompletedTask;
Expand Down
9 changes: 6 additions & 3 deletions src/ChatGptNet/IChatGptCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,25 @@ public interface IChatGptCache
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="messages">The list of messages.</param>
/// <param name="expiration">The amount of time in which messages must be stored in cache.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The <see cref="Task"/> corresponding to the asynchronous operation.</returns>
/// <seealso cref="ChatGptMessage"/>
Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration);
Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default);

/// <summary>
/// Gets the list of messages for the given <paramref name="conversationId"/>.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The message list of the conversation, or <see langword="null"/> if the Conversation Id does not exist.</returns>
/// <seealso cref="ChatGptMessage"/>
Task<List<ChatGptMessage>?> GetAsync(Guid conversationId);
Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default);

/// <summary>
/// Removes from the cache all the message for the given <paramref name="conversationId"/>.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The <see cref="Task"/> corresponding to the asynchronous operation.</returns>
Task RemoveAsync(Guid conversationId);
Task RemoveAsync(Guid conversationId, CancellationToken cancellationToken = default);
}
26 changes: 16 additions & 10 deletions src/ChatGptNet/IChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@ public interface IChatGptClient
/// Setups a new conversation with a system message, that is used to influence assistant behavior.
/// </summary>
/// <param name="message">The system message.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The unique identifier of the new conversation.</returns>
/// <remarks>This method creates a new conversation with a system message and a random Conversation Id. Then, call <see cref="AskAsync(Guid, string, ChatGptParameters, string, CancellationToken)"/> with this Id to start the actual conversation.</remarks>
/// <exception cref="ArgumentNullException"><paramref name="message"/> is <see langword="null"/>.</exception>
/// <seealso cref="AskAsync(Guid, string, ChatGptParameters, string, CancellationToken)"/>
Task<Guid> SetupAsync(string message)
=> SetupAsync(Guid.NewGuid(), message);
Task<Guid> SetupAsync(string message, CancellationToken cancellationToken = default)
=> SetupAsync(Guid.NewGuid(), message, cancellationToken);

/// <summary>
/// Setups a conversation with a system message, that is used to influence assistant behavior.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation, used to automatically retrieve previous messages in the chat history.</param>
/// <param name="message">The system message.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <remarks>This method creates a new conversation, with a system message and the given <paramref name="conversationId"/>. If a conversation with this Id already exists, it will be automatically cleared. Then, call <see cref="AskAsync(Guid, string, ChatGptParameters, string, CancellationToken)"/> to start the actual conversation.</remarks>
/// <exception cref="ArgumentNullException"><paramref name="message"/> is <see langword="null"/>.</exception>
/// <seealso cref="AskAsync(Guid, string, ChatGptParameters, string, CancellationToken)"/>
Task<Guid> SetupAsync(Guid conversationId, string message);
Task<Guid> SetupAsync(Guid conversationId, string message, CancellationToken cancellationToken = default);

/// <summary>
/// Requests a new chat interaction using the default completion model specified in the <see cref="ChatGptOptions.DefaultModel"/> property.
Expand Down Expand Up @@ -102,14 +104,16 @@ Task<Guid> SetupAsync(string message)
/// Retrieves a chat conversation from the cache.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The message list of the conversation, or <see cref="Enumerable.Empty{ChatGptMessage}"/> if the Conversation Id does not exist.</returns>
/// <seealso cref="ChatGptMessage"/>
Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId);
Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId, CancellationToken cancellationToken = default);

/// <summary>
/// Loads messages into a new conversation.
/// </summary>
/// <param name="messages">Messages to load into a new conversation.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The unique identifier of the new conversation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="messages"/> is <see langword="null"/>.</exception>
/// <remarks>
Expand All @@ -118,27 +122,29 @@ Task<Guid> SetupAsync(string message)
/// </remarks>
/// <seealso cref="ChatGptOptions.MessageLimit"/>
/// <seealso cref="AskStreamAsync(Guid, string, ChatGptParameters?, string?, CancellationToken)"/>
Task<Guid> LoadConversationAsync(IEnumerable<ChatGptMessage> messages)
=> LoadConversationAsync(Guid.NewGuid(), messages);
Task<Guid> LoadConversationAsync(IEnumerable<ChatGptMessage> messages, CancellationToken cancellationToken = default)
=> LoadConversationAsync(Guid.NewGuid(), messages, true, cancellationToken);

/// <summary>
/// Loads messages into conversation history.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="messages">The messages to load into conversation history.</param>
/// <param name="replaceHistory"><see langword="true"/> to replace all the existing messages; <see langword="false"/> to mantain them.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The unique identifier of the conversation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="messages"/> is <see langword="null"/>.</exception>
/// <remarks>The total number of messages never exceeds the message limit defined in <see cref="ChatGptOptions.MessageLimit"/>. If <paramref name="messages"/> contains more, only the latest ones are loaded.</remarks>
/// <seealso cref="ChatGptOptions.MessageLimit"/>
Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true);
Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true, CancellationToken cancellationToken = default);

/// <summary>
/// Deletes a chat conversation, clearing all the history.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="preserveSetup"><see langword="true"/> to maintain the system message that has been specified with the <see cref="SetupAsync(Guid, string)"/> method.</param>
/// <param name="preserveSetup"><see langword="true"/> to maintain the system message that has been specified with the <see cref="SetupAsync(Guid, string, CancellationToken)"/> method.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The <see cref="Task"/> corresponding to the asynchronous operation.</returns>
/// <seealso cref="SetupAsync(Guid, string)"/>
Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false);
/// <seealso cref="SetupAsync(Guid, string, CancellationToken)"/>
Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false, CancellationToken cancellationToken = default);
}

0 comments on commit 8898b48

Please sign in to comment.