Skip to content

Commit 6420740

Browse files
committed
support multi modal inputs
1 parent c8aa470 commit 6420740

File tree

4 files changed

+106
-21
lines changed

4 files changed

+106
-21
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
216216
TargetModel: "test-model1",
217217
Body: &types.LLMRequestBody{
218218
ChatCompletions: &types.ChatCompletionsRequest{
219-
Messages: []types.Message{
219+
Messages: []types.Message[string]{
220220
{Role: "user", Content: "hello world"},
221221
{Role: "assistant", Content: "hi there"},
222222
},
@@ -251,7 +251,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
251251
TargetModel: "test-model1",
252252
Body: &types.LLMRequestBody{
253253
ChatCompletions: &types.ChatCompletionsRequest{
254-
Messages: []types.Message{
254+
Messages: []types.Message[string]{
255255
{Role: "system", Content: "You are a helpful assistant"},
256256
{Role: "user", Content: "Hello, how are you?"},
257257
},
@@ -284,7 +284,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
284284
TargetModel: "test-model1",
285285
Body: &types.LLMRequestBody{
286286
ChatCompletions: &types.ChatCompletionsRequest{
287-
Messages: []types.Message{
287+
Messages: []types.Message[string]{
288288
{Role: "system", Content: "You are a helpful assistant"},
289289
{Role: "user", Content: "Hello, how are you?"},
290290
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
@@ -317,7 +317,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
317317
TargetModel: "test-model1",
318318
Body: &types.LLMRequestBody{
319319
ChatCompletions: &types.ChatCompletionsRequest{
320-
Messages: []types.Message{
320+
Messages: []types.Message[string]{
321321
{Role: "system", Content: "You are a helpful assistant"},
322322
{Role: "user", Content: "Hello, how are you?"},
323323
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
@@ -442,16 +442,16 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
442442
for _, scenario := range scenarios {
443443
b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) {
444444
// Generate messages for this scenario
445-
messages := make([]types.Message, scenario.messageCount)
446-
messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."}
445+
messages := make([]types.Message[string], scenario.messageCount)
446+
messages[0] = types.Message[string]{Role: "system", Content: "You are a helpful assistant."}
447447

448448
for i := 1; i < scenario.messageCount; i++ {
449449
role := "user"
450450
if i%2 == 0 {
451451
role = "assistant"
452452
}
453453
content := randomPrompt(scenario.messageLength)
454-
messages[i] = types.Message{Role: role, Content: content}
454+
messages[i] = types.Message[string]{Role: role, Content: content}
455455
}
456456

457457
pod := &types.PodMetrics{

pkg/epp/scheduling/types/types.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package types
1818

1919
import (
20+
"encoding/json"
2021
"fmt"
2122

2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
@@ -48,12 +49,14 @@ func (r *LLMRequest) String() string {
4849

4950
// LLMRequestBody contains the request-body fields that we parse out as user input,
5051
// to be used in forming scheduling decisions.
51-
// An LLMRequestBody must contain exactly one of CompletionsRequest or ChatCompletionsRequest.
52+
// An LLMRequestBody must contain exactly one of CompletionsRequest,ChatCompletionsRequest or MultiModalChatCompletions.
5253
type LLMRequestBody struct {
5354
// CompletionsRequest is the representation of the OpenAI /v1/completions request body.
5455
Completions *CompletionsRequest `json:"completions,omitempty"`
5556
// ChatCompletionsRequest is the representation of the OpenAI /v1/chat_completions request body.
5657
ChatCompletions *ChatCompletionsRequest `json:"chat_completions,omitempty"`
58+
// MultiModalChatCompletionsRequest is the representation of the OpenAI /v1/chat/completions request body.
59+
MultiModalChatCompletions *MultiModalChatCompletionsRequest `json:"multi_modal_chat_completions,omitempty"`
5760
}
5861

5962
// CompletionsRequest is a structured representation of the fields we parse out of the
@@ -79,8 +82,8 @@ func (r *CompletionsRequest) String() string {
7982
// API spec.
8083
type ChatCompletionsRequest struct {
8184
/* parameters from the official OpenAI chat-completions API */
82-
Messages []Message `json:"messages,omitempty"`
83-
Tools []interface{} `json:"tools,omitempty"`
85+
Messages []Message[string] `json:"messages,omitempty"`
86+
Tools []interface{} `json:"tools,omitempty"`
8487
/* parameters from the HuggingFace transformers chat-templates API */
8588
Documents []interface{} `json:"documents,omitempty"`
8689
ChatTemplate string `json:"chat_template,omitempty"`
@@ -97,16 +100,52 @@ func (r *ChatCompletionsRequest) String() string {
97100

98101
messagesLen := 0
99102
for _, msg := range r.Messages {
100-
messagesLen += len(msg.Content)
103+
data, _ := json.Marshal(msg.Content)
104+
messagesLen += len(data)
105+
}
106+
107+
return fmt.Sprintf("{MessagesLength: %d}", messagesLen)
108+
}
109+
110+
// MultiModalChatCompletionsRequest is a structured representation of the fields we parse out of the
111+
// /v1/chat/completions request body.
112+
// This struct includes fields usable for plugins and scheduling decisions - and not the entire
113+
// API spec.
114+
type MultiModalChatCompletionsRequest struct {
115+
/* parameters from the official OpenAI chat-completions API */
116+
Messages []Message[map[string]interface{}] `json:"messages,omitempty"`
117+
Tools []interface{} `json:"tools,omitempty"`
118+
/* parameters from the HuggingFace transformers chat-templates API */
119+
Documents []interface{} `json:"documents,omitempty"`
120+
ChatTemplate string `json:"chat_template,omitempty"`
121+
ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"`
122+
ContinueFinalMessage bool `json:"continue_final_message,omitempty"`
123+
AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"`
124+
ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
125+
}
126+
127+
func (r *MultiModalChatCompletionsRequest) String() string {
128+
if r == nil {
129+
return nilString
130+
}
131+
132+
messagesLen := 0
133+
for _, msg := range r.Messages {
134+
data, _ := json.Marshal(msg.Content)
135+
messagesLen += len(data)
101136
}
102137

103138
return fmt.Sprintf("{MessagesLength: %d}", messagesLen)
104139
}
105140

106141
// Message represents a single message in a chat-completions request.
107-
type Message struct {
142+
type Message[T ContentConstraint] struct {
108143
Role string
109-
Content string // TODO: support multi-modal content
144+
Content T
145+
}
146+
147+
type ContentConstraint interface {
148+
string | map[string]any
110149
}
111150

112151
type Pod interface {

pkg/epp/util/request/body.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package request
1818

1919
import (
2020
"encoding/json"
21+
"errors"
2122

2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2324
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
@@ -39,21 +40,36 @@ func ExtractRequestBody(rawBody map[string]any) (*types.LLMRequestBody, error) {
3940

4041
// Try chat completions
4142
var chatCompletions types.ChatCompletionsRequest
42-
if err = json.Unmarshal(jsonBytes, &chatCompletions); err != nil {
43-
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid request format"}
43+
if err = json.Unmarshal(jsonBytes, &chatCompletions); err == nil {
44+
if err = validateChatCompletionsMessages(chatCompletions.Messages); err != nil {
45+
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid chat-completions request: " + err.Error()}
46+
}
47+
return &types.LLMRequestBody{ChatCompletions: &chatCompletions}, nil
4448
}
4549

46-
if err = validateChatCompletionsMessages(chatCompletions.Messages); err != nil {
47-
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid chat-completions request: " + err.Error()}
50+
// Try chat completions
51+
var multiModalChatCompletions types.MultiModalChatCompletionsRequest
52+
if err = json.Unmarshal(jsonBytes, &multiModalChatCompletions); err == nil {
53+
if err = validateMultiModalChatCompletionsMessages(multiModalChatCompletions.Messages); err != nil {
54+
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid multi model chat-completions request: " + err.Error()}
55+
}
56+
return &types.LLMRequestBody{MultiModalChatCompletions: &multiModalChatCompletions}, nil
4857
}
4958

50-
return &types.LLMRequestBody{ChatCompletions: &chatCompletions}, nil
59+
return nil, errors.New("invalid request body")
5160
}
5261

53-
func validateChatCompletionsMessages(messages []types.Message) error {
62+
func validateChatCompletionsMessages(messages []types.Message[string]) error {
5463
if len(messages) == 0 {
5564
return errutil.Error{Code: errutil.BadRequest, Msg: "chat-completions request must have at least one message"}
5665
}
5766

5867
return nil
5968
}
69+
70+
func validateMultiModalChatCompletionsMessages(messages []types.Message[map[string]interface{}]) error {
71+
if len(messages) == 0 {
72+
return errutil.Error{Code: errutil.BadRequest, Msg: "multi modal chat-completions request must have at least one message"}
73+
}
74+
return nil
75+
}

pkg/epp/util/request/body_test.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,43 @@ func TestExtractRequestData(t *testing.T) {
5757
},
5858
want: &types.LLMRequestBody{
5959
ChatCompletions: &types.ChatCompletionsRequest{
60-
Messages: []types.Message{
60+
Messages: []types.Message[string]{
6161
{Role: "system", Content: "this is a system message"},
6262
{Role: "user", Content: "hello"},
6363
},
6464
},
6565
},
6666
},
67+
{
68+
name: "chat completions request body with multi-modal content",
69+
body: map[string]any{
70+
"model": "test",
71+
"messages": []any{
72+
map[string]any{
73+
"role": "system",
74+
"content": map[string]any{
75+
"type": "text",
76+
"text": "Describe this image in one sentence.",
77+
},
78+
},
79+
map[string]any{
80+
"role": "user",
81+
"content": map[string]any{
82+
"type": "image_url",
83+
"image_url": "https://example.com/images/dui.jpg.",
84+
},
85+
},
86+
},
87+
},
88+
want: &types.LLMRequestBody{
89+
MultiModalChatCompletions: &types.MultiModalChatCompletionsRequest{
90+
Messages: []types.Message[map[string]any]{
91+
{Role: "system", Content: map[string]any{"type": "text", "text": "Describe this image in one sentence."}},
92+
{Role: "user", Content: map[string]any{"type": "image_url", "image_url": "https://example.com/images/dui.jpg."}},
93+
},
94+
},
95+
},
96+
},
6797
{
6898
name: "chat completions with all optional fields",
6999
body: map[string]any{
@@ -81,7 +111,7 @@ func TestExtractRequestData(t *testing.T) {
81111
},
82112
want: &types.LLMRequestBody{
83113
ChatCompletions: &types.ChatCompletionsRequest{
84-
Messages: []types.Message{{Role: "user", Content: "hello"}},
114+
Messages: []types.Message[string]{{Role: "user", Content: "hello"}},
85115
Tools: []any{map[string]any{"type": "function"}},
86116
Documents: []any{map[string]any{"content": "doc"}},
87117
ChatTemplate: "custom template",

0 commit comments

Comments
 (0)