Skip to content

Commit 82831e9

Browse files
authored
Minor refactor, simplified how "interactive" mode works (#48)
2 parents 12a2f2d + d1d49fb commit 82831e9

File tree

2 files changed

+75
-71
lines changed

2 files changed

+75
-71
lines changed

cmd/run/run.go

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -258,19 +258,19 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
258258
return err
259259
}
260260

261+
interactiveMode := true
261262
initialPrompt := ""
262-
singleShot := false
263263
pipedContent := ""
264264

265265
if len(args) > 1 {
266266
initialPrompt = strings.Join(args[1:], " ")
267-
singleShot = true
267+
interactiveMode = false
268268
}
269269

270270
if isPipe(os.Stdin) {
271271
promptFromPipe, _ := io.ReadAll(os.Stdin)
272272
if len(promptFromPipe) > 0 {
273-
singleShot = true
273+
interactiveMode = false
274274
pipedContent = strings.TrimSpace(string(promptFromPipe))
275275
if initialPrompt != "" {
276276
initialPrompt = initialPrompt + "\n" + pipedContent
@@ -289,28 +289,26 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
289289
systemPrompt: systemPrompt,
290290
}
291291

292-
// If a prompt file is passed, load the messages from the file, templating {{input}} from stdin
293-
if pf != nil {
292+
// If there is no prompt file, add the initialPrompt to the conversation.
293+
// If a prompt file is passed, load the messages from the file, templating {{input}}
294+
// using the initialPrompt.
295+
if pf == nil {
296+
conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt)
297+
} else {
298+
interactiveMode = false
299+
294300
for _, m := range pf.Messages {
295301
content := m.Content
296-
if strings.ToLower(m.Role) == "user" {
297-
content = strings.ReplaceAll(content, "{{input}}", initialPrompt)
298-
}
299302
switch strings.ToLower(m.Role) {
300303
case "system":
301-
if conversation.systemPrompt == "" {
302-
conversation.systemPrompt = content
303-
} else {
304-
conversation.AddMessage(azuremodels.ChatMessageRoleSystem, content)
305-
}
304+
conversation.systemPrompt = content
306305
case "user":
306+
content = strings.ReplaceAll(content, "{{input}}", initialPrompt)
307307
conversation.AddMessage(azuremodels.ChatMessageRoleUser, content)
308308
case "assistant":
309309
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content)
310310
}
311311
}
312-
313-
initialPrompt = ""
314312
}
315313

316314
mp := ModelParameters{}
@@ -327,63 +325,15 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
327325
}
328326

329327
for {
330-
prompt := ""
331-
if initialPrompt != "" {
332-
prompt = initialPrompt
333-
initialPrompt = ""
334-
}
335-
336-
if prompt == "" && pf == nil {
337-
fmt.Printf(">>> ")
338-
reader := bufio.NewReader(os.Stdin)
339-
prompt, err = reader.ReadString('\n')
340-
if err != nil {
341-
return err
342-
}
343-
}
344-
345-
prompt = strings.TrimSpace(prompt)
346-
347-
if prompt == "" && pf == nil {
348-
continue
349-
}
350-
351-
if strings.HasPrefix(prompt, "/") {
352-
if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" {
328+
if interactiveMode {
329+
conversation, err = cmdHandler.ChatWithUser(conversation, mp)
330+
if errors.Is(err, ErrExitChat) || errors.Is(err, io.EOF) {
353331
break
332+
} else if err != nil {
333+
return err
354334
}
355-
356-
if prompt == "/parameters" {
357-
cmdHandler.handleParametersPrompt(conversation, mp)
358-
continue
359-
}
360-
361-
if prompt == "/reset" || prompt == "/clear" {
362-
cmdHandler.handleResetPrompt(conversation)
363-
continue
364-
}
365-
366-
if strings.HasPrefix(prompt, "/set ") {
367-
cmdHandler.handleSetPrompt(prompt, mp)
368-
continue
369-
}
370-
371-
if strings.HasPrefix(prompt, "/system-prompt ") {
372-
conversation = cmdHandler.handleSystemPrompt(prompt, conversation)
373-
continue
374-
}
375-
376-
if prompt == "/help" {
377-
cmdHandler.handleHelpPrompt()
378-
continue
379-
}
380-
381-
cmdHandler.handleUnrecognizedPrompt(prompt)
382-
continue
383335
}
384336

385-
conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt)
386-
387337
req := azuremodels.ChatCompletionOptions{
388338
Messages: conversation.GetMessages(),
389339
Model: modelName,
@@ -432,7 +382,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
432382

433383
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String())
434384

435-
if singleShot || pf != nil {
385+
if !interactiveMode {
436386
break
437387
}
438388
}
@@ -619,3 +569,57 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice
619569
func (h *runCommandHandler) writeToOut(message string) {
620570
h.cfg.WriteToOut(message)
621571
}
572+
573+
var ErrExitChat = errors.New("exiting chat")
574+
575+
func (h *runCommandHandler) ChatWithUser(conversation Conversation, mp ModelParameters) (Conversation, error) {
576+
fmt.Printf(">>> ")
577+
reader := bufio.NewReader(os.Stdin)
578+
579+
prompt, err := reader.ReadString('\n')
580+
if err != nil {
581+
return conversation, err
582+
}
583+
584+
prompt = strings.TrimSpace(prompt)
585+
if prompt == "" {
586+
return conversation, nil
587+
}
588+
589+
if strings.HasPrefix(prompt, "/") {
590+
if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" {
591+
return conversation, ErrExitChat
592+
}
593+
594+
if prompt == "/parameters" {
595+
h.handleParametersPrompt(conversation, mp)
596+
return conversation, nil
597+
}
598+
599+
if prompt == "/reset" || prompt == "/clear" {
600+
h.handleResetPrompt(conversation)
601+
return conversation, nil
602+
}
603+
604+
if strings.HasPrefix(prompt, "/set ") {
605+
h.handleSetPrompt(prompt, mp)
606+
return conversation, nil
607+
}
608+
609+
if strings.HasPrefix(prompt, "/system-prompt ") {
610+
conversation = h.handleSystemPrompt(prompt, conversation)
611+
return conversation, nil
612+
}
613+
614+
if prompt == "/help" {
615+
h.handleHelpPrompt()
616+
return conversation, nil
617+
}
618+
619+
h.handleUnrecognizedPrompt(prompt)
620+
return conversation, nil
621+
}
622+
623+
conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt)
624+
return conversation, nil
625+
}

cmd/run/run_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ messages:
139139
_, err = runCmd.ExecuteC()
140140
require.NoError(t, err)
141141

142-
require.Equal(t, 3, len(capturedReq.Messages))
142+
require.Equal(t, 2, len(capturedReq.Messages))
143143
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
144144
require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content)
145145

@@ -220,7 +220,7 @@ messages:
220220
_, err = runCmd.ExecuteC()
221221
require.NoError(t, err)
222222

223-
require.Len(t, capturedReq.Messages, 3)
223+
require.Len(t, capturedReq.Messages, 2)
224224
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
225225
require.Equal(t, initialPrompt+"\n"+piped, *capturedReq.Messages[1].Content) // {{input}} -> "Please summarize the provided text.\nHello there!"
226226

0 commit comments

Comments
 (0)