diff --git a/README.md b/README.md index 3112fd18..aa4c61e8 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,12 @@ compiledGraph.Invoke(ctx, input, WithCallbacks(handler).DesignateNode("node_1")) - Developers can easily create custom callback handlers, add them during graph run via options, and they will be invoked during graph run. - Graph can also inject aspects to those component implementations that do not support callbacks on their own. +## Swarm Agent Feature + +- The Swarm Agent feature allows for the addition of swarm agents to the multi-agent system. +- Swarm agents can be configured and added to the graph, enabling more complex and distributed agent interactions. +- Swarm agents can generate and stream responses, similar to other agents in the system. + # Eino Framework Structure ![](.github/static/img/eino/eino_framework.jpeg) @@ -215,4 +221,4 @@ Please do **not** create a public GitHub issue. ## License -This project is licensed under the [Apache-2.0 License](LICENSE.txt). \ No newline at end of file +This project is licensed under the [Apache-2.0 License](LICENSE.txt). diff --git a/flow/agent/multiagent/host/compose.go b/flow/agent/multiagent/host/compose.go index 83f61087..a2ffe13d 100644 --- a/flow/agent/multiagent/host/compose.go +++ b/flow/agent/multiagent/host/compose.go @@ -88,6 +88,27 @@ func NewMultiAgent(ctx context.Context, config *MultiAgentConfig) (*MultiAgent, agentMap[specialist.Name] = true } + for i := range config.SwarmAgents { + swarmAgent := config.SwarmAgents[i] + + agentTools = append(agentTools, &schema.ToolInfo{ + Name: swarmAgent.Name, + Desc: swarmAgent.IntendedUse, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "reason": { + Type: schema.String, + Desc: "the reason to call this tool", + }, + }), + }) + + if err := addSwarmAgent(swarmAgent, g); err != nil { + return nil, err + } + + agentMap[swarmAgent.Name] = true + } + if err := addHostAgent(config.Host.ChatModel, hostPrompt, agentTools, g); err != nil { return nil, err } @@ -149,6 +170,17 @@ func addSpecialistAgent(specialist *Specialist, g *compose.Graph[[]*schema.Messa return g.AddEdge(specialist.Name, compose.END) } +func addSwarmAgent(swarmAgent *SwarmAgent, g *compose.Graph[[]*schema.Message, *schema.Message]) error { + preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { + return state.msgs, nil // replace the tool call message with input msgs stored in state + } + if err := g.AddLambdaNode(swarmAgent.Name, swarmAgent, compose.WithStatePreHandler(preHandler), compose.WithNodeName(swarmAgent.Name)); err != nil { + return err + } + + return g.AddEdge(swarmAgent.Name, compose.END) +} + func addHostAgent(model model.ChatModel, prompt string, agentTools []*schema.ToolInfo, g *compose.Graph[[]*schema.Message, *schema.Message]) error { if err := model.BindTools(agentTools); err != nil { return err @@ -168,7 +200,29 @@ func addHostAgent(model model.ChatModel, prompt string, agentTools []*schema.Too return err } - return g.AddEdge(compose.START, defaultHostNodeKey) + if err := g.AddEdge(compose.START, defaultHostNodeKey); err != nil { + return err + } + + // Add branch to handle swarm agents + branch := compose.NewGraphBranch(func(ctx context.Context, input []*schema.Message) (string, error) { + if len(input) != 1 { + return "", fmt.Errorf("host agent output %d messages, but expected 1", len(input)) + } + + if len(input[0].ToolCalls) != 1 { + return "", fmt.Errorf("host agent output %d tool calls, but expected 1", len(input[0].ToolCalls)) + } + + toolName := input[0].ToolCalls[0].Function.Name + if _, ok := agentMap[toolName]; ok { + return toolName, nil + } + + return "", fmt.Errorf("unknown tool name: %s", toolName) + }, agentMap) + + return g.AddBranch(defaultHostNodeKey, branch) } func addDirectAnswerBranch(convertorName string, g *compose.Graph[[]*schema.Message, *schema.Message], diff --git a/flow/agent/multiagent/host/compose_test.go b/flow/agent/multiagent/host/compose_test.go index 6ec577fa..2d7942c7 100644 --- a/flow/agent/multiagent/host/compose_test.go +++ b/flow/agent/multiagent/host/compose_test.go @@ -70,6 +70,30 @@ func TestHostMultiAgent(t *testing.T) { }, } + swarmAgent := &SwarmAgent{ + AgentMeta: AgentMeta{ + Name: "swarm agent", + IntendedUse: "do swarm stuff", + }, + Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { + return &schema.Message{ + Role: schema.Assistant, + Content: "swarm agent invoke answer", + }, nil + }, + Streamable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](0) + go func() { + sw.Send(&schema.Message{ + Role: schema.Assistant, + Content: "swarm agent stream answer", + }, nil) + sw.Close() + }() + return sr, nil + }, + } + ctx := context.Background() mockHostLLM.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() @@ -82,6 +106,9 @@ func TestHostMultiAgent(t *testing.T) { specialist1, specialist2, }, + SwarmAgents: []*SwarmAgent{ + swarmAgent, + }, }) assert.NoError(t, err) @@ -289,6 +316,9 @@ func TestHostMultiAgent(t *testing.T) { Specialists: []*Specialist{ specialist1, specialist2, + }, + SwarmAgents: []*SwarmAgent{ + swarmAgent, }, StreamToolCallChecker: streamToolCallChecker, }) @@ -397,6 +427,9 @@ func TestHostMultiAgent(t *testing.T) { specialist1, specialist2, }, + SwarmAgents: []*SwarmAgent{ + swarmAgent, + }, }) assert.NoError(t, err) diff --git a/flow/agent/multiagent/host/types.go b/flow/agent/multiagent/host/types.go index c8584a9d..20d63139 100644 --- a/flow/agent/multiagent/host/types.go +++ b/flow/agent/multiagent/host/types.go @@ -72,6 +72,7 @@ func (ma *MultiAgent) HostNodeKey() string { type MultiAgentConfig struct { Host Host Specialists []*Specialist + SwarmAgents []*SwarmAgent Name string // the name of the host multi-agent @@ -100,8 +101,8 @@ func (conf *MultiAgentConfig) validate() error { return errors.New("host multi agent host ChatModel is nil") } - if len(conf.Specialists) == 0 { - return errors.New("host multi agent specialists are empty") + if len(conf.Specialists) == 0 && len(conf.SwarmAgents) == 0 { + return errors.New("host multi agent specialists and swarm agents are empty") } for _, s := range conf.Specialists { @@ -114,6 +115,12 @@ func (conf *MultiAgentConfig) validate() error { } } + for _, s := range conf.SwarmAgents { + if err := s.AgentMeta.validate(); err != nil { + return err + } + } + return nil } @@ -158,6 +165,28 @@ type Specialist struct { Streamable compose.Stream[[]*schema.Message, *schema.Message, agent.AgentOption] } +// SwarmAgent is a swarm agent within a host multi-agent system. +type SwarmAgent struct { + AgentMeta + + Invokable compose.Invoke[[]*schema.Message, *schema.Message, agent.AgentOption] + Streamable compose.Stream[[]*schema.Message, *schema.Message, agent.AgentOption] +} + +func (sa *SwarmAgent) Invoke(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { + if sa.Invokable != nil { + return sa.Invokable(ctx, input, opts...) + } + return nil, errors.New("swarm agent does not support invocation") +} + +func (sa *SwarmAgent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { + if sa.Streamable != nil { + return sa.Streamable(ctx, input, opts...) + } + return nil, errors.New("swarm agent does not support streaming") +} + func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[*schema.Message]) (bool, error) { defer sr.Close() diff --git a/litellm/litellm_model.go b/litellm/litellm_model.go new file mode 100644 index 00000000..8974fc72 --- /dev/null +++ b/litellm/litellm_model.go @@ -0,0 +1,29 @@ +package litellm + +import ( + "github.com/cloudwego/eino-ext/libs/acl/openai" +) + +// LitellmModel represents a model that interacts with OpenAI-compatible APIs. +type LitellmModel struct { + client *openai.Client +} + +// NewLitellmModel creates a new instance of LitellmModel. +func NewLitellmModel(apiKey string) (*LitellmModel, error) { + client, err := openai.NewClient(apiKey) + if err != nil { + return nil, err + } + return &LitellmModel{client: client}, nil +} + +// GenerateText generates text based on the given prompt. +func (m *LitellmModel) GenerateText(prompt string) (string, error) { + response, err := m.client.Completions.Create(prompt) + if err != nil { + return "", err + } + return response.Choices[0].Text, nil +} +