Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/resource/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ var Module = fx.Options(

var DefaultOptions = fx.Options(
fx.Provide(RPCFactoryProvider),
fx.Provide(PerServiceDialOptionsProvider),
fx.Provide(ArchivalMetadataProvider),
fx.Provide(ArchiverProviderProvider),
fx.Provide(ThrottledLoggerProvider),
Expand Down Expand Up @@ -337,6 +338,10 @@ func DCRedirectionPolicyProvider(cfg *config.Config) config.DCRedirectionPolicy
return cfg.DCRedirectionPolicy
}

func PerServiceDialOptionsProvider() map[primitives.ServiceName][]grpc.DialOption {
return map[primitives.ServiceName][]grpc.DialOption{}
}

func RPCFactoryProvider(
cfg *config.Config,
svcName primitives.ServiceName,
Expand All @@ -345,6 +350,7 @@ func RPCFactoryProvider(
tlsConfigProvider encryption.TLSConfigProvider,
resolver *membership.GRPCResolver,
tracingStatsHandler telemetry.ClientStatsHandler,
perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption,
monitor membership.Monitor,
dc *dynamicconfig.Collection,
) (common.RPCFactory, error) {
Expand All @@ -370,6 +376,7 @@ func RPCFactoryProvider(
frontendHTTPPort,
frontendTLSConfig,
options,
perServiceDialOptions,
monitor,
)
factory.EnableInternodeServerKeepalive = enableServerKeepalive
Expand Down
47 changes: 27 additions & 20 deletions common/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ type RPCFactory struct {
frontendHTTPPort int
frontendTLSConfig *tls.Config

grpcListener func() net.Listener
tlsFactory encryption.TLSConfigProvider
dialOptions []grpc.DialOption
monitor membership.Monitor
grpcListener func() net.Listener
tlsFactory encryption.TLSConfigProvider
commonDialOptions []grpc.DialOption
perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption
monitor membership.Monitor
// A OnceValues wrapper for createLocalFrontendHTTPClient.
localFrontendClient func() (*common.FrontendHTTPClient, error)
interNodeGrpcConnections cache.Cache
Expand All @@ -67,21 +68,23 @@ func NewFactory(
frontendHTTPURL string,
frontendHTTPPort int,
frontendTLSConfig *tls.Config,
dialOptions []grpc.DialOption,
commonDialOptions []grpc.DialOption,
perServiceDialOptions map[primitives.ServiceName][]grpc.DialOption,
monitor membership.Monitor,
) *RPCFactory {
f := &RPCFactory{
config: cfg,
serviceName: sName,
logger: logger,
metricsHandler: metricsHandler,
frontendURL: frontendURL,
frontendHTTPURL: frontendHTTPURL,
frontendHTTPPort: frontendHTTPPort,
frontendTLSConfig: frontendTLSConfig,
tlsFactory: tlsProvider,
dialOptions: dialOptions,
monitor: monitor,
config: cfg,
serviceName: sName,
logger: logger,
metricsHandler: metricsHandler,
frontendURL: frontendURL,
frontendHTTPURL: frontendHTTPURL,
frontendHTTPPort: frontendHTTPPort,
frontendTLSConfig: frontendTLSConfig,
tlsFactory: tlsProvider,
commonDialOptions: commonDialOptions,
perServiceDialOptions: perServiceDialOptions,
monitor: monitor,
}
f.grpcListener = sync.OnceValue(f.createGRPCListener)
f.localFrontendClient = sync.OnceValues(f.createLocalFrontendHTTPClient)
Expand Down Expand Up @@ -214,13 +217,16 @@ func (d *RPCFactory) CreateRemoteFrontendGRPCConnection(rpcAddress string) *grpc
}
}
keepAliveOption := d.getClientKeepAliveConfig(primitives.FrontendService)
additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[primitives.FrontendService]...)

return d.dial(rpcAddress, tlsClientConfig, keepAliveOption)
return d.dial(rpcAddress, tlsClientConfig, append(additionalDialOptions, keepAliveOption)...)
}

// CreateLocalFrontendGRPCConnection creates connection for internal frontend calls
func (d *RPCFactory) CreateLocalFrontendGRPCConnection() *grpc.ClientConn {
return d.dial(d.frontendURL, d.frontendTLSConfig)
additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[primitives.InternalFrontendService]...)

return d.dial(d.frontendURL, d.frontendTLSConfig, additionalDialOptions...)
}

// createInternodeGRPCConnection creates connection for gRPC calls
Expand All @@ -237,7 +243,8 @@ func (d *RPCFactory) createInternodeGRPCConnection(hostName string, serviceName
return nil
}
}
c := d.dial(hostName, tlsClientConfig, d.getClientKeepAliveConfig(serviceName))
additionalDialOptions := append([]grpc.DialOption{}, d.perServiceDialOptions[serviceName]...)
c := d.dial(hostName, tlsClientConfig, append(additionalDialOptions, d.getClientKeepAliveConfig(serviceName))...)
d.interNodeGrpcConnections.Put(hostName, c)
return c
}
Expand All @@ -251,7 +258,7 @@ func (d *RPCFactory) CreateMatchingGRPCConnection(rpcAddress string) *grpc.Clien
}

func (d *RPCFactory) dial(hostName string, tlsClientConfig *tls.Config, dialOptions ...grpc.DialOption) *grpc.ClientConn {
dialOptions = append(d.dialOptions, dialOptions...)
dialOptions = append(d.commonDialOptions, dialOptions...)
connection, err := Dial(hostName, tlsClientConfig, d.logger, d.metricsHandler, dialOptions...)
if err != nil {
d.logger.Fatal("Failed to create gRPC connection", tag.Error(err))
Expand Down
4 changes: 4 additions & 0 deletions common/rpc/test/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/rpc"
"go.uber.org/mock/gomock"
"google.golang.org/grpc"
)

func TestCreateLocalFrontendHTTPClient_UsingMembership(t *testing.T) {
Expand Down Expand Up @@ -41,6 +42,7 @@ func TestCreateLocalFrontendHTTPClient_UsingMembership(t *testing.T) {
int(port),
nil, // No TLS
nil,
map[primitives.ServiceName][]grpc.DialOption{},
monitor,
)

Expand Down Expand Up @@ -72,6 +74,7 @@ func TestCreateLocalFrontendHTTPClient_UsingFixedHostPort(t *testing.T) {
0, // Port is unused
nil, // No TLS
nil,
map[primitives.ServiceName][]grpc.DialOption{},
nil, // monitor should not be used
)

Expand Down Expand Up @@ -104,6 +107,7 @@ func TestCreateLocalFrontendHTTPClient_UsingFixedHostPort_AndTLS(t *testing.T) {
0, // Port is unused
tlsConfig,
nil,
map[primitives.ServiceName][]grpc.DialOption{},
nil, // monitor should not be used
)

Expand Down
24 changes: 12 additions & 12 deletions common/rpc/test/rpc_localstore_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (s *localStoreRPCSuite) SetupSuite() {

provider, err := encryption.NewTLSConfigProviderFromConfig(serverCfgInsecure.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
insecureFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil)
insecureFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil, nil)
s.NotNil(insecureFactory)
s.insecureRPCFactory = i(insecureFactory)

Expand Down Expand Up @@ -320,26 +320,26 @@ func (s *localStoreRPCSuite) setupFrontend() {
s.NoError(err)
tlsConfig, err := provider.GetFrontendClientConfig()
s.NoError(err)
frontendMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
frontendMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(frontendMutualTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreServerTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
frontendServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil)
frontendServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, nil, nil, nil, nil)
s.NotNil(frontendServerTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSSystemWorker.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
frontendSystemWorkerMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
frontendSystemWorkerMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(frontendSystemWorkerMutualTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSWithRefresh.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
frontendMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
frontendMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(frontendMutualTLSRefreshFactory)

s.frontendMutualTLSRPCFactory = f(frontendMutualTLSFactory)
Expand All @@ -356,7 +356,7 @@ func (s *localStoreRPCSuite) setupFrontend() {
s.NoError(err)
tlsConfig, err = s.dynamicConfigProvider.GetFrontendClientConfig()
s.NoError(err)
dynamicServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, s.dynamicConfigProvider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
dynamicServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, s.dynamicConfigProvider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.frontendDynamicTLSFactory = f(dynamicServerTLSFactory)
s.internodeDynamicTLSFactory = i(dynamicServerTLSFactory)

Expand All @@ -366,15 +366,15 @@ func (s *localStoreRPCSuite) setupFrontend() {
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
frontendRootCAForceTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
frontendRootCAForceTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(frontendServerTLSFactory)
s.frontendConfigRootCAForceTLSFactory = f(frontendRootCAForceTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSRemoteCluster.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
remoteClusterMutualTLSRPCFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
remoteClusterMutualTLSRPCFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(remoteClusterMutualTLSRPCFactory)
s.remoteClusterMutualTLSRPCFactory = r(remoteClusterMutualTLSRPCFactory)
}
Expand Down Expand Up @@ -412,28 +412,28 @@ func (s *localStoreRPCSuite) setupInternode() {
s.NoError(err)
tlsConfig, err := provider.GetFrontendClientConfig()
s.NoError(err)
internodeMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
internodeMutualTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(internodeMutualTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreServerTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
internodeServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
internodeServerTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(internodeServerTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreAltMutualTLS.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
internodeMutualAltTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
internodeMutualAltTLSFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(internodeMutualAltTLSFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLSWithRefresh.TLS, metrics.NoopMetricsHandler, s.logger, nil)
s.NoError(err)
tlsConfig, err = provider.GetFrontendClientConfig()
s.NoError(err)
internodeMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil)
internodeMutualTLSRefreshFactory := rpc.NewFactory(cfg, "tester", s.logger, nil, provider, frontendURL, frontendHTTPURL, 0, tlsConfig, nil, nil, nil)
s.NotNil(internodeMutualTLSRefreshFactory)

s.internodeMutualTLSRPCFactory = i(internodeMutualTLSFactory)
Expand Down
1 change: 1 addition & 0 deletions tests/testcore/onebox.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ func (c *TemporalImpl) newRPCFactory(
int(httpPort),
frontendTLSConfig,
options,
map[primitives.ServiceName][]grpc.DialOption{},
monitor,
), nil
}
Expand Down
Loading