Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions components/model/grok/example/grok.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package main

import (
"context"
"fmt"
"io"
"log"
"os"

"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3"

"github.com/cloudwego/eino-ext/components/model/grok"
)

// Helper function to create pointers
func ptrOf[T any](v T) *T {
return &v
}

func main() {
ctx := context.Background()
apiKey := os.Getenv("GROK_API_KEY")
if apiKey == "" {
log.Fatal("GROK_API_KEY environment variable is not set")
}

// Create a new Grok model
cm, err := grok.NewChatModel(ctx, &grok.Config{
APIKey: apiKey,
Model: "grok-3",
MaxTokens: ptrOf(2000),
})
if err != nil {
log.Fatalf("NewChatModel of grok failed, err=%v", err)
}

fmt.Println("\n=== Basic Chat ===")
basicChat(ctx, cm)

fmt.Println("\n=== Streaming Chat ===")
streamingChat(ctx, cm)

fmt.Println("\n=== Function Calling ===")
functionCalling(ctx, cm)

fmt.Println("\n=== Advanced Options ===")
advancedOptions(ctx, cm)
}

func basicChat(ctx context.Context, cm model.ChatModel) {
messages := []*schema.Message{
{
Role: schema.System,
Content: "You are a helpful AI assistant. Be concise in your responses.",
},
{
Role: schema.User,
Content: "What is the capital of France?",
},
}

resp, err := cm.Generate(ctx, messages)
if err != nil {
log.Printf("Generate error: %v", err)
return
}

fmt.Printf("Assistant: %s\n", resp.Content)
if resp.ResponseMeta != nil && resp.ResponseMeta.Usage != nil {
fmt.Printf("Tokens used: %d (prompt) + %d (completion) = %d (total)\n",
resp.ResponseMeta.Usage.PromptTokens,
resp.ResponseMeta.Usage.CompletionTokens,
resp.ResponseMeta.Usage.TotalTokens)
}
}

func streamingChat(ctx context.Context, cm model.ChatModel) {
messages := []*schema.Message{
{
Role: schema.User,
Content: "Write a short poem about spring, word by word.",
},
}

stream, err := cm.Stream(ctx, messages)
if err != nil {
log.Printf("Stream error: %v", err)
return
}

fmt.Print("Assistant: ")
for {
resp, err := stream.Recv()
if err == io.EOF {
// 正常结束,不需要报错
break
}
if err != nil {
log.Printf("Stream receive error: %v", err)
return
}
fmt.Print(resp.Content)
}
fmt.Println()
}

func functionCalling(ctx context.Context, cm model.ChatModel) {
// Bind tools to the model
err := cm.BindTools([]*schema.ToolInfo{
{
Name: "get_weather",
Desc: "Get current weather information for a city",
ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3(&openapi3.Schema{
Type: "object",
Properties: map[string]*openapi3.SchemaRef{
"city": {
Value: &openapi3.Schema{
Type: "string",
Description: "The city name",
},
},
"unit": {
Value: &openapi3.Schema{
Type: "string",
Enum: []interface{}{"celsius", "fahrenheit"},
},
},
},
Required: []string{"city"},
}),
},
})
if err != nil {
log.Printf("Bind tools error: %v", err)
return
}

// Stream the response with a function call
streamResp, err := cm.Stream(ctx, []*schema.Message{
{
Role: schema.User,
Content: "What's the weather like in Paris today? Please use Celsius.",
},
})
if err != nil {
log.Printf("Generate error: %v", err)
return
}

msgs := make([]*schema.Message, 0)
for {
msg, err := streamResp.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Printf("Stream receive error: %v", err)
return
}
msgs = append(msgs, msg)
}
resp, err := schema.ConcatMessages(msgs)
if err != nil {
log.Printf("Concat error: %v", err)
return
}

if len(resp.ToolCalls) > 0 {
fmt.Printf("Function called: %s\n", resp.ToolCalls[0].Function.Name)
fmt.Printf("Arguments: %s\n", resp.ToolCalls[0].Function.Arguments)

// Handle the function call with a mock response
weatherResp, err := cm.Generate(ctx, []*schema.Message{
{
Role: schema.User,
Content: "What's the weather like in Paris today? Please use Celsius.",
},
resp,
{
Role: schema.Tool,
ToolCallID: resp.ToolCalls[0].ID,
Content: `{"temperature": 18, "condition": "sunny"}`,
},
})
if err != nil {
log.Printf("Generate error: %v", err)
return
}
fmt.Printf("Final response: %s\n", weatherResp.Content)
} else {
fmt.Printf("No function was called. Response: %s\n", resp.Content)
}
}

// Advanced example showing TopK parameter usage
func advancedOptions(ctx context.Context, cm model.ChatModel) {
messages := []*schema.Message{
{
Role: schema.User,
Content: "Generate 5 creative business ideas.",
},
}

// Using TopK parameter to control diversity of tokens
resp, err := cm.Generate(ctx, messages, grok.WithTopK(50))
if err != nil {
log.Printf("Generate error: %v", err)
return
}

fmt.Printf("Assistant (with TopK=50): %s\n", resp.Content)
}
40 changes: 40 additions & 0 deletions components/model/grok/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module github.com/cloudwego/eino-ext/components/model/grok

go 1.24.1

require (
github.com/SimonMorphy/grok-go v1.0.0
github.com/cloudwego/eino v0.3.27
github.com/getkin/kin-openapi v0.118.0
)

require (
github.com/bytedance/sonic v1.13.2 // indirect
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/swag v0.19.5 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/invopop/yaml v0.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/nikolalohinski/gonja v1.5.3 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/perimeterx/marshmallow v1.1.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
golang.org/x/arch v0.11.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/sys v0.26.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading