Skip to content

Commit 28c6289

Browse files
authored
Merge pull request #9 from truenas/NAS-137879
NAS-137879 / 26.04 / Add ARI implementation
2 parents 8e89c55 + 3432537 commit 28c6289

File tree

4 files changed

+190
-5
lines changed

4 files changed

+190
-5
lines changed

truenas_acme_utils/ari.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import json
2+
import time
3+
from datetime import datetime
4+
5+
import requests
6+
7+
8+
# ARI retry configuration per RFC 9773 Section 4.3.3
9+
DEFAULT_RETRY_AFTER = 21600 # 6 hours default
10+
MIN_RETRY_AFTER = 60 # 1 minute minimum
11+
MAX_RETRY_AFTER = 86400 # 1 day maximum
12+
MAX_RETRIES = 3 # Max temporary error retries
13+
14+
15+
def fetch_renewal_info(ari_endpoint: str, cert_id: str, retries: int = MAX_RETRIES, timeout: int = 30) -> dict:
16+
"""
17+
Fetch renewal information from ACME server per RFC 9773 Section 4
18+
19+
Implements exponential backoff for temporary errors per RFC 9773 Section 4.3.3
20+
21+
:param ari_endpoint: RenewalInfo endpoint URL
22+
:param cert_id: Unique certificate identifier from get_cert_id()
23+
:param retries: Number of retries remaining for temporary errors
24+
:param timeout: Request timeout in seconds
25+
:return: Dict with error field (None if success), suggestedWindow (start/end datetimes),
26+
optional explanationURL, retry_after
27+
"""
28+
url = f'{ari_endpoint.rstrip("/")}/{cert_id}'
29+
backoff_delay = 1
30+
response = None
31+
32+
for attempt in range(retries + 1):
33+
try:
34+
response = requests.get(url, timeout=timeout)
35+
36+
# Check for HTTP 409 alreadyReplaced error (RFC 9773 Section 7.4)
37+
if response.status_code == 409:
38+
return {'error': 'Certificate has already been marked as replaced'}
39+
40+
# Handle 5xx server errors as temporary (RFC 9773 Section 4.3.3)
41+
if 500 <= response.status_code < 600:
42+
if attempt < retries:
43+
time.sleep(backoff_delay)
44+
backoff_delay *= 2
45+
continue
46+
return {'error': f'ARI server error after {retries + 1} attempts: HTTP {response.status_code}'}
47+
48+
if response.status_code not in (200, 201, 204):
49+
return {'error': f'ARI request failed: HTTP {response.status_code}'}
50+
51+
data = response.json()
52+
break
53+
54+
except (ConnectionError, TimeoutError, requests.exceptions.RequestException) as e:
55+
if attempt < retries:
56+
time.sleep(backoff_delay)
57+
backoff_delay *= 2
58+
continue
59+
60+
return {'error': f'ARI request failed after {retries + 1} attempts: {e}'}
61+
except json.JSONDecodeError as e:
62+
return {'error': f'Invalid JSON response: {e}'}
63+
except Exception as e:
64+
return {'error': f'ARI request failed: {e}'}
65+
66+
if 'suggestedWindow' not in data:
67+
return {'error': 'Invalid ARI response: missing suggestedWindow'}
68+
69+
window = data['suggestedWindow']
70+
if 'start' not in window or 'end' not in window:
71+
return {'error': 'Invalid suggestedWindow: missing start or end'}
72+
73+
try:
74+
start = datetime.fromisoformat(window['start'].replace('Z', '+00:00'))
75+
end = datetime.fromisoformat(window['end'].replace('Z', '+00:00'))
76+
except (ValueError, TypeError) as e:
77+
return {'error': f'Invalid date format in suggestedWindow: {e}'}
78+
79+
result = {
80+
'error': None,
81+
'suggested_window': {'start': start, 'end': end},
82+
'retry_after': None,
83+
'explanation_url': data.get('explanationURL'),
84+
}
85+
86+
# Parse Retry-After header per RFC 9773 Section 4.3
87+
if response and 'Retry-After' in response.headers:
88+
try:
89+
retry_after = int(response.headers['Retry-After'])
90+
# Clamp to reasonable limits per RFC 9773 Section 4.3.2
91+
retry_after = max(MIN_RETRY_AFTER, min(retry_after, MAX_RETRY_AFTER))
92+
result['retry_after'] = retry_after
93+
except ValueError:
94+
result['retry_after'] = DEFAULT_RETRY_AFTER
95+
else:
96+
# Use default if not provided per RFC 9773 Section 4.3.3
97+
result['retry_after'] = DEFAULT_RETRY_AFTER
98+
99+
return result

truenas_acme_utils/client_utils.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import typing
33

44
import josepy as jose
5-
from acme import client, messages
5+
from acme import client, crypto_util, messages
6+
from cryptography import x509
7+
8+
9+
class NewOrder(messages.NewOrder):
10+
replaces: str = jose.field('replaces', omitempty=True)
611

712

813
class BodyDict(typing.TypedDict):
@@ -17,6 +22,7 @@ class ACMEClientAndKeyData(typing.TypedDict):
1722
new_nonce_uri: str
1823
new_order_uri: str
1924
revoke_cert_uri: str
25+
renewal_info: str | None
2026
body: BodyDict
2127

2228

@@ -29,6 +35,7 @@ def get_acme_client_and_key(data: ACMEClientAndKeyData) -> tuple[client.ClientV2
2935
- new_nonce_uri: str
3036
- new_order_uri: str
3137
- revoke_cert_uri: str
38+
- renewal_info: str (optional)
3239
- body: dict
3340
- status: str
3441
- key: dict
@@ -58,7 +65,50 @@ def get_acme_client_and_key(data: ACMEClientAndKeyData) -> tuple[client.ClientV2
5865
'newAccount': data['new_account_uri'],
5966
'newNonce': data['new_nonce_uri'],
6067
'newOrder': data['new_order_uri'],
61-
'revokeCert': data['revoke_cert_uri']
68+
'revokeCert': data['revoke_cert_uri'],
69+
**({'renewalInfo': data['renewal_info']} if data.get('renewal_info') else {}),
6270
}),
6371
client.ClientNetwork(key, account=registration)
6472
), key
73+
74+
75+
def acme_order(
76+
acme_client: client.ClientV2, csr_pem: bytes, replaces_cert_id: str | None = None,
77+
) -> messages.OrderResource:
78+
csr = x509.load_pem_x509_csr(csr_pem)
79+
dnsNames = crypto_util.get_names_from_subject_and_extensions(csr.subject, csr.extensions)
80+
try:
81+
san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
82+
except x509.ExtensionNotFound:
83+
ipNames = []
84+
else:
85+
ipNames = san_ext.value.get_values_for_type(x509.IPAddress)
86+
87+
identifiers = []
88+
for name in dnsNames:
89+
identifiers.append(messages.Identifier(typ=messages.IDENTIFIER_FQDN, value=name))
90+
91+
for ip in ipNames:
92+
identifiers.append(messages.Identifier(typ=messages.IDENTIFIER_IP, value=str(ip)))
93+
94+
payload = {'identifiers': identifiers}
95+
if replaces_cert_id:
96+
payload['replaces'] = replaces_cert_id
97+
98+
order = NewOrder(**payload)
99+
response = acme_client._post(acme_client.directory['newOrder'], order)
100+
body = messages.Order.from_json(response.json())
101+
102+
authorizations = []
103+
# pylint has trouble understanding our josepy based objects which use
104+
# things like custom metaclass logic. body.authorizations should be a
105+
# list of strings containing URLs so let's disable this check here.
106+
for url in body.authorizations: # pylint: disable=not-an-iterable
107+
authorizations.append(acme_client._authzr_from_response(acme_client._post_as_get(url), uri=url))
108+
109+
return messages.OrderResource(
110+
body=body,
111+
uri=response.headers.get('Location'),
112+
authorizations=authorizations,
113+
csr_pem=csr_pem,
114+
)

truenas_acme_utils/issue_cert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import josepy as jose
77
from acme import errors, messages
88

9-
from .client_utils import ACMEClientAndKeyData, get_acme_client_and_key
9+
from .client_utils import ACMEClientAndKeyData, acme_order, get_acme_client_and_key
1010
from .event import send_event
1111
from .exceptions import CallError
1212

@@ -15,13 +15,15 @@
1515

1616

1717
def issue_certificate(
18-
acme_client_key_payload: ACMEClientAndKeyData, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25
18+
acme_client_key_payload: ACMEClientAndKeyData, csr: str, authenticator_mapping_copy: dict, progress_base: int = 25,
19+
cert_renewal_id: str | None = None,
1920
) -> messages.OrderResource:
21+
# cert_id is the ID of the certificate being replaced if any
2022
# Authenticator mapping should be a valid mapping of domain to authenticator object
2123
acme_client, key = get_acme_client_and_key(acme_client_key_payload)
2224
try:
2325
# perform operations and have a cert issued
24-
order = acme_client.new_order(csr.encode())
26+
order = acme_order(acme_client, csr.encode(), cert_renewal_id)
2527
except messages.Error as e:
2628
raise CallError(f'Failed to issue a new order for Certificate : {e}')
2729
else:

truenas_crypto_utils/read.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import datetime
23
import dateutil
34
import dateutil.parser
@@ -6,9 +7,11 @@
67

78
from contextlib import suppress
89
from typing import TypeAlias
10+
from cryptography import x509
911
from cryptography.hazmat.backends import default_backend
1012
from cryptography.hazmat.primitives import serialization
1113
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, ed448, rsa
14+
from cryptography.x509.oid import ExtensionOID
1215
from OpenSSL import crypto
1316

1417
from .utils import RE_CERTIFICATE
@@ -20,6 +23,37 @@
2023
logger = logging.getLogger(__name__)
2124

2225

26+
def _b64url(b: bytes) -> str:
27+
return base64.urlsafe_b64encode(b).decode().rstrip("=")
28+
29+
30+
def _serial_value_bytes(n: int) -> bytes:
31+
# DER INTEGER value bytes for non-negative n
32+
if n == 0:
33+
return b'\x00'
34+
v = n.to_bytes((n.bit_length() + 7) // 8, 'big')
35+
return v if not (v[0] & 0x80) else b'\x00' + v
36+
37+
38+
def get_cert_id(cert_str: str) -> str:
39+
"""
40+
ARI cert_id per RFC 9773 §4.1
41+
format: base64url(AKI.keyIdentifier) + "." + base64url(serial INTEGER value bytes)
42+
"""
43+
cert = x509.load_pem_x509_certificate(cert_str.encode(), default_backend())
44+
try:
45+
aki_ext = cert.extensions.get_extension_for_oid(ExtensionOID.AUTHORITY_KEY_IDENTIFIER)
46+
except x509.ExtensionNotFound as e:
47+
raise ValueError('Certificate missing Authority Key Identifier (AKI)') from e
48+
49+
if aki_ext.value.key_identifier is None:
50+
raise ValueError('AKI keyIdentifier is None')
51+
52+
aki_b64 = _b64url(aki_ext.value.key_identifier)
53+
serial_b64 = _b64url(_serial_value_bytes(cert.serial_number))
54+
return f'{aki_b64}.{serial_b64}'
55+
56+
2357
def parse_cert_date_string(date_value: bytes | str) -> str:
2458
t1 = dateutil.parser.parse(date_value)
2559
t2 = t1.astimezone(dateutil.tz.tzlocal())

0 commit comments

Comments
 (0)