Skip to content

Commit 16eca51

Browse files
committed
Merge branch 'release/v1.5.8'
2 parents 90f0a89 + 56b9e09 commit 16eca51

File tree

11 files changed

+331
-170
lines changed

11 files changed

+331
-170
lines changed

agents/agent.go

Lines changed: 189 additions & 64 deletions
Large diffs are not rendered by default.

agents/option.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,24 @@ func WithModel(model string) Option {
3333
}
3434
}
3535

36-
func WithTemperature(temperature float32) Option {
36+
func WithTemperature(temperature float64) Option {
3737
return func(c *Config) {
3838
c.temperature = temperature
3939
}
4040
}
4141

42+
func WithTopP(topP float64) Option {
43+
return func(c *Config) {
44+
c.topP = topP
45+
}
46+
}
47+
48+
func WithTopK(topK int) Option {
49+
return func(c *Config) {
50+
c.topK = topK
51+
}
52+
}
53+
4254
func WithMaxTokens(maxTokens int) Option {
4355
return func(c *Config) {
4456
c.maxTokens = maxTokens

agents/tool_agent.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,21 @@ func (t *ToolAgent[I, T, O]) SetModel(model string) {
6161
t.end.model = model
6262
}
6363

64-
func (t *ToolAgent[I, T, O]) SetTemperature(temperature float32) {
64+
func (t *ToolAgent[I, T, O]) SetTemperature(temperature float64) {
6565
t.start.temperature = temperature
6666
t.end.temperature = temperature
6767
}
6868

69+
func (t *ToolAgent[I, T, O]) SetTopP(topP float64) {
70+
t.start.topP = topP
71+
t.end.topP = topP
72+
}
73+
74+
func (t *ToolAgent[I, T, O]) SetTopK(topK int) {
75+
t.start.topK = topK
76+
t.end.topK = topK
77+
}
78+
6979
func (t *ToolAgent[I, T, O]) SetMaxTokens(maxTokens int) {
7080
t.start.maxTokens = maxTokens
7181
t.end.maxTokens = maxTokens

components/embedder/providers/cohere/embeder.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ func (p *Embedder) Embed(ctx context.Context, text string, embedding *embedder.E
4747
respV := resp.GetEmbeddingsFloats()
4848
if usage != nil && respV.Meta != nil && respV.Meta.Tokens != nil {
4949
if v := respV.Meta.Tokens.InputTokens; v != nil {
50-
usage.InputTokens = int(*v)
50+
usage.InputTokens = int64(*v)
5151
}
5252
if v := respV.Meta.Tokens.OutputTokens; v != nil {
53-
usage.OutputTokens = int(*v)
53+
usage.OutputTokens = int64(*v)
5454
}
5555
}
5656
if len(respV.Embeddings) == 0 {
@@ -77,10 +77,10 @@ func (p *Embedder) BatchEmbed(ctx context.Context, parts []string, usage *compon
7777
respV := resp.GetEmbeddingsFloats()
7878
if usage != nil && respV.Meta != nil && respV.Meta.Tokens != nil {
7979
if v := respV.Meta.Tokens.InputTokens; v != nil {
80-
usage.InputTokens = int(*v)
80+
usage.InputTokens = int64(*v)
8181
}
8282
if v := respV.Meta.Tokens.OutputTokens; v != nil {
83-
usage.OutputTokens = int(*v)
83+
usage.OutputTokens = int64(*v)
8484
}
8585
}
8686
ret := make([]embedder.Embedding, 0, len(respV.Embeddings))

components/embedder/providers/openai/embedder.go

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package openai
22

33
import (
44
"context"
5+
"errors"
56

6-
openai "github.com/sashabaranov/go-openai"
7+
"github.com/openai/openai-go"
78

89
"github.com/bububa/atomic-agents/components"
910
"github.com/bububa/atomic-agents/components/embedder"
@@ -34,22 +35,24 @@ func New(client *openai.Client, opts ...embedder.Option) *Embedder {
3435

3536
func (p *Embedder) Embed(ctx context.Context, text string, embedding *embedder.Embedding, usage *components.LLMUsage) error {
3637
// Create an EmbeddingRequest for the user query
37-
req := openai.EmbeddingRequest{
38-
Input: []string{text},
38+
req := openai.EmbeddingNewParams{
39+
Input: openai.EmbeddingNewParamsInputUnion{
40+
OfString: openai.String(text),
41+
},
3942
Model: openai.EmbeddingModel(p.Model()),
4043
}
41-
resp, err := p.CreateEmbeddings(ctx, &req)
44+
resp, err := p.Embeddings.New(ctx, req)
4245
if err != nil {
4346
return err
4447
}
4548
if usage != nil {
46-
usage.InputTokens = int(resp.Usage.TotalTokens)
49+
usage.InputTokens = resp.Usage.TotalTokens
4750
}
4851
if len(resp.Data) == 0 {
4952
return nil
5053
}
5154
ret := resp.Data[0]
52-
embedding.Object = ret.Object
55+
embedding.Object = text
5356
embedding.Embedding = make([]float64, 0, len(ret.Embedding))
5457
for _, v := range ret.Embedding {
5558
embedding.Embedding = append(embedding.Embedding, float64(v))
@@ -60,16 +63,18 @@ func (p *Embedder) Embed(ctx context.Context, text string, embedding *embedder.E
6063

6164
func (p *Embedder) BatchEmbed(ctx context.Context, parts []string, usage *components.LLMUsage) ([]embedder.Embedding, error) {
6265
// Create an EmbeddingRequest for the user query
63-
req := openai.EmbeddingRequest{
64-
Input: parts,
66+
req := openai.EmbeddingNewParams{
67+
Input: openai.EmbeddingNewParamsInputUnion{
68+
OfArrayOfStrings: parts,
69+
},
6570
Model: openai.EmbeddingModel(p.Model()),
6671
}
67-
resp, err := p.CreateEmbeddings(ctx, &req)
72+
resp, err := p.Embeddings.New(ctx, req)
6873
if err != nil {
6974
return nil, err
7075
}
7176
if usage != nil {
72-
usage.InputTokens = int(resp.Usage.TotalTokens)
77+
usage.InputTokens = resp.Usage.TotalTokens
7378
}
7479
ret := make([]embedder.Embedding, 0, len(resp.Data))
7580
for _, v := range resp.Data {
@@ -78,20 +83,17 @@ func (p *Embedder) BatchEmbed(ctx context.Context, parts []string, usage *compon
7883
embeddings = append(embeddings, float64(e))
7984
}
8085
ret = append(ret, embedder.Embedding{
81-
Object: v.Object,
86+
Object: parts[int(v.Index)],
8287
Embedding: embeddings,
83-
Index: v.Index,
88+
Index: int(v.Index),
8489
})
8590
}
8691
return ret, nil
8792
}
8893

8994
func convertToOpenAI(src *embedder.Embedding, dist *openai.Embedding) {
90-
embeddings := make([]float32, 0, len(src.Embedding))
91-
for _, e := range src.Embedding {
92-
embeddings = append(embeddings, float32(e))
93-
}
94-
dist.Embedding = embeddings
95+
dist.Embedding = make([]float64, len(src.Embedding))
96+
copy(dist.Embedding, src.Embedding)
9597
}
9698

9799
// DotProduct calculates the dot product of the embedding vector with another
@@ -103,9 +105,12 @@ func (p *Embedder) DotProduct(ctx context.Context, target, query *embedder.Embed
103105
convertToOpenAI(target, t)
104106
q := new(openai.Embedding)
105107
convertToOpenAI(query, q)
106-
ret, err := t.DotProduct(q)
107-
if err != nil {
108-
return 0, err
108+
if len(t.Embedding) != len(q.Embedding) {
109+
return 0, errors.New("vector length mismatch")
110+
}
111+
var dotProduct float64
112+
for i := range t.Embedding {
113+
dotProduct += t.Embedding[i] * q.Embedding[i]
109114
}
110-
return float64(ret), nil
115+
return dotProduct, nil
111116
}

components/embedder/providers/voyageai/embedder.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (p *Embedder) Embed(ctx context.Context, text string, embedding *embedder.E
4141
return err
4242
}
4343
if usage != nil {
44-
usage.InputTokens = int(resp.Usage.TotalTokens)
44+
usage.InputTokens = int64(resp.Usage.TotalTokens)
4545
}
4646
if len(resp.Data) == 0 {
4747
return nil
@@ -64,7 +64,7 @@ func (p *Embedder) BatchEmbed(ctx context.Context, parts []string, usage *compon
6464
return nil, err
6565
}
6666
if usage != nil {
67-
usage.InputTokens = int(resp.Usage.TotalTokens)
67+
usage.InputTokens = int64(resp.Usage.TotalTokens)
6868
}
6969
ret := make([]embedder.Embedding, 0, len(resp.Data))
7070
for _, v := range resp.Data {

components/message.go

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ import (
1616
cohere "github.com/cohere-ai/cohere-go/v2"
1717
"github.com/gabriel-vasile/mimetype"
1818
anthropic "github.com/liushuangls/go-anthropic/v2"
19+
"github.com/openai/openai-go"
20+
"github.com/openai/openai-go/shared/constant"
1921
"github.com/rs/xid"
20-
openai "github.com/sashabaranov/go-openai"
2122
gemini "google.golang.org/genai"
2223

2324
"github.com/bububa/atomic-agents/schema"
@@ -50,7 +51,7 @@ type LLMResponse struct {
5051
}
5152

5253
// FromOpenAI convnert response from openai
53-
func (r *LLMResponse) FromOpenAI(v *openai.ChatCompletionResponse) {
54+
func (r *LLMResponse) FromOpenAI(v *openai.ChatCompletion) {
5455
r.ID = v.ID
5556
r.Role = AssistantRole
5657
r.Model = v.Model
@@ -67,8 +68,8 @@ func (r *LLMResponse) FromAnthropic(v *anthropic.MessagesResponse) {
6768
r.Role = AssistantRole
6869
r.Model = string(v.Model)
6970
r.Usage = &LLMUsage{
70-
InputTokens: v.Usage.InputTokens,
71-
OutputTokens: v.Usage.OutputTokens,
71+
InputTokens: int64(v.Usage.InputTokens),
72+
OutputTokens: int64(v.Usage.OutputTokens),
7273
}
7374
r.Details = v.Content
7475
}
@@ -83,10 +84,10 @@ func (r *LLMResponse) FromCohere(v *cohere.NonStreamedChatResponse) {
8384
if usage := meta.Tokens; usage != nil {
8485
r.Usage = new(LLMUsage)
8586
if usage.InputTokens != nil {
86-
r.Usage.InputTokens = int(*usage.InputTokens)
87+
r.Usage.InputTokens = int64(*usage.InputTokens)
8788
}
8889
if usage.OutputTokens != nil {
89-
r.Usage.OutputTokens = int(*usage.OutputTokens)
90+
r.Usage.OutputTokens = int64(*usage.OutputTokens)
9091
}
9192
}
9293
if version := meta.ApiVersion; version != nil {
@@ -100,15 +101,15 @@ func (r *LLMResponse) FromGemini(v *gemini.GenerateContentResponse) {
100101
r.Role = AssistantRole
101102
if v.UsageMetadata != nil && (v.UsageMetadata.PromptTokenCount > 0 || v.UsageMetadata.CandidatesTokenCount > 0) {
102103
r.Usage = new(LLMUsage)
103-
r.Usage.InputTokens = int(v.UsageMetadata.PromptTokenCount)
104-
r.Usage.OutputTokens = int(v.UsageMetadata.CachedContentTokenCount)
104+
r.Usage.InputTokens = int64(v.UsageMetadata.PromptTokenCount)
105+
r.Usage.OutputTokens = int64(v.UsageMetadata.CachedContentTokenCount)
105106
}
106107
r.Details = v.Candidates
107108
}
108109

109110
type LLMUsage struct {
110-
InputTokens int `json:"input_tokens,omitempty"`
111-
OutputTokens int `json:"output_tokens,omitempty"`
111+
InputTokens int64 `json:"input_tokens,omitempty"`
112+
OutputTokens int64 `json:"output_tokens,omitempty"`
112113
}
113114

114115
func (u *LLMUsage) Merge(v *LLMUsage) {
@@ -256,12 +257,12 @@ func (m Message) TryAttachChunkPrompt(idx int) string {
256257
}
257258

258259
// ToOpenAI convert message to openai ChatCompletionMessage
259-
func (m Message) ToOpenAI(dist *openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
260+
func (m Message) ToOpenAI(dist *openai.ChatCompletionMessageParamUnion) []openai.ChatCompletionMessageParamUnion {
260261
m.toOpenAI(dist, 0)
261262
if l := len(m.Chunks()); l > 0 {
262-
list := make([]openai.ChatCompletionMessage, 0, l)
263+
list := make([]openai.ChatCompletionMessageParamUnion, 0, l)
263264
for idx := range l {
264-
var llmMsg openai.ChatCompletionMessage
265+
var llmMsg openai.ChatCompletionMessageParamUnion
265266
if err := m.toOpenAI(&llmMsg, idx+1); err == nil {
266267
list = append(list, llmMsg)
267268
}
@@ -271,7 +272,7 @@ func (m Message) ToOpenAI(dist *openai.ChatCompletionMessage) []openai.ChatCompl
271272
return nil
272273
}
273274

274-
func (m Message) toOpenAI(dist *openai.ChatCompletionMessage, idx int) error {
275+
func (m Message) toOpenAI(dist *openai.ChatCompletionMessageParamUnion, idx int) error {
275276
src := m
276277
chunks := m.Chunks()
277278
if idx > 0 {
@@ -281,35 +282,40 @@ func (m Message) toOpenAI(dist *openai.ChatCompletionMessage, idx int) error {
281282
return errors.New("invalid chunk index")
282283
}
283284
}
284-
dist.Role = m.role
285285
txt := m.TryAttachChunkPrompt(idx)
286-
if attachement := src.Attachement(); attachement != nil && (len(attachement.ImageURLs) > 0 || len(attachement.VideoURLs) > 0) {
287-
dist.MultiContent = make([]openai.ChatMessagePart, 0, len(attachement.ImageURLs)+len(attachement.VideoURLs)+1)
288-
dist.MultiContent = append(dist.MultiContent, openai.ChatMessagePart{
289-
Type: openai.ChatMessagePartTypeText,
290-
Text: txt,
291-
})
286+
if attachement := src.Attachement(); m.role == UserRole && attachement != nil && (len(attachement.ImageURLs) > 0 || len(attachement.VideoURLs) > 0) {
287+
contents := make([]openai.ChatCompletionContentPartUnionParam, 0, len(attachement.ImageURLs)+len(attachement.VideoURLs)+1)
288+
contents = append(contents, openai.TextContentPart(txt))
292289
for _, imageURL := range attachement.ImageURLs {
293-
dist.MultiContent = append(dist.MultiContent, openai.ChatMessagePart{
294-
Type: openai.ChatMessagePartTypeImageURL,
295-
ImageURL: &openai.ChatMessageImageURL{
296-
URL: imageURL,
297-
},
298-
})
290+
contents = append(contents, openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{
291+
URL: imageURL,
292+
}))
299293
}
300294
for _, videoURL := range attachement.VideoURLs {
301-
dist.MultiContent = append(dist.MultiContent, openai.ChatMessagePart{
302-
Type: "video_url",
303-
VideoURL: &openai.ChatMessageVideoURL{
304-
URL: videoURL,
305-
},
306-
Video: &openai.ChatMessageVideo{
307-
URL: videoURL,
295+
videoParam := &openai.ChatCompletionContentPartImageParam{
296+
Type: constant.ImageURL("video_url"),
297+
}
298+
videoParam.SetExtraFields(map[string]any{
299+
"video_url": openai.ImageURL{
300+
URL: videoURL,
301+
Detail: openai.ImageURLDetailAuto,
308302
},
309303
})
304+
part := openai.ChatCompletionContentPartUnionParam{
305+
OfImageURL: videoParam,
306+
}
307+
contents = append(contents, part)
310308
}
311-
} else {
312-
dist.Content = txt
309+
*dist = openai.UserMessage(contents)
310+
return nil
311+
}
312+
switch m.role {
313+
case SystemRole:
314+
*dist = openai.SystemMessage(txt)
315+
case AssistantRole:
316+
*dist = openai.AssistantMessage(txt)
317+
case UserRole:
318+
*dist = openai.UserMessage(txt)
313319
}
314320
return nil
315321
}

examples/common_utils.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import (
88
cohereClient "github.com/cohere-ai/cohere-go/v2/client"
99
cohereOption "github.com/cohere-ai/cohere-go/v2/option"
1010
anthropic "github.com/liushuangls/go-anthropic/v2"
11-
openai "github.com/sashabaranov/go-openai"
11+
"github.com/openai/openai-go"
12+
"github.com/openai/openai-go/option"
1213
)
1314

1415
func NewInstructor(provider instructor.Provider, modes ...instructor.Mode) instructor.Instructor {
@@ -39,11 +40,13 @@ func NewInstructor(provider instructor.Provider, modes ...instructor.Mode) instr
3940
default:
4041
authToken := os.Getenv("OPENAI_API_KEY")
4142
baseURL := os.Getenv("OPENAI_BASE_URL")
42-
cfg := openai.DefaultConfig(authToken)
43+
opts := make([]option.RequestOption, 0, 2)
44+
45+
opts = append(opts, option.WithAPIKey(authToken))
4346
if baseURL != "" {
44-
cfg.BaseURL = baseURL
47+
opts = append(opts, option.WithBaseURL(baseURL))
4548
}
46-
clt := openai.NewClientWithConfig(cfg)
47-
return instructors.FromOpenAI(clt, instructor.WithMode(mode), instructor.WithMaxRetries(1), instructor.WithValidation())
49+
clt := openai.NewClient(opts...)
50+
return instructors.FromOpenAI(&clt, instructor.WithMode(mode), instructor.WithMaxRetries(1), instructor.WithValidation(), instructor.WithVerbose())
4851
}
4952
}

0 commit comments

Comments
 (0)