diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index a6630a0e44..7ed6f23d82 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -86,7 +86,8 @@ func Dial(ctx context.Context, opts ...option.ClientOption) (*grpc.ClientConn, e return o.GRPCConnPool.Conn(), nil } if o.IsNewAuthLibraryEnabled() { - pool, err := dialPoolNewAuth(ctx, true, 1, o) + poolSize := 1 + pool, err := dialPoolNewAuth(ctx, true, &poolSize, opts) if err != nil { return nil, err } @@ -108,7 +109,8 @@ func DialInsecure(ctx context.Context, opts ...option.ClientOption) (*grpc.Clien return nil, err } if o.IsNewAuthLibraryEnabled() { - pool, err := dialPoolNewAuth(ctx, false, 1, o) + poolSize := 1 + pool, err := dialPoolNewAuth(ctx, false, &poolSize, opts) if err != nil { return nil, err } @@ -137,7 +139,7 @@ func DialPool(ctx context.Context, opts ...option.ClientOption) (ConnPool, error if o.GRPCConn != nil { return &singleConnPool{o.GRPCConn}, nil } - pool, err := dialPoolNewAuth(ctx, true, o.GRPCConnPoolSize, o) + pool, err := dialPoolNewAuth(ctx, true, nil, opts) if err != nil { return nil, err } @@ -174,7 +176,23 @@ func DialPool(ctx context.Context, opts ...option.ClientOption) (ConnPool, error } // dialPoolNewAuth is an adapter to call new auth library. -func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *internal.DialSettings) (grpctransport.GRPCClientConnPool, error) { +func dialPoolNewAuth(ctx context.Context, secure bool, poolSize *int, opts []option.ClientOption) (grpctransport.GRPCClientConnPool, error) { + authGRPCOpts, err := AuthGRPCOptions(opts, poolSize) + if err != nil { + return nil, err + } + return dialContextNewAuth(ctx, secure, authGRPCOpts) +} + +// AuthGRPCOptions is an adapter converting []option.ClientOption to +// cloud.google.com/go/auth/grpctransport.Options. If a non-nil +// poolSizeOverride pointer is provided, it will be used instead of the +// option.WithGRPCConnectionPool value in opts. +func AuthGRPCOptions(opts []option.ClientOption, poolSizeOverride *int) (*grpctransport.Options, error) { + ds, err := processAndValidateOpts(opts) + if err != nil { + return nil, err + } // honor options if set var creds *auth.Credentials if ds.InternalCredentials != nil { @@ -220,13 +238,16 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna defaultEndpointTemplate = ds.DefaultEndpoint } - pool, err := dialContextNewAuth(ctx, secure, &grpctransport.Options{ + if poolSizeOverride == nil { + poolSizeOverride = &ds.GRPCConnPoolSize + } + return &grpctransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, Endpoint: ds.Endpoint, Metadata: metadata, GRPCDialOpts: prepareDialOptsNewAuth(ds), - PoolSize: poolSize, + PoolSize: *poolSizeOverride, Credentials: creds, ClientCertProvider: ds.ClientCertSource, APIKey: ds.APIKey, @@ -251,8 +272,7 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna }, UniverseDomain: ds.UniverseDomain, Logger: ds.Logger, - }) - return pool, err + }, nil } func prepareDialOptsNewAuth(ds *internal.DialSettings) []grpc.DialOption { diff --git a/transport/grpc/dial_test.go b/transport/grpc/dial_test.go index 63b0a1901f..e70dbd149e 100644 --- a/transport/grpc/dial_test.go +++ b/transport/grpc/dial_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/oauth2/google" "google.golang.org/api/internal" + "google.golang.org/api/option" "google.golang.org/grpc" ) @@ -38,15 +39,15 @@ func TestDial(t *testing.T) { func TestDialPoolNewAuthDialOptions(t *testing.T) { oldDialContextNewAuth := dialContextNewAuth - var wantNumOpts int - var universeDomain string + var wantNumDialOpts int + var wantUniverseDomain string // Replace package var in order to assert DialContext args. dialContextNewAuth = func(ctx context.Context, secure bool, opts *grpctransport.Options) (grpctransport.GRPCClientConnPool, error) { - if len(opts.GRPCDialOpts) != wantNumOpts { - t.Fatalf("got: %d, want: %d", len(opts.GRPCDialOpts), wantNumOpts) + if len(opts.GRPCDialOpts) != wantNumDialOpts { + t.Fatalf("got: %d, want: %d", len(opts.GRPCDialOpts), wantNumDialOpts) } - if opts.UniverseDomain != universeDomain { - t.Fatalf("got: %q, want: %q", opts.UniverseDomain, universeDomain) + if opts.UniverseDomain != wantUniverseDomain { + t.Fatalf("got: %q, want: %q", opts.UniverseDomain, wantUniverseDomain) } return nil, nil } @@ -55,34 +56,31 @@ func TestDialPoolNewAuthDialOptions(t *testing.T) { }() for _, testcase := range []struct { - name string - ds *internal.DialSettings - wantNumOpts int + name string + ds []option.ClientOption + wantUniverseDomain string + wantNumDialOpts int }{ { - name: "no dial options", - ds: &internal.DialSettings{}, - wantNumOpts: 0, + name: "no dial options", + ds: make([]option.ClientOption, 0), }, { - name: "with user agent", - ds: &internal.DialSettings{ - UserAgent: "test", - }, - wantNumOpts: 1, + name: "with user agent", + ds: []option.ClientOption{option.WithUserAgent("test")}, + wantNumDialOpts: 1, }, { - name: "universe domain", - ds: &internal.DialSettings{ - UniverseDomain: "example.com", - }, - wantNumOpts: 0, + name: "universe domain", + ds: []option.ClientOption{option.WithUniverseDomain("example.com")}, + wantUniverseDomain: "example.com", }, } { t.Run(testcase.name, func(t *testing.T) { - wantNumOpts = testcase.wantNumOpts - universeDomain = testcase.ds.UniverseDomain - dialPoolNewAuth(context.Background(), false, 1, testcase.ds) + wantNumDialOpts = testcase.wantNumDialOpts + wantUniverseDomain = testcase.wantUniverseDomain + poolSize := 1 + dialPoolNewAuth(context.Background(), false, &poolSize, testcase.ds) }) } }