Skip to content

Bring back AsIChatClient for OpenAI AssistantClient #6501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ internal sealed partial class OpenAIChatClient : IChatClient
MoveDefaultKeywordToDescription = true,
});

/// <summary>Gets the default OpenAI endpoint.</summary>
private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");

/// <summary>Metadata about the client.</summary>
private readonly ChatClientMetadata _metadata;

Expand All @@ -57,7 +54,7 @@ public OpenAIChatClient(ChatClient chatClient)
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatClient) as Uri ?? DefaultOpenAIEndpoint;
?.GetValue(chatClient) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint;
string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatClient) as string;

Expand Down Expand Up @@ -113,8 +110,6 @@ void IDisposable.Dispose()
// Nothing to dispose. Implementation required for the IChatClient interface.
}

private static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");

/// <summary>Converts an Extensions chat message enumerable to an OpenAI chat message enumerable.</summary>
private static IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs, JsonSerializerOptions options)
{
Expand All @@ -125,12 +120,12 @@ void IDisposable.Dispose()
{
if (input.Role == ChatRole.System ||
input.Role == ChatRole.User ||
input.Role == ChatRoleDeveloper)
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper)
{
var parts = ToOpenAIChatContent(input.Contents);
yield return
input.Role == ChatRole.System ? new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
input.Role == ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Tool)
Expand Down Expand Up @@ -622,7 +617,7 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) =>
ChatMessageRole.User => ChatRole.User,
ChatMessageRole.Assistant => ChatRole.Assistant,
ChatMessageRole.Tool => ChatRole.Tool,
ChatMessageRole.Developer => ChatRoleDeveloper,
ChatMessageRole.Developer => OpenAIResponseChatClient.ChatRoleDeveloper,
_ => new ChatRole(role.ToString()),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics.CodeAnalysis;
using OpenAI;
using OpenAI.Assistants;
using OpenAI.Audio;
using OpenAI.Chat;
using OpenAI.Embeddings;
Expand All @@ -25,6 +26,19 @@ public static IChatClient AsIChatClient(this ChatClient chatClient) =>
public static IChatClient AsIChatClient(this OpenAIResponseClient responseClient) =>
new OpenAIResponseChatClient(responseClient);

/// <summary>Gets an <see cref="IChatClient"/> for use with this <see cref="AssistantClient"/>.</summary>
/// <param name="assistantClient">The <see cref="AssistantClient"/> instance to be accessed as an <see cref="IChatClient"/>.</param>
/// <param name="assistantId">The unique identifier of the assistant with which to interact.</param>
/// <param name="threadId">
/// An optional existing thread identifier for the chat session. This serves as a default, and may be overridden per call to
/// <see cref="IChatClient.GetResponseAsync"/> or <see cref="IChatClient.GetStreamingResponseAsync"/> via the <see cref="ChatOptions.ConversationId"/>
/// property. If no thread ID is provided via either mechanism, a new thread will be created for the request.
/// </param>
/// <returns>An <see cref="IChatClient"/> instance configured to interact with the specified agent and thread.</returns>
[Experimental("OPENAI001")]
public static IChatClient AsIChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) =>
new OpenAIAssistantChatClient(assistantClient, assistantId, threadId);

/// <summary>Gets an <see cref="ISpeechToTextClient"/> for use with this <see cref="AudioClient"/>.</summary>
/// <param name="audioClient">The client.</param>
/// <returns>An <see cref="ISpeechToTextClient"/> that can be used to transcribe audio via the <see cref="AudioClient"/>.</returns>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ namespace Microsoft.Extensions.AI;
internal sealed partial class OpenAIResponseChatClient : IChatClient
{
/// <summary>Gets the default OpenAI endpoint.</summary>
private static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");

/// <summary>A <see cref="ChatRole"/> for "developer".</summary>
private static readonly ChatRole _chatRoleDeveloper = new("developer");
/// <summary>Gets a <see cref="ChatRole"/> for "developer".</summary>
internal static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");

/// <summary>Metadata about the client.</summary>
private readonly ChatClientMetadata _metadata;
Expand Down Expand Up @@ -88,7 +88,7 @@ public async Task<ChatResponse> GetResponseAsync(
// Convert and return the results.
ChatResponse response = new()
{
ConversationId = openAIResponse.Id,
ConversationId = openAIOptions.StoredOutputEnabled is false ? null : openAIResponse.Id,
CreatedAt = openAIResponse.CreatedAt,
FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason),
Messages = [new(ChatRole.Assistant, [])],
Expand Down Expand Up @@ -167,6 +167,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
// Make the call to the OpenAIResponseClient and process the streaming results.
DateTimeOffset? createdAt = null;
string? responseId = null;
string? conversationId = null;
string? modelId = null;
string? lastMessageId = null;
ChatRole? lastRole = null;
Expand All @@ -179,18 +180,19 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
case StreamingResponseCreatedUpdate createdUpdate:
createdAt = createdUpdate.Response.CreatedAt;
responseId = createdUpdate.Response.Id;
conversationId = openAIOptions.StoredOutputEnabled is false ? null : responseId;
modelId = createdUpdate.Response.Model;
goto default;

case StreamingResponseCompletedUpdate completedUpdate:
yield return new()
{
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
ConversationId = conversationId,
CreatedAt = createdAt,
FinishReason =
ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ??
(functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop),
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
ConversationId = responseId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
Expand Down Expand Up @@ -223,7 +225,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
lastRole = ToChatRole(messageItem?.Role);
yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta)
{
ConversationId = responseId,
ConversationId = conversationId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
Expand Down Expand Up @@ -258,7 +260,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
lastRole = ChatRole.Assistant;
yield return new ChatResponseUpdate(lastRole, [fci])
{
ConversationId = responseId,
ConversationId = conversationId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
Expand All @@ -275,7 +277,6 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
case StreamingResponseErrorUpdate errorUpdate:
yield return new ChatResponseUpdate
{
ConversationId = responseId,
Contents =
[
new ErrorContent(errorUpdate.Message)
Expand All @@ -284,6 +285,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
Details = errorUpdate.Param,
}
],
ConversationId = conversationId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
Expand All @@ -296,21 +298,21 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
case StreamingResponseRefusalDoneUpdate refusalDone:
yield return new ChatResponseUpdate
{
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
ConversationId = conversationId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
RawRepresentation = streamingUpdate,
ResponseId = responseId,
Role = lastRole,
ConversationId = responseId,
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
};
break;

default:
yield return new ChatResponseUpdate
{
ConversationId = responseId,
ConversationId = conversationId,
CreatedAt = createdAt,
MessageId = lastMessageId,
ModelId = modelId,
Expand All @@ -334,7 +336,7 @@ private static ChatRole ToChatRole(MessageRole? role) =>
role switch
{
MessageRole.System => ChatRole.System,
MessageRole.Developer => _chatRoleDeveloper,
MessageRole.Developer => ChatRoleDeveloper,
MessageRole.User => ChatRole.User,
_ => ChatRole.Assistant,
};
Expand Down Expand Up @@ -452,7 +454,7 @@ private static IEnumerable<ResponseItem> ToOpenAIResponseItems(
foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System ||
input.Role == _chatRoleDeveloper)
input.Role == ChatRoleDeveloper)
{
string text = input.Text;
if (!string.IsNullOrWhiteSpace(text))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,9 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange

// Second time, the calls to the LLM don't happen, but the function is called again
var secondResponse = await chatClient.GetResponseAsync([message]);
Assert.Equal(response.Text, secondResponse.Text);
Assert.Equal(2, functionCallCount);
Assert.Equal(FunctionInvokingChatClientSetsConversationId ? 3 : 2, llmCallCount!.CallCount);
Assert.Equal(response.Text, secondResponse.Text);
}

public virtual bool FunctionInvokingChatClientSetsConversationId => false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private static BinaryEmbedding QuantizeToBinary(Embedding<float> embedding)
{
if (vector[i] > 0)
{
result[i / 8] = true;
result[i] = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
#pragma warning disable CA1822 // Mark members as static
#pragma warning disable CA2000 // Dispose objects before losing scope
#pragma warning disable S1135 // Track uses of "TODO" tags
#pragma warning disable xUnit1013 // Public method should be marked as test

using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using OpenAI.Assistants;
using Xunit;

namespace Microsoft.Extensions.AI;

public class OpenAIAssistantChatClientIntegrationTests : ChatClientIntegrationTests
{
protected override IChatClient? CreateChatClient()
{
var openAIClient = IntegrationTestHelpers.GetOpenAIClient();
if (openAIClient is null)
{
return null;
}

AssistantClient ac = openAIClient.GetAssistantClient();
var assistant =
ac.GetAssistants().FirstOrDefault() ??
ac.CreateAssistant("gpt-4o-mini");

return ac.AsIChatClient(assistant.Id);
}

public override bool FunctionInvokingChatClientSetsConversationId => true;

// These tests aren't written in a way that works well with threads.
public override Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync() => Task.CompletedTask;
public override Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() => Task.CompletedTask;

// Assistants doesn't support data URIs.
public override Task MultiModal_DescribeImage() => Task.CompletedTask;
public override Task MultiModal_DescribePdf() => Task.CompletedTask;

// [Fact] // uncomment and run to clear out _all_ threads in your OpenAI account
public async Task DeleteAllThreads()
{
using HttpClient client = new(new HttpClientHandler
{
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
});

// These values need to be filled in. The bearer token needs to be sniffed from a browser
// session interacting with the dashboard (e.g. use F12 networking tools to look at request headers
// made to "https://api.openai.com/v1/threads?limit=10" after clicking on Assistants | Threads in the
// OpenAI portal dashboard).
client.DefaultRequestHeaders.Add("authorization", $"Bearer sess-ENTERYOURSESSIONTOKEN");
client.DefaultRequestHeaders.Add("openai-organization", "org-ENTERYOURORGID");
client.DefaultRequestHeaders.Add("openai-project", "proj_ENTERYOURPROJECTID");

AssistantClient ac = new AssistantClient(Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!);
while (true)
{
string listing = await client.GetStringAsync("https://api.openai.com/v1/threads?limit=100");

var matches = Regex.Matches(listing, @"thread_\w+");
if (matches.Count == 0)
{
break;
}

foreach (Match m in matches)
{
var dr = await ac.DeleteThreadAsync(m.Value);
Assert.True(dr.Value.Deleted);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.ClientModel;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using OpenAI;
using OpenAI.Assistants;
using Xunit;

#pragma warning disable S103 // Lines should not be too long
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

namespace Microsoft.Extensions.AI;

public class OpenAIAssistantChatClientTests
{
[Fact]
public void AsIChatClient_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("assistantClient", () => ((AssistantClient)null!).AsIChatClient("assistantId"));
Assert.Throws<ArgumentNullException>("assistantId", () => new AssistantClient("ignored").AsIChatClient(null!));
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void AsIChatClient_OpenAIClient_ProducesExpectedMetadata(bool useAzureOpenAI)
{
Uri endpoint = new("http://localhost/some/endpoint");

var client = useAzureOpenAI ?
new AzureOpenAIClient(endpoint, new ApiKeyCredential("key")) :
new OpenAIClient(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint });

IChatClient[] clients =
[
client.GetAssistantClient().AsIChatClient("assistantId"),
client.GetAssistantClient().AsIChatClient("assistantId", "threadId"),
];

foreach (var chatClient in clients)
{
var metadata = chatClient.GetService<ChatClientMetadata>();
Assert.Equal("openai", metadata?.ProviderName);
Assert.Equal(endpoint, metadata?.ProviderUri);
}
}

[Fact]
public void GetService_AssistantClient_SuccessfullyReturnsUnderlyingClient()
{
AssistantClient assistantClient = new OpenAIClient("key").GetAssistantClient();
IChatClient chatClient = assistantClient.AsIChatClient("assistantId");

Assert.Same(assistantClient, chatClient.GetService<AssistantClient>());

Assert.Null(chatClient.GetService<OpenAIClient>());

using IChatClient pipeline = chatClient
.AsBuilder()
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Build();

Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
Assert.NotNull(pipeline.GetService<CachingChatClient>());
Assert.NotNull(pipeline.GetService<OpenTelemetryChatClient>());

Assert.Same(assistantClient, pipeline.GetService<AssistantClient>());
Assert.IsType<FunctionInvokingChatClient>(pipeline.GetService<IChatClient>());
}
}
Loading
Loading