Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 87 additions & 72 deletions adk/prebuilt/planexecute/plan_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ type Plan interface {
json.Unmarshaler
}

// NewPlan is a function type that creates a new Plan instance.
type NewPlan func(ctx context.Context) Plan
// PlanFactory is a function type that creates a new Plan instance.
type PlanFactory func(ctx context.Context) Plan

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 Breaking API Changes Detected

Package: github.com/cloudwego/eino/adk/prebuilt/planexecute

Incompatible changes:

  • ExecutionContext: removed
  • GenModelInputFn: removed
  • GenPlannerModelInputFn: changed from func(context.Context, []*github.com/cloudwego/eino/schema.Message) ([]*github.com/cloudwego/eino/schema.Message, error) to func(context.Context, []*github.com/cloudwego/eino/schema.Message, *PlannerConfig) ([]*github.com/cloudwego/eino/schema.Message, error)
  • NewPlan: removed
  • PlannerConfig.NewPlan: removed
  • ReplannerConfig.GenInputFn: changed from GenModelInputFn to GenReplannerModelInputFn
  • ReplannerConfig.NewPlan: removed
Review Guidelines

Please ensure that:

  • The changes are absolutely necessary
  • They are properly documented
  • Migration guides are provided if needed

⚠️ Please resolve this thread after reviewing the breaking changes.


// defaultPlan is the default implementation of the Plan interface.
// DefaultPlan is the default implementation of the Plan interface.
//
// JSON Schema:
//
Expand All @@ -68,27 +68,27 @@ type NewPlan func(ctx context.Context) Plan
// },
// "required": ["steps"]
// }
type defaultPlan struct {
type DefaultPlan struct {
// Steps contains the ordered list of actions to be taken.
// Each step should be clear, actionable, and arranged in a logical sequence.
Steps []string `json:"steps"`
}

// FirstStep returns the first step in the plan or an empty string if no steps exist.
func (p *defaultPlan) FirstStep() string {
func (p *DefaultPlan) FirstStep() string {
if len(p.Steps) == 0 {
return ""
}
return p.Steps[0]
}

func (p *defaultPlan) MarshalJSON() ([]byte, error) {
type planTyp defaultPlan
func (p *DefaultPlan) MarshalJSON() ([]byte, error) {
type planTyp DefaultPlan
return sonic.Marshal((*planTyp)(p))
}

func (p *defaultPlan) UnmarshalJSON(bytes []byte) error {
type planTyp defaultPlan
func (p *DefaultPlan) UnmarshalJSON(bytes []byte) error {
type planTyp DefaultPlan
return sonic.Unmarshal(bytes, (*planTyp)(p))
}

Expand Down Expand Up @@ -265,24 +265,23 @@ type PlannerConfig struct {
// Optional. If not provided, PlanToolInfo will be used as the default.
ToolInfo *schema.ToolInfo

// GenInputFn is a function that generates the input messages for the planner.
// Optional. If not provided, defaultGenPlannerInputFn will be used.
// GenInputFn generates input messages for the planner.
// Optional. Defaults to using PlannerPrompt as the template to render model input messages.
GenInputFn GenPlannerModelInputFn

// NewPlan creates a new Plan instance for JSON.
// The returned Plan will be used to unmarshal the model-generated JSON output.
// Optional. If not provided, defaultNewPlan will be used.
NewPlan NewPlan
// Factory creates Plan instances for JSON unmarshaling.
// Optional. Defaults to creating DefaultPlan instances.
Factory PlanFactory
}

// GenPlannerModelInputFn is a function type that generates input messages for the planner.
type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error)
type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message, cfg *PlannerConfig) ([]adk.Message, error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥入参是 planner config 啊


func defaultNewPlan(ctx context.Context) Plan {
return &defaultPlan{}
func defaultPlanFactory(ctx context.Context) Plan {
return &DefaultPlan{}
}

func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message, _ *PlannerConfig) ([]adk.Message, error) {
msgs, err := PlannerPrompt.Format(ctx, map[string]any{
"input": userInput,
})
Expand All @@ -293,10 +292,11 @@ func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]a
}

type planner struct {
cfg *PlannerConfig
toolCall bool
chatModel model.BaseChatModel
genInputFn GenPlannerModelInputFn
newPlan NewPlan
factory PlanFactory
}

func (p *planner) Name(_ context.Context) string {
Expand Down Expand Up @@ -333,7 +333,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput,
generator.Close()
}()

msgs, err := p.genInputFn(ctx, input.Messages)
msgs, err := p.genInputFn(ctx, input.Messages, p.cfg)
if err != nil {
generator.Send(&adk.AgentEvent{Err: err})
return
Expand Down Expand Up @@ -401,7 +401,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput,
} else {
planJSON = msg.Content
}
plan := p.newPlan(ctx)
plan := p.factory(ctx)
err = plan.UnmarshalJSON([]byte(planJSON))
if err != nil {
err = fmt.Errorf("unmarshal plan error: %w", err)
Expand Down Expand Up @@ -440,34 +440,34 @@ func NewPlanner(_ context.Context, cfg *PlannerConfig) (adk.Agent, error) {
return nil, err
}
}

inputFn := cfg.GenInputFn
if inputFn == nil {
inputFn = defaultGenPlannerInputFn
}

planParser := cfg.NewPlan
if planParser == nil {
planParser = defaultNewPlan
factory := cfg.Factory
if factory == nil {
factory = defaultPlanFactory
}

return &planner{
cfg: cfg,
toolCall: toolCall,
chatModel: chatModel,
genInputFn: inputFn,
newPlan: planParser,
factory: factory,
}, nil
}

// ExecutionContext is the input information for the executor and the planner.
// ExecutionContext is the input information for the executor and re-planner.
type ExecutionContext struct {
UserInput []adk.Message
Plan Plan
ExecutedSteps []ExecutedStep
}

// GenModelInputFn is a function that generates the input messages for the executor and the planner.
type GenModelInputFn func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error)
// GenExecutorModelInputFn is a function that generates the input messages for the executor.
type GenExecutorModelInputFn func(ctx context.Context, in *ExecutionContext, cfg *ExecutorConfig) ([]adk.Message, error)

// ExecutorConfig provides configuration options for creating an executor agent.
type ExecutorConfig struct {
Expand All @@ -482,9 +482,9 @@ type ExecutorConfig struct {
// Optional. Defaults to 20.
MaxIterations int

// GenInputFn generates the input messages for the Executor.
// Optional. If not provided, defaultGenExecutorInputFn will be used.
GenInputFn GenModelInputFn
// GenInputFn generates input messages for the executor.
// Optional. Defaults to using ExecutorPrompt as the template to render model input messages.
GenInputFn GenExecutorModelInputFn
}

type ExecutedStep struct {
Expand Down Expand Up @@ -525,7 +525,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) {
ExecutedSteps: executedSteps_,
}

msgs, err := genInputFn(ctx, in)
msgs, err := genInputFn(ctx, in, cfg)
if err != nil {
return nil, err
}
Expand All @@ -549,7 +549,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) {
return agent, nil
}

func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) {
func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext, _ *ExecutorConfig) ([]adk.Message, error) {

planContent, err := in.Plan.MarshalJSON()
if err != nil {
Expand All @@ -570,14 +570,18 @@ func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk
}

type replanner struct {
cfg *ReplannerConfig
chatModel model.ToolCallingChatModel
planTool *schema.ToolInfo
respondTool *schema.ToolInfo

genInputFn GenModelInputFn
newPlan NewPlan
genInputFn GenReplannerModelInputFn
factory PlanFactory
}

// GenReplannerModelInputFn is a function that generates the input messages for the re-planner.
type GenReplannerModelInputFn func(ctx context.Context, in *ExecutionContext, conf *ReplannerConfig) ([]adk.Message, error)

type ReplannerConfig struct {
// ChatModel is the model that supports tool calling capabilities.
// It will be configured with PlanTool and RespondTool to generate updated plans or responses.
Expand All @@ -591,14 +595,13 @@ type ReplannerConfig struct {
// Optional. If not provided, the default RespondToolInfo will be used.
RespondTool *schema.ToolInfo

// GenInputFn generates the input messages for the Replanner.
// Optional. If not provided, buildGenReplannerInputFn will be used.
GenInputFn GenModelInputFn
// GenInputFn generates input messages for the re-planner.
// Optional. Defaults to using ReplannerPrompt as the template to render model input messages.
GenInputFn GenReplannerModelInputFn

// NewPlan creates a new Plan instance.
// The returned Plan will be used to unmarshal the model-generated JSON output from PlanTool.
// Optional. If not provided, defaultNewPlan will be used.
NewPlan NewPlan
// Factory creates Plan instances for JSON unmarshaling.
// Optional. Defaults to creating DefaultPlan instances.
Factory PlanFactory
}

// formatInput formats the input messages into a string.
Expand Down Expand Up @@ -667,11 +670,8 @@ func (r *replanner) genInput(ctx context.Context) ([]adk.Message, error) {
Plan: plan_,
ExecutedSteps: executedSteps_,
}
genInputFn := r.genInputFn
if genInputFn == nil {
genInputFn = buildGenReplannerInputFn(r.planTool.Name, r.respondTool.Name)
}
msgs, err := genInputFn(ctx, in)

msgs, err := r.genInputFn(ctx, in, r.cfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -787,7 +787,7 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age
return
}

plan_ := r.newPlan(ctx)
plan_ := r.factory(ctx)
err = plan_.UnmarshalJSON([]byte(planMsg.ToolCalls[0].Function.Arguments))
if err != nil {
err = fmt.Errorf("unmarshal plan error: %w", err)
Expand All @@ -801,25 +801,34 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age
return iterator
}

func buildGenReplannerInputFn(planToolName, respondToolName string) GenModelInputFn {
return func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON()
if err != nil {
return nil, err
}
msgs, err := ReplannerPrompt.Format(ctx, map[string]any{
"plan": string(planContent),
"input": formatInput(in.UserInput),
"executed_steps": formatExecutedSteps(in.ExecutedSteps),
"plan_tool": planToolName,
"respond_tool": respondToolName,
})
if err != nil {
return nil, err
}
func defaultGenReplannerInputFn(ctx context.Context, in *ExecutionContext, cfg *ReplannerConfig) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON()
if err != nil {
return nil, err
}

return msgs, nil
planToolName := PlanToolInfo.Name
if cfg.PlanTool != nil {
planToolName = cfg.PlanTool.Name
}

respondToolName := RespondToolInfo.Name
if cfg.RespondTool != nil {
respondToolName = cfg.RespondTool.Name
}

msgs, err := ReplannerPrompt.Format(ctx, map[string]any{
"plan": string(planContent),
"input": formatInput(in.UserInput),
"executed_steps": formatExecutedSteps(in.ExecutedSteps),
"plan_tool": planToolName,
"respond_tool": respondToolName,
})
if err != nil {
return nil, err
}

return msgs, nil
}

func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) {
Expand All @@ -838,17 +847,23 @@ func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) {
return nil, err
}

planParser := cfg.NewPlan
if planParser == nil {
planParser = defaultNewPlan
factory := cfg.Factory
if factory == nil {
factory = defaultPlanFactory
}

genInputFn := cfg.GenInputFn
if genInputFn == nil {
genInputFn = defaultGenReplannerInputFn
}

return &replanner{
cfg: cfg,
chatModel: chatModel,
planTool: planTool,
respondTool: respondTool,
genInputFn: cfg.GenInputFn,
newPlan: planParser,
genInputFn: genInputFn,
factory: factory,
}, nil
}

Expand Down
Loading
Loading