Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 157 additions & 37 deletions components/model/ark/chat_completion_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"io"
"log"
"runtime/debug"

"github.com/eino-contrib/jsonschema"
Expand Down Expand Up @@ -297,7 +298,7 @@ func (cm *completionAPIChatModel) genRequest(in []*schema.Message, options *fmod
}

for _, msg := range in {
content, e := cm.toArkContent(msg.Content, msg.MultiContent)
content, e := cm.toArkContent(msg)
if e != nil {
return req, e
}
Expand Down Expand Up @@ -522,44 +523,163 @@ func (cm *completionAPIChatModel) toMessageToolCalls(toolCalls []*model.ToolCall
return ret
}

func (cm *completionAPIChatModel) toArkContent(content string, multiContent []schema.ChatMessagePart) (*model.ChatCompletionMessageContent, error) {
if len(multiContent) == 0 {
return &model.ChatCompletionMessageContent{StringValue: ptrOf(content)}, nil
}

parts := make([]*model.ChatCompletionMessageContentPart, 0, len(multiContent))

for _, part := range multiContent {
switch part.Type {
case schema.ChatMessagePartTypeText:
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: part.Text,
})
case schema.ChatMessagePartTypeImageURL:
if part.ImageURL == nil {
return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL")
func (cm *completionAPIChatModel) toArkContent(msg *schema.Message) (*model.ChatCompletionMessageContent, error) {
if len(msg.UserInputMultiContent) == 0 && len(msg.AssistantGenMultiContent) == 0 && len(msg.MultiContent) == 0 {
return &model.ChatCompletionMessageContent{StringValue: ptrOf(msg.Content)}, nil
}

var parts []*model.ChatCompletionMessageContentPart
if len(msg.UserInputMultiContent) > 0 {
parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.UserInputMultiContent))
for _, part := range msg.UserInputMultiContent {
switch part.Type {
case schema.ChatMessagePartTypeText:
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: part.Text,
})
case schema.ChatMessagePartTypeImageURL:
if part.Image == nil {
return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message")
}
var imageURL string
if part.Image.URL != nil && *part.Image.URL != "" {
imageURL = *part.Image.URL
} else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" {
if part.Image.MIMEType == "" {
return nil, fmt.Errorf("image part must have MIMEType when using Base64Data")
}
imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType)
} else {
return nil, fmt.Errorf("image part for user input must contain either a URL or Base64Data, but got: %+v", part.Image)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: imageURL,
Detail: model.ImageURLDetail(part.Image.Detail),
},
})
case schema.ChatMessagePartTypeVideoURL:
if part.Video == nil {
return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in user message")
}
var videoURL string
if part.Video.URL != nil && *part.Video.URL != "" {
videoURL = *part.Video.URL
} else if part.Video.Base64Data != nil && *part.Video.Base64Data != "" {
if part.Video.MIMEType == "" {
return nil, fmt.Errorf("video part must have MIMEType when using Base64Data")
}
videoURL = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType)
} else {
return nil, fmt.Errorf("video part for user input must contain either a URL or Base64Data, but got: %+v", part.Video)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeVideoURL,
VideoURL: &model.ChatMessageVideoURL{
URL: videoURL,
FPS: GetInputVideoFPS(part.Video),
},
})
default:
return nil, fmt.Errorf("unsupported chat message part type in user message: %s", part.Type)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: part.ImageURL.URL,
Detail: model.ImageURLDetail(part.ImageURL.Detail),
},
})
case schema.ChatMessagePartTypeVideoURL:
if part.VideoURL == nil {
return nil, fmt.Errorf("VideoURL field must not be nil when Type is ChatMessagePartTypeVideoURL")
}
} else if len(msg.AssistantGenMultiContent) > 0 {
if msg.Role != schema.Assistant {
return nil, fmt.Errorf("AssistantGenMultiContent only used when Role is Assistant, but got role: %s", msg.Role)
}
parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.AssistantGenMultiContent))
for _, part := range msg.AssistantGenMultiContent {
switch part.Type {
case schema.ChatMessagePartTypeText:
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: part.Text,
})
case schema.ChatMessagePartTypeImageURL:
if part.Image == nil {
return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message")
}
var imageURL string
if part.Image.URL != nil && *part.Image.URL != "" {
imageURL = *part.Image.URL
} else if part.Image.Base64Data != nil && *part.Image.Base64Data != "" {
if part.Image.MIMEType == "" {
return nil, fmt.Errorf("image part must have MIMEType when using Base64Data")
}
imageURL = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType)
} else {
return nil, fmt.Errorf("image part for assistant output must contain either a URL or Base64Data, but got: %+v", part.Image)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: imageURL,
},
})
case schema.ChatMessagePartTypeVideoURL:
if part.Video == nil {
return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL in assistant message")
}
var videoURL string
if part.Video.URL != nil && *part.Video.URL != "" {
videoURL = *part.Video.URL
} else if part.Video.Base64Data != nil && *part.Video.Base64Data != "" {
if part.Video.MIMEType == "" {
return nil, fmt.Errorf("video part must have MIMEType when using Base64Data")
}
videoURL = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType)
} else {
return nil, fmt.Errorf("video part for assistant output must contain either a URL or Base64Data, but got: %+v", part.Video)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeVideoURL,
VideoURL: &model.ChatMessageVideoURL{
URL: videoURL,
FPS: GetOutputVideoFPS(part.Video),
},
})
default:
return nil, fmt.Errorf("unsupported chat message part type in assistant message: %s", part.Type)
}
}
} else if len(msg.MultiContent) > 0 {
log.Printf("warning: MultiContent is deprecated, use UserInputMultiContent or AssistantGenMultiContent instead")
parts = make([]*model.ChatCompletionMessageContentPart, 0, len(msg.MultiContent))
for _, part := range msg.MultiContent {
switch part.Type {
case schema.ChatMessagePartTypeText:
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: part.Text,
})
case schema.ChatMessagePartTypeImageURL:
if part.ImageURL == nil {
return nil, fmt.Errorf("ImageURL field must not be nil when Type is ChatMessagePartTypeImageURL")
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: part.ImageURL.URL,
Detail: model.ImageURLDetail(part.ImageURL.Detail),
},
})
case schema.ChatMessagePartTypeVideoURL:
if part.VideoURL == nil {
return nil, fmt.Errorf("VideoURL field must not be nil when Type is ChatMessagePartTypeVideoURL")
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeVideoURL,
VideoURL: &model.ChatMessageVideoURL{
URL: part.VideoURL.URL,
FPS: GetFPS(part.VideoURL),
},
})
default:
return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type)
}
parts = append(parts, &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeVideoURL,
VideoURL: &model.ChatMessageVideoURL{
URL: part.VideoURL.URL,
FPS: GetFPS(part.VideoURL),
},
})
default:
return nil, fmt.Errorf("unsupported chat message part type: %s", part.Type)
}
}

Expand Down
124 changes: 123 additions & 1 deletion components/model/ark/chat_completion_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func TestChatCompletionAPIGenerate(t *testing.T) {
},
}

req, err := m.chatModel.toArkContent(multiModalMsg.Content, multiModalMsg.MultiContent)
req, err := m.chatModel.toArkContent(multiModalMsg)
convey.So(err, convey.ShouldBeNil)
convey.So(req.StringValue, convey.ShouldBeNil)
convey.So(req.ListValue, convey.ShouldHaveLength, 2)
Expand Down Expand Up @@ -330,3 +330,125 @@ func TestChatCompletionAPILogProbs(t *testing.T) {
},
}}))
}

func TestCompletionAPIChatModel_toArkContent(t *testing.T) {
cm := &completionAPIChatModel{}
base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
httpURL := "https://example.com/image.png"
videoURL := "https://example.com/video.mp4"

PatchConvey("Test toArkContent Comprehensive", t, func() {
PatchConvey("Pure Text Content", func() {
msg := &schema.Message{Content: "just text"}
content, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldBeNil)
convey.So(*content.StringValue, convey.ShouldEqual, "just text")
})

PatchConvey("UserInputMultiContent", func() {
PatchConvey("Success with all types", func() {
msg := &schema.Message{
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: "some text"},
{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}},
{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "video/mp4"}}},
},
}
content, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldBeNil)
convey.So(content.ListValue, convey.ShouldHaveLength, 3)
})
PatchConvey("Error on nil image", func() {
msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on empty image data", func() {
msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{}}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on nil video", func() {
msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on empty video data", func() {
msg := &schema.Message{UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageInputVideo{}}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
})

PatchConvey("AssistantGenMultiContent", func() {
PatchConvey("Success with image and video", func() {
msg := &schema.Message{
Role: schema.Assistant,
AssistantGenMultiContent: []schema.MessageOutputPart{
{Type: schema.ChatMessagePartTypeText, Text: "some text"},
{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{URL: &httpURL}}},
{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{MessagePartCommon: schema.MessagePartCommon{URL: &videoURL}}},
},
}
content, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldBeNil)
convey.So(content.ListValue, convey.ShouldHaveLength, 3)
})
PatchConvey("Error on wrong role", func() {
msg := &schema.Message{Role: schema.User, AssistantGenMultiContent: []schema.MessageOutputPart{{}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on nil image", func() {
msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on empty image data", func() {
msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{}}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on nil video", func() {
msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on empty video data", func() {
msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeVideoURL, Video: &schema.MessageOutputVideo{}}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on unsupported type", func() {
msg := &schema.Message{Role: schema.Assistant, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: "unsupported"}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
})

PatchConvey("MultiContent (Legacy)", func() {
PatchConvey("Success with all types", func() {
msg := &schema.Message{
MultiContent: []schema.ChatMessagePart{
{Type: schema.ChatMessagePartTypeText, Text: "some text"},
{Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}},
{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: videoURL}},
},
}
content, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldBeNil)
convey.So(content.ListValue, convey.ShouldHaveLength, 3)
})
PatchConvey("Error on nil ImageURL", func() {
msg := &schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeImageURL, ImageURL: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
PatchConvey("Error on nil VideoURL", func() {
msg := &schema.Message{MultiContent: []schema.ChatMessagePart{{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: nil}}}
_, err := cm.toArkContent(msg)
convey.So(err, convey.ShouldNotBeNil)
})
})
})
}
2 changes: 1 addition & 1 deletion components/model/ark/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ func (cm *ChatModel) createContextByContextAPI(ctx context.Context, prefix []*sc
TruncationStrategy: truncation,
}
for _, msg := range prefix {
content, err := cm.chatModel.toArkContent(msg.Content, msg.MultiContent)
content, err := cm.chatModel.toArkContent(msg)
if err != nil {
return nil, fmt.Errorf("convert message fail: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ func main() {
}

multiModalMsg := schema.UserMessage("")
multiModalMsg.MultiContent = []schema.ChatMessagePart{
var url = "https://d2908q01vomqb2.cloudfront.net/887309d048beef83ad3eabf2a79a64a389ab1c9f/2023/07/13/DBBLOG-3334-image001.png"
multiModalMsg.UserInputMultiContent = []schema.MessageInputPart{
{
Type: schema.ChatMessagePartTypeText,
Text: "this picture is LangDChain's architecture, what's the picture's content",
},
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: "https://d2908q01vomqb2.cloudfront.net/887309d048beef83ad3eabf2a79a64a389ab1c9f/2023/07/13/DBBLOG-3334-image001.png",
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{
URL: &url,
},
Detail: schema.ImageURLDetailAuto,
},
},
Expand Down
Loading
Loading