@@ -343,11 +343,24 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
343343// attempted. If overriding this, be sure to close the body if needed.
344344type ErrorHandler func (resp * http.Response , err error , numTries int ) (* http.Response , error )
345345
346+ type HTTPClient interface {
347+ // Do performs an HTTP request and returns an HTTP response.
348+ Do (* http.Request ) (* http.Response , error )
349+ // Done is called when the client is no longer needed.
350+ Done ()
351+ }
352+
353+ type HTTPClientFactory interface {
354+ // New returns an HTTP client to use for a request, including retries.
355+ New () HTTPClient
356+ }
357+
346358// Client is used to make HTTP requests. It adds additional functionality
347359// like automatic retries to tolerate minor outages.
348360type Client struct {
349- HTTPClient * http.Client // Internal HTTP client.
350- Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
361+ HTTPClient * http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
362+ HTTPClientFactory HTTPClientFactory
363+ Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
351364
352365 RetryWaitMin time.Duration // Minimum time to wait
353366 RetryWaitMax time.Duration // Maximum time to wait
@@ -372,19 +385,18 @@ type Client struct {
372385 ErrorHandler ErrorHandler
373386
374387 loggerInit sync.Once
375- clientInit sync.Once
376388}
377389
378390// NewClient creates a new Client with default settings.
379391func NewClient () * Client {
380392 return & Client {
381- HTTPClient : cleanhttp . DefaultPooledClient () ,
382- Logger : defaultLogger ,
383- RetryWaitMin : defaultRetryWaitMin ,
384- RetryWaitMax : defaultRetryWaitMax ,
385- RetryMax : defaultRetryMax ,
386- CheckRetry : DefaultRetryPolicy ,
387- Backoff : DefaultBackoff ,
393+ HTTPClientFactory : & CleanPooledClientFactory {} ,
394+ Logger : defaultLogger ,
395+ RetryWaitMin : defaultRetryWaitMin ,
396+ RetryWaitMax : defaultRetryWaitMax ,
397+ RetryMax : defaultRetryMax ,
398+ CheckRetry : DefaultRetryPolicy ,
399+ Backoff : DefaultBackoff ,
388400 }
389401}
390402
@@ -545,12 +557,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
545557
546558// Do wraps calling an HTTP method with retries.
547559func (c * Client ) Do (req * Request ) (* http.Response , error ) {
548- c .clientInit .Do (func () {
549- if c .HTTPClient == nil {
550- c .HTTPClient = cleanhttp .DefaultPooledClient ()
551- }
552- })
553-
554560 logger := c .logger ()
555561
556562 if logger != nil {
@@ -562,6 +568,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
562568 }
563569 }
564570
571+ httpClient := c .getHTTPClient ()
572+ defer httpClient .Done ()
573+
565574 var resp * http.Response
566575 var attempt int
567576 var shouldRetry bool
@@ -574,7 +583,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
574583 if req .body != nil {
575584 body , err := req .body ()
576585 if err != nil {
577- c .HTTPClient .CloseIdleConnections ()
578586 return resp , err
579587 }
580588 if c , ok := body .(io.ReadCloser ); ok {
@@ -596,7 +604,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
596604 }
597605
598606 // Attempt the request
599- resp , doErr = c .HTTPClient .Do (req .Request )
607+
608+ resp , doErr = httpClient .Do (req .Request )
600609
601610 // Check if we should continue with retries.
602611 shouldRetry , checkErr = c .CheckRetry (req .Context (), resp , doErr )
@@ -657,7 +666,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
657666 select {
658667 case <- req .Context ().Done ():
659668 timer .Stop ()
660- c .HTTPClient .CloseIdleConnections ()
661669 return nil , req .Context ().Err ()
662670 case <- timer .C :
663671 }
@@ -673,8 +681,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
673681 return resp , nil
674682 }
675683
676- defer c .HTTPClient .CloseIdleConnections ()
677-
678684 err := doErr
679685 if checkErr != nil {
680686 err = checkErr
@@ -717,6 +723,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
717723 }
718724}
719725
726+ func (c * Client ) getHTTPClient () HTTPClient {
727+ if c .HTTPClient != nil {
728+ return & idleConnectionsClosingClient {
729+ httpClient : c .HTTPClient ,
730+ }
731+ }
732+ clientFactory := c .HTTPClientFactory
733+ if clientFactory == nil {
734+ clientFactory = & CleanPooledClientFactory {}
735+ }
736+ return clientFactory .New ()
737+ }
738+
720739// Get is a shortcut for doing a GET request without making a new client.
721740func Get (url string ) (* http.Response , error ) {
722741 return defaultClient .Get (url )
@@ -779,3 +798,29 @@ func (c *Client) StandardClient() *http.Client {
779798 Transport : & RoundTripper {Client : c },
780799 }
781800}
801+
802+ var (
803+ _ HTTPClientFactory = & CleanPooledClientFactory {}
804+ _ HTTPClient = & idleConnectionsClosingClient {}
805+ )
806+
807+ type CleanPooledClientFactory struct {
808+ }
809+
810+ func (f * CleanPooledClientFactory ) New () HTTPClient {
811+ return & idleConnectionsClosingClient {
812+ httpClient : cleanhttp .DefaultPooledClient (),
813+ }
814+ }
815+
816+ type idleConnectionsClosingClient struct {
817+ httpClient * http.Client
818+ }
819+
820+ func (c * idleConnectionsClosingClient ) Do (req * http.Request ) (* http.Response , error ) {
821+ return c .httpClient .Do (req )
822+ }
823+
824+ func (c * idleConnectionsClosingClient ) Done () {
825+ c .httpClient .CloseIdleConnections ()
826+ }
0 commit comments