diff --git a/temporal/service_helpers.py b/temporal/service_helpers.py index b72bb45..976a6bb 100644 --- a/temporal/service_helpers.py +++ b/temporal/service_helpers.py @@ -1,13 +1,42 @@ import os import socket +import ssl +from typing import List +from dataclasses import dataclass, field from grpclib.client import Channel from temporal.api.workflowservice.v1 import WorkflowServiceStub -def create_workflow_service(host: str, port: int, timeout: float) -> WorkflowServiceStub: - channel = Channel(host=host, port=port) +@dataclass +class TLSOptions: + alpn_protocols: List = field(default_factory=lambda: ['h2']) + ca_cert: str = None + client_cert: str = None + client_key: str = None + ciphers: str = "ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20" + npn_protocols: List = field(default_factory=lambda: ['h2']) + + +def create_secure_context(tls_options: TLSOptions) -> ssl.SSLContext: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_cert_chain(str(tls_options.client_cert), str(tls_options.client_key)) + ctx.load_verify_locations(str(tls_options.ca_cert)) + ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx.set_ciphers(tls_options.ciphers) + ctx.set_alpn_protocols(tls_options.alpn_protocols) + try: + ctx.set_npn_protocols(tls_options.npn_protocols) + except NotImplementedError: + pass + return ctx + + +def create_workflow_service(host: str, port: int, timeout: float, tls_options: TLSOptions) -> WorkflowServiceStub: + ssl_context = create_secure_context(tls_options) if tls_options is not None else None + channel = Channel(host=host, port=port, ssl=ssl_context) return WorkflowServiceStub(channel, timeout=timeout) diff --git a/temporal/workflow.py b/temporal/workflow.py index 7cd896a..9603e87 100644 --- a/temporal/workflow.py +++ b/temporal/workflow.py @@ -25,7 +25,7 @@ from .exception_handling import deserialize_exception from .exceptions import WorkflowFailureException, ActivityFailureException, QueryRejectedException, \ QueryFailureException, WorkflowOperationException -from .service_helpers import create_workflow_service, get_identity +from .service_helpers import create_workflow_service, get_identity, TLSOptions T = TypeVar('T') @@ -157,8 +157,9 @@ class WorkflowClient: @classmethod def new_client(cls, host: str = "localhost", port: int = 7233, namespace: str = "", options: WorkflowClientOptions = None, timeout: int = DEFAULT_SOCKET_TIMEOUT_SECONDS, - data_converter: DataConverter = DEFAULT_DATA_CONVERTER_INSTANCE) -> WorkflowClient: - service = create_workflow_service(host, port, timeout=timeout) + data_converter: DataConverter = DEFAULT_DATA_CONVERTER_INSTANCE, + tls_options: TLSOptions = None) -> WorkflowClient: + service = create_workflow_service(host, port, timeout=timeout, tls_options=tls_options) return cls(service=service, namespace=namespace, options=options, data_converter=data_converter) @classmethod