Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pytest/unit/test_private_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@


@pytest.mark.parametrize('generate_params,expected_type,key_size', [
({}, rsa.RSAPrivateKey, 2048),
({'type': 'EC'}, ec.EllipticCurvePrivateKey, 384),
(
{},
rsa.RSAPrivateKey,
2048,
),
(
{'type': 'EC'},
ec.EllipticCurvePrivateKey,
384,
),
(
{
'type': 'RSA',
'key_length': 4096,
},
rsa.RSAPrivateKey,
4096
4096,
),
])
def test_generating_private_key(generate_params, expected_type, key_size):
Expand Down
8 changes: 4 additions & 4 deletions truenas_acme_utils/issue_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import josepy as jose
from acme import errors, messages

from .client_utils import get_acme_client_and_key
from .client_utils import ACMEClientAndKeyData, get_acme_client_and_key
from .event import send_event
from .exceptions import CallError

Expand All @@ -15,13 +15,13 @@


def issue_certificate(
acme_client_key_payload: dict, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25
):
acme_client_key_payload: ACMEClientAndKeyData, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25
) -> messages.OrderResource:
# Authenticator mapping should be a valid mapping of domain to authenticator object
acme_client, key = get_acme_client_and_key(acme_client_key_payload)
try:
# perform operations and have a cert issued
order = acme_client.new_order(csr)
order = acme_client.new_order(csr.encode())
except messages.Error as e:
raise CallError(f'Failed to issue a new order for Certificate : {e}')
else:
Expand Down
2 changes: 1 addition & 1 deletion truenas_crypto_utils/generate_certs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate_certificate(data: dict) -> tuple[str, str]:
else:
issuer = None

cert = add_extensions(generate_builder(builder_data), data.get('cert_extensions'), key, issuer)
cert = add_extensions(generate_builder(builder_data), data.get('cert_extensions', {}), key, issuer)

cert = cert.sign(
ca_key or key, retrieve_signing_algorithm(data, ca_key or key), default_backend()
Expand Down
1 change: 0 additions & 1 deletion truenas_crypto_utils/generate_self_signed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def generate_self_signed_certificate() -> tuple[str, str]:
'san': normalize_san(['localhost'])
})
key = generate_private_key({
'serialize': False,
'key_length': 2048,
'type': 'RSA'
})
Expand Down
2 changes: 1 addition & 1 deletion truenas_crypto_utils/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_builder(options: dict) -> x509.CertificateBuilder | x509.Certificat
return cert


def normalize_san(san_list: list) -> list:
def normalize_san(san_list: list[str] | None) -> list[list[str]]:
# TODO: ADD MORE TYPES WRT RFC'S
normalized = []
for count, san in enumerate(san_list or []):
Expand Down
48 changes: 19 additions & 29 deletions truenas_crypto_utils/key.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,53 @@
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, ed448, rsa
from typing import Literal, overload

from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization

from .read import load_private_key
from .read import GeneratedPrivateKey, PrivateKey, load_private_key
from .utils import EC_CURVE_DEFAULT


def retrieve_signing_algorithm(data: dict, signing_key: (
ed25519.Ed25519PrivateKey |
ed448.Ed448PrivateKey |
rsa.RSAPrivateKey |
dsa.DSAPrivateKey |
ec.EllipticCurvePrivateKey
)):
def retrieve_signing_algorithm(data: dict, signing_key: PrivateKey):
if isinstance(signing_key, Ed25519PrivateKey):
return None
else:
return getattr(hashes, data.get('digest_algorithm') or 'SHA256')()


def generate_private_key(options: dict) -> (
str,
ed25519.Ed25519PrivateKey |
ed448.Ed448PrivateKey |
rsa.RSAPrivateKey |
dsa.DSAPrivateKey |
ec.EllipticCurvePrivateKey
):
@overload
def generate_private_key(options: dict, *, serialize: Literal[True]) -> str: ...


@overload
def generate_private_key(options: dict, *, serialize: Literal[False] = False) -> GeneratedPrivateKey: ...


def generate_private_key(options: dict, *, serialize: bool = False) -> GeneratedPrivateKey | str:
# We should make sure to return in PEM format
# Reason for using PKCS8
# https://stackoverflow.com/questions/48958304/pkcs1-and-pkcs8-format-for-rsa-private-key
options.setdefault('serialize', False)
options.setdefault('key_length', 2048)
options.setdefault('type', 'RSA')
options.setdefault('curve', EC_CURVE_DEFAULT)

if options.get('type') == 'EC':
if options['type'] == 'EC':
if options['curve'] == 'ed25519':
key = Ed25519PrivateKey.generate()
else:
key = ec.generate_private_key(
getattr(ec, options.get('curve')),
getattr(ec, options['curve'])(),
default_backend()
)
else:
key = rsa.generate_private_key(
public_exponent=65537,
key_size=options.get('key_length'),
key_size=options['key_length'],
backend=default_backend()
)

if options.get('serialize'):
if serialize:
return key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
Expand All @@ -67,13 +63,7 @@ def export_private_key(buffer: str, passphrase: str | None = None) -> str | None
return export_private_key_object(key)


def export_private_key_object(key: (
ed25519.Ed25519PrivateKey |
ed448.Ed448PrivateKey |
rsa.RSAPrivateKey |
dsa.DSAPrivateKey |
ec.EllipticCurvePrivateKey
)) -> str:
def export_private_key_object(key: PrivateKey) -> str:
return key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
Expand Down
13 changes: 7 additions & 6 deletions truenas_crypto_utils/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re

from contextlib import suppress
from typing import TypeAlias
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, ed448, rsa
Expand All @@ -13,10 +14,13 @@
from .utils import RE_CERTIFICATE


GeneratedPrivateKey: TypeAlias = ed25519.Ed25519PrivateKey | rsa.RSAPrivateKey | ec.EllipticCurvePrivateKey
PrivateKey: TypeAlias = GeneratedPrivateKey | ed448.Ed448PrivateKey | dsa.DSAPrivateKey

logger = logging.getLogger(__name__)


def parse_cert_date_string(date_value: str) -> str:
def parse_cert_date_string(date_value: bytes | str) -> str:
t1 = dateutil.parser.parse(date_value)
t2 = t1.astimezone(dateutil.tz.tzlocal())
return t2.ctime()
Expand Down Expand Up @@ -129,17 +133,14 @@ def parse_name_components(obj: crypto.X509Name) -> str:

def load_certificate_request(csr: str) -> dict:
try:
csr_obj = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
csr_obj = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr.encode())
except crypto.Error:
return {}
else:
return get_x509_subject(csr_obj)


def load_private_key(key_string: str, passphrase: str | None = None) -> (
ed25519.Ed25519PrivateKey | ed448.Ed448PrivateKey | rsa.RSAPrivateKey |
dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey
):
def load_private_key(key_string: str, passphrase: str | None = None) -> PrivateKey:
with suppress(ValueError, TypeError, AttributeError):
return serialization.load_pem_private_key(
key_string.encode(),
Expand Down