Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 75 additions & 34 deletions components/model/claude/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ import (
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/packages/param"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/anthropics/anthropic-sdk-go/vertex"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"

"github.com/cloudwego/eino/components"
"golang.org/x/oauth2/google"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
Expand All @@ -60,40 +61,51 @@ var _ model.ToolCallingChatModel = (*ChatModel)(nil)
// })
func NewChatModel(ctx context.Context, config *Config) (*ChatModel, error) {
var cli anthropic.Client
if !config.ByBedrock {
var opts []option.RequestOption

opts = append(opts, option.WithAPIKey(config.APIKey))

if config.BaseURL != nil {
opts = append(opts, option.WithBaseURL(*config.BaseURL))
}

if config.HTTPClient != nil {
opts = append(opts, option.WithHTTPClient(config.HTTPClient))
}

cli = anthropic.NewClient(opts...)
} else {
var opts []func(*awsConfig.LoadOptions) error
if config.ByBedrock {
var opts []func(*awsconfig.LoadOptions) error
if config.Region != "" {
opts = append(opts, awsConfig.WithRegion(config.Region))
opts = append(opts, awsconfig.WithRegion(config.Region))
}
if config.SecretAccessKey != "" && config.AccessKey != "" {
opts = append(opts, awsConfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
opts = append(opts, awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
config.AccessKey,
config.SecretAccessKey,
config.SessionToken,
)))
} else if config.Profile != "" {
opts = append(opts, awsConfig.WithSharedConfigProfile(config.Profile))
opts = append(opts, awsconfig.WithSharedConfigProfile(config.Profile))
}

if config.HTTPClient != nil {
opts = append(opts, awsConfig.WithHTTPClient(config.HTTPClient))
opts = append(opts, awsconfig.WithHTTPClient(config.HTTPClient))
}
cli = anthropic.NewClient(bedrock.WithLoadDefaultConfig(ctx, opts...))
} else if config.ByVertex {
if config.GoogleCredentials != nil {
cli = anthropic.NewClient(
vertex.WithCredentials(ctx, config.Region, config.ProjectID, config.GoogleCredentials),
)
} else {
cli = anthropic.NewClient(
vertex.WithGoogleAuth(ctx, config.Region, config.ProjectID, config.Scopes...),
)
}
} else {
var opts []option.RequestOption

opts = append(opts, option.WithAPIKey(config.APIKey))

if config.BaseURL != nil {
opts = append(opts, option.WithBaseURL(*config.BaseURL))
}

if config.HTTPClient != nil {
opts = append(opts, option.WithHTTPClient(config.HTTPClient))
}

cli = anthropic.NewClient(opts...)
}

return &ChatModel{
cli: cli,
maxTokens: config.MaxTokens,
Expand All @@ -110,6 +122,8 @@ func NewChatModel(ctx context.Context, config *Config) (*ChatModel, error) {
// Config contains the configuration options for the Claude model
type Config struct {
// ByBedrock indicates whether to use Bedrock Service
// If both [Config.ByBedrock] and [Config.ByVertex] are set to true,
// the [Config.ByBedrock] configuration will take precedence.
// Required for Bedrock
ByBedrock bool

Expand All @@ -134,9 +148,30 @@ type Config struct {
// Optional for Bedrock
Profile string

// Region is your Bedrock API region
// Obtain from: https://docs.aws.amazon.com/bedrock/latest/userguide/getting-started.html
// Optional for Bedrock
// ByVertex indicates whether to use Google Cloud Vertex AI
// If both [Config.ByBedrock] and [Config.ByVertex] are set to true,
// the [Config.ByBedrock] configuration will take precedence.
// Required for Google Vertex
ByVertex bool

// ProjectID is your Google Cloud project ID
// Obtain from: https://cloud.google.com/resource-manager/docs/creating-managing-projects
// Required for Google Vertex
ProjectID string

// Scopes is your list of Google Cloud OAuth scopes
// Obtain from: https://developers.google.com/identity/protocols/oauth2/scopes
// Required for Google Vertex
Scopes []string

// GoogleCredentials is your Google Cloud credentials.
// If set, these credentials will be used for authentication.
// If not set, the default application credentials will be used:
// https://cloud.google.com/docs/authentication/application-default-credentials
GoogleCredentials *google.Credentials

// Region is your Bedrock API region if using Bedrock, or your Google Cloud region if using Vertex AI
// Optional for Bedrock, but required for Google Vertex
Region string

// BaseURL is the custom API endpoint URL
Expand All @@ -149,8 +184,12 @@ type Config struct {
// Required
APIKey string

// Model specifies which Claude model to use
// Required
// Model specifies which Claude model to use.
//
// Note that the model names are different for Bedrock and Vertex.
// See https://docs.anthropic.com/en/api/claude-on-amazon-bedrock and https://docs.anthropic.com/en/api/claude-on-vertex-ai.
//
// Required.
Model string

// MaxTokens limits the maximum number of tokens in the response
Expand Down Expand Up @@ -308,7 +347,6 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts .
_ = sw.Send(nil, stream.Err())
return
}

}()
_, sr = callbacks.OnEndWithStreamOutput(ctx, sr)
return schema.StreamReaderWithConvert(sr, func(t *model.CallbackOutput) (*schema.Message, error) {
Expand Down Expand Up @@ -390,7 +428,8 @@ func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolUnionParam,
Name: tool.Name,
Description: param.NewOpt(tool.Desc),
InputSchema: inputSchema,
}})
},
})
}

return result, nil
Expand Down Expand Up @@ -420,7 +459,8 @@ func preProcessMessages(input []*schema.Message) ([]*schema.Message, []*schema.M
}

func (cm *ChatModel) genMessageNewParams(input []*schema.Message, opts ...model.Option) (
anthropic.MessageNewParams, error) {
anthropic.MessageNewParams, error,
) {
if len(input) == 0 {
return anthropic.MessageNewParams{}, fmt.Errorf("input is empty")
}
Expand All @@ -442,7 +482,8 @@ func (cm *ChatModel) genMessageNewParams(input []*schema.Message, opts ...model.
claudeOptions := model.GetImplSpecificOptions(&options{
TopK: cm.topK,
Thinking: cm.thinking,
DisableParallelToolUse: cm.disableParallelToolUse}, opts...)
DisableParallelToolUse: cm.disableParallelToolUse,
}, opts...)

params := anthropic.MessageNewParams{}
if commonOptions.Model != nil {
Expand Down Expand Up @@ -593,7 +634,6 @@ func (cm *ChatModel) IsCallbacksEnabled() bool {
}

func convSchemaMessage(message *schema.Message) (mp anthropic.MessageParam, err error) {

var messageParams []anthropic.ContentBlockParamUnion
if len(message.Content) > 0 {
if len(message.ToolCallID) > 0 {
Expand Down Expand Up @@ -673,7 +713,8 @@ type streamContext struct {
}

func convContentBlockToEinoMsg(
contentBlock any, dstMsg *schema.Message, streamCtx *streamContext) error {
contentBlock any, dstMsg *schema.Message, streamCtx *streamContext,
) error {
// case anthropic.TextBlock:
// case anthropic.ToolUseBlock:
// case anthropic.ServerToolUseBlock:
Expand Down
10 changes: 9 additions & 1 deletion components/model/claude/examples/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,19 @@ func main() {

// 创建 Claude 模型
cm, err := claude.NewChatModel(ctx, &claude.Config{
// if you want to use Aws Bedrock Service, set these four field.
// if you want to use Aws Bedrock Service, set these four fields.
// ByBedrock: true,
// AccessKey: "",
// SecretAccessKey: "",
// Region: "us-west-2",

// if you want to use Google Cloud Vertex AI, set these fields.
// ByVertex: true,
// ProjectID: "your-project-id",
// Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"},
// Region: "global",
// GoogleCredentials: &google.Credentials{}, // Optional

APIKey: apiKey,
// Model: "claude-3-5-sonnet-20240620",
BaseURL: baseURLPtr,
Expand Down
25 changes: 25 additions & 0 deletions components/model/claude/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ require (
github.com/eino-contrib/jsonschema v1.0.0
github.com/stretchr/testify v1.9.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/oauth2 v0.21.0
)

require (
cloud.google.com/go/auth v0.7.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.3 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.33.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 // indirect
Expand All @@ -33,9 +37,16 @@ require (
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/getkin/kin-openapi v0.118.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/swag v0.19.5 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/invopop/yaml v0.1.0 // indirect
Expand All @@ -62,10 +73,24 @@ require (
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/arch v0.11.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/api v0.189.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade // indirect
google.golang.org/grpc v1.64.1 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading