Skip to content

Retry #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions retry/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# retry

Simple retry package to get a retryable `*http.Client` or `http.RoundTripper`
that wraps [github.com/cenkalti/backoff/v5](github.com/cenkalti/backoff/v5).

## How to use

You can customise `retry.Config` as documented.

### Retryable `http.Client`

```go
import (
"github.com/smithy-security/pkg/retry"
)

...

client, err := retry.NewClient(retry.Config{
MaxRetries: 10,
})
...
```

### Retryable `http.RoundTripper`

```go
import (
"github.com/smithy-security/pkg/retry"
)

...

rt, err := retry.NewRoundTripper(retry.Config{
MaxRetries: 10,
})
...
```
11 changes: 11 additions & 0 deletions retry/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package retry

// NoopLogger is exported for testing.
type NoopLogger = noopLogger

var (
// DefaultRetryableStatusCodes is exported for testing only.
DefaultRetryableStatusCodes = defaultRetryableStatusCodes
// DefaultAcceptedStatusCodes is exported for testing only.
DefaultAcceptedStatusCodes = defaultAcceptedStatusCodes
)
17 changes: 17 additions & 0 deletions retry/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module github.com/smithy-security/pkg/retry

go 1.24.0

require (
github.com/cenkalti/backoff/v5 v5.0.2
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
25 changes: 25 additions & 0 deletions retry/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
11 changes: 11 additions & 0 deletions retry/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package retry

type noopLogger struct{}

func (n *noopLogger) Debug(msg string, keyvals ...any) {}
func (n *noopLogger) Info(msg string, keyvals ...any) {}
func (n *noopLogger) Warn(msg string, keyvals ...any) {}
func (n *noopLogger) Error(msg string, keyvals ...any) {}
func (n *noopLogger) With(args ...any) Logger {
return &noopLogger{}
}
219 changes: 219 additions & 0 deletions retry/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package retry

import (
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"

"github.com/cenkalti/backoff/v5"
)

const defaultMaxRetries uint = 5

var (
defaultRetryableStatusCodes = map[int]struct{}{
http.StatusTooManyRequests: {},
http.StatusRequestTimeout: {},
http.StatusGatewayTimeout: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
}
defaultAcceptedStatusCodes = map[int]struct{}{
http.StatusCreated: {},
http.StatusAccepted: {},
http.StatusNoContent: {},
http.StatusOK: {},
}
)

type (
// Logger allows to inject a custom logger in the client.
Logger interface {
Error(msg string, keysAndValues ...interface{})
Info(msg string, keysAndValues ...interface{})
Debug(msg string, keysAndValues ...interface{})
Warn(msg string, keysAndValues ...interface{})
}

// NextRetryInSeconds allows customising the behaviour for the calculating the next retry.
NextRetryInSeconds func(currAttempt uint) int

// Config allows configuring the client.
Config struct {
// BaseClient allows to specify a base http.Client.
BaseClient *http.Client
// Logger allows to specify a custom logger. *slog.Logger will satisfy this.
Logger Logger
// NextRetryInSecondsFunc allows to specify a custom retry function.
// By default, exponential fibonacci like function is used.
NextRetryInSecondsFunc NextRetryInSeconds
// MaxRetries allows to specify the number of max retries before returning a fatal error.
// 5 is the default.
MaxRetries uint
// RetryableStatusCodes allows to specify the retryable status codes.
// defaultRetryableStatusCodes are the default.
RetryableStatusCodes map[int]struct{}
// AcceptedStatusCodes allows to specify the non-retryable status codes.
// defaultAcceptedStatusCodes are the default.
AcceptedStatusCodes map[int]struct{}
}

retry struct {
config Config
}
)

// Validate validates the client configuration.
func (c Config) Validate() error {
switch {
case c.MaxRetries == 0:
return errors.New("max retries is required")
case len(c.RetryableStatusCodes) == 0:
return errors.New("retryable status codes is required")
case len(c.AcceptedStatusCodes) == 0:
return errors.New("accepted status codes is required")
case c.BaseClient == nil:
return errors.New("base client is required")
case c.Logger == nil:
return errors.New("logger is required")
case c.NextRetryInSecondsFunc == nil:
return errors.New("next retry function is required")
}

return nil
}

func applyConfig(cfg Config) (Config, error) {
clonedCfg := cfg
if cfg.MaxRetries == 0 {
clonedCfg.MaxRetries = defaultMaxRetries
}

if len(clonedCfg.RetryableStatusCodes) == 0 {
clonedCfg.RetryableStatusCodes = defaultRetryableStatusCodes
}

if len(clonedCfg.AcceptedStatusCodes) == 0 {
clonedCfg.AcceptedStatusCodes = defaultAcceptedStatusCodes
}

if clonedCfg.BaseClient == nil {
clonedCfg.BaseClient = http.DefaultClient
}

if clonedCfg.Logger == nil {
clonedCfg.Logger = &noopLogger{}
}

if clonedCfg.NextRetryInSecondsFunc == nil {
clonedCfg.NextRetryInSecondsFunc = fibNextRetry
}

return clonedCfg, clonedCfg.Validate()
}

// NewClient returns a new http.Client with retry behaviour.
func NewClient(config Config) (*http.Client, error) {
config, err := applyConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to apply config: %w", err)
}

config.BaseClient.Transport = &retry{
config: config,
}

return config.BaseClient, nil
}

// NewRoundTripper returns a new http.RoundTripper with retry behaviour.
func NewRoundTripper(config Config) (http.RoundTripper, error) {
config, err := applyConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to apply config: %w", err)
}

return &retry{
config: config,
}, nil
}

// RoundTrip implements a http transport RoundTripper with retry capabilities.
func (re *retry) RoundTrip(req *http.Request) (*http.Response, error) {
var (
logger = re.config.Logger
currAttempt uint
retryableOp = func() (*http.Response, error) {
resp, err := re.config.BaseClient.Do(req)
switch {
case err != nil:
return resp, backoff.Permanent(err)
case resp == nil:
return nil, backoff.Permanent(errors.New("invalid nil response"))
}

_, isAcceptedStatus := re.config.AcceptedStatusCodes[resp.StatusCode]
_, isRetryableStatus := re.config.RetryableStatusCodes[resp.StatusCode]

switch {
case !isAcceptedStatus && currAttempt >= re.config.MaxRetries:
return resp, backoff.Permanent(
fmt.Errorf(
"maximum number of retries exceeded: %d",
currAttempt,
),
)
case !isAcceptedStatus && isRetryableStatus:
nextRetryInSeconds := re.config.NextRetryInSecondsFunc(currAttempt)

logger.Debug(
"retryable status code, retrying",
slog.Int("retry_in_seconds", nextRetryInSeconds),
slog.Int("curr_attempt", int(currAttempt)),
slog.Int("status_code", resp.StatusCode),
)
currAttempt++
return resp, backoff.RetryAfter(nextRetryInSeconds)
case !isAcceptedStatus && !isRetryableStatus:
bb, err := httputil.DumpResponse(resp, true)
if err != nil {
logger.Error(
"failed to dump response",
slog.String("error", err.Error()),
)
}
logger.Error(
"unexpected response",
slog.Int("status_code", resp.StatusCode),
slog.String("raw_body", string(bb)),
)
return resp, backoff.Permanent(fmt.Errorf("invalid status code: %d", resp.StatusCode))
}

return resp, nil
}
)

result, err := backoff.Retry(
req.Context(),
retryableOp,
)
if err != nil {
return result, fmt.Errorf("could not process backoff result: %w", err)
}

return result, nil
}

func fibNextRetry(attempt uint) int {
switch attempt {
case 0:
return 0
case 1:
return 1
default:
return fibNextRetry(attempt-1) + fibNextRetry(attempt-2)
}
}
Loading
Loading