Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 24 additions & 53 deletions acapy_agent/wallet/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

import json
import logging
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional

from marshmallow import fields
from pydid import DIDUrl, Resource, VerificationMethod
from pydid.verification_method import Ed25519VerificationKey2018, Multikey

from acapy_agent.wallet.keys.manager import key_type_from_multikey, multikey_to_verkey
from acapy_agent.wallet.keys.manager import (
MultikeyManager,
key_type_from_multikey,
multikey_to_verkey,
)

from ..core.profile import Profile
from ..messaging.jsonld.error import BadJWSHeaderError, InvalidVerificationMethod
from ..messaging.jsonld.error import BadJWSHeaderError
from ..messaging.models.base import BaseModel, BaseModelSchema
from ..resolver.did_resolver import DIDResolver
from .base import BaseWallet
from .default_verification_key_strategy import BaseVerificationKeyStrategy
from .key_type import ED25519, KeyType
from .util import b64_to_bytes, bytes_to_b64

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,19 +64,18 @@ async def jwt_sign(
verification_method = await verkey_strat.get_verification_method_id_for_did(
did, profile
)
else:
# We look up keys by did for now
did = DIDUrl.parse(verification_method).did
if not did:
raise ValueError("DID URL must be absolute")

async with profile.session() as session:
wallet = session.inject(BaseWallet)
did_info = await wallet.get_local_did(did_lookup_name(did))
key_manager = MultikeyManager(session)
key_info = await key_manager.resolve_and_bind_kid(verification_method)
multikey = key_info["multikey"]
key_type = key_type_from_multikey(multikey)
public_key_base58 = multikey_to_verkey(multikey)

header_alg = did_info.key_type.jws_algorithm
header_alg = key_type.jws_algorithm
if not header_alg:
raise ValueError(f"DID key type '{did_info.key_type}' cannot be used for JWS")
raise ValueError(f"DID key type '{key_type}' cannot be used for JWS")

if not headers.get("typ", None):
headers["typ"] = "JWT"
Expand All @@ -88,9 +87,9 @@ async def jwt_sign(
encoded_headers = dict_to_b64(headers)
encoded_payload = dict_to_b64(payload)

LOGGER.info(f"jwt sign: {did}")
LOGGER.info(f"jwt sign: {verification_method}")
sig_bytes = await wallet.sign_message(
f"{encoded_headers}.{encoded_payload}".encode(), did_info.verkey
f"{encoded_headers}.{encoded_payload}".encode(), public_key_base58
)

sig = bytes_to_b64(sig_bytes, urlsafe=True, pad=False)
Expand Down Expand Up @@ -138,38 +137,6 @@ class Meta:
error = fields.Str(required=False, metadata={"description": "Error text"})


async def resolve_public_key_by_kid_for_verify(
profile: Profile, kid: str
) -> Tuple[str, KeyType]:
"""Resolve public key verkey (base58 public key) and key type from a kid."""
resolver = profile.inject(DIDResolver)
vmethod: Resource = await resolver.dereference(
profile,
kid,
)

if not isinstance(vmethod, VerificationMethod):
raise InvalidVerificationMethod(
"Dereferenced resource is not a verification method"
)

if isinstance(vmethod, Ed25519VerificationKey2018):
verkey = vmethod.public_key_base58
ktyp = ED25519
return (verkey, ktyp)

if isinstance(vmethod, Multikey):
multikey = vmethod.public_key_multibase
verkey = multikey_to_verkey(multikey)
ktyp = key_type_from_multikey(multikey=multikey)
return (verkey, ktyp)

# unsupported
raise InvalidVerificationMethod(
f"Dereferenced method {type(vmethod).__name__} is not supported"
)


async def jwt_verify(profile: Profile, jwt: str) -> JWTVerifyResult:
"""Verify a JWT and return the headers and payload."""
encoded_headers, encoded_payload, encoded_signature = jwt.split(".", 3)
Expand All @@ -189,15 +156,19 @@ async def jwt_verify(profile: Profile, jwt: str) -> JWTVerifyResult:
decoded_signature = b64_to_bytes(encoded_signature, urlsafe=True)

async with profile.session() as session:
(verkey, ktyp) = await resolve_public_key_by_kid_for_verify(
profile, verification_method
key_manager = MultikeyManager(session)
multikey = await key_manager.resolve_multikey_from_verification_method_id(
verification_method
)
key_type = key_type_from_multikey(multikey)
public_key_base58 = multikey_to_verkey(multikey)

wallet = session.inject(BaseWallet)
valid = await wallet.verify_message(
f"{encoded_headers}.{encoded_payload}".encode(),
decoded_signature,
from_verkey=verkey,
key_type=ktyp,
from_verkey=public_key_base58,
key_type=key_type,
)

return JWTVerifyResult(headers, payload, valid, verification_method)
27 changes: 7 additions & 20 deletions acapy_agent/wallet/tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

from acapy_agent.resolver.default.key import KeyDIDResolver

from ...resolver.did_resolver import DIDResolver
from ...resolver.tests.test_did_resolver import MockResolver
from ...utils.testing import create_test_profile
Expand All @@ -13,7 +15,7 @@
BaseVerificationKeyStrategy,
DefaultVerificationKeyStrategy,
)
from ..jwt import jwt_sign, jwt_verify, resolve_public_key_by_kid_for_verify
from ..jwt import jwt_sign, jwt_verify


class TestJWT(IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -92,6 +94,9 @@ async def asyncSetUp(self):
BaseVerificationKeyStrategy, DefaultVerificationKeyStrategy()
)
self.profile.context.injector.bind_instance(KeyTypes, KeyTypes())
self.profile.context.injector.bind_instance(
DIDResolver, DIDResolver([KeyDIDResolver()])
)

async def setUpTestingDid(self, key_type: KeyType) -> Tuple[str, str]:
async with self.profile.session() as session:
Expand Down Expand Up @@ -164,7 +169,7 @@ async def test_sign_x_invalid_verification_method(self):
verification_method = "did:key:zzzzgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
with pytest.raises(Exception) as e_info:
await jwt_sign(self.profile, headers, payload, did, verification_method)
assert "Unknown DID" in str(e_info)
assert "DIDNotFound" in str(e_info)

async def test_verify_x_invalid_signed(self):
for key_type in [ED25519, P256]:
Expand All @@ -182,21 +187,3 @@ async def test_verify_x_invalid_signed(self):

with pytest.raises(Exception):
await jwt_verify(self.profile, signed)

async def test_resolve_public_key_by_kid_for_verify_ed25519(self):
(_, kid) = await self.setUpTestingDid(ED25519)
(key_bs58, key_type) = await resolve_public_key_by_kid_for_verify(
self.profile, kid
)

assert key_bs58 == "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
assert key_type == ED25519

async def test_resolve_public_key_by_kid_for_verify_p256(self):
(_, kid) = await self.setUpTestingDid(P256)
(key_bs58, key_type) = await resolve_public_key_by_kid_for_verify(
self.profile, kid
)

assert key_bs58 == "tYbR5egjfja9D5ix1jjYGqfh5QPu73RcZ7UjQUXtargj"
assert key_type == P256
Loading