Skip to content

Commit d634baf

Browse files
authored
feat(acl_openai): sort the required fields in the openapi3 schema in … (#374)
1 parent 2ccbac3 commit d634baf

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

libs/acl/openai/chat_model.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"io"
2424
"net/http"
2525
"runtime/debug"
26+
"slices"
2627

2728
"github.com/getkin/kin-openapi/openapi3"
2829
"github.com/meguminnnnnnnnn/go-openai"
@@ -791,6 +792,33 @@ func resolveStreamResponse(resp openai.ChatCompletionStreamResponse) (msg *schem
791792
}
792793

793794
func toTools(tis []*schema.ToolInfo) ([]tool, error) {
795+
var sortArrayFields func(*openapi3.Schema)
796+
sortArrayFields = func(sc *openapi3.Schema) {
797+
if sc == nil {
798+
return
799+
}
800+
switch sc.Type {
801+
case openapi3.TypeObject:
802+
if len(sc.Required) == 0 {
803+
return
804+
}
805+
806+
slices.Sort(sc.Required)
807+
808+
for _, v := range sc.Properties {
809+
sortArrayFields(v.Value)
810+
}
811+
812+
case openapi3.TypeArray:
813+
if sc.Items != nil && sc.Items.Value != nil {
814+
sortArrayFields(sc.Items.Value)
815+
}
816+
817+
default:
818+
return
819+
}
820+
}
821+
794822
tools := make([]tool, len(tis))
795823
for i := range tis {
796824
ti := tis[i]
@@ -803,6 +831,8 @@ func toTools(tis []*schema.ToolInfo) ([]tool, error) {
803831
return nil, fmt.Errorf("failed to convert tool parameters to JSONSchema: %w", err)
804832
}
805833

834+
sortArrayFields(paramsJSONSchema)
835+
806836
tools[i] = tool{
807837
Function: &functionDefinition{
808838
Name: ti.Name,

libs/acl/openai/chat_model_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"math/rand"
2121
"testing"
2222

23+
"github.com/bytedance/mockey"
2324
goopenai "github.com/meguminnnnnnnnn/go-openai"
2425
"github.com/stretchr/testify/assert"
2526

@@ -186,3 +187,54 @@ func TestClientWithExtraHeader(t *testing.T) {
186187
WithExtraHeader(map[string]string{"test": "test"}),
187188
}), 1)
188189
}
190+
191+
func TestToTools(t *testing.T) {
192+
mockey.PatchConvey("", t, func() {
193+
mockTools := []*schema.ToolInfo{
194+
{
195+
Name: "test tool name",
196+
Desc: "description of test tool",
197+
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
198+
"126": {
199+
Type: schema.String,
200+
Required: true,
201+
},
202+
"123": {
203+
Type: schema.Array,
204+
Required: true,
205+
ElemInfo: &schema.ParameterInfo{
206+
Type: schema.Object,
207+
Required: true,
208+
SubParams: map[string]*schema.ParameterInfo{
209+
"459": {
210+
Type: schema.String,
211+
Required: true,
212+
},
213+
"458": {
214+
Type: schema.String,
215+
Required: true,
216+
},
217+
"457": {
218+
Type: schema.String,
219+
Required: true,
220+
},
221+
},
222+
},
223+
},
224+
"129": {
225+
Type: schema.Object,
226+
Required: true,
227+
},
228+
}),
229+
},
230+
}
231+
232+
tools, err := toTools(mockTools)
233+
assert.Nil(t, err)
234+
assert.Len(t, tools, 1)
235+
236+
sc := tools[0].Function.Parameters
237+
assert.Equal(t, []string{"123", "126", "129"}, sc.Required)
238+
assert.Equal(t, []string{"457", "458", "459"}, sc.Properties["123"].Value.Items.Value.Required)
239+
})
240+
}

0 commit comments

Comments
 (0)