From 9d009b5d18ad1e59e61071f077e5b7cd7488dda2 Mon Sep 17 00:00:00 2001 From: "lvxinyu.1117" Date: Wed, 24 Sep 2025 15:25:59 +0800 Subject: [PATCH] feat: Enable Gemini multimodal output and input conversion for Ark, Claude, Ollama, and Gemini --- components/model/ark/chat_completion_api.go | 198 ++++++++-- .../model/ark/chat_completion_api_test.go | 157 +++++++- components/model/ark/chatmodel.go | 2 +- .../generate_with_image.go | 9 +- components/model/ark/go.mod | 2 +- components/model/ark/go.sum | 5 +- components/model/ark/message_extra.go | 52 ++- components/model/ark/responses_api.go | 152 ++++++-- components/model/ark/responses_api_test.go | 176 ++++++++- components/model/arkbot/chatmodel.go | 147 ++++++-- components/model/arkbot/chatmodel_test.go | 171 ++++++++- components/model/arkbot/go.mod | 2 +- components/model/arkbot/go.sum | 5 +- components/model/claude/claude.go | 76 ++++ components/model/claude/claude_test.go | 221 +++++++++++ .../claude/examples/basic_usage/claude.go | 11 +- components/model/claude/go.mod | 15 +- components/model/claude/go.sum | 23 ++ components/model/gemini/examples/gemini.go | 35 ++ components/model/gemini/gemini.go | 355 +++++++++++++++++- components/model/gemini/gemini_test.go | 333 ++++++++++++++-- components/model/gemini/go.mod | 11 +- components/model/gemini/go.sum | 18 + components/model/gemini/message_extra.go | 92 +++++ components/model/gemini/option.go | 7 + components/model/ollama/chatmodel.go | 96 +++-- components/model/ollama/chatmodel_test.go | 273 +++++++++++++- .../model/ollama/examples/image/image.go | 12 +- components/model/ollama/go.mod | 11 +- components/model/ollama/go.sum | 18 + components/model/qwen/go.mod | 2 +- components/model/qwen/go.sum | 2 + 32 files changed, 2472 insertions(+), 217 deletions(-) create mode 100644 components/model/gemini/message_extra.go diff --git a/components/model/ark/chat_completion_api.go b/components/model/ark/chat_completion_api.go index 9f119d302..7d3f3e9c9 100644 --- a/components/model/ark/chat_completion_api.go +++ b/components/model/ark/chat_completion_api.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log" "runtime/debug" "github.com/eino-contrib/jsonschema" @@ -297,7 +298,7 @@ func (cm *completionAPIChatModel) genRequest(in []*schema.Message, options *fmod } for _, msg := range in { - content, e := cm.toArkContent(msg.Content, msg.MultiContent) + content, e := cm.toArkContent(msg) if e != nil { return req, e } @@ -522,44 +523,167 @@ func (cm *completionAPIChatModel) toMessageToolCalls(toolCalls []*model.ToolCall return ret } -func (cm *completionAPIChatModel) toArkContent(content string, multiContent []schema.ChatMessagePart) (*model.ChatCompletionMessageContent, error) { - if len(multiContent) == 0 { - return &model.ChatCompletionMessageContent{StringValue: ptrOf(content)}, nil - } - - parts := make([]*model.ChatCompletionMessageContentPart, 0, len(multiContent)) - - for _, part := range multiContent { - switch part.Type { - case schema.ChatMessagePartTypeText: - parts = append(parts, &model.ChatCompletionMessageContentPart{ - Type: model.ChatCompletionMessageContentPartTypeText, - Text: part.Text, - }) - case schema.ChatMessagePartTypeImageURL: - if part.ImageURL == nil { - return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL") +func (cm *completionAPIChatModel) toArkContent(msg *schema.Message) (*model.ChatCompletionMessageContent, error) { + if len(msg.UserInputMultiContent) == 0 && len(msg.AssistantGenMultiContent) == 0 && len(msg.MultiContent) == 0 { + return &model.ChatCompletionMessageContent{StringValue: ptrOf(msg.Content)}, nil + } + + var parts []*model.ChatCompletionMessageContentPart + if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + + if len(msg.UserInputMultiContent) > 0 { + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.UserInputMultiContent)) + for _, part := range msg.UserInputMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } + var imageURL string + if part.Image.URL != nil && *part.Image.URL != "" { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" { + if part.Image.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when using Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } else { + return nil, fmt.Errorf("image part for user input must contain either a URL or Base64Data, but got: %+v", part.Image) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: imageURL, + Detail: model.ImageURLDetail(part.Image.Detail), + }, + }) + case schema.ChatMessagePartTypeVideoURL: + if part.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in user message") + } + var videoURL string + if part.Video.URL != nil && *part.Video.URL != "" { + videoURL = *part.Video.URL + } else if part.Video.Base64Data != nil && *part.Video.Base64Data != "" { + if part.Video.MIMEType == "" { + return nil, fmt.Errorf("video part must have MIMEType when using Base64Data") + } + videoURL = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType) + } else { + return nil, fmt.Errorf("video part for user input must contain either a URL or Base64Data, but got: %+v", part.Video) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeVideoURL, + VideoURL: &model.ChatMessageVideoURL{ + URL: videoURL, + FPS: GetInputVideoFPS(part.Video), + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type in user message: %s", part.Type) } - parts = append(parts, &model.ChatCompletionMessageContentPart{ - Type: model.ChatCompletionMessageContentPartTypeImageURL, - ImageURL: &model.ChatMessageImageURL{ - URL: part.ImageURL.URL, - Detail: model.ImageURLDetail(part.ImageURL.Detail), - }, - }) - case schema.ChatMessagePartTypeVideoURL: - if part.VideoURL == nil { - return nil, fmt.Errorf("VideoURL field must not be nil when Type is ChatMessagePartTypeVideoURL") + } + } else if len(msg.AssistantGenMultiContent) > 0 { + if msg.Role != schema.Assistant { + return nil, fmt.Errorf("AssistantGenMultiContent only used when Role is Assistant, but got role: %s", msg.Role) + } + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.AssistantGenMultiContent)) + for _, part := range msg.AssistantGenMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + } + var imageURL string + if part.Image.URL != nil && *part.Image.URL != "" { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" { + if part.Image.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when using Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } else { + return nil, fmt.Errorf("image part for assistant output must contain either a URL or Base64Data, but got: %+v", part.Image) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: imageURL, + }, + }) + case schema.ChatMessagePartTypeVideoURL: + if part.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in assistant message") + } + var videoURL string + if part.Video.URL != nil && *part.Video.URL != "" { + videoURL = *part.Video.URL + } else if part.Video.Base64Data != nil && *part.Video.Base64Data != "" { + if part.Video.MIMEType == "" { + return nil, fmt.Errorf("video part must have MIMEType when using Base64Data") + } + videoURL = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType) + } else { + return nil, fmt.Errorf("video part for assistant output must contain either a URL or Base64Data, but got: %+v", part.Video) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeVideoURL, + VideoURL: &model.ChatMessageVideoURL{ + URL: videoURL, + FPS: GetOutputVideoFPS(part.Video), + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type in assistant message: %s", part.Type) + } + } + } else if len(msg.MultiContent) > 0 { + log.Printf("warning: MultiContent is deprecated, use UserInputMultiContent or AssistantGenMultiContent instead") + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.MultiContent)) + for _, part := range msg.MultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.ImageURL == nil { + return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL") + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: part.ImageURL.URL, + Detail: model.ImageURLDetail(part.ImageURL.Detail), + }, + }) + case schema.ChatMessagePartTypeVideoURL: + if part.VideoURL == nil { + return nil, fmt.Errorf("VideoURL field must not be nil when Type is ChatMessagePartTypeVideoURL") + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeVideoURL, + VideoURL: &model.ChatMessageVideoURL{ + URL: part.VideoURL.URL, + FPS: GetFPS(part.VideoURL), + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type) } - parts = append(parts, &model.ChatCompletionMessageContentPart{ - Type: model.ChatCompletionMessageContentPartTypeVideoURL, - VideoURL: &model.ChatMessageVideoURL{ - URL: part.VideoURL.URL, - FPS: GetFPS(part.VideoURL), - }, - }) - default: - return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type) } } diff --git a/components/model/ark/chat_completion_api_test.go b/components/model/ark/chat_completion_api_test.go index dd726e4b3..38fa6dcfb 100644 --- a/components/model/ark/chat_completion_api_test.go +++ b/components/model/ark/chat_completion_api_test.go @@ -280,7 +280,7 @@ func TestChatCompletionAPIGenerate(t *testing.T) { }, } - req, err := m.chatModel.toArkContent(multiModalMsg.Content, multiModalMsg.MultiContent) + req, err := m.chatModel.toArkContent(multiModalMsg) convey.So(err, convey.ShouldBeNil) convey.So(req.StringValue, convey.ShouldBeNil) convey.So(req.ListValue, convey.ShouldHaveLength, 2) @@ -330,3 +330,158 @@ func TestChatCompletionAPILogProbs(t *testing.T) { }, }})) } + +func TestCompletionAPIChatModel_toArkContent(t *testing.T) { + cm := &completionAPIChatModel{} + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + httpURL := "https://example.com/image.png" + videoURL := "https://example.com/video.mp4" + + PatchConvey("Test toArkContent Comprehensive", t, func() { + PatchConvey("Pure Text Content", func() { + msg := &schema.Message{Content: "just text"} + content, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(*content.StringValue, convey.ShouldEqual, "just text") + }) + + PatchConvey("UserInputMultiContent", func() { + PatchConvey("Success with all types", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "some text"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + {Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}}}, + }, + } + content, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 3) + }) + PatchConvey("Error on nil image", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on empty image data", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on nil video", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on empty video data", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("Error on missing MIMEType for image", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("Error on missing MIMEType for video", func() { + msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + PatchConvey("AssistantGenMultiContent", func() { + PatchConvey("Success with image and video", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "some text"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + {Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{URL: &videoURL}}}, + }, + } + content, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 3) + }) + PatchConvey("Error on wrong role", func() { + msg := &schema.Message{Role: schema.User, AssistantGenMultiContent: []schema.MessageOutputPart{{}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on nil image", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on empty image data", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on nil video", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on empty video data", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on unsupported type", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: "unsupported"}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("Error on missing MIMEType for image", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("Error on missing MIMEType for video", func() { + msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + PatchConvey("MultiContent (Legacy)", func() { + PatchConvey("Success with all types", func() { + msg := &schema.Message{ + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: "some text"}, + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}}, + {Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: videoURL}}, + }, + } + content, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 3) + }) + PatchConvey("Error on nil ImageURL", func() { + msg := &schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeImageURL, ImageURL: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + PatchConvey("Error on nil VideoURL", func() { + msg := &schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: nil}}} + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + PatchConvey("Error on both UserInputMultiContent and AssistantGenMultiContent", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, + AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, + } + _, err := cm.toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} diff --git a/components/model/ark/chatmodel.go b/components/model/ark/chatmodel.go index 082778e60..1dc2a0fbb 100644 --- a/components/model/ark/chatmodel.go +++ b/components/model/ark/chatmodel.go @@ -521,7 +521,7 @@ func (cm *ChatModel) createContextByContextAPI(ctx context.Context, prefix []*sc TruncationStrategy: truncation, } for _, msg := range prefix { - content, err := cm.chatModel.toArkContent(msg.Content, msg.MultiContent) + content, err := cm.chatModel.toArkContent(msg) if err != nil { return nil, fmt.Errorf("convert message fail: %w", err) } diff --git a/components/model/ark/examples/generate_with_image/generate_with_image.go b/components/model/ark/examples/generate_with_image/generate_with_image.go index b33fbd735..d171c9033 100644 --- a/components/model/ark/examples/generate_with_image/generate_with_image.go +++ b/components/model/ark/examples/generate_with_image/generate_with_image.go @@ -39,15 +39,18 @@ func main() { } multiModalMsg := schema.UserMessage("") - multiModalMsg.MultiContent = []schema.ChatMessagePart{ + var url = "https://d2908q01vomqb2.cloudfront.net/887309d048beef83ad3eabf2a79a64a389ab1c9f/2023/07/13/DBBLOG-3334-image001.png" + multiModalMsg.UserInputMultiContent = []schema.MessageInputPart{ { Type: schema.ChatMessagePartTypeText, Text: "this picture is LangDChain's architecture, what's the picture's content", }, { Type: schema.ChatMessagePartTypeImageURL, - ImageURL: &schema.ChatMessageImageURL{ - URL: "https://d2908q01vomqb2.cloudfront.net/887309d048beef83ad3eabf2a79a64a389ab1c9f/2023/07/13/DBBLOG-3334-image001.png", + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: &url, + }, Detail: schema.ImageURLDetailAuto, }, }, diff --git a/components/model/ark/go.mod b/components/model/ark/go.mod index b3470bf89..eae4c3fe9 100644 --- a/components/model/ark/go.mod +++ b/components/model/ark/go.mod @@ -9,7 +9,7 @@ require ( github.com/eino-contrib/jsonschema v1.0.1 github.com/openai/openai-go v1.10.1 github.com/smartystreets/goconvey v1.8.1 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/volcengine/volcengine-go-sdk v1.1.37 ) diff --git a/components/model/ark/go.sum b/components/model/ark/go.sum index 3800eabae..fb7ff6a7f 100644 --- a/components/model/ark/go.sum +++ b/components/model/ark/go.sum @@ -155,8 +155,11 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/components/model/ark/message_extra.go b/components/model/ark/message_extra.go index 24b3c626c..990963bad 100644 --- a/components/model/ark/message_extra.go +++ b/components/model/ark/message_extra.go @@ -157,14 +157,62 @@ func SetFPS(part *schema.ChatMessageVideoURL, fps float64) { if part == nil { return } - part.Extra[videoURLFPS] = fps + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setFPS(part.Extra, fps) } func GetFPS(part *schema.ChatMessageVideoURL) *float64 { if part == nil { return nil } - fps, ok := part.Extra[videoURLFPS].(float64) + return getFPS(part.Extra) +} + +func SetInputVideoFPS(part *schema.MessageInputVideo, fps float64) { + if part == nil { + return + } + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setFPS(part.Extra, fps) +} + +func GetInputVideoFPS(part *schema.MessageInputVideo) *float64 { + if part == nil { + return nil + } + return getFPS(part.Extra) +} + +func SetOutputVideoFPS(part *schema.MessageOutputVideo, fps float64) { + if part == nil { + return + } + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setFPS(part.Extra, fps) +} + +func GetOutputVideoFPS(part *schema.MessageOutputVideo) *float64 { + if part == nil { + return nil + } + return getFPS(part.Extra) +} + +func setFPS(extra map[string]any, fps float64) { + extra[videoURLFPS] = fps +} + +func getFPS(extra map[string]any) *float64 { + if extra == nil { + return nil + } + fps, ok := extra[videoURLFPS].(float64) if !ok { return nil } diff --git a/components/model/ark/responses_api.go b/components/model/ark/responses_api.go index a69e90b23..7e2cc79e5 100644 --- a/components/model/ark/responses_api.go +++ b/components/model/ark/responses_api.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "runtime/debug" + "strings" "time" "github.com/bytedance/sonic" @@ -658,7 +659,7 @@ func (cm *responsesAPIChatModel) toOpenaiMultiModalContent(msg *schema.Message) content := responses.EasyInputMessageContentUnionParam{} if msg.Content != "" { - if len(msg.MultiContent) == 0 { + if len(msg.MultiContent) == 0 && len(msg.UserInputMultiContent) == 0 && len(msg.AssistantGenMultiContent) == 0 { content.OfString = param.NewOpt(msg.Content) return content, nil } @@ -670,37 +671,122 @@ func (cm *responsesAPIChatModel) toOpenaiMultiModalContent(msg *schema.Message) }) } - for _, c := range msg.MultiContent { - switch c.Type { - case schema.ChatMessagePartTypeText: - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{ - Text: c.Text, - }, - }) + if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { + return content, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } - case schema.ChatMessagePartTypeImageURL: - if c.ImageURL == nil { - continue + if len(msg.UserInputMultiContent) > 0 { + for _, part := range msg.UserInputMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{ + Text: part.Text, + }, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return content, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } else { + var imageURL string + if part.Image.URL != nil { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil { + if part.Image.MIMEType == "" { + return content, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(imageURL), + }, + }) + } + case schema.ChatMessagePartTypeFileURL: + if part.File == nil { + return content, fmt.Errorf("file field must not be nil when Type is ChatMessagePartTypeFileURL in user message") + } else { + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputFile: &responses.ResponseInputFileParam{ + FileURL: param.NewOpt(*part.File.URL), + }, + }) + } + default: + return content, fmt.Errorf("unsupported content type in UserInputMultiContent: %s", part.Type) } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: param.NewOpt(c.ImageURL.URL), - }, - }) - - case schema.ChatMessagePartTypeFileURL: - if c.FileURL == nil { - continue + } + return content, nil + } else if len(msg.AssistantGenMultiContent) > 0 { + if msg.Role != schema.Assistant { + return content, fmt.Errorf("AssistantGenMultiContent is only allowed for messages with role 'assistant', but got role '%s'", msg.Role) + } + for _, part := range msg.AssistantGenMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{ + Text: part.Text, + }, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return content, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + } else { + var imageURL string + if part.Image.URL != nil { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil { + if part.Image.MIMEType == "" { + return content, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(imageURL), + }, + }) + } + default: + return content, fmt.Errorf("unsupported content type in AssistantGenMultiContent: %s", part.Type) } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputFile: &responses.ResponseInputFileParam{ - FileURL: param.NewOpt(c.FileURL.URL), - }, - }) + } + return content, nil + } else { + for _, c := range msg.MultiContent { + switch c.Type { + case schema.ChatMessagePartTypeText: + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputText: &responses.ResponseInputTextParam{ + Text: c.Text, + }, + }) - default: - return content, fmt.Errorf("unsupported content type: %s", c.Type) + case schema.ChatMessagePartTypeImageURL: + if c.ImageURL == nil { + continue + } + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(c.ImageURL.URL), + }, + }) + + case schema.ChatMessagePartTypeFileURL: + if c.FileURL == nil { + continue + } + content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ + OfInputFile: &responses.ResponseInputFileParam{ + FileURL: param.NewOpt(c.FileURL.URL), + }, + }) + + default: + return content, fmt.Errorf("unsupported content type: %s", c.Type) + } } } @@ -839,3 +925,13 @@ func (cm *responsesAPIChatModel) getOptions(opts []model.Option) (*model.Options return options, arkOpts, nil } + +func ensureDataURL(dataOfBase64, mimeType string) string { + if strings.HasPrefix(dataOfBase64, "data:") { + return dataOfBase64 + } + if mimeType == "" { + return dataOfBase64 + } + return fmt.Sprintf("data:%s;base64,%s", mimeType, dataOfBase64) +} diff --git a/components/model/ark/responses_api_test.go b/components/model/ark/responses_api_test.go index 436bea721..b03d40c14 100644 --- a/components/model/ark/responses_api_test.go +++ b/components/model/ark/responses_api_test.go @@ -211,7 +211,7 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { { Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{ - URL: "http://example.com/image.png", + URL: "https://example.com/image.png", }, }, }, @@ -222,7 +222,7 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { contentList := content.OfInputItemContentList assert.Equal(t, 1, len(contentList)) - assert.Equal(t, "http://example.com/image.png", contentList[0].OfInputImage.ImageURL.Value) + assert.Equal(t, "https://example.com/image.png", contentList[0].OfInputImage.ImageURL.Value) }) PatchConvey("text and file message", t, func() { @@ -233,7 +233,7 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { { Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{ - URL: "http://example.com/file.pdf", + URL: "https://example.com/file.pdf", }, }, }, @@ -245,7 +245,7 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { contentList := content.OfInputItemContentList assert.Equal(t, 2, len(contentList)) assert.Equal(t, "Here is the file.", contentList[0].OfInputText.Text) - assert.Equal(t, "http://example.com/file.pdf", contentList[1].OfInputFile.FileURL.Value) + assert.Equal(t, "https://example.com/file.pdf", contentList[1].OfInputFile.FileURL.Value) }) PatchConvey("unknown modal type", t, func() { @@ -409,10 +409,9 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { }) } -func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { +func TestResponsesAPIChatModelReceivedStreamResponse_ResponseCreatedEvent(t *testing.T) { cm := &responsesAPIChatModel{} streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} - PatchConvey("ResponseCreatedEvent", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(Sequence(true).Then(false)).Build() @@ -426,7 +425,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_ResponseCompletedEvent(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseCompletedEvent", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(true).Build() @@ -441,7 +444,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_ResponseErrorEvent(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseErrorEvent", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(true).Build() @@ -457,7 +464,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_ResponseIncompleteEvent(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseIncompleteEvent", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(Sequence(true).Then(false)).Build() @@ -472,7 +483,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_ResponseFailedEvent(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseFailedEvent", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(true).Build() @@ -487,7 +502,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_Default(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("Default", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(Sequence(true).Then(false)).Build() @@ -502,7 +521,11 @@ func TestResponsesAPIChatModelReceivedStreamResponse(t *testing.T) { cm.receivedStreamResponse(streamResp, nil, true, nil) assert.Equal(t, 1, mocker.Times()) }) +} +func TestResponsesAPIChatModelReceivedStreamResponse_ToolCallMetaMsg(t *testing.T) { + cm := &responsesAPIChatModel{} + streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("toolCallMetaMsg", t, func() { MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). Return(Sequence(true).Then(true).Then(false)).Build() @@ -686,3 +709,144 @@ func TestGetArkRequestID(t *testing.T) { t.Log("eq") } } + +func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { + cm := &responsesAPIChatModel{} + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + httpURL := "https://example.com/image.png" + fileURL := "https://example.com/file.pdf" + + PatchConvey("Test toOpenaiMultiModalContent Comprehensive", t, func() { + PatchConvey("Pure Text Content", func() { + msg := &schema.Message{Content: "just text"} + content, err := cm.toOpenaiMultiModalContent(msg) + assert.Nil(t, err) + assert.Equal(t, "just text", content.OfString.Value) + }) + + PatchConvey("UserInputMultiContent", func() { + PatchConvey("Success with all types", func() { + msg := &schema.Message{ + Content: "initial text", + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: " more text"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, + {Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{URL: &fileURL}}}, + }, + } + content, err := cm.toOpenaiMultiModalContent(msg) + assert.Nil(t, err) + assert.Len(t, content.OfInputItemContentList, 5) // initial text + 4 parts + }) + + PatchConvey("Error on missing MIMEType for Base64", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, + }, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") + }) + + PatchConvey("Error on nil Image", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "image field must not be nil") + }) + + PatchConvey("Error on nil File", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeFileURL, File: nil}, + }, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "file field must not be nil") + }) + }) + + PatchConvey("AssistantGenMultiContent", func() { + PatchConvey("Success with all types", func() { + msg := &schema.Message{ + Role: schema.Assistant, + Content: "assistant text", + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeText, Text: " more assistant text"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, + }, + } + content, err := cm.toOpenaiMultiModalContent(msg) + assert.Nil(t, err) + assert.Len(t, content.OfInputItemContentList, 4) // initial text + 3 parts + }) + + PatchConvey("Error on wrong role", func() { + msg := &schema.Message{ + Role: schema.User, + AssistantGenMultiContent: []schema.MessageOutputPart{{}}, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "AssistantGenMultiContent is only allowed for messages with role 'assistant'") + }) + + PatchConvey("Error on nil Image", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "image field must not be nil") + }) + + PatchConvey("Error on missing MIMEType for Base64", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, + }, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") + }) + }) + + PatchConvey("MultiContent (Legacy)", func() { + msg := &schema.Message{ + Content: "legacy text", + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: " more legacy text"}, + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}}, + {Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: fileURL}}, + }, + } + content, err := cm.toOpenaiMultiModalContent(msg) + assert.Nil(t, err) + assert.Len(t, content.OfInputItemContentList, 4) // initial text + 3 parts + }) + + PatchConvey("Error on both UserInputMultiContent and AssistantGenMultiContent", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, + AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, + } + _, err := cm.toOpenaiMultiModalContent(msg) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + }) + }) +} diff --git a/components/model/arkbot/chatmodel.go b/components/model/arkbot/chatmodel.go index 4d646cce9..b425e3900 100644 --- a/components/model/arkbot/chatmodel.go +++ b/components/model/arkbot/chatmodel.go @@ -21,8 +21,10 @@ import ( "errors" "fmt" "io" + "log" "net/http" "runtime/debug" + "strings" "time" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" @@ -216,8 +218,9 @@ func (cm *ChatModel) CreatePrefixCache(ctx context.Context, prefix []*schema.Mes Messages: make([]*model.ChatCompletionMessage, 0, len(prefix)), TTL: nil, } + for _, msg := range prefix { - content, err := toArkContent(msg.Content, msg.MultiContent) + content, err := toArkContent(msg) if err != nil { return nil, fmt.Errorf("create prefix fail, convert message fail: %w", err) } @@ -457,7 +460,7 @@ func (cm *ChatModel) genRequest(in []*schema.Message, options *fmodel.Options) ( } for _, msg := range in { - content, e := toArkContent(msg.Content, msg.MultiContent) + content, e := toArkContent(msg) if e != nil { return req, e } @@ -713,34 +716,112 @@ func toMessageToolCalls(toolCalls []*model.ToolCall) []schema.ToolCall { return ret } -func toArkContent(content string, multiContent []schema.ChatMessagePart) (*model.ChatCompletionMessageContent, error) { - if len(multiContent) == 0 { - return &model.ChatCompletionMessageContent{StringValue: ptrOf(content)}, nil - } - - parts := make([]*model.ChatCompletionMessageContentPart, 0, len(multiContent)) - - for _, part := range multiContent { - switch part.Type { - case schema.ChatMessagePartTypeText: - parts = append(parts, &model.ChatCompletionMessageContentPart{ - Type: model.ChatCompletionMessageContentPartTypeText, - Text: part.Text, - }) - case schema.ChatMessagePartTypeImageURL: - if part.ImageURL == nil { - return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL") +func toArkContent(msg *schema.Message) (*model.ChatCompletionMessageContent, error) { + var parts []*model.ChatCompletionMessageContentPart + + if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + + if len(msg.UserInputMultiContent) > 0 { + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.UserInputMultiContent)) + for _, part := range msg.UserInputMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } + var imageURL string + if part.Image.URL != nil && *part.Image.URL != "" { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" { + if part.Image.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when using Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } else { + return nil, fmt.Errorf("image part for user input must contain either a URL or Base64Data, but got: %+v", part.Image) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: imageURL, + Detail: model.ImageURLDetail(part.Image.Detail), + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type in user message: %s", part.Type) + } + } + } else if len(msg.AssistantGenMultiContent) > 0 { + if msg.Role != schema.Assistant { + return nil, fmt.Errorf("AssistantGenMultiContent only used when Role is Assistant, but got role: %s", msg.Role) + } + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.AssistantGenMultiContent)) + for _, part := range msg.AssistantGenMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + } + var imageURL string + if part.Image.URL != nil && *part.Image.URL != "" { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" { + if part.Image.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when using Base64Data") + } + imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + } else { + return nil, fmt.Errorf("image part for assistant output must contain either a URL or Base64Data, but got: %+v", part.Image) + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: imageURL, + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type in assistant message: %s", part.Type) + } + } + } else if len(msg.MultiContent) > 0 { + log.Printf("warning: MultiContent is deprecated, use UserInputMultiContent or AssistantGenMultiContent instead") + parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.MultiContent)) + for _, part := range msg.MultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: part.Text, + }) + case schema.ChatMessagePartTypeImageURL: + if part.ImageURL == nil { + return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL") + } + parts = append(parts, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: part.ImageURL.URL, + Detail: model.ImageURLDetail(part.ImageURL.Detail), + }, + }) + default: + return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type) } - parts = append(parts, &model.ChatCompletionMessageContentPart{ - Type: model.ChatCompletionMessageContentPartTypeImageURL, - ImageURL: &model.ChatMessageImageURL{ - URL: part.ImageURL.URL, - Detail: model.ImageURLDetail(part.ImageURL.Detail), - }, - }) - default: - return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type) } + } else { + return &model.ChatCompletionMessageContent{StringValue: ptrOf(msg.Content)}, nil } return &model.ChatCompletionMessageContent{ @@ -795,6 +876,16 @@ func toTools(tls []*schema.ToolInfo) ([]tool, error) { return tools, nil } +func ensureDataURL(dataOfBase64, mimeType string) string { + if strings.HasPrefix(dataOfBase64, "data:") { + return dataOfBase64 + } + if mimeType == "" { + return dataOfBase64 + } + return fmt.Sprintf("data:%s;base64,%s", mimeType, dataOfBase64) +} + func closeArkStreamReader(r *autils.BotChatCompletionStreamReader) error { if r == nil || r.Response == nil || r.Response.Body == nil { return nil diff --git a/components/model/arkbot/chatmodel_test.go b/components/model/arkbot/chatmodel_test.go index 2ff4ca2be..575db0558 100644 --- a/components/model/arkbot/chatmodel_test.go +++ b/components/model/arkbot/chatmodel_test.go @@ -176,7 +176,7 @@ func Test_Generate(t *testing.T) { }, } - req, err := toArkContent(multiModalMsg.Content, multiModalMsg.MultiContent) + req, err := toArkContent(multiModalMsg) convey.So(err, convey.ShouldBeNil) convey.So(req.StringValue, convey.ShouldBeNil) convey.So(req.ListValue, convey.ShouldHaveLength, 2) @@ -370,3 +370,172 @@ func TestLogProbs(t *testing.T) { }, }})) } + +func Test_toArkContent(t *testing.T) { + PatchConvey("test toArkContent", t, func() { + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + httpURL := "https://example.com/image.png" + + PatchConvey("UserInputMultiContent", func() { + PatchConvey("success", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + }, + } + content, err := toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 3) + convey.So(content.ListValue[0].Text, convey.ShouldEqual, "hello") + convey.So(content.ListValue[1].ImageURL.URL, convey.ShouldContainSubstring, "data:image/png;base64,") + convey.So(content.ListValue[2].ImageURL.URL, convey.ShouldEqual, httpURL) + }) + + PatchConvey("nil image", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("empty image", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{}}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("no mime type", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + PatchConvey("AssistantGenMultiContent", func() { + PatchConvey("success", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "assistant response"}, + }, + } + content, err := toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 1) + }) + + PatchConvey("success with image", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: &httpURL, + }, + }, + }, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &base64Data, + MIMEType: "image/png", + }, + }, + }, + }, + } + content, err := toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 2) + convey.So(content.ListValue[0].ImageURL.URL, convey.ShouldEqual, httpURL) + convey.So(content.ListValue[1].ImageURL.URL, convey.ShouldContainSubstring, "data:image/png;base64,") + }) + + PatchConvey("wrong role", func() { + msg := &schema.Message{ + Role: schema.User, + AssistantGenMultiContent: []schema.MessageOutputPart{{}}, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("nil image", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("empty image", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{}}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + + PatchConvey("no mime type", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, + }, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) + + PatchConvey("MultiContent", func() { + msg := &schema.Message{ + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: "legacy"}, + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}}, + }, + } + content, err := toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(content.ListValue, convey.ShouldHaveLength, 2) + }) + + PatchConvey("Text Content", func() { + msg := &schema.Message{Content: "just text"} + content, err := toArkContent(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(*content.StringValue, convey.ShouldEqual, "just text") + }) + + PatchConvey("both UserInputMultiContent and AssistantGenMultiContent", func() { + msg := &schema.Message{ + UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, + AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, + } + _, err := toArkContent(msg) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} diff --git a/components/model/arkbot/go.mod b/components/model/arkbot/go.mod index 603324c81..03ee7795f 100644 --- a/components/model/arkbot/go.mod +++ b/components/model/arkbot/go.mod @@ -7,7 +7,7 @@ require ( github.com/cloudwego/eino v0.5.5 github.com/eino-contrib/jsonschema v1.0.1 github.com/smartystreets/goconvey v1.8.1 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/volcengine/volcengine-go-sdk v1.1.16 ) diff --git a/components/model/arkbot/go.sum b/components/model/arkbot/go.sum index 565f97870..fe7385623 100644 --- a/components/model/arkbot/go.sum +++ b/components/model/arkbot/go.sum @@ -149,8 +149,11 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index 6f4054b32..4b510518b 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "runtime/debug" "strings" @@ -612,13 +613,82 @@ func convSchemaMessage(message *schema.Message) (mp anthropic.MessageParam, err } } + if len(message.UserInputMultiContent) > 0 && len(message.AssistantGenMultiContent) > 0 { + return mp, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + if len(message.Content) > 0 { if len(message.ToolCallID) > 0 { messageParams = append(messageParams, anthropic.NewToolResultBlock(message.ToolCallID, message.Content, false)) } else { messageParams = append(messageParams, anthropic.NewTextBlock(message.Content)) } + } else if len(message.UserInputMultiContent) > 0 { + for i := range message.UserInputMultiContent { + switch message.UserInputMultiContent[i].Type { + case schema.ChatMessagePartTypeText: + messageParams = append(messageParams, anthropic.NewTextBlock(message.UserInputMultiContent[i].Text)) + case schema.ChatMessagePartTypeImageURL: + if message.UserInputMultiContent[i].Image == nil { + return mp, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } + image := message.UserInputMultiContent[i].Image + if image.URL != nil && *image.URL != "" { + messageParams = append(messageParams, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ + URL: *image.URL, + })) + } else if image.Base64Data != nil && *image.Base64Data != "" { + if image.MIMEType == "" { + return mp, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + if strings.HasPrefix(*image.Base64Data, "data:") { + return mp, fmt.Errorf("Base64Data should be a raw base64 string, but it has a 'data:' prefix") + } + messageParams = append(messageParams, anthropic.NewImageBlockBase64(image.MIMEType, *image.Base64Data)) + } else { + return mp, fmt.Errorf("image part must have either a URL or Base64Data") + } + default: + return mp, fmt.Errorf("anthropic message type not supported: %s", message.UserInputMultiContent[i].Type) + } + } + } else if len(message.AssistantGenMultiContent) > 0 { + if message.Role != schema.Assistant { + return mp, fmt.Errorf("AssistantGenMultiContent is only allowed for messages with role 'assistant', but got role '%s'", message.Role) + } + for i := range message.AssistantGenMultiContent { + switch message.AssistantGenMultiContent[i].Type { + case schema.ChatMessagePartTypeText: + messageParams = append(messageParams, anthropic.NewTextBlock(message.AssistantGenMultiContent[i].Text)) + case schema.ChatMessagePartTypeImageURL: + if message.AssistantGenMultiContent[i].Image == nil { + return mp, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + } + image := message.AssistantGenMultiContent[i].Image + if image.URL != nil && *image.URL != "" { + messageParams = append(messageParams, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ + URL: *image.URL, + })) + } else if image.Base64Data != nil && *image.Base64Data != "" { + if image.MIMEType == "" { + return mp, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + if strings.HasPrefix(*image.Base64Data, "data:") { + return mp, fmt.Errorf("Base64Data should be a raw base64 string, but it has a 'data:' prefix") + } + messageParams = append(messageParams, anthropic.NewImageBlockBase64(image.MIMEType, *image.Base64Data)) + } else { + return mp, fmt.Errorf("image part must have either a URL or Base64Data") + } + default: + return mp, fmt.Errorf("anthropic message type not supported: %s", message.AssistantGenMultiContent[i].Type) + } + } } else { + // The `MultiContent` field is deprecated. In its design, the `URL` field of `ImageURL` + // could contain either an HTTP URL or a Base64-encoded DATA URL. This is different from the new + // `UserInputMultiContent` and `AssistantGenMultiContent` fields, where `URL` and `Base64Data` are separate. + log.Printf("MultiContent is deprecated, please use UserInputMultiContent or AssistantGenMultiContent instead") for i := range message.MultiContent { switch message.MultiContent[i].Type { case schema.ChatMessagePartTypeText: @@ -627,6 +697,12 @@ func convSchemaMessage(message *schema.Message) (mp anthropic.MessageParam, err if message.MultiContent[i].ImageURL == nil { continue } + if strings.HasPrefix(message.MultiContent[i].ImageURL.URL, "http") { + messageParams = append(messageParams, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ + URL: message.MultiContent[i].ImageURL.URL, + })) + continue + } mediaType, data, err_ := convImageBase64(message.MultiContent[i].ImageURL.URL) if err_ != nil { return mp, fmt.Errorf("extract base64 image fail: %w", err_) diff --git a/components/model/claude/claude_test.go b/components/model/claude/claude_test.go index 1742d0113..55cccfbb1 100644 --- a/components/model/claude/claude_test.go +++ b/components/model/claude/claude_test.go @@ -354,3 +354,224 @@ func TestInjectContentBlockBreakPoint(t *testing.T) { injectContentBlockBreakPoint(lastBlock) assert.NotEmpty(t, lastBlock.OfToolResult.CacheControl.Type) } + +func Test_convSchemaMessage_MultiContent(t *testing.T) { + rawBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + invalidDataURL := "data:image/png;base64," + rawBase64 + httpURL := "https://example.com/image.png" + + t.Run("UserInputMultiContent", func(t *testing.T) { + t.Run("success with base64", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &rawBase64, MIMEType: "image/png"}}}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 2) + assert.Equal(t, "hello", result.Content[0].OfText.Text) + assert.Equal(t, anthropic.Base64ImageSourceMediaType("image/png"), result.Content[1].OfImage.Source.OfBase64.MediaType) + }) + + t.Run("success with url", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 1) + assert.Equal(t, httpURL, result.Content[0].OfImage.Source.OfURL.URL) + }) + + t.Run("error with data url prefix", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidDataURL, MIMEType: "image/png"}}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "Base64Data should be a raw base64 string") + }) + + t.Run("error with no mime type for base64", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &rawBase64}}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") + }) + + t.Run("error with no url or base64", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image part must have either a URL or Base64Data") + }) + }) + + t.Run("AssistantGenMultiContent", func(t *testing.T) { + t.Run("success with image", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &rawBase64, + MIMEType: "image/png", + }, + }, + }, + {Type: schema.ChatMessagePartTypeText, Text: "some text"}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 2) + assert.Equal(t, anthropic.Base64ImageSourceMediaType("image/png"), result.Content[0].OfImage.Source.OfBase64.MediaType) + assert.Equal(t, "some text", result.Content[1].OfText.Text) + }) + + t.Run("success with url", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 1) + assert.Equal(t, httpURL, result.Content[0].OfImage.Source.OfURL.URL) + }) + + t.Run("error with wrong role", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "some text"}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "AssistantGenMultiContent is only allowed for messages with role 'assistant', but got role 'user'") + }) + + t.Run("error with data url prefix", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidDataURL, MIMEType: "image/png"}}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "Base64Data should be a raw base64 string") + }) + + t.Run("error with no mime type for base64", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &rawBase64}}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") + }) + + t.Run("error with no url or base64", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{}}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image part must have either a URL or Base64Data") + }) + }) + + t.Run("MultiContent backward compatibility", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: "legacy"}, + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: invalidDataURL}}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 2) + assert.Equal(t, "legacy", result.Content[0].OfText.Text) + assert.Equal(t, anthropic.Base64ImageSourceMediaType("image/png"), result.Content[1].OfImage.Source.OfBase64.MediaType) + }) + + t.Run("MultiContent backward compatibility with http url", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}}, + }, + } + result, err := convSchemaMessage(msg) + assert.NoError(t, err) + assert.Len(t, result.Content, 1) + assert.Equal(t, httpURL, result.Content[0].OfImage.Source.OfURL.URL) + }) + + t.Run("error with both UserInputMultiContent and AssistantGenMultiContent", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, + AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + }) + + t.Run("error with nil image in UserInputMultiContent", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image field must not be nil") + }) + + t.Run("error with nil image in AssistantGenMultiContent", func(t *testing.T) { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, + }, + } + _, err := convSchemaMessage(msg) + assert.Error(t, err) + assert.ErrorContains(t, err, "image field must not be nil") + }) +} diff --git a/components/model/claude/examples/basic_usage/claude.go b/components/model/claude/examples/basic_usage/claude.go index e41fc8374..90f76b03b 100644 --- a/components/model/claude/examples/basic_usage/claude.go +++ b/components/model/claude/examples/basic_usage/claude.go @@ -242,19 +242,22 @@ func imageProcessing(ctx context.Context, cm model.BaseChatModel) { if err != nil { log.Fatalf("read file failed, err=%v", err) } + base64Str := base64.StdEncoding.EncodeToString(imageBinary) resp, err := cm.Generate(ctx, []*schema.Message{ { Role: schema.User, - MultiContent: []schema.ChatMessagePart{ + UserInputMultiContent: []schema.MessageInputPart{ { Type: schema.ChatMessagePartTypeText, Text: "What do you see in this image?", }, { Type: schema.ChatMessagePartTypeImageURL, - ImageURL: &schema.ChatMessageImageURL{ - URL: "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(imageBinary), - MIMEType: "image/jpeg", + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &base64Str, + MIMEType: "image/jpeg", + }, }, }, }, diff --git a/components/model/claude/go.mod b/components/model/claude/go.mod index c72143c57..4601b2a3c 100644 --- a/components/model/claude/go.mod +++ b/components/model/claude/go.mod @@ -7,9 +7,9 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.29.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.54 github.com/bytedance/mockey v1.2.13 - github.com/cloudwego/eino v0.4.7 - github.com/eino-contrib/jsonschema v1.0.0 - github.com/stretchr/testify v1.9.0 + github.com/cloudwego/eino v0.5.5 + github.com/eino-contrib/jsonschema v1.0.1 + github.com/stretchr/testify v1.11.1 github.com/wk8/go-ordered-map/v2 v2.1.8 ) @@ -28,9 +28,10 @@ require ( github.com/aws/smithy-go v1.22.1 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/bytedance/sonic v1.13.2 // indirect - github.com/bytedance/sonic/loader v0.2.4 // indirect - github.com/cloudwego/base64x v0.1.5 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/getkin/kin-openapi v0.118.0 // indirect @@ -42,7 +43,7 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect - github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/components/model/claude/go.sum b/components/model/claude/go.sum index 6d9a2bcd5..cd60a6393 100644 --- a/components/model/claude/go.sum +++ b/components/model/claude/go.sum @@ -37,18 +37,34 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/eino v0.4.7 h1:wwqsFWCuzCQuhw1dYKqHjGWULzjDjFfN9sTn/cezYV4= github.com/cloudwego/eino v0.4.7/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg= +github.com/cloudwego/eino v0.5.4-0.20250917102129-a48857bf4c79 h1:voV9TjDX/AMHDf95wUgutcMZWGvWW92qKSBo7N3o0Ek= +github.com/cloudwego/eino v0.5.4-0.20250917102129-a48857bf4c79/go.mod h1:S38tlNO4cNqFfGJKQSJZimxjzc9JDJKdf2eW3FEEfdc= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b h1:pw2BTeX19inORUtLJ00iRA/dwcLHAE8jqMaIUXwd6cg= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b/go.mod h1:JxKeWsO8iUZfKh3iE4iN0JCvCzYLRNyqjSag/RisPbc= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521 h1:WIk/K6QcstEj2+xtBWSaSTylUqxtt1vV3K2k0QoMUnQ= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= +github.com/cloudwego/eino v0.5.5 h1:CsUC+DQfMDjfHZoM9n4GHab5bZNMVqwqJb6dLa3A7VY= +github.com/cloudwego/eino v0.5.5/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -57,6 +73,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww= github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/jsonschema v1.0.1 h1:Ty2r/J+mHUGz3tqQNympPiTeaCVTST09yvTKlFlZUCA= +github.com/eino-contrib/jsonschema v1.0.1/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= @@ -89,6 +107,8 @@ github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfV github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -150,6 +170,9 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/components/model/gemini/examples/gemini.go b/components/model/gemini/examples/gemini.go index 425f3d776..5222260d2 100644 --- a/components/model/gemini/examples/gemini.go +++ b/components/model/gemini/examples/gemini.go @@ -69,6 +69,9 @@ func main() { fmt.Println("\n=== Image Processing ===") imageProcessing(ctx, client) + + fmt.Println("\n=== Image Generation ===") + generateImage(ctx, client) } func basicChat(ctx context.Context, cm model.ChatModel) { @@ -239,3 +242,35 @@ func imageProcessing(ctx context.Context, client *genai.Client) { } fmt.Printf("Assistant: %s\n", resp.Content) } + +func generateImage(ctx context.Context, client *genai.Client) { + cm, err := gemini.NewChatModel(ctx, &gemini.Config{ + Client: client, + Model: "gemini-2.5-flash-image-preview", + ResponseModalities: []gemini.GeminiResponseModalities{ + gemini.GeminiResponseModalitiesText, + gemini.GeminiResponseModalitiesImage, + }, + }) + if err != nil { + log.Printf("NewChatModel error: %v", err) + return + } + + resp, err := cm.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "Generate an image of a cat", + }, + }, + }, + }) + if err != nil { + log.Printf("Generate error: %v", err) + return + } + fmt.Printf("Assistant: %s\n", resp) +} diff --git a/components/model/gemini/gemini.go b/components/model/gemini/gemini.go index e01ddc53a..2aea27c19 100644 --- a/components/model/gemini/gemini.go +++ b/components/model/gemini/gemini.go @@ -18,9 +18,12 @@ package gemini import ( "context" + "encoding/base64" "errors" "fmt" + "log" "runtime/debug" + "strings" "github.com/bytedance/sonic" "github.com/eino-contrib/jsonschema" @@ -64,6 +67,7 @@ func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) { enableCodeExecution: cfg.EnableCodeExecution, safetySettings: cfg.SafetySettings, thinkingConfig: cfg.ThinkingConfig, + responseModalities: cfg.ResponseModalities, }, nil } @@ -110,6 +114,10 @@ type Config struct { SafetySettings []*genai.SafetySetting ThinkingConfig *genai.ThinkingConfig + + // ResponseModalities specifies the modalities the model can return. + // Optional. + ResponseModalities []GeminiResponseModalities } type ChatModel struct { @@ -127,6 +135,7 @@ type ChatModel struct { enableCodeExecution bool safetySettings []*genai.SafetySetting thinkingConfig *genai.ThinkingConfig + responseModalities []GeminiResponseModalities } func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) { @@ -288,8 +297,9 @@ func (cm *ChatModel) genInputAndConf(input []*schema.Message, opts ...model.Opti ToolChoice: cm.toolChoice, }, opts...) geminiOptions := model.GetImplSpecificOptions(&options{ - TopK: cm.topK, - ResponseSchema: cm.responseSchema, + TopK: cm.topK, + ResponseSchema: cm.responseSchema, + ResponseModalities: cm.responseModalities, }, opts...) conf := &model.Config{} @@ -375,6 +385,13 @@ func (cm *ChatModel) genInputAndConf(input []*schema.Message, opts ...model.Opti } } + if len(geminiOptions.ResponseModalities) > 0 { + m.ResponseModalities = make([]string, len(geminiOptions.ResponseModalities)) + for i, v := range geminiOptions.ResponseModalities { + m.ResponseModalities[i] = string(v) + } + } + nInput := make([]*schema.Message, len(input)) copy(nInput, input) if len(input) > 1 && input[0].Role == schema.System { @@ -640,12 +657,183 @@ func (cm *ChatModel) convSchemaMessage(message *schema.Message) (*genai.Content, if message.Content != "" { content.Parts = append(content.Parts, genai.NewPartFromText(message.Content)) } - content.Parts = append(content.Parts, cm.convMedia(message.MultiContent)...) + if len(message.UserInputMultiContent) > 0 && len(message.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + if message.UserInputMultiContent != nil { + parts, err := cm.convInputMedia(message.UserInputMultiContent) + if err != nil { + return nil, err + } + content.Parts = append(content.Parts, parts...) + } else if message.AssistantGenMultiContent != nil { + if message.Role != schema.Assistant { + return nil, fmt.Errorf("assistant gen multi content only support assistant role, got %s", message.Role) + } + parts, err := cm.convOutputMedia(message.AssistantGenMultiContent) + if err != nil { + return nil, err + } + content.Parts = append(content.Parts, parts...) + } else if message.MultiContent != nil { + log.Printf("MultiContent field is deprecated, please use UserInputMultiContent or AssistantGenMultiContent instead") + parts, err := cm.convMedia(message.MultiContent) + if err != nil { + return nil, err + } + content.Parts = parts + } } return content, nil } -func (cm *ChatModel) convMedia(contents []schema.ChatMessagePart) []*genai.Part { +func (cm *ChatModel) convInputMedia(contents []schema.MessageInputPart) ([]*genai.Part, error) { + result := make([]*genai.Part, 0, len(contents)) + for _, content := range contents { + switch content.Type { + case schema.ChatMessagePartTypeText: + result = append(result, genai.NewPartFromText(content.Text)) + case schema.ChatMessagePartTypeImageURL: + if content.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } + if content.Image.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Image.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Image.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for image parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Image.MIMEType)) + } else if content.Image != nil && content.Image.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for image parts, please use Base64Data instead") + } + case schema.ChatMessagePartTypeAudioURL: + if content.Audio == nil { + return nil, fmt.Errorf("audio field must not be nil when Type is ChatMessagePartTypeAudioURL in user message") + } + if content.Audio.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Audio.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Audio.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for audio parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Audio.MIMEType)) + } else if content.Audio != nil && content.Audio.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for audio parts, please use Base64Data instead") + } + case schema.ChatMessagePartTypeVideoURL: + if content.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in user message") + } + if content.Video.Extra != nil { + videoMetaData := GetInputVideoMetaData(content.Video) + if videoMetaData != nil { + result = append(result, &genai.Part{VideoMetadata: videoMetaData}) + } + } + if content.Video.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Video.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Video.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for video parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Video.MIMEType)) + } else if content.Video.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for video parts, please use Base64Data instead") + } + case schema.ChatMessagePartTypeFileURL: + if content.File == nil { + return nil, fmt.Errorf("file field must not be nil when Type is ChatMessagePartTypeFileURL in user message") + } + if content.File.Base64Data != nil { + data, err := decodeBase64DataURL(*content.File.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.File.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for file parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.File.MIMEType)) + } else if content.File != nil && content.File.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for file parts, please use Base64Data instead") + } + } + } + return result, nil +} + +func (cm *ChatModel) convOutputMedia(contents []schema.MessageOutputPart) ([]*genai.Part, error) { + result := make([]*genai.Part, 0, len(contents)) + for _, content := range contents { + switch content.Type { + case schema.ChatMessagePartTypeText: + result = append(result, genai.NewPartFromText(content.Text)) + case schema.ChatMessagePartTypeImageURL: + if content.Image == nil { + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + } + if content.Image.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Image.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Image.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for image parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Image.MIMEType)) + } else if content.Image != nil && content.Image.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for image parts, please use Base64Data instead") + } + case schema.ChatMessagePartTypeAudioURL: + if content.Audio == nil { + return nil, fmt.Errorf("audio field must not be nil when Type is ChatMessagePartTypeAudioURL in assistant message") + } + if content.Audio.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Audio.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Audio.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for audio parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Audio.MIMEType)) + } else if content.Audio != nil && content.Audio.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for audio parts, please use Base64Data instead") + } + case schema.ChatMessagePartTypeVideoURL: + if content.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in assistant message") + } + if content.Video.Extra != nil { + videoMetaData := GetOutputVideoMetaData(content.Video) + if videoMetaData != nil { + result = append(result, &genai.Part{VideoMetadata: videoMetaData}) + } + } + if content.Video.Base64Data != nil { + data, err := decodeBase64DataURL(*content.Video.Base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + if content.Video.MIMEType == "" { + return nil, fmt.Errorf("MIMEType is required for video parts with Base64Data") + } + result = append(result, genai.NewPartFromBytes(data, content.Video.MIMEType)) + } else if content.Video.URL != nil { + return nil, fmt.Errorf("gemini: URL is not supported for video parts, please use Base64Data instead") + } + } + } + return result, nil +} + +func (cm *ChatModel) convMedia(contents []schema.ChatMessagePart) ([]*genai.Part, error) { result := make([]*genai.Part, 0, len(contents)) for _, content := range contents { switch content.Type { @@ -653,23 +841,93 @@ func (cm *ChatModel) convMedia(contents []schema.ChatMessagePart) []*genai.Part result = append(result, genai.NewPartFromText(content.Text)) case schema.ChatMessagePartTypeImageURL: if content.ImageURL != nil { - result = append(result, genai.NewPartFromURI(content.ImageURL.URI, content.ImageURL.MIMEType)) + if content.ImageURL.URI != "" { + result = append(result, genai.NewPartFromURI(content.ImageURL.URI, content.ImageURL.MIMEType)) + } else { + data, err := decodeBase64DataURL(content.ImageURL.URL) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + result = append(result, genai.NewPartFromBytes(data, content.ImageURL.MIMEType)) + } } case schema.ChatMessagePartTypeAudioURL: if content.AudioURL != nil { - result = append(result, genai.NewPartFromURI(content.AudioURL.URI, content.AudioURL.MIMEType)) + if content.AudioURL.URI != "" { + result = append(result, genai.NewPartFromURI(content.AudioURL.URI, content.AudioURL.MIMEType)) + } else { + data, err := decodeBase64DataURL(content.AudioURL.URL) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + result = append(result, genai.NewPartFromBytes(data, content.AudioURL.MIMEType)) + } } case schema.ChatMessagePartTypeVideoURL: if content.VideoURL != nil { - result = append(result, genai.NewPartFromURI(content.VideoURL.URI, content.VideoURL.MIMEType)) + if content.VideoURL.Extra != nil { + videoMetaData := GetVideoMetaData(content.VideoURL) + if videoMetaData != nil { + result = append(result, &genai.Part{ + VideoMetadata: videoMetaData, + }) + } + } + if content.VideoURL.URI != "" { + result = append(result, genai.NewPartFromURI(content.VideoURL.URI, content.VideoURL.MIMEType)) + } else { + data, err := decodeBase64DataURL(content.VideoURL.URL) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + result = append(result, genai.NewPartFromBytes(data, content.VideoURL.MIMEType)) + } } case schema.ChatMessagePartTypeFileURL: if content.FileURL != nil { - result = append(result, genai.NewPartFromURI(content.FileURL.URI, content.FileURL.MIMEType)) + if content.FileURL.URI != "" { + result = append(result, genai.NewPartFromURI(content.FileURL.URI, content.FileURL.MIMEType)) + } else { + data, err := decodeBase64DataURL(content.FileURL.URL) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data URL: %w", err) + } + result = append(result, genai.NewPartFromBytes(data, content.FileURL.MIMEType)) + } } } } - return result + return result, nil +} + +// decodeBase64DataURL decodes a base64 data URL string into raw bytes. +// It correctly handles the "data:[];base64," prefix. +func decodeBase64DataURL(dataURL string) ([]byte, error) { + // Check if a web URL is passed by mistake. + if strings.HasPrefix(dataURL, "http") { + return nil, fmt.Errorf("invalid input: expected base64 data or data URL, but got a web URL starting with 'http'. Please fetch the content from the URL first") + } + // Find the comma that separates the prefix from the data + commaIndex := strings.Index(dataURL, ",") + if commaIndex == -1 { + // If no comma, assume it's a raw base64 string and try to decode it directly. + decoded, err := base64.StdEncoding.DecodeString(dataURL) + if err != nil { + return nil, fmt.Errorf("failed to decode raw base64 data: %w", err) + } + return decoded, nil + } + + // Extract the base64 part of the data URL + base64Data := dataURL[commaIndex+1:] + + // Decode the base64 string + decoded, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data from data URL: %w", err) + } + + return decoded, nil } func (cm *ChatModel) convResponse(resp *genai.GenerateContentResponse) (*schema.Message, error) { @@ -710,12 +968,21 @@ func (cm *ChatModel) convCandidate(candidate *genai.Candidate) (*schema.Message, result.Role = schema.User } - var texts []string + var ( + texts []string + outParts []schema.MessageOutputPart + contentBuilder strings.Builder + ) for _, part := range candidate.Content.Parts { if part.Thought { result.ReasoningContent = part.Text } else if len(part.Text) > 0 { texts = append(texts, part.Text) + contentBuilder.WriteString(part.Text) + outParts = append(outParts, schema.MessageOutputPart{ + Type: schema.ChatMessagePartTypeText, + Text: part.Text, + }) } if part.FunctionCall != nil { fc, err := convFC(part.FunctionCall) @@ -726,14 +993,24 @@ func (cm *ChatModel) convCandidate(candidate *genai.Candidate) (*schema.Message, } if part.CodeExecutionResult != nil { texts = append(texts, part.CodeExecutionResult.Output) + outParts = append(outParts, schema.MessageOutputPart{ + Type: schema.ChatMessagePartTypeText, + Text: part.CodeExecutionResult.Output, + }) } if part.ExecutableCode != nil { texts = append(texts, part.ExecutableCode.Code) + outParts = append(outParts, schema.MessageOutputPart{ + Type: schema.ChatMessagePartTypeText, + Text: part.ExecutableCode.Code, + }) + } + if part.InlineData != nil && part.InlineData.Data != nil { + outParts = append(outParts, toMultiOutPart(part)) } } - if len(texts) == 1 { - result.Content = texts[0] - } else if len(texts) > 1 { + result.Content = contentBuilder.String() + if len(texts) > 1 { for _, text := range texts { result.MultiContent = append(result.MultiContent, schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, @@ -741,10 +1018,54 @@ func (cm *ChatModel) convCandidate(candidate *genai.Candidate) (*schema.Message, }) } } + if len(outParts) > 0 { + result.AssistantGenMultiContent = outParts + } } return result, nil } +func toMultiOutPart(part *genai.Part) schema.MessageOutputPart { + if part == nil { + return schema.MessageOutputPart{} + } + res := schema.MessageOutputPart{} + if part.InlineData != nil { + mimeType := part.InlineData.MIMEType + multiMediaData := part.InlineData.Data + encodedStr := base64.StdEncoding.EncodeToString(multiMediaData) + switch { + case strings.HasPrefix(mimeType, "image/"): + res.Type = schema.ChatMessagePartTypeImageURL + res.Image = &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &encodedStr, + MIMEType: mimeType, + }, + } + case strings.HasPrefix(mimeType, "audio/"): + res.Type = schema.ChatMessagePartTypeAudioURL + res.Audio = &schema.MessageOutputAudio{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &encodedStr, + MIMEType: mimeType, + }, + } + case strings.HasPrefix(mimeType, "video/"): + res.Type = schema.ChatMessagePartTypeVideoURL + res.Video = &schema.MessageOutputVideo{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &encodedStr, + MIMEType: mimeType, + }, + } + default: + log.Printf("gemini part type not support: %s", mimeType) + } + } + return res +} + func convFC(tp *genai.FunctionCall) (*schema.ToolCall, error) { args, err := sonic.MarshalString(tp.Args) if err != nil { @@ -781,6 +1102,14 @@ func (cm *ChatModel) IsCallbacksEnabled() bool { return true } +type GeminiResponseModalities string + +const ( + GeminiResponseModalitiesText GeminiResponseModalities = "TEXT" + GeminiResponseModalitiesImage GeminiResponseModalities = "IMAGE" + GeminiResponseModalitiesAudio GeminiResponseModalities = "AUDIO" +) + const ( roleModel = "model" roleUser = "user" diff --git a/components/model/gemini/gemini_test.go b/components/model/gemini/gemini_test.go index e209b8c6a..455ddd9fa 100644 --- a/components/model/gemini/gemini_test.go +++ b/components/model/gemini/gemini_test.go @@ -18,8 +18,10 @@ package gemini import ( "context" + "encoding/base64" "io" "testing" + "time" "github.com/bytedance/mockey" "github.com/bytedance/sonic" @@ -269,49 +271,302 @@ func TestWithTools(t *testing.T) { assert.Equal(t, "test tool name", ncm.(*ChatModel).origTools[0].Name) } -func TestChatModelConvMedia(t *testing.T) { - cm := &ChatModel{model: "test model"} - contents := []schema.ChatMessagePart{ - { - Type: schema.ChatMessagePartTypeText, - Text: "test text", - }, - { - Type: schema.ChatMessagePartTypeImageURL, - ImageURL: &schema.ChatMessageImageURL{ - URI: "test uri", - MIMEType: "test mime type", +func Test_toMultiOutPart(t *testing.T) { + t.Run("nil part", func(t *testing.T) { + part := toMultiOutPart(nil) + assert.Empty(t, part) + }) + + t.Run("nil inline data", func(t *testing.T) { + part := toMultiOutPart(&genai.Part{InlineData: nil}) + assert.Empty(t, part) + }) + + t.Run("image part", func(t *testing.T) { + data := []byte("fake-image-data") + encoded := base64.StdEncoding.EncodeToString(data) + part := toMultiOutPart(&genai.Part{ + InlineData: &genai.Blob{ + MIMEType: "image/png", + Data: data, }, - }, - { - Type: schema.ChatMessagePartTypeFileURL, - FileURL: &schema.ChatMessageFileURL{ - URI: "test uri", - MIMEType: "test mime type", + }) + assert.Equal(t, schema.ChatMessagePartTypeImageURL, part.Type) + assert.NotNil(t, part.Image) + assert.Equal(t, "image/png", part.Image.MIMEType) + assert.Equal(t, encoded, *part.Image.Base64Data) + }) + + t.Run("audio part", func(t *testing.T) { + data := []byte("fake-audio-data") + encoded := base64.StdEncoding.EncodeToString(data) + part := toMultiOutPart(&genai.Part{ + InlineData: &genai.Blob{ + MIMEType: "audio/mp3", + Data: data, }, - }, - { - Type: schema.ChatMessagePartTypeAudioURL, - AudioURL: &schema.ChatMessageAudioURL{ - URI: "test uri", - MIMEType: "test mime type", + }) + assert.Equal(t, schema.ChatMessagePartTypeAudioURL, part.Type) + assert.NotNil(t, part.Audio) + assert.Equal(t, "audio/mp3", part.Audio.MIMEType) + assert.Equal(t, encoded, *part.Audio.Base64Data) + }) + + t.Run("video part", func(t *testing.T) { + data := []byte("fake-video-data") + encoded := base64.StdEncoding.EncodeToString(data) + part := toMultiOutPart(&genai.Part{ + InlineData: &genai.Blob{ + MIMEType: "video/mp4", + Data: data, }, - }, - { - Type: schema.ChatMessagePartTypeVideoURL, - VideoURL: &schema.ChatMessageVideoURL{ - URI: "test uri", - MIMEType: "test mime type", + }) + assert.Equal(t, schema.ChatMessagePartTypeVideoURL, part.Type) + assert.NotNil(t, part.Video) + assert.Equal(t, "video/mp4", part.Video.MIMEType) + assert.Equal(t, encoded, *part.Video.Base64Data) + }) + + t.Run("unsupported type", func(t *testing.T) { + part := toMultiOutPart(&genai.Part{ + InlineData: &genai.Blob{ + MIMEType: "application/pdf", + Data: []byte("fake-pdf-data"), }, - }, - } + }) + assert.Empty(t, part.Type) + assert.Nil(t, part.Image) + assert.Nil(t, part.Audio) + assert.Nil(t, part.Video) + }) +} + +func TestChatModel_convMedia(t *testing.T) { + t.Run("convMedia", func(t *testing.T) { + cm := &ChatModel{model: "test model"} + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + dataURL := "data:image/png;base64," + base64Data + t.Run("success", func(t *testing.T) { + contents := []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "test text", + }, + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{URL: dataURL, MIMEType: "image/png"}, + }, + { + Type: schema.ChatMessagePartTypeFileURL, + FileURL: &schema.ChatMessageFileURL{URL: dataURL, MIMEType: "application/pdf"}, + }, + { + Type: schema.ChatMessagePartTypeAudioURL, + AudioURL: &schema.ChatMessageAudioURL{URL: dataURL, MIMEType: "audio/mp3"}, + }, + { + Type: schema.ChatMessagePartTypeVideoURL, + VideoURL: &schema.ChatMessageVideoURL{URL: dataURL, MIMEType: "video/mp4"}, + }, + } + + parts, err := cm.convMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 5) + assert.Equal(t, "test text", parts[0].Text) + + decodedData, err := base64.StdEncoding.DecodeString(base64Data) + assert.NoError(t, err) + + assert.Equal(t, "image/png", parts[1].InlineData.MIMEType) + assert.Equal(t, decodedData, parts[1].InlineData.Data) + assert.Equal(t, "application/pdf", parts[2].InlineData.MIMEType) + assert.Equal(t, decodedData, parts[2].InlineData.Data) + assert.Equal(t, "audio/mp3", parts[3].InlineData.MIMEType) + assert.Equal(t, decodedData, parts[3].InlineData.Data) + assert.Equal(t, "video/mp4", parts[4].InlineData.MIMEType) + assert.Equal(t, decodedData, parts[4].InlineData.Data) + }) + + t.Run("with video metadata", func(t *testing.T) { + videoPart := &schema.ChatMessageVideoURL{URL: dataURL, MIMEType: "video/mp4"} + SetVideoMetaData(videoPart, &genai.VideoMetadata{ + StartOffset: time.Second, + EndOffset: time.Second * 5, + }) + contents := []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeVideoURL, + VideoURL: videoPart, + }, + } + parts, err := cm.convMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 2) + assert.NotNil(t, parts[0].VideoMetadata) + assert.Equal(t, time.Second, parts[0].VideoMetadata.StartOffset) + assert.Equal(t, time.Second*5, parts[0].VideoMetadata.EndOffset) + }) + + t.Run("with invalid data url", func(t *testing.T) { + contents := []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{URL: ""}, + }, + } + _, err := cm.convMedia(contents) + assert.Error(t, err) + }) + }) + cm := &ChatModel{model: "test model"} + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + + t.Run("convInputMedia", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + contents := []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, + {Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageInputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "audio/mp3"}}}, + {Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}}}, + {Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "application/pdf"}}}, + } + parts, err := cm.convInputMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 5) + assert.Equal(t, "hello", parts[0].Text) + assert.Equal(t, "image/png", parts[1].InlineData.MIMEType) + assert.Equal(t, "audio/mp3", parts[2].InlineData.MIMEType) + assert.Equal(t, "video/mp4", parts[3].InlineData.MIMEType) + assert.Equal(t, "application/pdf", parts[4].InlineData.MIMEType) + // check data + decodedData, err := base64.StdEncoding.DecodeString(base64Data) + assert.NoError(t, err) + assert.Equal(t, decodedData, parts[1].InlineData.Data) + }) + + t.Run("with video metadata", func(t *testing.T) { + videoPart := &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}} + SetInputVideoMetaData(videoPart, &genai.VideoMetadata{ + StartOffset: time.Second, + EndOffset: time.Second * 5, + }) + contents := []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeVideoURL, + Video: videoPart, + }, + } + parts, err := cm.convInputMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 2) + assert.NotNil(t, parts[0].VideoMetadata) + assert.Equal(t, time.Second, parts[0].VideoMetadata.StartOffset) + assert.Equal(t, time.Second*5, parts[0].VideoMetadata.EndOffset) + }) + + t.Run("error cases", func(t *testing.T) { + url := "https://example.com/image.png" + invalidBase64 := "invalid-base64" + testCases := []struct { + name string + content schema.MessageInputPart + }{ + {name: "Image with URL", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Audio with URL", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageInputAudio{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Video with URL", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "File with URL", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Image with invalid base64", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Image without MIMEType", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Audio with invalid base64", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageInputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Audio without MIMEType", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageInputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Video with invalid base64", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Video without MIMEType", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "File with invalid base64", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "File without MIMEType", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL, File: &schema.MessageInputFile{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Image with nil media", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}, + {name: "Audio with nil media", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: nil}}, + {name: "Video with nil media", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}, + {name: "File with nil media", content: schema.MessageInputPart{Type: schema.ChatMessagePartTypeFileURL, File: nil}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := cm.convInputMedia([]schema.MessageInputPart{tc.content}) + assert.Error(t, err) + }) + } + }) + }) + + t.Run("convOutputMedia", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + contents := []schema.MessageOutputPart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, + {Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageOutputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "audio/mp3"}}}, + {Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}}}, + } + parts, err := cm.convOutputMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 4) + assert.Equal(t, "hello", parts[0].Text) + assert.Equal(t, "image/png", parts[1].InlineData.MIMEType) + assert.Equal(t, "audio/mp3", parts[2].InlineData.MIMEType) + assert.Equal(t, "video/mp4", parts[3].InlineData.MIMEType) + // check data + decodedData, err := base64.StdEncoding.DecodeString(base64Data) + assert.NoError(t, err) + assert.Equal(t, decodedData, parts[1].InlineData.Data) + }) - parts := cm.convMedia(contents) - assert.Equal(t, 5, len(parts)) - assert.Equal(t, "test text", parts[0].Text) + t.Run("with video metadata", func(t *testing.T) { + videoPart := &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}} + SetOutputVideoMetaData(videoPart, &genai.VideoMetadata{ + StartOffset: time.Second, + EndOffset: time.Second * 5, + }) + contents := []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeVideoURL, + Video: videoPart, + }, + } + parts, err := cm.convOutputMedia(contents) + assert.NoError(t, err) + assert.Len(t, parts, 2) + assert.NotNil(t, parts[0].VideoMetadata) + assert.Equal(t, time.Second, parts[0].VideoMetadata.StartOffset) + assert.Equal(t, time.Second*5, parts[0].VideoMetadata.EndOffset) + }) - for i := 1; i < len(parts); i++ { - assert.Equal(t, "test uri", parts[i].FileData.FileURI) - assert.Equal(t, "test mime type", parts[i].FileData.MIMEType) - } + t.Run("error cases", func(t *testing.T) { + url := "https://example.com/image.png" + invalidBase64 := "invalid-base64" + testCases := []struct { + name string + content schema.MessageOutputPart + }{ + {name: "Image with URL", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Audio with URL", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageOutputAudio{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Video with URL", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{URL: &url}}}}, + {name: "Image with invalid base64", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Image without MIMEType", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Audio with invalid base64", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageOutputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Audio without MIMEType", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: &schema.MessageOutputAudio{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Video with invalid base64", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &invalidBase64}}}}, + {name: "Video without MIMEType", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}}, + {name: "Image with nil media", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}, + {name: "Audio with nil media", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeAudioURL, Audio: nil}}, + {name: "Video with nil media", content: schema.MessageOutputPart{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := cm.convOutputMedia([]schema.MessageOutputPart{tc.content}) + assert.Error(t, err) + }) + } + }) + }) } diff --git a/components/model/gemini/go.mod b/components/model/gemini/go.mod index 8ec7dcab4..f36cf6013 100644 --- a/components/model/gemini/go.mod +++ b/components/model/gemini/go.mod @@ -4,9 +4,9 @@ go 1.23.0 require ( github.com/bytedance/mockey v1.2.13 - github.com/bytedance/sonic v1.13.2 - github.com/cloudwego/eino v0.4.7 - github.com/eino-contrib/jsonschema v1.0.0 + github.com/bytedance/sonic v1.14.1 + github.com/cloudwego/eino v0.5.5 + github.com/eino-contrib/jsonschema v1.0.1 github.com/getkin/kin-openapi v0.118.0 github.com/stretchr/testify v1.10.0 github.com/wk8/go-ordered-map/v2 v2.1.8 @@ -19,8 +19,9 @@ require ( cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/bytedance/sonic/loader v0.2.4 // indirect - github.com/cloudwego/base64x v0.1.5 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect diff --git a/components/model/gemini/go.sum b/components/model/gemini/go.sum index f8440fc62..675dedbfc 100644 --- a/components/model/gemini/go.sum +++ b/components/model/gemini/go.sum @@ -15,20 +15,36 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.2.13 h1:jokWZAm/pUEbD939Rhznz615MKUCZNuvCFQlJ2+ntoo= github.com/bytedance/mockey v1.2.13/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/eino v0.4.7 h1:wwqsFWCuzCQuhw1dYKqHjGWULzjDjFfN9sTn/cezYV4= github.com/cloudwego/eino v0.4.7/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg= +github.com/cloudwego/eino v0.5.4-0.20250917102129-a48857bf4c79 h1:voV9TjDX/AMHDf95wUgutcMZWGvWW92qKSBo7N3o0Ek= +github.com/cloudwego/eino v0.5.4-0.20250917102129-a48857bf4c79/go.mod h1:S38tlNO4cNqFfGJKQSJZimxjzc9JDJKdf2eW3FEEfdc= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b h1:pw2BTeX19inORUtLJ00iRA/dwcLHAE8jqMaIUXwd6cg= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b/go.mod h1:JxKeWsO8iUZfKh3iE4iN0JCvCzYLRNyqjSag/RisPbc= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521 h1:WIk/K6QcstEj2+xtBWSaSTylUqxtt1vV3K2k0QoMUnQ= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= +github.com/cloudwego/eino v0.5.5 h1:CsUC+DQfMDjfHZoM9n4GHab5bZNMVqwqJb6dLa3A7VY= +github.com/cloudwego/eino v0.5.5/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -38,6 +54,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww= github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/jsonschema v1.0.1 h1:Ty2r/J+mHUGz3tqQNympPiTeaCVTST09yvTKlFlZUCA= +github.com/eino-contrib/jsonschema v1.0.1/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= diff --git a/components/model/gemini/message_extra.go b/components/model/gemini/message_extra.go new file mode 100644 index 000000000..f868700d3 --- /dev/null +++ b/components/model/gemini/message_extra.go @@ -0,0 +1,92 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "github.com/cloudwego/eino/schema" + "google.golang.org/genai" +) + +const videoMetaDataKey = "gemini_video_meta_data" + +// Deprecated: use SetInputVideoMetaData or SetOutputVideoMetaData instead. +func SetVideoMetaData(part *schema.ChatMessageVideoURL, metaData *genai.VideoMetadata) { + if part == nil { + return + } + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setVideoMetaData(part.Extra, metaData) +} + +// Deprecated: use GetInputVideoMetaData or GetOutputVideoMetaData instead. +func GetVideoMetaData(part *schema.ChatMessageVideoURL) *genai.VideoMetadata { + if part == nil || part.Extra == nil { + return nil + } + return getVideoMetaData(part.Extra) +} + +func SetInputVideoMetaData(part *schema.MessageInputVideo, metaData *genai.VideoMetadata) { + if part == nil { + return + } + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setVideoMetaData(part.Extra, metaData) +} + +func GetInputVideoMetaData(part *schema.MessageInputVideo) *genai.VideoMetadata { + if part == nil || part.Extra == nil { + return nil + } + return getVideoMetaData(part.Extra) +} + +func SetOutputVideoMetaData(part *schema.MessageOutputVideo, metaData *genai.VideoMetadata) { + if part == nil { + return + } + if part.Extra == nil { + part.Extra = make(map[string]any) + } + setVideoMetaData(part.Extra, metaData) +} + +func GetOutputVideoMetaData(part *schema.MessageOutputVideo) *genai.VideoMetadata { + if part == nil || part.Extra == nil { + return nil + } + return getVideoMetaData(part.Extra) +} + +func setVideoMetaData(extra map[string]any, metaData *genai.VideoMetadata) { + extra[videoMetaDataKey] = metaData +} + +func getVideoMetaData(extra map[string]any) *genai.VideoMetadata { + if extra == nil { + return nil + } + videoMetaData, ok := extra[videoMetaDataKey].(*genai.VideoMetadata) + if !ok { + return nil + } + return videoMetaData +} diff --git a/components/model/gemini/option.go b/components/model/gemini/option.go index da86fe6b8..fd50449d1 100644 --- a/components/model/gemini/option.go +++ b/components/model/gemini/option.go @@ -29,6 +29,7 @@ type options struct { ResponseSchema *openapi3.Schema ResponseJSONSchema *jsonschema.Schema ThinkingConfig *genai.ThinkingConfig + ResponseModalities []GeminiResponseModalities } func WithTopK(k int32) model.Option { @@ -54,3 +55,9 @@ func WithThinkingConfig(t *genai.ThinkingConfig) model.Option { o.ThinkingConfig = t }) } + +func WithResponseModalities(m []GeminiResponseModalities) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.ResponseModalities = m + }) +} diff --git a/components/model/ollama/chatmodel.go b/components/model/ollama/chatmodel.go index 670aaa767..aae79559c 100644 --- a/components/model/ollama/chatmodel.go +++ b/components/model/ollama/chatmodel.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "net/url" "runtime/debug" @@ -69,7 +70,7 @@ type ChatModel struct { } // NewChatModel initializes a new instance of ChatModel with provided configuration. -func NewChatModel(ctx context.Context, config *ChatModelConfig) (*ChatModel, error) { +func NewChatModel(_ context.Context, config *ChatModelConfig) (*ChatModel, error) { if config == nil { return nil, errors.New("config must not be nil") } @@ -234,7 +235,7 @@ func (cm *ChatModel) IsCallbacksEnabled() bool { return true } -func (cm *ChatModel) genRequest(ctx context.Context, stream bool, in []*schema.Message, opts ...model.Option) ( +func (cm *ChatModel) genRequest(_ context.Context, stream bool, in []*schema.Message, opts ...model.Option) ( req *api.ChatRequest, cbInput *model.CallbackInput, err error) { var ( @@ -356,43 +357,90 @@ func toOllamaMessage(einoMsg *schema.Message) (api.Message, error) { Thinking: einoMsg.ReasoningContent, ToolCalls: toolCalls, } - if len(einoMsg.MultiContent) == 0 { + + if len(einoMsg.UserInputMultiContent) == 0 && len(einoMsg.AssistantGenMultiContent) == 0 && len(einoMsg.MultiContent) == 0 { om.Content = einoMsg.Content return om, nil } content := "" var images []api.ImageData - for _, mc := range einoMsg.MultiContent { - switch mc.Type { - case schema.ChatMessagePartTypeText: - content += mc.Text - case schema.ChatMessagePartTypeImageURL: - if mc.ImageURL == nil { - return api.Message{}, errors.New("image url is required") - } - if err := validateImageURL(mc.ImageURL.URL); err != nil { - return api.Message{}, err + if len(einoMsg.UserInputMultiContent) > 0 && len(einoMsg.AssistantGenMultiContent) > 0 { + return api.Message{}, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + } + + if len(einoMsg.UserInputMultiContent) > 0 { + for _, part := range einoMsg.UserInputMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + content += part.Text + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return api.Message{}, fmt.Errorf("image is required in UserInputMultiContent, but got nil") + } + if part.Image.URL != nil { + return api.Message{}, fmt.Errorf("ollama model only supports base64-encoded strings for the raw binary, but got URL: %s", *part.Image.URL) + } + if part.Image.Base64Data == nil { + return api.Message{}, fmt.Errorf("image is required in UserInputMultiContent, but got nil Base64Data") + } + images = append(images, api.ImageData(*part.Image.Base64Data)) + default: + return api.Message{}, fmt.Errorf("unsupported content type in UserInputMultiContent: %s", part.Type) } + } + } else if len(einoMsg.AssistantGenMultiContent) > 0 { + if einoMsg.Role != schema.Assistant { + return api.Message{}, fmt.Errorf("AssistantGenMultiContent can only be used with assistant role, but got role '%s'", einoMsg.Role) + } + for _, part := range einoMsg.AssistantGenMultiContent { + switch part.Type { + case schema.ChatMessagePartTypeText: + content += part.Text + case schema.ChatMessagePartTypeImageURL: + if part.Image == nil { + return api.Message{}, fmt.Errorf("image is required in AssistantGenMultiContent, but got nil") + } + if part.Image.URL != nil { + return api.Message{}, fmt.Errorf("ollama model only supports base64-encoded strings for the raw binary, but got URL: %s", *part.Image.URL) + } + if part.Image.Base64Data == nil { + return api.Message{}, fmt.Errorf("image is required in AssistantGenMultiContent, but got nil Base64Data") + } + images = append(images, api.ImageData(*part.Image.Base64Data)) + default: + return api.Message{}, fmt.Errorf("unsupported content type in AssistantGenMultiContent: %s", part.Type) + } + } + } else if len(einoMsg.MultiContent) > 0 { + log.Printf("MultiContent is deprecated, please use UserInputMultiContent or AssistantGenMultiContent instead") + for _, mc := range einoMsg.MultiContent { + switch mc.Type { + case schema.ChatMessagePartTypeText: + content += mc.Text + case schema.ChatMessagePartTypeImageURL: + if mc.ImageURL == nil { + return api.Message{}, errors.New("image url is required") + } + if err := validateImageURL(mc.ImageURL.URL); err != nil { + return api.Message{}, err + } - images = append(images, api.ImageData(mc.ImageURL.URL)) - default: - return api.Message{}, fmt.Errorf("unsupported content type: %s", mc.Type) + images = append(images, api.ImageData(mc.ImageURL.URL)) + default: + return api.Message{}, fmt.Errorf("unsupported content type: %s", mc.Type) + } } } - return api.Message{ - Role: string(einoMsg.Role), - Content: content, - Images: images, - Thinking: einoMsg.ReasoningContent, - ToolCalls: toolCalls, - }, nil + om.Content = content + om.Images = images + return om, nil } func validateImageURL(url string) error { - if strings.HasPrefix(url, "http") || strings.HasPrefix(url, "https") { + if strings.HasPrefix(url, "http") { return errors.New("ollama model only supports base64-encoded strings for the raw binary") } return nil diff --git a/components/model/ollama/chatmodel_test.go b/components/model/ollama/chatmodel_test.go index b4d6ee5b5..a75032eb1 100644 --- a/components/model/ollama/chatmodel_test.go +++ b/components/model/ollama/chatmodel_test.go @@ -167,7 +167,8 @@ func Test_Generate(t *testing.T) { }), convey.ShouldBeNil) PatchConvey("test chat error", func() { - Mock(GetMethod(cli, "Chat")).To(MockChatInvokeError).Build() + mocker := Mock(GetMethod(cli, "Chat")).To(MockChatInvokeError).Build() + defer mocker.UnPatch() outMsg, err := m.Generate(ctx, msgs) @@ -176,7 +177,8 @@ func Test_Generate(t *testing.T) { }) PatchConvey("test resolveChatResponse error", func() { - Mock(GetMethod(cli, "Chat")).To(MockChatInvokeError).Build() + mocker := Mock(GetMethod(cli, "Chat")).To(MockChatInvokeError).Build() + defer mocker.UnPatch() outMsg, err := m.Generate(ctx, msgs) convey.So(err, convey.ShouldNotBeNil) @@ -184,7 +186,8 @@ func Test_Generate(t *testing.T) { }) PatchConvey("test success", func() { - Mock(GetMethod(cli, "Chat")).To(MockChatInvoke).Build() + mocker := Mock(GetMethod(cli, "Chat")).To(MockChatInvoke).Build() + defer mocker.UnPatch() outMsg, err := m.Generate(ctx, msgs, model.WithTemperature(1), @@ -229,7 +232,8 @@ func Test_Stream(t *testing.T) { } PatchConvey("test chan err", func() { - Mock(GetMethod(cli, "Chat")).To(MockChatStreamError).Build() + mocker := Mock(GetMethod(cli, "Chat")).To(MockChatStreamError).Build() + defer mocker.UnPatch() outStream, err := m.Stream(ctx, msgs) convey.So(err, convey.ShouldBeNil) @@ -237,7 +241,8 @@ func Test_Stream(t *testing.T) { }) PatchConvey("test chan success", func() { - Mock(GetMethod(cli, "Chat")).Return(MockChatStream).Build() + mocker := Mock(GetMethod(cli, "Chat")).Return(MockChatStream).Build() + defer mocker.UnPatch() outStream, err := m.Stream(ctx, msgs) convey.So(err, convey.ShouldBeNil) @@ -302,3 +307,261 @@ func TestWithTools(t *testing.T) { assert.Equal(t, "test model", ncm.(*ChatModel).config.Model) assert.Equal(t, "test tool name", ncm.(*ChatModel).tools[0].Name) } + +func Test_toOllamaMessage(t *testing.T) { + base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + PatchConvey("test toOllamaMessage", t, func() { + PatchConvey("test simple message", func() { + msg := &schema.Message{ + Role: schema.User, + Content: "test content", + } + ollamaMsg, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(ollamaMsg.Role, convey.ShouldEqual, string(schema.User)) + convey.So(ollamaMsg.Content, convey.ShouldEqual, "test content") + convey.So(ollamaMsg.Images, convey.ShouldBeEmpty) + }) + + PatchConvey("test UserInputMultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "hello", + }, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &base64Data, + }, + }, + }, + }, + } + ollamaMsg, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(ollamaMsg.Role, convey.ShouldEqual, string(schema.User)) + convey.So(ollamaMsg.Content, convey.ShouldEqual, "hello") + convey.So(len(ollamaMsg.Images), convey.ShouldEqual, 1) + convey.So(string(ollamaMsg.Images[0]), convey.ShouldEqual, base64Data) + }) + + PatchConvey("test AssistantGenMultiContent", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "hello", + }, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &base64Data, + }, + }, + }, + }, + } + ollamaMsg, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(ollamaMsg.Role, convey.ShouldEqual, string(schema.Assistant)) + convey.So(ollamaMsg.Content, convey.ShouldEqual, "hello") + convey.So(len(ollamaMsg.Images), convey.ShouldEqual, 1) + convey.So(string(ollamaMsg.Images[0]), convey.ShouldEqual, base64Data) + }) + + PatchConvey("test AssistantGenMultiContent with correct role", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "world", + }, + }, + } + ollamaMsg, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(ollamaMsg.Role, convey.ShouldEqual, string(schema.Assistant)) + convey.So(ollamaMsg.Content, convey.ShouldEqual, "world") + }) + + PatchConvey("test AssistantGenMultiContent with incorrect role", func() { + msg := &schema.Message{ + Role: schema.User, // Incorrect role + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "world", + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "AssistantGenMultiContent can only be used with assistant role") + }) + + PatchConvey("test MultiContent compatibility", func() { + msg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeText, + Text: "legacy content", + }, + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{ + URL: base64Data, + }, + }, + }, + } + ollamaMsg, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldBeNil) + convey.So(ollamaMsg.Content, convey.ShouldEqual, "legacy content") + convey.So(len(ollamaMsg.Images), convey.ShouldEqual, 1) + convey.So(string(ollamaMsg.Images[0]), convey.ShouldEqual, base64Data) + }) + + PatchConvey("test error on http URL in MultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{ + URL: "http://example.com/image.png", + }, + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "ollama model only supports base64-encoded strings") + }) + + PatchConvey("test error on nil image in UserInputMultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: nil, // Nil image + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "image is required in UserInputMultiContent, but got nil") + }) + + PatchConvey("test error on nil image in AssistantGenMultiContent", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: nil, // Nil image + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "image is required in AssistantGenMultiContent, but got nil") + }) + + PatchConvey("test error on nil ImageURL in MultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: nil, // Nil ImageURL + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "image url is required") + }) + + PatchConvey("test error on URL in UserInputMultiContent", func() { + url := "http://example.com/image.png" + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &url}}, + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "ollama model only supports base64-encoded strings") + }) + + PatchConvey("test error on nil Base64Data in UserInputMultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: nil}}, + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "image is required in UserInputMultiContent, but got nil Base64Data") + }) + + PatchConvey("test error on URL in AssistantGenMultiContent", func() { + url := "http://example.com/image.png" + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &url}}, + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "ollama model only supports base64-encoded strings") + }) + + PatchConvey("test error on nil Base64Data in AssistantGenMultiContent", func() { + msg := &schema.Message{ + Role: schema.Assistant, + AssistantGenMultiContent: []schema.MessageOutputPart{ + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: nil}}, + }, + }, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldEqual, "image is required in AssistantGenMultiContent, but got nil Base64Data") + }) + + PatchConvey("test error on both UserInputMultiContent and AssistantGenMultiContent", func() { + msg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, + AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, + } + _, err := toOllamaMessage(msg) + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + }) + }) +} diff --git a/components/model/ollama/examples/image/image.go b/components/model/ollama/examples/image/image.go index addba56d4..fd3e9e27b 100644 --- a/components/model/ollama/examples/image/image.go +++ b/components/model/ollama/examples/image/image.go @@ -18,6 +18,7 @@ package main import ( "context" + "encoding/base64" "log" "os" @@ -43,18 +44,23 @@ func main() { log.Fatalf("os.ReadFile failed, err=%v\n", err) } + imageStr := base64.StdEncoding.EncodeToString(image) + resp, err := chatModel.Generate(ctx, []*schema.Message{ { Role: schema.User, - MultiContent: []schema.ChatMessagePart{ + UserInputMultiContent: []schema.MessageInputPart{ { Type: schema.ChatMessagePartTypeText, Text: "describe this image", }, { Type: schema.ChatMessagePartTypeImageURL, - ImageURL: &schema.ChatMessageImageURL{ - URL: string(image), + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &imageStr, + }, + Detail: schema.ImageURLDetailAuto, }, }, }, diff --git a/components/model/ollama/go.mod b/components/model/ollama/go.mod index 106f5d02e..8f9cc5939 100644 --- a/components/model/ollama/go.mod +++ b/components/model/ollama/go.mod @@ -6,7 +6,7 @@ toolchain go1.24.2 require ( github.com/bytedance/mockey v1.2.14 - github.com/cloudwego/eino v0.4.7 + github.com/cloudwego/eino v0.5.5 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.10.0 ) @@ -16,12 +16,13 @@ require github.com/eino-contrib/ollama v0.1.0 require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect - github.com/cloudwego/base64x v0.1.5 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/eino-contrib/jsonschema v1.0.0 // indirect + github.com/eino-contrib/jsonschema v1.0.1 // indirect github.com/getkin/kin-openapi v0.118.0 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.19.5 // indirect @@ -32,7 +33,7 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/components/model/ollama/go.sum b/components/model/ollama/go.sum index 023b3e396..86306377c 100644 --- a/components/model/ollama/go.sum +++ b/components/model/ollama/go.sum @@ -7,18 +7,32 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8= github.com/bytedance/mockey v1.2.14/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/eino v0.4.7 h1:wwqsFWCuzCQuhw1dYKqHjGWULzjDjFfN9sTn/cezYV4= github.com/cloudwego/eino v0.4.7/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg= +github.com/cloudwego/eino v0.5.4-0.20250919093302-940f314e67dd h1:xYp+MDCsUC2GFBapfXQO03FmdF/xJ1iLI/tTP9rTLxk= +github.com/cloudwego/eino v0.5.4-0.20250919093302-940f314e67dd/go.mod h1:JxKeWsO8iUZfKh3iE4iN0JCvCzYLRNyqjSag/RisPbc= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b h1:pw2BTeX19inORUtLJ00iRA/dwcLHAE8jqMaIUXwd6cg= +github.com/cloudwego/eino v0.5.4-0.20250925133640-10f7a8ffec1b/go.mod h1:JxKeWsO8iUZfKh3iE4iN0JCvCzYLRNyqjSag/RisPbc= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521 h1:WIk/K6QcstEj2+xtBWSaSTylUqxtt1vV3K2k0QoMUnQ= +github.com/cloudwego/eino v0.5.5-0.20251009130421-8c297fc3c521/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= +github.com/cloudwego/eino v0.5.5 h1:CsUC+DQfMDjfHZoM9n4GHab5bZNMVqwqJb6dLa3A7VY= +github.com/cloudwego/eino v0.5.5/go.mod h1:XolsJjKmiA+g9Dvr1vBJxGyqCksx52Ia/O4Iq+iMmeI= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -28,6 +42,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww= github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/jsonschema v1.0.1 h1:Ty2r/J+mHUGz3tqQNympPiTeaCVTST09yvTKlFlZUCA= +github.com/eino-contrib/jsonschema v1.0.1/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/eino-contrib/ollama v0.1.0 h1:z1NaMdKW6X1ftP8g5xGGR5zDRPUtuTKFq35vBQgxsN4= github.com/eino-contrib/ollama v0.1.0/go.mod h1:mYsQ7b3DeqY8bHPuD3MZJYTqkgyL6LoemxoP/B7ZNhA= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -67,6 +83,8 @@ github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALr github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= diff --git a/components/model/qwen/go.mod b/components/model/qwen/go.mod index af08f461b..185ae7f88 100644 --- a/components/model/qwen/go.mod +++ b/components/model/qwen/go.mod @@ -29,7 +29,7 @@ require ( github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/meguminnnnnnnnn/go-openai v0.0.0-20250821095446-07791bea23a0 // indirect + github.com/meguminnnnnnnnn/go-openai v0.0.0-20250922074340-88b080a04c97 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect diff --git a/components/model/qwen/go.sum b/components/model/qwen/go.sum index 6c839dd85..db81f9a5c 100644 --- a/components/model/qwen/go.sum +++ b/components/model/qwen/go.sum @@ -78,6 +78,8 @@ github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/meguminnnnnnnnn/go-openai v0.0.0-20250821095446-07791bea23a0 h1:nIohpHs1ViKR0SVgW/cbBstHjmnqFZDM9RqgX9m9Xu8= github.com/meguminnnnnnnnn/go-openai v0.0.0-20250821095446-07791bea23a0/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= +github.com/meguminnnnnnnnn/go-openai v0.0.0-20250922074340-88b080a04c97 h1:DOqL77Pcj66zqyoCg+4cCixzrbhccMCnElkWNqD8X/A= +github.com/meguminnnnnnnnn/go-openai v0.0.0-20250922074340-88b080a04c97/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=