diff --git a/cmd/run/run.go b/cmd/run/run.go index 5668e06..f2c59e4 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -314,10 +314,6 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } mp := ModelParameters{} - err = mp.PopulateFromFlags(cmd.Flags()) - if err != nil { - return err - } if pf != nil { mp.maxTokens = pf.ModelParameters.MaxTokens @@ -325,6 +321,11 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { mp.topP = pf.ModelParameters.TopP } + err = mp.PopulateFromFlags(cmd.Flags()) + if err != nil { + return err + } + for { prompt := "" if initialPrompt != "" { diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 27cc468..9a258ac 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -226,4 +226,108 @@ messages: require.Contains(t, out.String(), reply) }) + + t.Run("cli flags override params set in the prompt.yaml file", func(t *testing.T) { + // Begin setup: + const yamlBody = ` + name: Example Prompt + description: Example description + model: openai/example-model + modelParameters: + maxTokens: 300 + temperature: 0.8 + topP: 0.9 + messages: + - role: system + content: System message + - role: user + content: User message + ` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yaml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "example-model", + Publisher: "openai", + Task: "chat-completion", + } + modelSummary2 := &azuremodels.ModelSummary{ + Name: "example-model-4o-mini-plus", + Publisher: "openai", + Task: "chat-completion", + } + + client.MockListModels = func(ctx context.Context) ([]*azuremodels. + ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary, modelSummary2}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "hello" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + + // End setup. + // --- + // We're finally ready to start making assertions. + + // Test case 1: with no flags, the model params come from the YAML file + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model", capturedReq.Model) + require.Equal(t, 300, *capturedReq.MaxTokens) + require.Equal(t, 0.8, *capturedReq.Temperature) + require.Equal(t, 0.9, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + + // Hooray! + // Test case 2: values from flags override the params from the YAML file + runCmd = NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "openai/example-model-4o-mini-plus", + "--file", tmp.Name(), + "--max-tokens", "150", + "--temperature", "0.1", + "--top-p", "0.3", + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model-4o-mini-plus", capturedReq.Model) + require.Equal(t, 150, *capturedReq.MaxTokens) + require.Equal(t, 0.1, *capturedReq.Temperature) + require.Equal(t, 0.3, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + }) }