diff --git a/common/resource/fx.go b/common/resource/fx.go index 252f062aa52..00e9baaac9f 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -109,6 +109,7 @@ var Module = fx.Options( var DefaultOptions = fx.Options( fx.Provide(RPCFactoryProvider), + fx.Provide(PerServiceDialOptionsProvider), fx.Provide(ArchivalMetadataProvider), fx.Provide(ArchiverProviderProvider), fx.Provide(ThrottledLoggerProvider), @@ -337,6 +338,10 @@ func DCRedirectionPolicyProvider(cfg *config.Config) config.DCRedirectionPolicy return cfg.DCRedirectionPolicy } +func PerServiceDialOptionsProvider() map[primitives.ServiceName][]grpc.DialOption { + return map[primitives.ServiceName][]grpc.DialOption{} +} + func RPCFactoryProvider( cfg *config.Config, svcName primitives.ServiceName, @@ -345,6 +350,7 @@ func RPCFactoryProvider( tlsConfigProvider encryption.TLSConfigProvider, resolver *membership.GRPCResolver, tracingStatsHandler telemetry.ClientStatsHandler, + perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption, monitor membership.Monitor, dc *dynamicconfig.Collection, ) (common.RPCFactory, error) { @@ -370,6 +376,7 @@ func RPCFactoryProvider( frontendHTTPPort, frontendTLSConfig, options, + perServiceDialOptions, monitor, ) factory.EnableInternodeServerKeepalive = enableServerKeepalive diff --git a/common/rpc/rpc.go b/common/rpc/rpc.go index 5b119ee49e5..663f5464781 100644 --- a/common/rpc/rpc.go +++ b/common/rpc/rpc.go @@ -42,10 +42,11 @@ type RPCFactory struct { frontendHTTPPort int frontendTLSConfig *tls.Config - grpcListener func() net.Listener - tlsFactory encryption.TLSConfigProvider - dialOptions []grpc.DialOption - monitor membership.Monitor + grpcListener func() net.Listener + tlsFactory encryption.TLSConfigProvider + commonDialOptions []grpc.DialOption + perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption + monitor membership.Monitor // A OnceValues wrapper for createLocalFrontendHTTPClient. localFrontendClient func() (*common.FrontendHTTPClient, error) interNodeGrpcConnections cache.Cache @@ -67,21 +68,23 @@ func NewFactory( frontendHTTPURL string, frontendHTTPPort int, frontendTLSConfig *tls.Config, - dialOptions []grpc.DialOption, + commonDialOptions []grpc.DialOption, + perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption, monitor membership.Monitor, ) *RPCFactory { f := &RPCFactory{ - config: cfg, - serviceName: sName, - logger: logger, - metricsHandler: metricsHandler, - frontendURL: frontendURL, - frontendHTTPURL: frontendHTTPURL, - frontendHTTPPort: frontendHTTPPort, - frontendTLSConfig: frontendTLSConfig, - tlsFactory: tlsProvider, - dialOptions: dialOptions, - monitor: monitor, + config: cfg, + serviceName: sName, + logger: logger, + metricsHandler: metricsHandler, + frontendURL: frontendURL, + frontendHTTPURL: frontendHTTPURL, + frontendHTTPPort: frontendHTTPPort, + frontendTLSConfig: frontendTLSConfig, + tlsFactory: tlsProvider, + commonDialOptions: commonDialOptions, + perServiceDialOptions: perServiceDialOptions, + monitor: monitor, } f.grpcListener = sync.OnceValue(f.createGRPCListener) f.localFrontendClient = sync.OnceValues(f.createLocalFrontendHTTPClient) @@ -214,13 +217,16 @@ func (d *RPCFactory) CreateRemoteFrontendGRPCConnection(rpcAddress string) *grpc } } keepAliveOption := d.getClientKeepAliveConfig(primitives.FrontendService) + additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[primitives.FrontendService]...) - return d.dial(rpcAddress, tlsClientConfig, keepAliveOption) + return d.dial(rpcAddress, tlsClientConfig, append(additionalDialOptions, keepAliveOption)...) } // CreateLocalFrontendGRPCConnection creates connection for internal frontend calls func (d *RPCFactory) CreateLocalFrontendGRPCConnection() *grpc.ClientConn { - return d.dial(d.frontendURL, d.frontendTLSConfig) + additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[primitives.InternalFrontendService]...) + + return d.dial(d.frontendURL, d.frontendTLSConfig, additionalDialOptions...) } // createInternodeGRPCConnection creates connection for gRPC calls @@ -237,7 +243,8 @@ func (d *RPCFactory) createInternodeGRPCConnection(hostName string, serviceName return nil } } - c := d.dial(hostName, tlsClientConfig, d.getClientKeepAliveConfig(serviceName)) + additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[serviceName]...) + c := d.dial(hostName, tlsClientConfig, append(additionalDialOptions, d.getClientKeepAliveConfig(serviceName))...) d.interNodeGrpcConnections.Put(hostName, c) return c } @@ -251,7 +258,7 @@ func (d *RPCFactory) CreateMatchingGRPCConnection(rpcAddress string) *grpc.Clien } func (d *RPCFactory) dial(hostName string, tlsClientConfig *tls.Config, dialOptions ...grpc.DialOption) *grpc.ClientConn { - dialOptions = append(d.dialOptions, dialOptions...) + dialOptions = append(d.commonDialOptions, dialOptions...) connection, err := Dial(hostName, tlsClientConfig, d.logger, d.metricsHandler, dialOptions...) if err != nil { d.logger.Fatal("Failed to create gRPC connection", tag.Error(err)) diff --git a/common/rpc/test/http_test.go b/common/rpc/test/http_test.go index f362b0bede7..aa159b0b638 100644 --- a/common/rpc/test/http_test.go +++ b/common/rpc/test/http_test.go @@ -12,6 +12,7 @@ import ( "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/rpc" "go.uber.org/mock/gomock" + "google.golang.org/grpc" ) func TestCreateLocalFrontendHTTPClient_UsingMembership(t *testing.T) { @@ -41,6 +42,7 @@ func TestCreateLocalFrontendHTTPClient_UsingMembership(t *testing.T) { int(port), nil, // No TLS nil, + map[primitives.ServiceName][]grpc.DialOption{}, monitor, ) @@ -72,6 +74,7 @@ func TestCreateLocalFrontendHTTPClient_UsingFixedHostPort(t *testing.T) { 0, // Port is unused nil, // No TLS nil, + map[primitives.ServiceName][]grpc.DialOption{}, nil, // monitor should not be used ) @@ -104,6 +107,7 @@ func TestCreateLocalFrontendHTTPClient_UsingFixedHostPort_AndTLS(t *testing.T) { 0, // Port is unused tlsConfig, nil, + map[primitives.ServiceName][]grpc.DialOption{}, nil, // monitor should not be used ) diff --git a/common/rpc/test/rpc_localstore_tls_test.go b/common/rpc/test/rpc_localstore_tls_test.go index cc31873c346..a05a8224e4e 100644 --- a/common/rpc/test/rpc_localstore_tls_test.go +++ b/common/rpc/test/rpc_localstore_tls_test.go @@ -110,7 +110,7 @@ func (s *localStoreRPCSuite) SetupSuite() { provider, err := encryption.NewTLSConfigProviderFromConfig(serverCfgInsecure.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) - insecureFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil) + insecureFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil, nil) s.NotNil(insecureFactory) s.insecureRPCFactory = i(insecureFactory) @@ -320,26 +320,26 @@ func (s *localStoreRPCSuite) setupFrontend() { s.NoError(err) tlsConfig, err := provider.GetFrontendClientConfig() s.NoError(err) - frontendMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + frontendMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(frontendMutualTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreServerTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) - frontendServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil) + frontendServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil, nil) s.NotNil(frontendServerTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSSystemWorker.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - frontendSystemWorkerMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + frontendSystemWorkerMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(frontendSystemWorkerMutualTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSWithRefresh.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - frontendMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + frontendMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(frontendMutualTLSRefreshFactory) s.frontendMutualTLSRPCFactory = f(frontendMutualTLSFactory) @@ -356,7 +356,7 @@ func (s *localStoreRPCSuite) setupFrontend() { s.NoError(err) tlsConfig, err = s.dynamicConfigProvider.GetFrontendClientConfig() s.NoError(err) - dynamicServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, s.dynamicConfigProvider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + dynamicServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, s.dynamicConfigProvider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.frontendDynamicTLSFactory = f(dynamicServerTLSFactory) s.internodeDynamicTLSFactory = i(dynamicServerTLSFactory) @@ -366,7 +366,7 @@ func (s *localStoreRPCSuite) setupFrontend() { s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - frontendRootCAForceTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + frontendRootCAForceTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(frontendServerTLSFactory) s.frontendConfigRootCAForceTLSFactory = f(frontendRootCAForceTLSFactory) @@ -374,7 +374,7 @@ func (s *localStoreRPCSuite) setupFrontend() { s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - remoteClusterMutualTLSRPCFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + remoteClusterMutualTLSRPCFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(remoteClusterMutualTLSRPCFactory) s.remoteClusterMutualTLSRPCFactory = r(remoteClusterMutualTLSRPCFactory) } @@ -412,28 +412,28 @@ func (s *localStoreRPCSuite) setupInternode() { s.NoError(err) tlsConfig, err := provider.GetFrontendClientConfig() s.NoError(err) - internodeMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + internodeMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(internodeMutualTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreServerTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - internodeServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + internodeServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(internodeServerTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreAltMutualTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - internodeMutualAltTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + internodeMutualAltTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(internodeMutualAltTLSFactory) provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSWithRefresh.TLS, metrics.NoopMetricsHandler, s.logger, nil) s.NoError(err) tlsConfig, err = provider.GetFrontendClientConfig() s.NoError(err) - internodeMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil) + internodeMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil) s.NotNil(internodeMutualTLSRefreshFactory) s.internodeMutualTLSRPCFactory = i(internodeMutualTLSFactory) diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 7279aa26975..c63e7dd1c64 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -767,6 +767,7 @@ func (c *TemporalImpl) newRPCFactory( int(httpPort), frontendTLSConfig, options, + map[primitives.ServiceName][]grpc.DialOption{}, monitor, ), nil }