Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
</Choose>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.9.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
</Choose>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.9.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<metadata>
<id>AWSSDK.Extensions.Bedrock.MEAI</id>
<title>AWSSDK - Bedrock integration with Microsoft.Extensions.AI.</title>
<version>4.0.2.0</version>
<version>4.0.3.0</version>
<authors>Amazon Web Services</authors>
<description>Implementations of Microsoft.Extensions.AI's abstractions for Bedrock.</description>
<language>en-US</language>
Expand All @@ -13,19 +13,19 @@
<icon>images\AWSLogo.png</icon>
<dependencies>
<group targetFramework="net472">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="AWSSDK.Core" version="4.0.0.26" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.6" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.0" />
</group>
<group targetFramework="netstandard2.0">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="AWSSDK.Core" version="4.0.0.26" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.6" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.0" />
</group>
<group targetFramework="net8.0">
<dependency id="AWSSDK.Core" version="4.0.0.4" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
<dependency id="AWSSDK.Core" version="4.0.0.26" />
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.6" />
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.0" />
</group>
</dependencies>
</metadata>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

using Microsoft.Extensions.AI;
using System;
using System.Diagnostics.CodeAnalysis;

namespace Amazon.BedrockRuntime;

Expand Down Expand Up @@ -53,4 +54,18 @@ public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerato
this IAmazonBedrockRuntime runtime, string? defaultModelId = null, int? defaultModelDimensions = null) =>
runtime is not null ? new BedrockEmbeddingGenerator(runtime, defaultModelId, defaultModelDimensions) :
throw new ArgumentNullException(nameof(runtime));

/// <summary>Gets an <see cref="IImageGenerator"/> for the specified <see cref="IAmazonBedrockRuntime"/> instance.</summary>
/// <param name="runtime">The runtime instance to be represented as an <see cref="IImageGenerator"/>.</param>
/// <param name="defaultModelId">
/// The default model ID to use when no model is specified in a request. If not specified,
/// a model must be provided in the <see cref="ImageGenerationOptions.ModelId"/> passed to <see cref="IImageGenerator.GenerateAsync"/>.
/// </param>
/// <returns>An <see cref="IImageGenerator"/> instance representing the <see cref="IAmazonBedrockRuntime"/> instance.</returns>
/// <exception cref="ArgumentNullException"><paramref name="runtime"/> is <see langword="null"/>.</exception>
[Experimental("MEAI001")]
public static IImageGenerator AsIImageGenerator(
this IAmazonBedrockRuntime runtime, string? defaultModelId = null) =>
runtime is not null ? new BedrockImageGenerator(runtime, defaultModelId) :
throw new ArgumentNullException(nameof(runtime));
}
43 changes: 39 additions & 4 deletions extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public async Task<ChatResponse> GetResponseAsync(

ChatMessage result = new()
{
CreatedAt = DateTimeOffset.UtcNow,
RawRepresentation = response.Output?.Message,
Role = ChatRole.Assistant,
MessageId = Guid.NewGuid().ToString("N"),
Expand All @@ -97,6 +98,23 @@ public async Task<ChatResponse> GetResponseAsync(
result.Contents.Add(new TextContent(text) { RawRepresentation = content });
}

if (content.CitationsContent is { } citations &&
citations.Citations is { Count: > 0 } &&
citations.Content is { Count: > 0 })
{
int count = Math.Min(citations.Citations.Count, citations.Content.Count);
for (int i = 0; i < count; i++)
{
TextContent tc = new(citations.Content[i]?.Text) { RawRepresentation = citations.Content[i] };
tc.Annotations = [new CitationAnnotation()
{
Title = citations.Citations[i].Title,
Snippet = citations.Citations[i].SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
result.Contents.Add(tc);
}
}

if (content.ReasoningContent is { ReasoningText.Text: not null } reasoningContent)
{
TextReasoningContent trc = new(reasoningContent.ReasoningText.Text) { RawRepresentation = content };
Expand Down Expand Up @@ -126,7 +144,11 @@ public async Task<ChatResponse> GetResponseAsync(

if (content.Document is { Source.Bytes: { } documentBytes, Format: { } documentFormat })
{
result.Contents.Add(new DataContent(documentBytes.ToArray(), GetMimeType(documentFormat)) { RawRepresentation = content });
result.Contents.Add(new DataContent(documentBytes.ToArray(), GetMimeType(documentFormat))
{
RawRepresentation = content,
Name = content.Document.Name
});
}

if (content.ToolUse is { } toolUse)
Expand All @@ -143,7 +165,7 @@ public async Task<ChatResponse> GetResponseAsync(

return new(result)
{
CreatedAt = DateTimeOffset.UtcNow,
CreatedAt = result.CreatedAt,
FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null,
RawRepresentation = response,
ResponseId = Guid.NewGuid().ToString("N"),
Expand Down Expand Up @@ -205,14 +227,26 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(

if (contentBlockDelta.Delta.Text is string text)
{
yield return new(ChatRole.Assistant, text)
ChatResponseUpdate textUpdate = new(ChatRole.Assistant, text)
{
CreatedAt = DateTimeOffset.UtcNow,
MessageId = messageId,
RawRepresentation = update,
FinishReason = finishReason,
ResponseId = responseId,
};

if (contentBlockDelta.Delta.Citation is { } citation &&
(citation.Title is not null || citation.SourceContent is { Count: > 0 }))
{
textUpdate.Contents[0].Annotations = [new CitationAnnotation()
{
Title = citation.Title,
Snippet = citation.SourceContent?.Select(c => c.Text).FirstOrDefault(),
}];
}

yield return textUpdate;
}

if (contentBlockDelta.Delta.ReasoningContent is { Text: not null } reasoningContent)
Expand Down Expand Up @@ -468,6 +502,7 @@ private static List<ContentBlock> CreateContents(ChatMessage message)
{
Source = new() { Bytes = new(dc.Data.ToArray()) },
Format = docFormat,
Name = dc.Name ?? "file",
}
});
}
Expand Down Expand Up @@ -693,7 +728,7 @@ private static Document ToDocument(JsonElement json)
{
foreach (AITool tool in tools)
{
if (tool is not AIFunction f)
if (tool is not AIFunctionDeclaration f)
{
continue;
}
Expand Down
207 changes: 207 additions & 0 deletions extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockImageGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

using Amazon.BedrockRuntime.Model;
using Microsoft.Extensions.AI;
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;

namespace Amazon.BedrockRuntime;

[Experimental("MEAI001")]
internal sealed partial class BedrockImageGenerator : IImageGenerator
{
/// <summary>The wrapped <see cref="IAmazonBedrockRuntime"/> instance.</summary>
private readonly IAmazonBedrockRuntime _runtime;
/// <summary>Default model ID to use when no model is specified in the request.</summary>
private readonly string? _modelId;
/// <summary>Metadata describing the image generator.</summary>
private readonly ImageGeneratorMetadata _metadata;

/// <summary>
/// Initializes a new instance of the <see cref="BedrockImageGenerator"/> class.
/// </summary>
/// <param name="runtime">The <see cref="IAmazonBedrockRuntime"/> instance to wrap.</param>
/// <param name="defaultModelId">Model ID to use as the default when no model ID is specified in a request.</param>
public BedrockImageGenerator(IAmazonBedrockRuntime runtime, string? defaultModelId)
{
Debug.Assert(runtime is not null);

_runtime = runtime!;
_modelId = defaultModelId;

_metadata = new(AmazonBedrockRuntimeExtensions.ProviderName, defaultModelId: defaultModelId);
}

public void Dispose()
{
// Do not dispose of _runtime, as this instance doesn't own it.
}

/// <inheritdoc />

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey)
{
if (serviceType is null)
{
throw new ArgumentNullException(nameof(serviceType));
}

return
serviceKey is not null ? null :
serviceType == typeof(ImageGeneratorMetadata) ? _metadata :
serviceType.IsInstanceOfType(_runtime) ? _runtime :
serviceType.IsInstanceOfType(this) ? this :
null;
}

public async Task<ImageGenerationResponse> GenerateAsync(
ImageGenerationRequest request, ImageGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}

int numImages = options?.Count ?? 1;
if (numImages < 1)
{
throw new ArgumentOutOfRangeException(nameof(options), "The number of images must be at least 1.");
}

InvokeModelRequest invokeRequest = options?.RawRepresentationFactory?.Invoke(this) as InvokeModelRequest ?? new();
invokeRequest.ModelId ??= options?.ModelId ?? _modelId;
invokeRequest.Accept ??= "application/json";
invokeRequest.ContentType ??= "application/json";
if (invokeRequest.Body is null)
{
JsonObject body = new();

// Each model has its own way of specifying the prompt and image generation parameters, unfortunately.
// The following logic handles the most common cases today, but may need to be extended for
// future models.

if (invokeRequest.ModelId?.IndexOf("stability", StringComparison.OrdinalIgnoreCase) >= 0)
{
// Stability AI models
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-stability-diffusion.html

if (invokeRequest.ModelId?.IndexOf("stable-diffusion", StringComparison.OrdinalIgnoreCase) >= 0)
{
JsonArray textPrompts = new();
for (int i = 0; i < numImages; i++)
{
textPrompts.Add((JsonNode)new JsonObject { ["text"] = request.Prompt ?? "" });
}
body["text_prompts"] = textPrompts;

if (options?.ImageSize?.Width is int width && options.ImageSize?.Height is int height)
{
body["width"] = width;
body["height"] = height;
}
}
else
{
body["prompt"] = request.Prompt ?? "";
}
}
else
{
// Amazon models (e.g. Titan, Nova Canvas)
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan.html

JsonObject textToImageParams = new() { ["text"] = request.Prompt ?? "" };
if (request.OriginalImages?.OfType<DataContent>().Where(d => d.HasTopLevelMediaType("image")).FirstOrDefault() is DataContent image)
{
textToImageParams["conditionImage"] = image.Base64Data.ToString();
}

JsonObject imageGenerationConfig = new()
{
["seed"] =
#if NET
Random.Shared.Next(),
#else
new Random().Next(),
#endif
};

if (options?.ImageSize?.Width is int width && options.ImageSize?.Height is int height)
{
imageGenerationConfig["width"] = width;
imageGenerationConfig["height"] = height;
}

if (numImages > 1)
{
imageGenerationConfig["numberOfImages"] = Math.Min(numImages, 5);
}

body["taskType"] = "TEXT_IMAGE";
body["textToImageParams"] = textToImageParams;
body["imageGenerationConfig"] = imageGenerationConfig;
}

invokeRequest.Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(body, BedrockJsonContext.Default.JsonNode));
}

InvokeModelResponse rawResponse = await _runtime.InvokeModelAsync(invokeRequest, cancellationToken).ConfigureAwait(false);

ImageGenerationResponse result = new() { RawRepresentation = rawResponse };

using JsonDocument doc = JsonDocument.Parse(rawResponse.Body);
JsonElement root = doc.RootElement;

const string DefaultGeneratedImageMimeType = "image/png";

if (root.TryGetProperty("artifacts", out JsonElement artifactElement) && artifactElement.ValueKind == JsonValueKind.Array)
{
foreach (var element in artifactElement.EnumerateArray())
{
if (element.TryGetProperty("base64", out JsonElement base64Element) &&
base64Element.ValueKind == JsonValueKind.String)
{
result.Contents.Add(new DataContent(Convert.FromBase64String(base64Element.GetString()!), DefaultGeneratedImageMimeType));
}
}
}
else if (root.TryGetProperty("images", out JsonElement imagesElement) && imagesElement.ValueKind == JsonValueKind.Array)
{
foreach (var image in imagesElement.EnumerateArray())
{
if (image.ValueKind == JsonValueKind.String)
{
result.Contents.Add(new DataContent(Convert.FromBase64String(image.GetString()!), DefaultGeneratedImageMimeType));
}
}
}

if (result.Contents is not { Count: > 0 })
{
throw new InvalidOperationException("Image generation did not produce any images.");
}

return result;
}
}
Loading