From 359d87efb8b04d2c8d87eb7e30538ae364325f26 Mon Sep 17 00:00:00 2001 From: Simon_Morphy Date: Sun, 4 May 2025 19:33:13 +0800 Subject: [PATCH] feat: add model support of x.ai's grok --- components/model/grok/example/grok.go | 230 +++++++++ components/model/grok/go.mod | 40 ++ components/model/grok/go.sum | 159 +++++++ components/model/grok/grok.go | 647 ++++++++++++++++++++++++++ 4 files changed, 1076 insertions(+) create mode 100644 components/model/grok/example/grok.go create mode 100644 components/model/grok/go.mod create mode 100644 components/model/grok/go.sum create mode 100644 components/model/grok/grok.go diff --git a/components/model/grok/example/grok.go b/components/model/grok/example/grok.go new file mode 100644 index 000000000..ec8a69d84 --- /dev/null +++ b/components/model/grok/example/grok.go @@ -0,0 +1,230 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "io" + "log" + "os" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/getkin/kin-openapi/openapi3" + + "github.com/cloudwego/eino-ext/components/model/grok" +) + +// Helper function to create pointers +func ptrOf[T any](v T) *T { + return &v +} + +func main() { + ctx := context.Background() + apiKey := os.Getenv("GROK_API_KEY") + if apiKey == "" { + log.Fatal("GROK_API_KEY environment variable is not set") + } + + // Create a new Grok model + cm, err := grok.NewChatModel(ctx, &grok.Config{ + APIKey: apiKey, + Model: "grok-3", + MaxTokens: ptrOf(2000), + }) + if err != nil { + log.Fatalf("NewChatModel of grok failed, err=%v", err) + } + + fmt.Println("\n=== Basic Chat ===") + basicChat(ctx, cm) + + fmt.Println("\n=== Streaming Chat ===") + streamingChat(ctx, cm) + + fmt.Println("\n=== Function Calling ===") + functionCalling(ctx, cm) + + fmt.Println("\n=== Advanced Options ===") + advancedOptions(ctx, cm) +} + +func basicChat(ctx context.Context, cm model.ChatModel) { + messages := []*schema.Message{ + { + Role: schema.System, + Content: "You are a helpful AI assistant. Be concise in your responses.", + }, + { + Role: schema.User, + Content: "What is the capital of France?", + }, + } + + resp, err := cm.Generate(ctx, messages) + if err != nil { + log.Printf("Generate error: %v", err) + return + } + + fmt.Printf("Assistant: %s\n", resp.Content) + if resp.ResponseMeta != nil && resp.ResponseMeta.Usage != nil { + fmt.Printf("Tokens used: %d (prompt) + %d (completion) = %d (total)\n", + resp.ResponseMeta.Usage.PromptTokens, + resp.ResponseMeta.Usage.CompletionTokens, + resp.ResponseMeta.Usage.TotalTokens) + } +} + +func streamingChat(ctx context.Context, cm model.ChatModel) { + messages := []*schema.Message{ + { + Role: schema.User, + Content: "Write a short poem about spring, word by word.", + }, + } + + stream, err := cm.Stream(ctx, messages) + if err != nil { + log.Printf("Stream error: %v", err) + return + } + + fmt.Print("Assistant: ") + for { + resp, err := stream.Recv() + if err == io.EOF { + // 正常结束,不需要报错 + break + } + if err != nil { + log.Printf("Stream receive error: %v", err) + return + } + fmt.Print(resp.Content) + } + fmt.Println() +} + +func functionCalling(ctx context.Context, cm model.ChatModel) { + // Bind tools to the model + err := cm.BindTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "Get current weather information for a city", + ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(&openapi3.Schema{ + Type: "object", + Properties: map[string]*openapi3.SchemaRef{ + "city": { + Value: &openapi3.Schema{ + Type: "string", + Description: "The city name", + }, + }, + "unit": { + Value: &openapi3.Schema{ + Type: "string", + Enum: []interface{}{"celsius", "fahrenheit"}, + }, + }, + }, + Required: []string{"city"}, + }), + }, + }) + if err != nil { + log.Printf("Bind tools error: %v", err) + return + } + + // Stream the response with a function call + streamResp, err := cm.Stream(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "What's the weather like in Paris today? Please use Celsius.", + }, + }) + if err != nil { + log.Printf("Generate error: %v", err) + return + } + + msgs := make([]*schema.Message, 0) + for { + msg, err := streamResp.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Printf("Stream receive error: %v", err) + return + } + msgs = append(msgs, msg) + } + resp, err := schema.ConcatMessages(msgs) + if err != nil { + log.Printf("Concat error: %v", err) + return + } + + if len(resp.ToolCalls) > 0 { + fmt.Printf("Function called: %s\n", resp.ToolCalls[0].Function.Name) + fmt.Printf("Arguments: %s\n", resp.ToolCalls[0].Function.Arguments) + + // Handle the function call with a mock response + weatherResp, err := cm.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "What's the weather like in Paris today? Please use Celsius.", + }, + resp, + { + Role: schema.Tool, + ToolCallID: resp.ToolCalls[0].ID, + Content: `{"temperature": 18, "condition": "sunny"}`, + }, + }) + if err != nil { + log.Printf("Generate error: %v", err) + return + } + fmt.Printf("Final response: %s\n", weatherResp.Content) + } else { + fmt.Printf("No function was called. Response: %s\n", resp.Content) + } +} + +// Advanced example showing TopK parameter usage +func advancedOptions(ctx context.Context, cm model.ChatModel) { + messages := []*schema.Message{ + { + Role: schema.User, + Content: "Generate 5 creative business ideas.", + }, + } + + // Using TopK parameter to control diversity of tokens + resp, err := cm.Generate(ctx, messages, grok.WithTopK(50)) + if err != nil { + log.Printf("Generate error: %v", err) + return + } + + fmt.Printf("Assistant (with TopK=50): %s\n", resp.Content) +} diff --git a/components/model/grok/go.mod b/components/model/grok/go.mod new file mode 100644 index 000000000..b842c3802 --- /dev/null +++ b/components/model/grok/go.mod @@ -0,0 +1,40 @@ +module github.com/cloudwego/eino-ext/components/model/grok + +go 1.24.1 + +require ( + github.com/SimonMorphy/grok-go v1.0.0 + github.com/cloudwego/eino v0.3.27 + github.com/getkin/kin-openapi v0.118.0 +) + +require ( + github.com/bytedance/sonic v1.13.2 // indirect + github.com/bytedance/sonic/loader v0.2.4 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/swag v0.19.5 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/invopop/yaml v0.1.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/perimeterx/marshmallow v1.1.4 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.26.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/model/grok/go.sum b/components/model/grok/go.sum new file mode 100644 index 000000000..a2b5349e6 --- /dev/null +++ b/components/model/grok/go.sum @@ -0,0 +1,159 @@ +github.com/SimonMorphy/grok-go v0.0.0-20250503133121-1ae8bb12750c h1:pYRoauOKeH9MSiRUb1FY1tPWFXENRAfESYwDyxI5zzw= +github.com/SimonMorphy/grok-go v0.0.0-20250503133121-1ae8bb12750c/go.mod h1:kG7gue5Rd9eso7JgME6wvT3pvwVa8AdOi0fUFguE4Pw= +github.com/SimonMorphy/grok-go v1.0.0 h1:h5pxFEutYnTUNpTfx5c4Hxno//CXxAoPUaUV2rE0QPk= +github.com/SimonMorphy/grok-go v1.0.0/go.mod h1:kG7gue5Rd9eso7JgME6wvT3pvwVa8AdOi0fUFguE4Pw= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= +github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= +github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/eino v0.3.27 h1:Oz4HcuivJyb+zT0W43Gmtb6wqmXZaYel0CS4iF6XsoI= +github.com/cloudwego/eino v0.3.27/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/components/model/grok/grok.go b/components/model/grok/grok.go new file mode 100644 index 000000000..b4768aec5 --- /dev/null +++ b/components/model/grok/grok.go @@ -0,0 +1,647 @@ +package grok + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "runtime/debug" + "time" + + grokgo "github.com/SimonMorphy/grok-go" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +var _ model.ToolCallingChatModel = (*ChatModel)(nil) + +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Config contains the configuration options for the Grok model +type Config struct { + // APIKey is your X.AI API key + // Required + APIKey string `json:"api_key"` + + // BaseURL is the custom API endpoint URL + // Optional. Default: "https://api.x.ai/v1/" + BaseURL *string `json:"base_url,omitempty"` + + // Timeout specifies the maximum duration to wait for API responses + // Optional. Default: 30 seconds + Timeout time.Duration `json:"timeout,omitempty"` + + // Model specifies which Grok model to use + // Required. Example: "grok-3-beta" + Model string `json:"model"` + + // MaxTokens limits the maximum number of tokens in the response + // Optional. Example: 1000 + MaxTokens *int `json:"max_tokens,omitempty"` + + // Temperature controls randomness in responses + // Range: [0.0, 2.0], where 0.0 is more focused and 2.0 is more creative + // Optional. Example: float32(0.7) + Temperature *float32 `json:"temperature,omitempty"` + + // TopP controls diversity via nucleus sampling + // Range: [0.0, 1.0], where 1.0 disables nucleus sampling + // Optional. Example: float32(0.95) + TopP *float32 `json:"top_p,omitempty"` + + // TopK controls diversity by limiting the top K tokens to sample from + // Optional. Example: 40 + TopK *int `json:"top_k,omitempty"` + + // Stop sequences where the API will stop generating further tokens + // Optional. Example: []string{"\n", "User:"} + Stop []string `json:"stop,omitempty"` + + // HTTPClient specifies the client to send HTTP requests + // Optional. + HTTPClient *http.Client `json:"http_client,omitempty"` +} + +// ChatModel represents a Grok chat model client. +type ChatModel struct { + cli *grokgo.Client + + model string + maxTokens *int + topP *float32 + temperature *float32 + topK *int + stop []string + oriTools []*schema.ToolInfo + tools []grokgo.Tool + toolChoice *schema.ToolChoice +} + +// NewChatModel creates a new Grok chat model instance +// +// Parameters: +// - ctx: The context for the operation +// - conf: Configuration for the Grok model +// +// Returns: +// - model.ChatModel: A chat model interface implementation +// - error: Any error that occurred during creation +// +// Example: +// +// model, err := grok.NewChatModel(ctx, &grok.Config{ +// APIKey: "your-api-key", +// Model: "grok-3-beta", +// MaxTokens: 1000, +// }) +func NewChatModel(ctx context.Context, config *Config) (*ChatModel, error) { + if config.APIKey == "" { + return nil, errors.New("api key is required") + } + if config.Model == "" { + return nil, errors.New("model is required") + } + + var opts []grokgo.ClientOption + if config.BaseURL != nil { + opts = append(opts, grokgo.WithBaseURL(*config.BaseURL)) + } + if config.Timeout > 0 { + opts = append(opts, grokgo.WithTimeout(config.Timeout)) + } + if config.HTTPClient != nil { + opts = append(opts, grokgo.WithHTTPClient(config.HTTPClient)) + } + + client, err := grokgo.NewClientWithOptions(config.APIKey, opts...) + if err != nil { + return nil, fmt.Errorf("create grok client fail: %w", err) + } + + return &ChatModel{ + cli: client, + model: config.Model, + maxTokens: config.MaxTokens, + temperature: config.Temperature, + topP: config.TopP, + topK: config.TopK, + stop: config.Stop, + }, nil +} + +func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + callbackInput := cm.getCallbackInput(input, opts...) + ctx = callbacks.OnStart(ctx, callbackInput) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + // Prepare request parameters + req, err := cm.createChatCompletionRequest(input, opts...) + if err != nil { + return nil, err + } + + // Call API + resp, err := grokgo.CreateChatCompletion(ctx, cm.cli, req) + if err != nil { + return nil, fmt.Errorf("create chat completion fail: %w", err) + } + + // Convert response to schema message + message, err = cm.convertResponseToMessage(resp) + if err != nil { + return nil, fmt.Errorf("convert response to schema message fail: %w", err) + } + + callbacks.OnEnd(ctx, cm.getCallbackOutput(message)) + return message, nil +} + +func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + callbackInput := cm.getCallbackInput(input, opts...) + ctx = callbacks.OnStart(ctx, callbackInput) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + // Prepare request parameters + req, err := cm.createChatCompletionRequest(input, opts...) + if err != nil { + return nil, err + } + req.Stream = true + + // Call API with streaming + stream, err := grokgo.CreateChatCompletionStream(ctx, cm.cli, req) + if err != nil { + return nil, fmt.Errorf("create chat completion stream fail: %w", err) + } + + sr, sw := schema.Pipe[*model.CallbackOutput](1) + go func() { + defer func() { + panicErr := recover() + _ = stream.Close() + + if panicErr != nil { + _ = sw.Send(nil, newPanicErr(panicErr, debug.Stack())) + } + + sw.Close() + }() + + var waitList []*schema.Message + for { + chunk, chunkErr := stream.Recv() + if errors.Is(chunkErr, io.EOF) { + return + } + if chunkErr != nil { + _ = sw.Send(nil, fmt.Errorf("receive stream chunk fail: %w", chunkErr)) + return + } + + message, err := cm.convertStreamResponseToMessage(chunk) + if err != nil { + _ = sw.Send(nil, fmt.Errorf("convert stream response to schema message fail: %w", err)) + return + } + + if message == nil { + continue + } + + if isMessageEmpty(message) { + waitList = append(waitList, message) + continue + } + + if len(waitList) != 0 { + message, err = schema.ConcatMessages(append(waitList, message)) + if err != nil { + _ = sw.Send(nil, fmt.Errorf("concat empty message fail: %w", err)) + return + } + waitList = []*schema.Message{} + } + + closed := sw.Send(cm.getCallbackOutput(message), nil) + if closed { + return + } + } + }() + + _, sr = callbacks.OnEndWithStreamOutput(ctx, sr) + return schema.StreamReaderWithConvert(sr, func(t *model.CallbackOutput) (*schema.Message, error) { + return t.Message, nil + }), nil +} + +func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + if len(tools) == 0 { + return nil, errors.New("no tools to bind") + } + grokTools, err := cm.toGrokTools(tools) + if err != nil { + return nil, fmt.Errorf("convert to grok tools fail: %w", err) + } + + tc := schema.ToolChoiceAllowed + ncm := *cm + ncm.tools = grokTools + ncm.oriTools = tools + ncm.toolChoice = &tc + return &ncm, nil +} + +func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error { + if len(tools) == 0 { + return errors.New("no tools to bind") + } + grokTools, err := cm.toGrokTools(tools) + if err != nil { + return fmt.Errorf("convert to grok tools fail: %w", err) + } + + cm.tools = grokTools + cm.oriTools = tools + tc := schema.ToolChoiceAllowed + cm.toolChoice = &tc + return nil +} + +func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error { + if len(tools) == 0 { + return errors.New("no tools to bind") + } + grokTools, err := cm.toGrokTools(tools) + if err != nil { + return fmt.Errorf("convert to grok tools fail: %w", err) + } + + cm.tools = grokTools + cm.oriTools = tools + tc := schema.ToolChoiceForced + cm.toolChoice = &tc + return nil +} + +func (cm *ChatModel) toGrokTools(tools []*schema.ToolInfo) ([]grokgo.Tool, error) { + result := make([]grokgo.Tool, 0, len(tools)) + for _, tool := range tools { + s, err := tool.ToOpenAPIV3() + if err != nil { + return nil, fmt.Errorf("convert to openapi v3 schema fail: %w", err) + } + + // Convert OpenAPI schema to Grok function parameters + params := &grokgo.FunctionParameters{ + Type: "object", + Properties: make(map[string]interface{}), + } + + if s.Properties != nil { + for name, prop := range s.Properties { + params.Properties[name] = prop.Value + } + } + + if len(s.Required) > 0 { + params.Required = s.Required + } + + result = append(result, grokgo.Tool{ + Type: "function", + Function: grokgo.Function{ + Name: tool.Name, + Description: tool.Desc, + Parameters: params, + }, + }) + } + + return result, nil +} + +func (cm *ChatModel) createChatCompletionRequest(input []*schema.Message, opts ...model.Option) (*grokgo.ChatCompletionRequest, error) { + if len(input) == 0 { + return nil, errors.New("input is empty") + } + + commonOptions := model.GetCommonOptions(&model.Options{ + Model: &cm.model, + Temperature: cm.temperature, + MaxTokens: cm.maxTokens, + TopP: cm.topP, + Stop: cm.stop, + Tools: nil, + ToolChoice: cm.toolChoice, + }, opts...) + + grokOptions := model.GetImplSpecificOptions(&options{ + TopK: cm.topK, + }, opts...) + + req := &grokgo.ChatCompletionRequest{ + Model: *commonOptions.Model, + } + + if commonOptions.MaxTokens != nil { + req.MaxTokens = *commonOptions.MaxTokens + } + + if commonOptions.Temperature != nil { + req.Temperature = float64(*commonOptions.Temperature) + } + + if commonOptions.TopP != nil { + req.TopP = float64(*commonOptions.TopP) + } + + if len(commonOptions.Stop) > 0 { + req.Stop = commonOptions.Stop + } + + if grokOptions.TopK != nil { + req.TopK = *grokOptions.TopK + } + + // Handle tools + tools := cm.tools + if commonOptions.Tools != nil { + var err error + if tools, err = cm.toGrokTools(commonOptions.Tools); err != nil { + return nil, err + } + } + + if len(tools) > 0 { + req.Tools = tools + } + + // Handle tool choice + if commonOptions.ToolChoice != nil { + switch *commonOptions.ToolChoice { + case schema.ToolChoiceForbidden: + req.ToolChoice = "none" + case schema.ToolChoiceAllowed: + req.ToolChoice = "auto" + case schema.ToolChoiceForced: + if len(tools) == 0 { + return nil, errors.New("tool choice is forced but tool is not provided") + } else if len(tools) == 1 { + req.ToolChoice = map[string]interface{}{ + "type": "function", + "function": map[string]string{"name": tools[0].Function.Name}, + } + } else { + req.ToolChoice = "required" + } + default: + return nil, fmt.Errorf("tool choice=%s not support", *commonOptions.ToolChoice) + } + } + + // Convert messages + messages, err := cm.convertMessagesToGrok(input) + if err != nil { + return nil, err + } + req.Messages = messages + + return req, nil +} + +func (cm *ChatModel) convertMessagesToGrok(messages []*schema.Message) ([]grokgo.ChatCompletionMessage, error) { + result := make([]grokgo.ChatCompletionMessage, 0, len(messages)) + for _, msg := range messages { + grokMsg := grokgo.ChatCompletionMessage{ + Role: convertRole(msg.Role), + Content: msg.Content, + } + + // Handle tool calls + if len(msg.ToolCalls) > 0 { + grokMsg.ToolCalls = make([]grokgo.APIToolCall, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + grokMsg.ToolCalls = append(grokMsg.ToolCalls, grokgo.APIToolCall{ + ID: tc.ID, + Type: "function", + Function: grokgo.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + } + + // Handle tool response + if msg.Role == schema.Tool && msg.ToolCallID != "" { + grokMsg.ToolCallID = msg.ToolCallID + } + + result = append(result, grokMsg) + } + return result, nil +} + +func (cm *ChatModel) convertResponseToMessage(resp *grokgo.Response) (*schema.Message, error) { + if len(resp.Choices) == 0 { + return nil, errors.New("no choices in response") + } + + choice := resp.Choices[0] + message := &schema.Message{ + Role: schema.Assistant, + Content: choice.Message.Content, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: choice.FinishReason, + Usage: &schema.TokenUsage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + }, + }, + } + + // Handle tool calls + if len(choice.Message.ToolCalls) > 0 { + message.ToolCalls = make([]schema.ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + message.ToolCalls = append(message.ToolCalls, schema.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: schema.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + } + + return message, nil +} + +func (cm *ChatModel) convertStreamResponseToMessage(resp *grokgo.StreamResponse) (*schema.Message, error) { + if len(resp.Choices) == 0 { + return nil, nil + } + + choice := resp.Choices[0] + message := &schema.Message{ + Role: schema.Assistant, + Content: choice.Delta.Content, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: choice.FinishReason, + }, + } + + if resp.Usage != nil { + message.ResponseMeta.Usage = &schema.TokenUsage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + } + + // Handle tool calls + if len(choice.Delta.ToolCalls) > 0 { + message.ToolCalls = make([]schema.ToolCall, 0, len(choice.Delta.ToolCalls)) + for _, tc := range choice.Delta.ToolCalls { + message.ToolCalls = append(message.ToolCalls, schema.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: schema.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + } + + return message, nil +} + +func (cm *ChatModel) getCallbackInput(input []*schema.Message, opts ...model.Option) *model.CallbackInput { + result := &model.CallbackInput{ + Messages: input, + Tools: model.GetCommonOptions(&model.Options{ + Tools: cm.oriTools, + }, opts...).Tools, + Config: cm.getConfig(), + } + return result +} + +func (cm *ChatModel) getCallbackOutput(output *schema.Message) *model.CallbackOutput { + result := &model.CallbackOutput{ + Message: output, + Config: cm.getConfig(), + } + if output.ResponseMeta != nil && output.ResponseMeta.Usage != nil { + result.TokenUsage = &model.TokenUsage{ + PromptTokens: output.ResponseMeta.Usage.PromptTokens, + CompletionTokens: output.ResponseMeta.Usage.CompletionTokens, + TotalTokens: output.ResponseMeta.Usage.TotalTokens, + } + } + return result +} + +func (cm *ChatModel) getConfig() *model.Config { + result := &model.Config{ + Model: cm.model, + Stop: cm.stop, + } + if cm.maxTokens != nil { + result.MaxTokens = *cm.maxTokens + } + if cm.temperature != nil { + result.Temperature = *cm.temperature + } + if cm.topP != nil { + result.TopP = *cm.topP + } + return result +} + +func (cm *ChatModel) GetType() string { + return "Grok" +} + +func (cm *ChatModel) IsCallbacksEnabled() bool { + return true +} + +func convertRole(role schema.RoleType) string { + switch role { + case schema.Assistant: + return "assistant" + case schema.System: + return "system" + case schema.User: + return "user" + case schema.Tool: + return "tool" + default: + return string(role) + } +} + +func isMessageEmpty(message *schema.Message) bool { + return len(message.Content) == 0 && len(message.ToolCalls) == 0 && len(message.MultiContent) == 0 +} + +// options holds implementation-specific options for Grok +type options struct { + // TopK controls diversity by limiting the top K tokens to sample from + TopK *int +} + +// WithTopK sets the TopK parameter for the Grok model +func WithTopK(topK int) model.Option { + return model.WrapImplSpecificOptFn(func(o *options) { + o.TopK = &topK + }) +} + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +func newPanicErr(info any, stack []byte) error { + return &panicErr{ + info: info, + stack: stack, + } +}