diff --git a/cmd/cmdoptions/client.go b/cmd/cmdoptions/client.go index 5446edb1..04b5e746 100644 --- a/cmd/cmdoptions/client.go +++ b/cmd/cmdoptions/client.go @@ -14,6 +14,8 @@ import ( "go.temporal.io/sdk/client" "go.temporal.io/sdk/converter" "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) const AUTH_HEADER_ENV_VAR = "TEMPORAL_OMES_AUTH_HEADER" @@ -46,6 +48,8 @@ type ClientOptions struct { AuthHeader string // Disable Host Verification DisableHostVerification bool + // API Key + APIKey string } // loadTLSConfig inits a TLS config from the provided cert and key files. @@ -95,6 +99,23 @@ func (c *ClientOptions) Dial(metrics *Metrics, logger *zap.SugaredLogger) (clien clientOptions.ConnectionOptions.TLS = tlsCfg clientOptions.Logger = NewZapAdapter(logger.Desugar()) clientOptions.MetricsHandler = metrics.NewHandler() + if c.APIKey != "" { + clientOptions.Credentials = client.NewAPIKeyStaticCredentials(c.APIKey) + clientOptions.ConnectionOptions.DialOptions = []grpc.DialOption{ + grpc.WithUnaryInterceptor( + func(ctx context.Context, method string, req any, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker( + metadata.AppendToOutgoingContext(ctx, "temporal-namespace", c.Namespace), + method, + req, + reply, + cc, + opts..., + ) + }, + ), + } + } var authHeader string if c.AuthHeader == "" { @@ -135,6 +156,7 @@ func (c *ClientOptions) AddCLIFlags(fs *pflag.FlagSet) { fs.StringVar(&c.ClientKeyPath, "tls-key-path", "", "Path to client private key") fs.BoolVar(&c.DisableHostVerification, "disable-tls-host-verification", false, "Disable TLS host verification") fs.StringVar(&c.TLSServerName, "tls-server-name", "", "TLS target server name") + fs.StringVar(&c.APIKey, "api-key", "", "API key for authentication") fs.StringVar(&c.AuthHeader, "auth-header", "", fmt.Sprintf("Authorization header value (can also be set via %s env var)", AUTH_HEADER_ENV_VAR)) } @@ -162,6 +184,9 @@ func (c *ClientOptions) ToFlags() (flags []string) { if c.TLSServerName != "" { flags = append(flags, "--tls-server-name", c.TLSServerName) } + if c.APIKey != "" { + flags = append(flags, "--api-key", c.APIKey) + } if c.AuthHeader != "" { flags = append(flags, "--auth-header", c.AuthHeader) } diff --git a/workers/dotnet/App.cs b/workers/dotnet/App.cs index b46fc6db..5dab9217 100644 --- a/workers/dotnet/App.cs +++ b/workers/dotnet/App.cs @@ -73,6 +73,10 @@ public static class App name: "--tls-key-path", description: "Path to a client key for TLS"); + private static readonly Option apiKeyOption = new( + name: "--api-key", + description: "API key for authentication"); + private static readonly Option promAddrOption = new Option( name: "--prom-listen-address", description: "Prometheus listen address"); @@ -106,6 +110,7 @@ private static Command CreateCommand() cmd.Add(useTLSOption); cmd.Add(clientCertPathOption); cmd.Add(clientKeyPathOption); + cmd.add(apiKeyOption); cmd.Add(promAddrOption); cmd.Add(promHandlerPathOption); cmd.SetHandler(RunCommandAsync); @@ -156,7 +161,12 @@ private static async Task RunCommandAsync(InvocationContext ctx) Runtime = runtime, Namespace = ctx.ParseResult.GetValueForOption(namespaceOption)!, Tls = tls, - LoggerFactory = loggerFactory + LoggerFactory = loggerFactory, + ApiKey = ctx.ParseResult.GetValueForOption(apiKeyOption), + RpcMetadata = new Dictionary() + { + ["temporal-namespace"] = ctx.ParseResult.GetValueForOption(namespaceOption)!, + } }); // Collect task queues to run workers for diff --git a/workers/java/io/temporal/omes/Main.java b/workers/java/io/temporal/omes/Main.java index f0e26d3c..35e0e258 100644 --- a/workers/java/io/temporal/omes/Main.java +++ b/workers/java/io/temporal/omes/Main.java @@ -8,7 +8,9 @@ import com.uber.m3.tally.RootScopeBuilder; import com.uber.m3.tally.Scope; import com.uber.m3.tally.StatsReporter; +import io.grpc.Metadata; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.stub.MetadataUtils; import io.micrometer.core.instrument.util.StringUtils; import io.micrometer.prometheus.PrometheusConfig; import io.micrometer.prometheus.PrometheusMeterRegistry; @@ -23,6 +25,10 @@ import io.temporal.worker.WorkerFactory; import io.temporal.worker.WorkerFactoryOptions; import io.temporal.worker.WorkerOptions; +import net.logstash.logback.encoder.LogstashEncoder; +import picocli.CommandLine; + +import javax.net.ssl.SSLException; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.InputStream; @@ -30,9 +36,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; -import javax.net.ssl.SSLException; -import net.logstash.logback.encoder.LogstashEncoder; -import picocli.CommandLine; @CommandLine.Command(name = "features", description = "Runs Java features") public class Main implements Runnable { @@ -86,6 +89,9 @@ public class Main implements Runnable { @CommandLine.Option(names = "--tls-key-path", description = "Path to a client key for TLS") private String clientKeyPath; + @CommandLine.Option(names = "--api-key", description = "API key for authentication") + private String apiKey; + // Metric parameters @CommandLine.Option( names = "--prom-listen-address", @@ -122,29 +128,36 @@ public class Main implements Runnable { @Override public void run() { + WorkflowServiceStubsOptions.Builder workflowServiceStubOptionsBuilder = WorkflowServiceStubsOptions.newBuilder(); // Configure TLS - SslContext sslContext = null; - if (StringUtils.isNotEmpty(clientCertPath)) { - if (StringUtils.isEmpty(clientKeyPath)) { + if (StringUtils.isNotEmpty(clientCertPath) || StringUtils.isNotEmpty(clientKeyPath)) { + if (StringUtils.isEmpty(clientKeyPath) || StringUtils.isEmpty(clientCertPath)) { throw new RuntimeException("Client key path must be specified since cert path is"); } try { InputStream clientCert = new FileInputStream(clientCertPath); InputStream clientKey = new FileInputStream(clientKeyPath); - sslContext = SimpleSslContextBuilder.forPKCS8(clientCert, clientKey).build(); + SslContext sslContext = SimpleSslContextBuilder.forPKCS8(clientCert, clientKey).build(); + workflowServiceStubOptionsBuilder.setSslContext(sslContext); } catch (FileNotFoundException | SSLException e) { throw new RuntimeException("Error loading certs", e); } - - } else if (StringUtils.isNotEmpty(clientKeyPath) && StringUtils.isEmpty(clientCertPath)) { - throw new RuntimeException("Client cert path must be specified since key path is"); } else if (isTlsEnabled) { - try { - sslContext = SimpleSslContextBuilder.noKeyOrCertChain().build(); - } catch (SSLException e) { - throw new RuntimeException(e); - } + workflowServiceStubOptionsBuilder.setEnableHttps(true); + } + // Configure API key + if (StringUtils.isNotEmpty(apiKey)) { + workflowServiceStubOptionsBuilder.addApiKey(() -> apiKey); + Metadata.Key TEMPORAL_NAMESPACE_HEADER_KEY = + Metadata.Key.of("temporal-namespace", Metadata.ASCII_STRING_MARSHALLER); + Metadata metadata = new Metadata(); + metadata.put(TEMPORAL_NAMESPACE_HEADER_KEY, namespace); + workflowServiceStubOptionsBuilder.setChannelInitializer( + (channel) -> { + channel.intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + }); + } // Configure logging @@ -174,9 +187,8 @@ public void run() { // Configure client WorkflowServiceStubs service = WorkflowServiceStubs.newServiceStubs( - WorkflowServiceStubsOptions.newBuilder() + workflowServiceStubOptionsBuilder .setTarget(serverAddress) - .setSslContext(sslContext) .setMetricsScope(scope) .build()); diff --git a/workers/python/main.py b/workers/python/main.py index 5ee9c58e..ea53bb3c 100644 --- a/workers/python/main.py +++ b/workers/python/main.py @@ -87,6 +87,7 @@ async def run(): "--tls-cert-path", default="", help="Path to client TLS certificate" ) parser.add_argument("--tls-key-path", default="", help="Path to client private key") + parser.add_argument("--api-key", help="API key for authentication") # Prometheus metric arguments parser.add_argument("--prom-listen-address", help="Prometheus listen address") parser.add_argument( @@ -144,6 +145,7 @@ async def run(): namespace=args.namespace, tls=tls_config, runtime=new_runtime, + api_key=args.api_key, ) # Collect task queues to run workers for (if there is a suffix end, we run diff --git a/workers/typescript/src/omes.ts b/workers/typescript/src/omes.ts index 680066ba..fe9a254c 100644 --- a/workers/typescript/src/omes.ts +++ b/workers/typescript/src/omes.ts @@ -39,6 +39,7 @@ async function run() { .option('--tls', 'Enable TLS') .option('--tls-cert-path ', 'Path to a client certificate for TLS') .option('--tls-key-path ', 'Path to a client key for TLS') + .option('--api-key ', 'API key for authentication') .option('--prom-listen-address ', 'Prometheus listen address') .option('--prom-handler-path ', 'Prometheus handler path', '/metrics'); @@ -122,6 +123,10 @@ async function run() { const connection = await NativeConnection.connect({ address: opts.serverAddress, tls: tlsConfig, + apiKey: opts.apiKey, + metadata: { + 'temporal-namespace': opts.namespace, + }, }); // Possibly create multiple workers if we are being asked to use multiple task queues