diff --git a/cmd/run/run.go b/cmd/run/run.go index 5668e06..c46a451 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -258,19 +258,19 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return err } + interactiveMode := true initialPrompt := "" - singleShot := false pipedContent := "" if len(args) > 1 { initialPrompt = strings.Join(args[1:], " ") - singleShot = true + interactiveMode = false } if isPipe(os.Stdin) { promptFromPipe, _ := io.ReadAll(os.Stdin) if len(promptFromPipe) > 0 { - singleShot = true + interactiveMode = false pipedContent = strings.TrimSpace(string(promptFromPipe)) if initialPrompt != "" { initialPrompt = initialPrompt + "\n" + pipedContent @@ -289,28 +289,26 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { systemPrompt: systemPrompt, } - // If a prompt file is passed, load the messages from the file, templating {{input}} from stdin - if pf != nil { + // If there is no prompt file, add the initialPrompt to the conversation. + // If a prompt file is passed, load the messages from the file, templating {{input}} + // using the initialPrompt. + if pf == nil { + conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt) + } else { + interactiveMode = false + for _, m := range pf.Messages { content := m.Content - if strings.ToLower(m.Role) == "user" { - content = strings.ReplaceAll(content, "{{input}}", initialPrompt) - } switch strings.ToLower(m.Role) { case "system": - if conversation.systemPrompt == "" { - conversation.systemPrompt = content - } else { - conversation.AddMessage(azuremodels.ChatMessageRoleSystem, content) - } + conversation.systemPrompt = content case "user": + content = strings.ReplaceAll(content, "{{input}}", initialPrompt) conversation.AddMessage(azuremodels.ChatMessageRoleUser, content) case "assistant": conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content) } } - - initialPrompt = "" } mp := ModelParameters{} @@ -326,63 +324,15 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } for { - prompt := "" - if initialPrompt != "" { - prompt = initialPrompt - initialPrompt = "" - } - - if prompt == "" && pf == nil { - fmt.Printf(">>> ") - reader := bufio.NewReader(os.Stdin) - prompt, err = reader.ReadString('\n') - if err != nil { - return err - } - } - - prompt = strings.TrimSpace(prompt) - - if prompt == "" && pf == nil { - continue - } - - if strings.HasPrefix(prompt, "/") { - if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + if interactiveMode { + conversation, err = cmdHandler.ChatWithUser(conversation, mp) + if errors.Is(err, ErrExitChat) || errors.Is(err, io.EOF) { break + } else if err != nil { + return err } - - if prompt == "/parameters" { - cmdHandler.handleParametersPrompt(conversation, mp) - continue - } - - if prompt == "/reset" || prompt == "/clear" { - cmdHandler.handleResetPrompt(conversation) - continue - } - - if strings.HasPrefix(prompt, "/set ") { - cmdHandler.handleSetPrompt(prompt, mp) - continue - } - - if strings.HasPrefix(prompt, "/system-prompt ") { - conversation = cmdHandler.handleSystemPrompt(prompt, conversation) - continue - } - - if prompt == "/help" { - cmdHandler.handleHelpPrompt() - continue - } - - cmdHandler.handleUnrecognizedPrompt(prompt) - continue } - conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) - req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, @@ -431,7 +381,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) - if singleShot || pf != nil { + if !interactiveMode { break } } @@ -618,3 +568,57 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice func (h *runCommandHandler) writeToOut(message string) { h.cfg.WriteToOut(message) } + +var ErrExitChat = errors.New("exiting chat") + +func (h *runCommandHandler) ChatWithUser(conversation Conversation, mp ModelParameters) (Conversation, error) { + fmt.Printf(">>> ") + reader := bufio.NewReader(os.Stdin) + + prompt, err := reader.ReadString('\n') + if err != nil { + return conversation, err + } + + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return conversation, nil + } + + if strings.HasPrefix(prompt, "/") { + if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + return conversation, ErrExitChat + } + + if prompt == "/parameters" { + h.handleParametersPrompt(conversation, mp) + return conversation, nil + } + + if prompt == "/reset" || prompt == "/clear" { + h.handleResetPrompt(conversation) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/set ") { + h.handleSetPrompt(prompt, mp) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/system-prompt ") { + conversation = h.handleSystemPrompt(prompt, conversation) + return conversation, nil + } + + if prompt == "/help" { + h.handleHelpPrompt() + return conversation, nil + } + + h.handleUnrecognizedPrompt(prompt) + return conversation, nil + } + + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) + return conversation, nil +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 27cc468..e930e77 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -139,7 +139,7 @@ messages: _, err = runCmd.ExecuteC() require.NoError(t, err) - require.Equal(t, 3, len(capturedReq.Messages)) + require.Equal(t, 2, len(capturedReq.Messages)) require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content) @@ -220,7 +220,7 @@ messages: _, err = runCmd.ExecuteC() require.NoError(t, err) - require.Len(t, capturedReq.Messages, 3) + require.Len(t, capturedReq.Messages, 2) require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) require.Equal(t, initialPrompt+"\n"+piped, *capturedReq.Messages[1].Content) // {{input}} -> "Please summarize the provided text.\nHello there!"