diff --git a/command_test.go b/command_test.go index e4fa951..638e4ac 100644 --- a/command_test.go +++ b/command_test.go @@ -49,8 +49,14 @@ func sampleCommand(t *testing.T) *serpent.Command { Use: "root [subcommand]", Options: serpent.OptionSet{ serpent.Option{ - Name: "verbose", - Flag: "verbose", + Name: "verbose", + Flag: "verbose", + Default: "false", + Value: serpent.BoolOf(&verbose), + }, + serpent.Option{ + Name: "verbose-old", + Flag: "verbode-old", Value: serpent.BoolOf(&verbose), }, serpent.Option{ @@ -742,6 +748,12 @@ func TestCommand_DefaultsOverride(t *testing.T) { Value: serpent.StringOf(&got), YAML: "url", }, + { + Name: "url-deprecated", + Flag: "url-deprecated", + Env: "URL_DEPRECATED", + Value: serpent.StringOf(&got), + }, { Name: "config", Flag: "config", @@ -790,6 +802,17 @@ func TestCommand_DefaultsOverride(t *testing.T) { inv.Args = []string{"--config", fi.Name(), "--url", "good.com"} }) + test("EnvOverYAML", "good.com", func(t *testing.T, inv *serpent.Invocation) { + fi, err := os.CreateTemp(t.TempDir(), "config.yaml") + require.NoError(t, err) + defer fi.Close() + + _, err = fi.WriteString("url: bad.com") + require.NoError(t, err) + + inv.Environ.Set("URL", "good.com") + }) + test("YAMLOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { fi, err := os.CreateTemp(t.TempDir(), "config.yaml") require.NoError(t, err) @@ -800,4 +823,83 @@ func TestCommand_DefaultsOverride(t *testing.T) { inv.Args = []string{"--config", fi.Name()} }) + + test("AltFlagOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { + inv.Args = []string{"--url-deprecated", "good.com"} + }) +} + +func TestCommand_OptionsWithSharedValue(t *testing.T) { + t.Parallel() + + var got string + makeCmd := func(def, altDef string) *serpent.Command { + got = "" + return &serpent.Command{ + Options: serpent.OptionSet{ + { + Name: "url", + Flag: "url", + Env: "URL", + Default: def, + Value: serpent.StringOf(&got), + }, + { + Name: "alt-url", + Flag: "alt-url", + Env: "ALT_URL", + Default: altDef, + Value: serpent.StringOf(&got), + }, + }, + Handler: (func(i *serpent.Invocation) error { + return nil + }), + } + } + + // Check proper value propagation. + err := makeCmd("def.com", "def.com").Invoke().Run() + require.NoError(t, err, "default values are same") + require.Equal(t, "def.com", got) + + err = makeCmd("def.com", "").Invoke().Run() + require.NoError(t, err, "other default value is empty") + require.Equal(t, "def.com", got) + + err = makeCmd("def.com", "").Invoke("--url", "sup").Run() + require.NoError(t, err) + require.Equal(t, "sup", got) + + err = makeCmd("def.com", "").Invoke("--alt-url", "hup").Run() + require.NoError(t, err) + require.Equal(t, "hup", got) + + // Both flags are given, last wins. + err = makeCmd("def.com", "").Invoke("--url", "sup", "--alt-url", "hup").Run() + require.NoError(t, err) + require.Equal(t, "hup", got) + + // Both flags are given, last wins #2. + err = makeCmd("", "def.com").Invoke("--alt-url", "hup", "--url", "sup").Run() + require.NoError(t, err) + require.Equal(t, "sup", got) + + // Both flags are given, option type priority wins. + inv := makeCmd("def.com", "").Invoke("--alt-url", "hup") + inv.Environ.Set("URL", "sup") + err = inv.Run() + require.NoError(t, err) + require.Equal(t, "hup", got) + + // Both flags are given, option type priority wins #2. + inv = makeCmd("", "def.com").Invoke("--url", "sup") + inv.Environ.Set("ALT_URL", "hup") + err = inv.Run() + require.NoError(t, err) + require.Equal(t, "sup", got) + + // Catch invalid configuration. + err = makeCmd("def.com", "alt-def.com").Invoke().Run() + require.Error(t, err, "default values are different") } diff --git a/option.go b/option.go index 2780fc6..e23214c 100644 --- a/option.go +++ b/option.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "os" + "slices" "strings" "github.com/hashicorp/go-multierror" @@ -21,6 +22,14 @@ const ( ValueSourceDefault ValueSource = "default" ) +var valueSourcePriority = []ValueSource{ + ValueSourceFlag, + ValueSourceEnv, + ValueSourceYAML, + ValueSourceDefault, + ValueSourceNone, +} + // Option is a configuration option for a CLI application. type Option struct { Name string `json:"name,omitempty"` @@ -305,16 +314,12 @@ func (optSet *OptionSet) SetDefaults() error { var merr *multierror.Error - for i, opt := range *optSet { - // Skip values that may have already been set by the user. - if opt.ValueSource != ValueSourceNone { - continue - } - - if opt.Default == "" { - continue - } - + // It's common to have multiple options with the same value to + // handle deprecation. We group the options by value so that we + // don't let other options overwrite user input. + groupByValue := make(map[pflag.Value][]*Option) + for i := range *optSet { + opt := &(*optSet)[i] if opt.Value == nil { merr = multierror.Append( merr, @@ -325,13 +330,69 @@ func (optSet *OptionSet) SetDefaults() error { ) continue } - (*optSet)[i].ValueSource = ValueSourceDefault - if err := opt.Value.Set(opt.Default); err != nil { + groupByValue[opt.Value] = append(groupByValue[opt.Value], opt) + } + + // Sorts by value source, then a default value being set. + sortOptionByValueSourcePriorityOrDefault := func(a, b *Option) int { + if a.ValueSource != b.ValueSource { + return slices.Index(valueSourcePriority, a.ValueSource) - slices.Index(valueSourcePriority, b.ValueSource) + } + if a.Default != b.Default { + if a.Default == "" { + return 1 + } + if b.Default == "" { + return -1 + } + } + return 0 + } + for _, opts := range groupByValue { + // Sort the options by priority and whether or not a default is + // set. This won't affect the value but represents correctness + // from whence the value originated. + slices.SortFunc(opts, sortOptionByValueSourcePriorityOrDefault) + + // If the first option has a value source, then we don't need to + // set the default, but mark the source for all options. + if opts[0].ValueSource != ValueSourceNone { + for _, opt := range opts[1:] { + opt.ValueSource = opts[0].ValueSource + } + continue + } + + var optWithDefault *Option + for _, opt := range opts { + if opt.Default == "" { + continue + } + if optWithDefault != nil && optWithDefault.Default != opt.Default { + merr = multierror.Append( + merr, + xerrors.Errorf( + "parse %q: multiple defaults set for the same value: %q and %q (%q)", + opt.Name, opt.Default, optWithDefault.Default, optWithDefault.Name, + ), + ) + continue + } + optWithDefault = opt + } + if optWithDefault == nil { + continue + } + if err := optWithDefault.Value.Set(optWithDefault.Default); err != nil { merr = multierror.Append( - merr, xerrors.Errorf("parse %q: %w", opt.Name, err), + merr, xerrors.Errorf("parse %q: %w", optWithDefault.Name, err), ) } + for _, opt := range opts { + opt.ValueSource = ValueSourceDefault + } } + return merr.ErrorOrNil() } diff --git a/yaml.go b/yaml.go index 37af35a..bfa9d25 100644 --- a/yaml.go +++ b/yaml.go @@ -213,8 +213,10 @@ func (o *Option) setFromYAMLNode(n *yaml.Node) error { // We treat empty values as nil for consistency with other option // mechanisms. if len(n.Content) == 0 { - o.Value = nil - return nil + if o.Value == nil { + return nil + } + return o.Value.Set("") } return n.Decode(o.Value) case yaml.MappingNode: