diff --git a/pytest/unit/test_private_key.py b/pytest/unit/test_private_key.py index 21bb09d..f38fa1f 100644 --- a/pytest/unit/test_private_key.py +++ b/pytest/unit/test_private_key.py @@ -8,15 +8,23 @@ @pytest.mark.parametrize('generate_params,expected_type,key_size', [ - ({}, rsa.RSAPrivateKey, 2048), - ({'type': 'EC'}, ec.EllipticCurvePrivateKey, 384), + ( + {}, + rsa.RSAPrivateKey, + 2048, + ), + ( + {'type': 'EC'}, + ec.EllipticCurvePrivateKey, + 384, + ), ( { 'type': 'RSA', 'key_length': 4096, }, rsa.RSAPrivateKey, - 4096 + 4096, ), ]) def test_generating_private_key(generate_params, expected_type, key_size): diff --git a/truenas_acme_utils/issue_cert.py b/truenas_acme_utils/issue_cert.py index 664974d..ef2d928 100644 --- a/truenas_acme_utils/issue_cert.py +++ b/truenas_acme_utils/issue_cert.py @@ -6,7 +6,7 @@ import josepy as jose from acme import errors, messages -from .client_utils import get_acme_client_and_key +from .client_utils import ACMEClientAndKeyData, get_acme_client_and_key from .event import send_event from .exceptions import CallError @@ -15,13 +15,13 @@ def issue_certificate( - acme_client_key_payload: dict, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25 -): + acme_client_key_payload: ACMEClientAndKeyData, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25 +) -> messages.OrderResource: # Authenticator mapping should be a valid mapping of domain to authenticator object acme_client, key = get_acme_client_and_key(acme_client_key_payload) try: # perform operations and have a cert issued - order = acme_client.new_order(csr) + order = acme_client.new_order(csr.encode()) except messages.Error as e: raise CallError(f'Failed to issue a new order for Certificate : {e}') else: diff --git a/truenas_crypto_utils/generate_certs.py b/truenas_crypto_utils/generate_certs.py index 437a8b7..712ac84 100644 --- a/truenas_crypto_utils/generate_certs.py +++ b/truenas_crypto_utils/generate_certs.py @@ -40,7 +40,7 @@ def generate_certificate(data: dict) -> tuple[str, str]: else: issuer = None - cert = add_extensions(generate_builder(builder_data), data.get('cert_extensions'), key, issuer) + cert = add_extensions(generate_builder(builder_data), data.get('cert_extensions', {}), key, issuer) cert = cert.sign( ca_key or key, retrieve_signing_algorithm(data, ca_key or key), default_backend() diff --git a/truenas_crypto_utils/generate_self_signed.py b/truenas_crypto_utils/generate_self_signed.py index afaa8fe..3ebb9fa 100644 --- a/truenas_crypto_utils/generate_self_signed.py +++ b/truenas_crypto_utils/generate_self_signed.py @@ -21,7 +21,6 @@ def generate_self_signed_certificate() -> tuple[str, str]: 'san': normalize_san(['localhost']) }) key = generate_private_key({ - 'serialize': False, 'key_length': 2048, 'type': 'RSA' }) diff --git a/truenas_crypto_utils/generate_utils.py b/truenas_crypto_utils/generate_utils.py index da57203..4f5817f 100644 --- a/truenas_crypto_utils/generate_utils.py +++ b/truenas_crypto_utils/generate_utils.py @@ -53,7 +53,7 @@ def generate_builder(options: dict) -> x509.CertificateBuilder | x509.Certificat return cert -def normalize_san(san_list: list) -> list: +def normalize_san(san_list: list[str] | None) -> list[list[str]]: # TODO: ADD MORE TYPES WRT RFC'S normalized = [] for count, san in enumerate(san_list or []): diff --git a/truenas_crypto_utils/key.py b/truenas_crypto_utils/key.py index 4bf23af..00b5dd7 100644 --- a/truenas_crypto_utils/key.py +++ b/truenas_crypto_utils/key.py @@ -1,57 +1,53 @@ -from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, ed448, rsa +from typing import Literal, overload + +from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization -from .read import load_private_key +from .read import GeneratedPrivateKey, PrivateKey, load_private_key from .utils import EC_CURVE_DEFAULT -def retrieve_signing_algorithm(data: dict, signing_key: ( - ed25519.Ed25519PrivateKey | - ed448.Ed448PrivateKey | - rsa.RSAPrivateKey | - dsa.DSAPrivateKey | - ec.EllipticCurvePrivateKey -)): +def retrieve_signing_algorithm(data: dict, signing_key: PrivateKey): if isinstance(signing_key, Ed25519PrivateKey): return None else: return getattr(hashes, data.get('digest_algorithm') or 'SHA256')() -def generate_private_key(options: dict) -> ( - str, - ed25519.Ed25519PrivateKey | - ed448.Ed448PrivateKey | - rsa.RSAPrivateKey | - dsa.DSAPrivateKey | - ec.EllipticCurvePrivateKey -): +@overload +def generate_private_key(options: dict, *, serialize: Literal[True]) -> str: ... + + +@overload +def generate_private_key(options: dict, *, serialize: Literal[False] = False) -> GeneratedPrivateKey: ... + + +def generate_private_key(options: dict, *, serialize: bool = False) -> GeneratedPrivateKey | str: # We should make sure to return in PEM format # Reason for using PKCS8 # https://stackoverflow.com/questions/48958304/pkcs1-and-pkcs8-format-for-rsa-private-key - options.setdefault('serialize', False) options.setdefault('key_length', 2048) options.setdefault('type', 'RSA') options.setdefault('curve', EC_CURVE_DEFAULT) - if options.get('type') == 'EC': + if options['type'] == 'EC': if options['curve'] == 'ed25519': key = Ed25519PrivateKey.generate() else: key = ec.generate_private_key( - getattr(ec, options.get('curve')), + getattr(ec, options['curve'])(), default_backend() ) else: key = rsa.generate_private_key( public_exponent=65537, - key_size=options.get('key_length'), + key_size=options['key_length'], backend=default_backend() ) - if options.get('serialize'): + if serialize: return key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, @@ -67,13 +63,7 @@ def export_private_key(buffer: str, passphrase: str | None = None) -> str | None return export_private_key_object(key) -def export_private_key_object(key: ( - ed25519.Ed25519PrivateKey | - ed448.Ed448PrivateKey | - rsa.RSAPrivateKey | - dsa.DSAPrivateKey | - ec.EllipticCurvePrivateKey -)) -> str: +def export_private_key_object(key: PrivateKey) -> str: return key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, diff --git a/truenas_crypto_utils/read.py b/truenas_crypto_utils/read.py index 1a9d1b4..24fc650 100644 --- a/truenas_crypto_utils/read.py +++ b/truenas_crypto_utils/read.py @@ -5,6 +5,7 @@ import re from contextlib import suppress +from typing import TypeAlias from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, ed448, rsa @@ -13,10 +14,13 @@ from .utils import RE_CERTIFICATE +GeneratedPrivateKey: TypeAlias = ed25519.Ed25519PrivateKey | rsa.RSAPrivateKey | ec.EllipticCurvePrivateKey +PrivateKey: TypeAlias = GeneratedPrivateKey | ed448.Ed448PrivateKey | dsa.DSAPrivateKey + logger = logging.getLogger(__name__) -def parse_cert_date_string(date_value: str) -> str: +def parse_cert_date_string(date_value: bytes | str) -> str: t1 = dateutil.parser.parse(date_value) t2 = t1.astimezone(dateutil.tz.tzlocal()) return t2.ctime() @@ -129,17 +133,14 @@ def parse_name_components(obj: crypto.X509Name) -> str: def load_certificate_request(csr: str) -> dict: try: - csr_obj = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr) + csr_obj = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr.encode()) except crypto.Error: return {} else: return get_x509_subject(csr_obj) -def load_private_key(key_string: str, passphrase: str | None = None) -> ( - ed25519.Ed25519PrivateKey | ed448.Ed448PrivateKey | rsa.RSAPrivateKey | - dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey -): +def load_private_key(key_string: str, passphrase: str | None = None) -> PrivateKey: with suppress(ValueError, TypeError, AttributeError): return serialization.load_pem_private_key( key_string.encode(),