Skip to content

Commit ec5b971

Browse files
authored
Add parameter to filter cloud provider. (#765)
* Add parameter to filter cloud provider. Signed-off-by: Rodrigo Reis <[email protected]> * goimports. Signed-off-by: Rodrigo Reis <[email protected]> * typo Signed-off-by: Rodrigo Reis <[email protected]> * Use enum for parameter, improve provider naming. Signed-off-by: Rodrigo Reis <[email protected]> * Test clamped retryer. Signed-off-by: Rodrigo Reis <[email protected]> * Test copy values. Signed-off-by: Rodrigo Reis <[email protected]> * Test cluster version. Signed-off-by: Rodrigo Reis <[email protected]> * goimports. Signed-off-by: Rodrigo Reis <[email protected]> --------- Signed-off-by: Rodrigo Reis <[email protected]>
1 parent de834d6 commit ec5b971

File tree

12 files changed

+341
-26
lines changed

12 files changed

+341
-26
lines changed

api/cluster_version_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package api
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestParseVersion(t *testing.T) {
10+
for _, tc := range []struct {
11+
name string
12+
version string
13+
expected string
14+
}{
15+
{
16+
name: "empty",
17+
version: "",
18+
expected: "#",
19+
},
20+
{
21+
name: "simple",
22+
version: "foo#bar",
23+
expected: "foo#bar",
24+
},
25+
{
26+
name: "missing hash",
27+
version: "foo",
28+
expected: "#",
29+
},
30+
{
31+
name: "missing version",
32+
version: "#bar",
33+
expected: "#bar",
34+
},
35+
} {
36+
t.Run(tc.name, func(t *testing.T) {
37+
version := ParseVersion(tc.version)
38+
result := version.String()
39+
require.Equal(t, tc.expected, result)
40+
})
41+
}
42+
}

cmd/clm/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ func main() {
137137
AccountFilter: cfg.AccountFilter,
138138
Interval: cfg.Interval,
139139
DryRun: cfg.DryRun,
140+
Providers: cfg.Providers,
140141
ConcurrentUpdates: cfg.ConcurrentUpdates,
141142
}
142143

config/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const (
2727
defaultDrainForceEvictInterval = "5m"
2828
defaultDrainPollInterval = "30s"
2929
defaultUpdateStrategy = "clc"
30+
defaultProvider = "zalando-aws"
3031
)
3132

3233
var defaultWorkdir = path.Join(os.TempDir(), "clm-workdir")
@@ -47,6 +48,7 @@ type LifecycleManagerConfig struct {
4748
Listen string
4849
Workdir string
4950
Directory string
51+
Providers []string
5052
ConfigSources []string
5153
SSHPrivateKeyFile string
5254
CredentialsDir string
@@ -89,6 +91,10 @@ func (cfg *LifecycleManagerConfig) ParseFlags() string {
8991
kingpin.Flag("dry-run", "Don't make any changes, just print.").BoolVar(&cfg.DryRun)
9092
kingpin.Flag("listen", "Address to listen at, e.g. :9090 or 0.0.0.0:9090").Default(defaultListener).StringVar(&cfg.Listen)
9193
kingpin.Flag("workdir", "Path to working directory used for storing channel configurations.").Default(defaultWorkdir).StringVar(&cfg.Workdir)
94+
kingpin.Flag(
95+
"provider",
96+
"Cloud provider. Defaults to single provider \"zalando-aws\".",
97+
).Default(defaultProvider).EnumsVar(&cfg.Providers, "zalando-aws")
9298
kingpin.Flag("config-source", "Config source specification (NAME:dir:PATH or NAME:git:URL). At least one is required.").StringsVar(&cfg.ConfigSources)
9399
kingpin.Flag("directory", "Use a single directory as a config source (for local/development use)").StringVar(&cfg.Directory)
94100
kingpin.Flag("concurrent-updates", "Number of updates allowed to run in parallel.").Default(defaultConcurrentUpdates).UintVar(&cfg.ConcurrentUpdates)

controller/controller.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ var (
3333
type Options struct {
3434
Interval time.Duration
3535
AccountFilter config.IncludeExcludeFilter
36+
Providers []string
3637
DryRun bool
3738
ConcurrentUpdates uint
3839
EnvironmentOrder []string
@@ -44,6 +45,7 @@ type Controller struct {
4445
execManager *command.ExecManager
4546
registry registry.Registry
4647
provisioner provisioner.Provisioner
48+
providers []string
4749
channelConfigSourcer channel.ConfigSource
4850
interval time.Duration
4951
dryRun bool
@@ -58,6 +60,7 @@ func New(logger *log.Entry, execManager *command.ExecManager, registry registry.
5860
execManager: execManager,
5961
registry: registry,
6062
provisioner: provisioner,
63+
providers: options.Providers,
6164
channelConfigSourcer: channel.NewCachingSource(channelConfigSourcer),
6265
interval: options.Interval,
6366
dryRun: options.DryRun,
@@ -116,7 +119,11 @@ func (c *Controller) refresh() error {
116119
return err
117120
}
118121

119-
clusters, err := c.registry.ListClusters(registry.Filter{})
122+
clusters, err := c.registry.ListClusters(
123+
registry.Filter{
124+
Providers: c.providers,
125+
},
126+
)
120127
if err != nil {
121128
return err
122129
}

pkg/aws/retryer_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package aws
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/aws/aws-sdk-go/aws/client/metadata"
8+
"github.com/aws/aws-sdk-go/aws/request"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestShouldRetry(t *testing.T) {
13+
for _, tc := range []struct {
14+
caseName string
15+
maxRetries int
16+
maxRetryInterval time.Duration
17+
request *request.Request
18+
expected bool
19+
}{
20+
{
21+
caseName: "should not retry",
22+
maxRetries: 1,
23+
request: &request.Request{},
24+
expected: false,
25+
},
26+
{
27+
caseName: "should retry with metadata service",
28+
maxRetries: 1,
29+
request: &request.Request{
30+
ClientInfo: metadata.ClientInfo{
31+
ServiceName: "ec2metadata",
32+
},
33+
},
34+
expected: true,
35+
},
36+
{
37+
caseName: "should not retry with metadata service",
38+
maxRetries: 1,
39+
request: &request.Request{
40+
ClientInfo: metadata.ClientInfo{
41+
ServiceName: "ec2metadata",
42+
},
43+
RetryCount: 8,
44+
},
45+
expected: false,
46+
},
47+
} {
48+
t.Run(tc.caseName, func(t *testing.T) {
49+
retryer := NewClampedRetryer(tc.maxRetries, time.Second)
50+
51+
res := retryer.ShouldRetry(tc.request)
52+
require.Equal(t, tc.expected, res)
53+
})
54+
}
55+
}
56+
57+
func TestRetryRules(t *testing.T) {
58+
for _, tc := range []struct {
59+
caseName string
60+
maxRetryInterval time.Duration
61+
request *request.Request
62+
expectedLessOrEqual time.Duration
63+
}{
64+
{
65+
caseName: "should return max retry interval",
66+
maxRetryInterval: time.Millisecond,
67+
request: &request.Request{},
68+
expectedLessOrEqual: time.Millisecond,
69+
},
70+
{
71+
caseName: "should not return max retry interval",
72+
maxRetryInterval: time.Second,
73+
request: &request.Request{
74+
ClientInfo: metadata.ClientInfo{
75+
ServiceName: "ec2metadata",
76+
},
77+
},
78+
expectedLessOrEqual: time.Second / 2,
79+
},
80+
} {
81+
t.Run(tc.caseName, func(t *testing.T) {
82+
retryer := NewClampedRetryer(1, tc.maxRetryInterval)
83+
84+
res := retryer.RetryRules(tc.request)
85+
require.LessOrEqual(t, res, tc.expectedLessOrEqual)
86+
})
87+
}
88+
}

pkg/util/copy_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package util
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestCopyValues(t *testing.T) {
10+
for _, tc := range []struct {
11+
name string
12+
value map[string]interface{}
13+
}{
14+
{
15+
name: "empty",
16+
value: map[string]interface{}{},
17+
},
18+
{
19+
name: "simple",
20+
value: map[string]interface{}{
21+
"foo": "bar",
22+
},
23+
},
24+
{
25+
name: "nested",
26+
value: map[string]interface{}{
27+
"foo": map[string]interface{}{
28+
"bar": "baz",
29+
},
30+
},
31+
},
32+
} {
33+
t.Run(tc.name, func(t *testing.T) {
34+
result := CopyValues(tc.value)
35+
require.Equal(t, tc.value, result)
36+
})
37+
}
38+
}

provisioner/clusterpy.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
)
3333

3434
const (
35-
providerID = "zalando-aws"
3635
etcdStackFileName = "stack.yaml"
3736
clusterStackFileName = "cluster.yaml"
3837
etcdStackNameDefault = "etcd-cluster-etcd"
@@ -103,7 +102,7 @@ func NewClusterpyProvisioner(execManager *command.ExecManager, tokenSource oauth
103102
}
104103

105104
func (p *clusterpyProvisioner) Supports(cluster *api.Cluster) bool {
106-
return cluster.Provider == providerID
105+
return cluster.Provider == string(ZalandoAWSProvider)
107106
}
108107

109108
func (p *clusterpyProvisioner) updateDefaults(cluster *api.Cluster, channelConfig channel.Config, adapter *awsAdapter, instanceTypes *awsUtils.InstanceTypes) error {
@@ -176,6 +175,10 @@ func (p *clusterpyProvisioner) propagateConfigItemsToNodePools(cluster *api.Clus
176175
// Provision provisions/updates a cluster on AWS. Provision is an idempotent
177176
// operation for the same input.
178177
func (p *clusterpyProvisioner) Provision(ctx context.Context, logger *log.Entry, cluster *api.Cluster, channelConfig channel.Config) error {
178+
if !p.Supports(cluster) {
179+
return ErrProviderNotSupported
180+
}
181+
179182
instanceTypes, awsAdapter, updater, err := p.prepareProvision(logger, cluster, channelConfig)
180183
if err != nil {
181184
return err
@@ -599,7 +602,7 @@ func selectSubnetIDs(subnets []*ec2.Subnet) *AZInfo {
599602

600603
// Decommission decommissions a cluster provisioned in AWS.
601604
func (p *clusterpyProvisioner) Decommission(ctx context.Context, logger *log.Entry, cluster *api.Cluster) error {
602-
if cluster.Provider != providerID {
605+
if !p.Supports(cluster) {
603606
return ErrProviderNotSupported
604607
}
605608

@@ -779,10 +782,6 @@ func (p *clusterpyProvisioner) setupAWSAdapter(logger *log.Entry, cluster *api.C
779782
// prepares to provision a cluster by initializing the aws adapter.
780783
// TODO: this is doing a lot of things to glue everything together, this should be refactored.
781784
func (p *clusterpyProvisioner) prepareProvision(logger *log.Entry, cluster *api.Cluster, channelConfig channel.Config) (*awsUtils.InstanceTypes, *awsAdapter, updatestrategy.UpdateStrategy, error) {
782-
if cluster.Provider != providerID {
783-
return nil, nil, nil, ErrProviderNotSupported
784-
}
785-
786785
logger.Infof("clusterpy: Prepare for provisioning cluster %s (%s)..", cluster.ID, cluster.LifecycleStatus)
787786

788787
adapter, err := p.setupAWSAdapter(logger, cluster)

provisioner/clusterpy_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package provisioner
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"net/http"
@@ -517,3 +518,22 @@ func TestWaitForAPIServer(t *testing.T) {
517518
})
518519
}
519520
}
521+
522+
func TestProvisionDoesNotSupportProvider(t *testing.T) {
523+
cluster := &api.Cluster{
524+
Provider: "zalando-eks",
525+
}
526+
527+
p := clusterpyProvisioner{}
528+
err := p.Provision(context.TODO(), nil, cluster, nil)
529+
assert.Equal(t, ErrProviderNotSupported, err)
530+
}
531+
func TestDecommissionDoesNotSupportProvider(t *testing.T) {
532+
cluster := &api.Cluster{
533+
Provider: "zalando-eks",
534+
}
535+
536+
p := clusterpyProvisioner{}
537+
err := p.Decommission(context.TODO(), nil, cluster)
538+
assert.Equal(t, ErrProviderNotSupported, err)
539+
}

provisioner/provisioner.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@ import (
1111
log "github.com/sirupsen/logrus"
1212
)
1313

14+
type (
15+
// A provider ID is a string that identifies a cluster provider.
16+
ProviderID string
17+
18+
// Options is the options that can be passed to a provisioner when initialized.
19+
Options struct {
20+
DryRun bool
21+
ApplyOnly bool
22+
UpdateStrategy config.UpdateStrategy
23+
RemoveVolumes bool
24+
ManageEtcdStack bool
25+
}
26+
)
27+
28+
const (
29+
// ZalandoAWS Provider is the provider ID for Zalando managed AWS clusters.
30+
ZalandoAWSProvider ProviderID = "zalando-aws"
31+
)
32+
1433
var (
1534
// ErrProviderNotSupported is the error returned from porvisioners if
1635
// they don't support the cluster provider defined.
1736
ErrProviderNotSupported = errors.New("unsupported provider type")
1837
)
1938

20-
// Options is the options that can be passed to a provisioner when initialized.
21-
type Options struct {
22-
DryRun bool
23-
ApplyOnly bool
24-
UpdateStrategy config.UpdateStrategy
25-
RemoveVolumes bool
26-
ManageEtcdStack bool
27-
}
28-
2939
// Provisioner is an interface describing how to provision or decommission
3040
// clusters.
3141
type Provisioner interface {

registry/http.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,18 @@ func (r *httpRegistry) ListClusters(filter Filter) ([]*api.Cluster, error) {
6363
var result []*api.Cluster
6464

6565
for _, cluster := range resp.Payload.Items {
66-
if filter.LifecycleStatus == nil || *cluster.LifecycleStatus == *filter.LifecycleStatus {
67-
c, err := convertFromClusterModel(cluster)
68-
if err != nil {
69-
return nil, err
70-
}
71-
if account, ok := accounts[c.InfrastructureAccount]; ok {
72-
c.Owner = *account.Owner
73-
}
74-
result = append(result, c)
66+
if !filter.Includes(cluster) {
67+
continue
7568
}
69+
70+
c, err := convertFromClusterModel(cluster)
71+
if err != nil {
72+
return nil, err
73+
}
74+
if account, ok := accounts[c.InfrastructureAccount]; ok {
75+
c.Owner = *account.Owner
76+
}
77+
result = append(result, c)
7678
}
7779

7880
return result, nil

0 commit comments

Comments
 (0)