Skip to content

Commit e0ae070

Browse files
committed
bug fixes, compatibility changes for typescript implementation - added "type" literal field in all JSON payloads for easier discrimination/parsing, changed format of signature to the raw integers r||s instead of DER representation. unfortunately very hard to find support in java/typescript for that
1 parent addaa25 commit e0ae070

File tree

7 files changed

+98
-35
lines changed

7 files changed

+98
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "koi-net"
7-
version = "1.1.0-beta.5"
7+
version = "1.1.0-beta.6"
88
description = "Implementation of KOI-net protocol in Python"
99
authors = [
1010
{name = "Luke Miller", email = "[email protected]"}

src/koi_net/network/error_handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from logging import getLogger
2-
from koi_net.protocol.errors import ErrorTypes
2+
from koi_net.protocol.errors import ErrorType
33
from koi_net.protocol.event import EventType
44
from rid_lib.types import KoiNetNode
55
from ..processor.interface import ProcessorInterface
@@ -36,15 +36,15 @@ def handle_connection_error(self, node: KoiNetNode):
3636

3737
def handle_protocol_error(
3838
self,
39-
error_type: ErrorTypes,
39+
error_type: ErrorType,
4040
node: KoiNetNode
4141
):
4242
logger.info(f"Handling protocol error {error_type} for node {node!r}")
4343
match error_type:
44-
case ErrorTypes.UnknownNode:
44+
case ErrorType.UnknownNode:
4545
logger.info("Peer doesn't know me, attempting handshake...")
4646
self.actor.handshake_with(node)
4747

48-
case ErrorTypes.InvalidKey: ...
49-
case ErrorTypes.InvalidSignature: ...
50-
case ErrorTypes.InvalidTarget: ...
48+
case ErrorType.InvalidKey: ...
49+
case ErrorType.InvalidSignature: ...
50+
case ErrorType.InvalidTarget: ...

src/koi_net/network/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def generate(self):
4141
logger.debug(f"Added edge {rid!r} ({edge_profile.source} -> {edge_profile.target})")
4242
logger.debug("Done")
4343

44-
def get_edge(self, source: KoiNetNode, target: KoiNetNode,) -> EdgeProfile | None:
44+
def get_edge(self, source: KoiNetNode, target: KoiNetNode,) -> KoiNetEdge | None:
4545
"""Returns edge RID given the RIDs of a source and target node."""
4646
if (source, target) in self.dg.edges:
4747
edge_data = self.dg.get_edge_data(source, target)

src/koi_net/protocol/api_models.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,60 @@
11
"""Pydantic models for request and response/payload objects in the KOI-net API."""
22

3-
from pydantic import BaseModel
3+
from typing import Literal
4+
from pydantic import BaseModel, Field
45
from rid_lib import RID, RIDType
56
from rid_lib.ext import Bundle, Manifest
67
from .event import Event
7-
from .errors import ErrorTypes
8+
from .errors import ErrorType
89

910

1011
# REQUEST MODELS
1112

1213
class PollEvents(BaseModel):
13-
rid: RID
14+
type: Literal["poll_events"] = Field("poll_events")
1415
limit: int = 0
1516

1617
class FetchRids(BaseModel):
18+
type: Literal["fetch_rids"] = Field("fetch_rids")
1719
rid_types: list[RIDType] = []
1820

1921
class FetchManifests(BaseModel):
22+
type: Literal["fetch_manifests"] = Field("fetch_manifests")
2023
rid_types: list[RIDType] = []
2124
rids: list[RID] = []
2225

2326
class FetchBundles(BaseModel):
27+
type: Literal["fetch_bundles"] = Field("fetch_bundles")
2428
rids: list[RID]
2529

2630

2731
# RESPONSE/PAYLOAD MODELS
2832

2933
class RidsPayload(BaseModel):
34+
type: Literal["rids_payload"] = Field("rids_payload")
3035
rids: list[RID]
3136

3237
class ManifestsPayload(BaseModel):
38+
type: Literal["manifests_payload"] = Field("manifests_payload")
3339
manifests: list[Manifest]
3440
not_found: list[RID] = []
3541

3642
class BundlesPayload(BaseModel):
43+
type: Literal["bundles_payload"] = Field("bundles_payload")
3744
bundles: list[Bundle]
3845
not_found: list[RID] = []
3946
deferred: list[RID] = []
4047

4148
class EventsPayload(BaseModel):
49+
type: Literal["events_payload"] = Field("events_payload")
4250
events: list[Event]
4351

4452

4553
# ERROR MODELS
4654

4755
class ErrorResponse(BaseModel):
48-
error: ErrorTypes
56+
type: Literal["error_response"] = Field("error_response")
57+
error: ErrorType
4958

5059
# TYPES
5160

src/koi_net/protocol/envelope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def verify_with(self, pub_key: PublicKey):
2727
)
2828

2929
logger.debug(f"Verifying envelope: {unsigned_envelope.model_dump_json()}")
30-
30+
3131
pub_key.verify(
3232
self.signature,
3333
unsigned_envelope.model_dump_json().encode()

src/koi_net/protocol/errors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from enum import StrEnum
22

33

4-
class ErrorTypes(StrEnum):
4+
class ErrorType(StrEnum):
55
UnknownNode = "unknown_node"
66
InvalidKey = "invalid_key"
77
InvalidSignature = "invalid_signature"
88
InvalidTarget = "invalid_target"
99

1010
class ProtocolError(Exception):
11-
error_type: ErrorTypes
11+
error_type: ErrorType
1212

1313
class UnknownNodeError(ProtocolError):
14-
error_type = ErrorTypes.UnknownNode
14+
error_type = ErrorType.UnknownNode
1515

1616
class InvalidKeyError(ProtocolError):
17-
error_type = ErrorTypes.InvalidKey
17+
error_type = ErrorType.InvalidKey
1818

1919
class InvalidSignatureError(ProtocolError):
20-
error_type = ErrorTypes.InvalidSignature
20+
error_type = ErrorType.InvalidSignature
2121

2222
class InvalidTargetError(ProtocolError):
23-
error_type = ErrorTypes.InvalidTarget
23+
error_type = ErrorType.InvalidTarget

src/koi_net/protocol/secure.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,64 @@
11
import logging
2-
from base64 import urlsafe_b64decode, urlsafe_b64encode
2+
from base64 import b64decode, b64encode
33
from cryptography.hazmat.primitives import hashes
44
from cryptography.hazmat.primitives.asymmetric import ec
55
from cryptography.hazmat.primitives import serialization
66
from rid_lib.ext.utils import sha256_hash
7+
from cryptography.hazmat.primitives.asymmetric.utils import (
8+
decode_dss_signature,
9+
encode_dss_signature
10+
)
711

812
logger = logging.getLogger(__name__)
913

1014

15+
def der_to_raw_signature(der_signature: bytes, curve=ec.SECP256R1()) -> bytes:
16+
"""Convert a DER-encoded signature to raw r||s format."""
17+
18+
# Decode the DER signature to get r and s
19+
r, s = decode_dss_signature(der_signature)
20+
21+
# Determine byte length based on curve bit size
22+
byte_length = (curve.key_size + 7) // 8
23+
24+
# Convert r and s to big-endian byte arrays of fixed length
25+
r_bytes = r.to_bytes(byte_length, byteorder='big')
26+
s_bytes = s.to_bytes(byte_length, byteorder='big')
27+
28+
# Concatenate r and s
29+
return r_bytes + s_bytes
30+
31+
32+
def raw_to_der_signature(raw_signature: bytes, curve=ec.SECP256R1()) -> bytes:
33+
"""Convert a raw r||s signature to DER format."""
34+
35+
# Determine byte length based on curve bit size
36+
byte_length = (curve.key_size + 7) // 8
37+
38+
# Split the raw signature into r and s components
39+
if len(raw_signature) != 2 * byte_length:
40+
raise ValueError(f"Raw signature must be {2 * byte_length} bytes for {curve.name}")
41+
42+
r_bytes = raw_signature[:byte_length]
43+
s_bytes = raw_signature[byte_length:]
44+
45+
# Convert bytes to integers
46+
r = int.from_bytes(r_bytes, byteorder='big')
47+
s = int.from_bytes(s_bytes, byteorder='big')
48+
49+
# Encode as DER
50+
return encode_dss_signature(r, s)
51+
52+
1153
class PrivateKey:
1254
priv_key: ec.EllipticCurvePrivateKey
1355

1456
def __init__(self, priv_key):
1557
self.priv_key = priv_key
16-
58+
1759
@classmethod
1860
def generate(cls):
19-
return cls(priv_key=ec.generate_private_key(ec.SECP192R1()))
61+
return cls(priv_key=ec.generate_private_key(ec.SECP256R1()))
2062

2163
def public_key(self) -> "PublicKey":
2264
return PublicKey(self.priv_key.public_key())
@@ -40,12 +82,14 @@ def to_pem(self, password: str) -> str:
4082
def sign(self, message: bytes) -> str:
4183
hashed_message = sha256_hash(message.decode())
4284

43-
signature = urlsafe_b64encode(
44-
self.priv_key.sign(
45-
data=message,
46-
signature_algorithm=ec.ECDSA(hashes.SHA256())
47-
)
48-
).decode()
85+
der_signature_bytes = self.priv_key.sign(
86+
data=message,
87+
signature_algorithm=ec.ECDSA(hashes.SHA256())
88+
)
89+
90+
raw_signature_bytes = der_to_raw_signature(der_signature_bytes)
91+
92+
signature = b64encode(raw_signature_bytes).decode()
4993

5094
logger.debug(f"Signing message with [{self.public_key().to_der()}]")
5195
logger.debug(f"hash: {hashed_message}")
@@ -78,29 +122,39 @@ def to_pem(self) -> str:
78122
def from_der(cls, pub_key_der: str):
79123
return cls(
80124
pub_key=serialization.load_der_public_key(
81-
data=urlsafe_b64decode(pub_key_der)
125+
data=b64decode(pub_key_der)
82126
)
83127
)
84128

85129
def to_der(self) -> str:
86-
return urlsafe_b64encode(
130+
return b64encode(
87131
self.pub_key.public_bytes(
88132
encoding=serialization.Encoding.DER,
89133
format=serialization.PublicFormat.SubjectPublicKeyInfo
90134
)
91135
).decode()
92136

137+
93138
def verify(self, signature: str, message: bytes) -> bool:
94-
hashed_message = sha256_hash(message.decode())
139+
# hashed_message = sha256_hash(message.decode())
140+
141+
# print(message.hex())
142+
# print()
143+
# print(hashed_message)
144+
# print()
145+
# print(message.decode())
95146

96-
logger.debug(f"Verifying message with [{self.to_der()}]")
97-
logger.debug(f"hash: {hashed_message}")
98-
logger.debug(f"signature: {signature}")
147+
# logger.debug(f"Verifying message with [{self.to_der()}]")
148+
# logger.debug(f"hash: {hashed_message}")
149+
# logger.debug(f"signature: {signature}")
150+
151+
raw_signature_bytes = b64decode(signature)
152+
der_signature_bytes = raw_to_der_signature(raw_signature_bytes)
99153

100154
# NOTE: throws cryptography.exceptions.InvalidSignature on failure
101155

102156
self.pub_key.verify(
103-
signature=urlsafe_b64decode(signature),
157+
signature=der_signature_bytes,
104158
data=message,
105159
signature_algorithm=ec.ECDSA(hashes.SHA256())
106160
)

0 commit comments

Comments
 (0)