diff --git a/Cargo.toml b/Cargo.toml index 429d7c9..d9fca9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,9 +62,10 @@ rand = "0.8.5" rayon = "1.10.0" sha2 = "0.10.7" sha3 = "0.10.8" +thiserror = "2.0.12" +zerocopy = "0.8" zeroize = "1.8.1" - # Un-comment below for latest arkworks libraries. # [patch.crates-io] # ark-std = { git = "https://github.com/arkworks-rs/utils" } diff --git a/spongefish-pow/src/blake3.rs b/spongefish-pow/src/blake3.rs index b5d5181..3e17ddf 100644 --- a/spongefish-pow/src/blake3.rs +++ b/spongefish-pow/src/blake3.rs @@ -202,7 +202,7 @@ impl Blake3PoW { #[cfg(test)] mod tests { - use spongefish::{DefaultHash, DomainSeparator}; + use spongefish::DefaultHash; use super::*; use crate::{ diff --git a/spongefish-pow/src/lib.rs b/spongefish-pow/src/lib.rs index 38c0483..e786dfb 100644 --- a/spongefish-pow/src/lib.rs +++ b/spongefish-pow/src/lib.rs @@ -45,11 +45,11 @@ where Self: BytesToUnitSerialize + UnitToBytes, { fn challenge_pow(&mut self, bits: f64) -> ProofResult<()> { - let challenge = self.challenge_bytes()?; + let challenge = self.challenge_bytes(); let nonce = S::new(challenge, bits) .solve() .ok_or(ProofError::InvalidProof)?; - self.add_bytes(&nonce.to_be_bytes())?; + self.add_bytes(&nonce.to_be_bytes()); Ok(()) } } @@ -61,7 +61,7 @@ where Self: BytesToUnitDeserialize + UnitToBytes, { fn challenge_pow(&mut self, bits: f64) -> ProofResult<()> { - let challenge = self.challenge_bytes()?; + let challenge = self.challenge_bytes(); let nonce = u64::from_be_bytes(self.next_bytes()?); if S::new(challenge, bits).check(nonce) { Ok(()) diff --git a/spongefish/Cargo.toml b/spongefish/Cargo.toml index 6aea4eb..d956c3f 100644 --- a/spongefish/Cargo.toml +++ b/spongefish/Cargo.toml @@ -13,6 +13,7 @@ license = "BSD-3-Clause" workspace = true [dependencies] +zerocopy = { workspace = true } zeroize = { workspace = true, features = ["zeroize_derive"] } rand = { workspace = true, features = ["getrandom"] } digest = { workspace = true } @@ -24,6 +25,8 @@ ark-ec = { workspace = true, optional = true } ark-serialize = { workspace = true, features = ["std"], optional = true } group = { workspace = true, optional = true } hex = { workspace = true } +thiserror.workspace = true +sha3.workspace = true [features] default = [] diff --git a/spongefish/src/codecs/arkworks_algebra/deserialize.rs b/spongefish/src/codecs/arkworks_algebra/deserialize.rs index 05842e4..0efcf29 100644 --- a/spongefish/src/codecs/arkworks_algebra/deserialize.rs +++ b/spongefish/src/codecs/arkworks_algebra/deserialize.rs @@ -4,7 +4,7 @@ use ark_ec::{ CurveGroup, }; use ark_ff::{Field, Fp, FpConfig}; -use ark_serialize::CanonicalDeserialize; +use ark_serialize::{CanonicalDeserialize, SerializationError}; use super::{FieldToUnitDeserialize, GroupToUnitDeserialize}; use crate::{ @@ -67,7 +67,7 @@ where for o in output.iter_mut() { let o_affine = EdwardsAffine::deserialize_compressed(&mut self.narg_string)?; *o = o_affine.into(); - self.public_units(&[o.x, o.y])?; + self.public_units(&[o.x, o.y]); } Ok(()) } @@ -83,13 +83,14 @@ where for o in output.iter_mut() { let o_affine = SWAffine::deserialize_compressed(&mut self.narg_string)?; *o = o_affine.into(); - self.public_units(&[o.x, o.y])?; + self.public_units(&[o.x, o.y]); } Ok(()) } } #[cfg(test)] +#[cfg(feature = "disable")] mod tests { use ark_bls12_381::G1Projective; use ark_curve25519::EdwardsProjective; @@ -99,8 +100,8 @@ mod tests { use super::*; use crate::{ - codecs::arkworks_algebra::{FieldDomainSeparator, GroupDomainSeparator}, - DefaultHash, DomainSeparator, + codecs::arkworks_algebra::{FieldPattern, GroupPattern}, + DefaultHash, }; /// Custom field for testing: BabyBear diff --git a/spongefish/src/codecs/arkworks_algebra/domain_separator.rs b/spongefish/src/codecs/arkworks_algebra/domain_separator.rs index 6f5ddc6..6db56b8 100644 --- a/spongefish/src/codecs/arkworks_algebra/domain_separator.rs +++ b/spongefish/src/codecs/arkworks_algebra/domain_separator.rs @@ -1,90 +1,114 @@ use ark_ec::CurveGroup; use ark_ff::{Field, Fp, FpConfig, PrimeField}; -use super::{ - ByteDomainSeparator, DomainSeparator, DuplexSpongeInterface, FieldDomainSeparator, - GroupDomainSeparator, +use super::{FieldPattern, GroupPattern}; +use crate::{ + codecs::{ + bytes::{self, Pattern as _}, + bytes_modp, bytes_uniform_modp, + unit::{self, Pattern as _}, + }, + pattern::{self, Label, Length, Pattern as _, PatternState}, }; -use crate::codecs::{bytes_modp, bytes_uniform_modp}; -impl FieldDomainSeparator for DomainSeparator +impl FieldPattern for PatternState where F: Field, - H: DuplexSpongeInterface, { - fn add_scalars(self, count: usize, label: &str) -> Self { - self.add_bytes( + fn message_scalars(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); + self.message_bytes( + "base-field-coefficients-little-endian", count * F::extension_degree() as usize * bytes_modp(F::BasePrimeField::MODULUS_BIT_SIZE), - label, - ) + ); + self.end_message::(label, Length::Fixed(count)); } - fn challenge_scalars(self, count: usize, label: &str) -> Self { + fn challenge_scalars(&mut self, label: Label, count: usize) { + self.begin_challenge::(label, Length::Fixed(count)); self.challenge_bytes( + "base-field-coefficients-little-endian", count * F::extension_degree() as usize * bytes_uniform_modp(F::BasePrimeField::MODULUS_BIT_SIZE), - label, - ) + ); + self.end_challenge::(label, Length::Fixed(count)); } } -impl FieldDomainSeparator for DomainSeparator> +impl FieldPattern for PatternState> where F: Field>, C: FpConfig, - H: DuplexSpongeInterface>, { - fn add_scalars(self, count: usize, label: &str) -> Self { - self.absorb(count * F::extension_degree() as usize, label) + fn message_scalars(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); + self.message_units( + "base-field-coefficients", + count * F::extension_degree() as usize, + ); + self.end_message::(label, Length::Fixed(count)); } - fn challenge_scalars(self, count: usize, label: &str) -> Self { - self.squeeze(count * F::extension_degree() as usize, label) + fn challenge_scalars(&mut self, label: Label, count: usize) { + self.begin_challenge::(label, Length::Fixed(count)); + self.challenge_units( + "base-field-coefficients", + count * F::extension_degree() as usize, + ); + self.end_challenge::(label, Length::Fixed(count)); } } -impl ByteDomainSeparator for DomainSeparator> +/// Implementation where `Unit = Fp` +impl bytes::Pattern for PatternState> where C: FpConfig, - H: DuplexSpongeInterface>, { /// Add `count` bytes to the transcript, encoding each of them as an element of the field `Fp`. - fn add_bytes(self, count: usize, label: &str) -> Self { - self.absorb(count, label) + fn public_bytes(&mut self, label: Label, size: usize) { + self.begin_public::(label, Length::Fixed(size)); + self.public_units("units", size); + self.end_public::(label, Length::Fixed(size)) } - fn hint(self, label: &str) -> Self { - self.hint(label) + /// Add `count` bytes to the transcript, encoding each of them as an element of the field `Fp`. + fn message_bytes(&mut self, label: Label, size: usize) { + self.begin_message::(label, Length::Fixed(size)); + self.message_units("units", size); + self.end_message::(label, Length::Fixed(size)) } - fn challenge_bytes(self, count: usize, label: &str) -> Self { + fn challenge_bytes(&mut self, label: Label, size: usize) { + self.begin_challenge::(label, Length::Fixed(size)); let n = crate::codecs::random_bits_in_random_modp(Fp::::MODULUS) / 8; - self.squeeze(count.div_ceil(n), label) + self.challenge_units("units", size.div_ceil(n)); + self.end_challenge::(label, Length::Fixed(size)) } } -impl GroupDomainSeparator for DomainSeparator +impl GroupPattern for PatternState where G: CurveGroup, - H: DuplexSpongeInterface, { - fn add_points(self, count: usize, label: &str) -> Self { - self.add_bytes(count * G::default().compressed_size(), label) + fn message_points(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); + self.message_bytes("serialized-group", count * G::default().compressed_size()); + self.end_message::(label, Length::Fixed(count)); } } -impl GroupDomainSeparator for DomainSeparator> +impl GroupPattern for PatternState> where G: CurveGroup>, - H: DuplexSpongeInterface>, C: FpConfig, - Self: FieldDomainSeparator>, { - fn add_points(self, count: usize, label: &str) -> Self { - self.absorb(count * 2, label) + fn message_points(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); + self.message_units("coordinates", count * 2); + self.end_message::(label, Length::Fixed(count)); } } @@ -98,7 +122,10 @@ mod tests { }; use super::*; - use crate::DefaultHash; + use crate::{ + pattern::{InteractionPattern, Pattern}, + DefaultHash, + }; /// Configuration for the BabyBear field (modulus = 2^31 - 2^27 + 1, generator = 21). #[derive(MontConfig)] @@ -153,21 +180,22 @@ mod tests { // .absorb_scalars(1, "resp"); // // OPTION 2 - fn add_schnorr_domain_separator>( - ) -> DomainSeparator + fn add_schnorr_domain_separator(pattern: &mut P) where - DomainSeparator: GroupDomainSeparator + FieldDomainSeparator, + P: pattern::Pattern + unit::Pattern + FieldPattern + GroupPattern, { - DomainSeparator::new("github.com/mmaker/spongefish") - .add_points(1, "g") - .add_points(1, "pk") - .ratchet() - .add_points(1, "com") - .challenge_scalars(1, "chal") - .add_scalars(1, "resp") + pattern.begin_protocol::<()>("github.com/mmaker/spongefish"); + pattern.message_points("g", 1); + pattern.message_points("pk", 1); + pattern.ratchet(); + pattern.message_points("com", 1); + pattern.challenge_scalars("chal", 1); + pattern.message_scalars("resp", 1); + pattern.end_protocol::<()>("github.com/mmaker/spongefish"); } - let domain_separator = - add_schnorr_domain_separator::(); + let mut pattern = PatternState::::new(); + add_schnorr_domain_separator::<_, ark_curve25519::EdwardsProjective>(&mut pattern); + let pattern = pattern.finalize(); // OPTION 3 (extra type, trait extensions should be on DomainSeparator or AlgebraicDomainSeparator?) // let domain_separator = @@ -180,298 +208,327 @@ mod tests { // .add_scalars(1, "resp"); assert_eq!( - domain_separator.as_bytes(), - b"github.com/mmaker/spongefish\0A32g\0A32pk\0R\0A32com\0S47chal\0A32resp" - ); - } - - #[test] - fn test_scalar_vs_byte_equivalence_field_modp() { - type F = Fr; - - let label = "same-scalar"; - // Compute number of bytes needed to represent one scalar - let scalar_bytes = bytes_modp(F::MODULUS_BIT_SIZE); - - // Add one scalar to the transcript using the scalar API - let scalar_sep = >::add_scalars( - DomainSeparator::::new("label"), - 1, - label, - ); - - // Add the same number of bytes directly - let byte_sep = DomainSeparator::::new("label").add_bytes(scalar_bytes, label); - - // Ensure the encodings are equal - assert_eq!(scalar_sep.as_bytes(), byte_sep.as_bytes()); - } - - #[test] - fn test_challenge_scalars_vs_bytes_equivalence() { - type F = Fr; - - let label = "challenge"; - // Compute the number of bytes needed for one uniform scalar - let uniform_bytes = bytes_uniform_modp(F::MODULUS_BIT_SIZE); - - // Request 2 scalar challenges - let sep_scalar = >::challenge_scalars( - DomainSeparator::::new("L"), - 2, - label, - ); - - // Request 2 * bytes directly - let sep_bytes = - DomainSeparator::::new("L").challenge_bytes(2 * uniform_bytes, label); - - // Ensure the encodings match - assert_eq!(sep_scalar.as_bytes(), sep_bytes.as_bytes()); - } - - #[test] - fn test_domain_separator_fq2_bytes_are_expected() { - type F = Fq2; - - // Construct the separator with one Fq2 absorbed and one Fq2 squeezed - let sep = >::challenge_scalars( - >::add_scalars( - DomainSeparator::::new("ark-fq2"), - 1, - "a", - ), - 1, - "b", - ); - - // Explanation of the expected encoding: - // "ark-fq2" → domain label - // "\0A96a" → absorb 96 bytes (Fq2 = 2 × 48) - // "\0S126b" → squeeze 126 bytes (Fq2 = 2 × 63 uniform bytes) - let expected_bytes = b"ark-fq2\0A96a\0S126b"; - assert_eq!(sep.as_bytes(), expected_bytes); - } - - #[test] - fn test_group_point_encoding_vs_bytes_direct() { - type G = Curve; - - // Add 2 group elements to the transcript (compressed size = 32 each) - let point_sep = >::add_points( - DomainSeparator::::new("G"), - 2, - "X", - ); - - // Add 64 raw bytes directly instead - let byte_sep = DomainSeparator::::new("G").add_bytes(64, "X"); - - // Ensure they are equivalent - assert_eq!(point_sep.as_bytes(), byte_sep.as_bytes()); - } - - #[test] - fn test_domain_separator_determinism() { - type G = Curve; - type F = Fr; - - // First sequence: add group point, absorb scalar, squeeze scalar - let add_pts = >::add_points( - DomainSeparator::::new("proof"), - 1, - "pk", - ); - let add_scalars = - >::add_scalars(add_pts, 1, "x"); - let d1 = - >::challenge_scalars(add_scalars, 1, "y"); - - // Repeat the same sequence again - let add_pts_2 = >::add_points( - DomainSeparator::::new("proof"), - 1, - "pk", - ); - let add_scalars_2 = - >::add_scalars(add_pts_2, 1, "x"); - let d2 = - >::challenge_scalars(add_scalars_2, 1, "y"); - - // Resulting byte encodings must be the same - assert_eq!(d1.as_bytes(), d2.as_bytes()); - } - - #[test] - fn test_group_and_field_mixed_usage_structure() { - type G = Curve; - type F = Fr; - - // Add one group element (compressed 32 bytes) - let step1 = >::add_points( - DomainSeparator::::new("joint"), - 1, - "pk", - ); - // Ratchet separator state - let step2 = step1.ratchet(); - // Add two scalars → 2 × 32 bytes = 64 - let step3 = >::add_scalars(step2, 2, "resp"); - // Squeeze one scalar → 32 bytes uniform modp = 47 - let sep = >::challenge_scalars(step3, 1, "c"); - - // "joint" → domain label - // "\0A32pk" → absorb 32 bytes (1 group point) - // "\0R" → ratchet - // "\0A64resp" → absorb 64 bytes (2 Fr) - // "\0S47c" → squeeze 47 bytes (1 Fr uniform) - assert_eq!(sep.as_bytes(), b"joint\0A32pk\0R\0A64resp\0S47c"); - } - - #[test] - fn test_field_domain_separator_for_custom_fp() { - #[derive(MontConfig)] - #[modulus = "18446744069414584321"] - #[generator = "7"] - pub struct FConfig64; - pub type Field64 = Fp64>; - - pub type Field64_2 = Fp2; - pub struct F2Config64; - impl Fp2Config for F2Config64 { - type Fp = Field64; - - const NONRESIDUE: Self::Fp = MontFp!("7"); - - const FROBENIUS_COEFF_FP2_C1: &'static [Self::Fp] = &[ - // Fq(7)**(((q^0) - 1) / 2) - MontFp!("1"), - // Fq(7)**(((q^1) - 1) / 2) - MontFp!("18446744069414584320"), - ]; - } - - // First absorb 3 Field64 elements: - // - Fp64 has MODULUS_BIT_SIZE = 64 - // - bytes_modp(64) = 8 - // - 3 scalars × 8 bytes = 24 bytes - // → \0A24foo - let sep = >::add_scalars( - DomainSeparator::new("test-fp"), - 3, - "foo", + format!("{pattern}"), + r#"Spongefish Transcript (28 interactions) +00 Begin Protocol github.com/mmaker/spongefish None () +01 Begin Message g Fixed(1) ark_ec::models::twisted_edwards::group::Projective +02 Begin Message serialized-group Fixed(32) u8 +03 Atomic Message units Fixed(32) u8 +04 End Message serialized-group Fixed(32) u8 +05 End Message g Fixed(1) ark_ec::models::twisted_edwards::group::Projective +06 Begin Message pk Fixed(1) ark_ec::models::twisted_edwards::group::Projective +07 Begin Message serialized-group Fixed(32) u8 +08 Atomic Message units Fixed(32) u8 +09 End Message serialized-group Fixed(32) u8 +10 End Message pk Fixed(1) ark_ec::models::twisted_edwards::group::Projective +11 Atomic Protocol ratchet None () +12 Begin Message com Fixed(1) ark_ec::models::twisted_edwards::group::Projective +13 Begin Message serialized-group Fixed(32) u8 +14 Atomic Message units Fixed(32) u8 +15 End Message serialized-group Fixed(32) u8 +16 End Message com Fixed(1) ark_ec::models::twisted_edwards::group::Projective +17 Begin Challenge chal Fixed(1) ark_ff::fields::models::fp::Fp, 4> +18 Begin Challenge base-field-coefficients-little-endian Fixed(47) u8 +19 Atomic Challenge units Fixed(47) u8 +20 End Challenge base-field-coefficients-little-endian Fixed(47) u8 +21 End Challenge chal Fixed(1) ark_ff::fields::models::fp::Fp, 4> +22 Begin Message resp Fixed(1) ark_ff::fields::models::fp::Fp, 4> +23 Begin Message base-field-coefficients-little-endian Fixed(32) u8 +24 Atomic Message units Fixed(32) u8 +25 End Message base-field-coefficients-little-endian Fixed(32) u8 +26 End Message resp Fixed(1) ark_ff::fields::models::fp::Fp, 4> +27 End Protocol github.com/mmaker/spongefish None () +"# ); - - // Then squeeze 1 Field64_2 element: - // - Fp2 has extension_degree = 2 (since it's two Field64 elements) - // - bytes_uniform_modp(64) = 24 - // - 2 × 24 = 48 bytes - // → \0S48bar - let sep = - >::challenge_scalars(sep, 1, "bar"); - - // Final byte encoding is: - // - "test-fp" domain label - // - \0A24foo → absorb 24 bytes labeled "foo" - // - \0S48bar → squeeze 48 bytes labeled "bar" - let expected = b"test-fp\0A24foo\0S48bar"; - assert_eq!(sep.as_bytes(), expected); - } - - #[test] - fn test_add_scalars_babybear() { - // Test absorption of scalars from the base field BabyBear. - // - BabyBear has extension degree = 1 - // - Field size: 2^31 - 2^27 + 1 → 31 bits → bytes_modp(31) = 4 - // - 2 scalars * 1 * 4 = 8 bytes absorbed - // - "A" prefix indicates absorption in the domain separator - let sep = >::add_scalars( - DomainSeparator::new("babybear"), - 2, - "foo", - ); - - let expected = b"babybear\0A8foo"; - assert_eq!(sep.as_bytes(), expected); - } - - #[test] - fn test_challenge_scalars_babybear() { - // Test squeezing of scalars from the base field BabyBear. - // - BabyBear has extension degree = 1 - // - bytes_uniform_modp(31) = 5 - // - 3 scalars * 1 * 5 = 15 bytes squeezed - // - "S" prefix indicates squeezing in the domain separator - let sep = >::challenge_scalars( - DomainSeparator::new("bb"), - 3, - "bar", - ); - - let expected = b"bb\0S57bar"; - assert_eq!(sep.as_bytes(), expected); } - #[test] - fn test_add_scalars_quadratic_ext_field() { - // Test absorption of scalars from a quadratic extension field (BabyBear2 = Fp2 over BabyBear). - // - Extension degree = 2 - // - Base field bits = 31 → bytes_modp(31) = 4 - // - 2 scalars * 2 * 4 = 16 bytes absorbed - let sep = >::add_scalars( - DomainSeparator::new("ext"), - 2, - "a", - ); - - let expected = b"ext\0A16a"; - assert_eq!(sep.as_bytes(), expected); - } - - #[test] - fn test_challenge_scalars_quadratic_ext_field() { - // Test squeezing of scalars from a quadratic extension field (BabyBear2 = Fp2 over BabyBear). - // - Extension degree = 2 - // - bytes_uniform_modp(31) = 19 - // - 1 scalar * 2 * 19 = 38 bytes squeezed - let sep = >::challenge_scalars( - DomainSeparator::new("ext2"), - 1, - "b", - ); - - let expected = b"ext2\0S38b"; - assert_eq!(sep.as_bytes(), expected); - } - - #[test] - fn test_add_scalars_quartic_ext_field() { - // Test absorption of scalars from a quartic extension field (BabyBear4 = Fp4 over BabyBear). - // - Extension degree = 4 - // - Base field bits = 31 → bytes_modp(31) = 4 - // - 2 scalars * 4 * 4 = 32 bytes absorbed - let sep = >::add_scalars( - DomainSeparator::new("ext"), - 2, - "a", - ); - let expected = b"ext\0A32a"; - assert_eq!(sep.as_bytes(), expected); - } - - #[test] - fn test_challenge_scalars_quartic_ext_field() { - // Test squeezing of scalars from a quartic extension field (BabyBear4 = Fp4 over BabyBear). - // - Extension degree = 4 - // - bytes_uniform_modp(31) = 19 - // - 1 scalar * 4 * 19 = 76 bytes squeezed - let sep = >::challenge_scalars( - DomainSeparator::new("ext2"), - 1, - "b", - ); - - let expected = b"ext2\0S76b"; - assert_eq!(sep.as_bytes(), expected); - } + // #[test] + // fn test_scalar_vs_byte_equivalence_field_modp() { + // type F = Fr; + + // let label = "same-scalar"; + // // Compute number of bytes needed to represent one scalar + // let scalar_bytes = bytes_modp(F::MODULUS_BIT_SIZE); + + // // Add one scalar to the transcript using the scalar API + // let scalar_sep = >::add_scalars( + // DomainSeparator::::new("label"), + // 1, + // label, + // ); + + // // Add the same number of bytes directly + // let byte_sep = DomainSeparator::::new("label").add_bytes(scalar_bytes, label); + + // // Ensure the encodings are equal + // assert_eq!(scalar_sep.as_bytes(), byte_sep.as_bytes()); + // } + + // #[test] + // fn test_challenge_scalars_vs_bytes_equivalence() { + // type F = Fr; + + // let label = "challenge"; + // // Compute the number of bytes needed for one uniform scalar + // let uniform_bytes = bytes_uniform_modp(F::MODULUS_BIT_SIZE); + + // // Request 2 scalar challenges + // let sep_scalar = >::challenge_scalars( + // DomainSeparator::::new("L"), + // 2, + // label, + // ); + + // // Request 2 * bytes directly + // let sep_bytes = + // DomainSeparator::::new("L").challenge_bytes(2 * uniform_bytes, label); + + // // Ensure the encodings match + // assert_eq!(sep_scalar.as_bytes(), sep_bytes.as_bytes()); + // } + + // #[test] + // fn test_domain_separator_fq2_bytes_are_expected() { + // type F = Fq2; + + // // Construct the separator with one Fq2 absorbed and one Fq2 squeezed + // let sep = >::challenge_scalars( + // >::add_scalars( + // DomainSeparator::::new("ark-fq2"), + // 1, + // "a", + // ), + // 1, + // "b", + // ); + + // // Explanation of the expected encoding: + // // "ark-fq2" → domain label + // // "\0A96a" → absorb 96 bytes (Fq2 = 2 × 48) + // // "\0S126b" → squeeze 126 bytes (Fq2 = 2 × 63 uniform bytes) + // let expected_bytes = b"ark-fq2\0A96a\0S126b"; + // assert_eq!(sep.as_bytes(), expected_bytes); + // } + + // #[test] + // fn test_group_point_encoding_vs_bytes_direct() { + // type G = Curve; + + // // Add 2 group elements to the transcript (compressed size = 32 each) + // let point_sep = >::add_points( + // DomainSeparator::::new("G"), + // 2, + // "X", + // ); + + // // Add 64 raw bytes directly instead + // let byte_sep = DomainSeparator::::new("G").add_bytes(64, "X"); + + // // Ensure they are equivalent + // assert_eq!(point_sep.as_bytes(), byte_sep.as_bytes()); + // } + + // #[test] + // fn test_domain_separator_determinism() { + // type G = Curve; + // type F = Fr; + + // // First sequence: add group point, absorb scalar, squeeze scalar + // let add_pts = >::add_points( + // DomainSeparator::::new("proof"), + // 1, + // "pk", + // ); + // let add_scalars = + // >::add_scalars(add_pts, 1, "x"); + // let d1 = + // >::challenge_scalars(add_scalars, 1, "y"); + + // // Repeat the same sequence again + // let add_pts_2 = >::add_points( + // DomainSeparator::::new("proof"), + // 1, + // "pk", + // ); + // let add_scalars_2 = + // >::add_scalars(add_pts_2, 1, "x"); + // let d2 = + // >::challenge_scalars(add_scalars_2, 1, "y"); + + // // Resulting byte encodings must be the same + // assert_eq!(d1.as_bytes(), d2.as_bytes()); + // } + + // #[test] + // fn test_group_and_field_mixed_usage_structure() { + // type G = Curve; + // type F = Fr; + + // // Add one group element (compressed 32 bytes) + // let step1 = >::add_points( + // DomainSeparator::::new("joint"), + // 1, + // "pk", + // ); + // // Ratchet separator state + // let step2 = step1.ratchet(); + // // Add two scalars → 2 × 32 bytes = 64 + // let step3 = >::add_scalars(step2, 2, "resp"); + // // Squeeze one scalar → 32 bytes uniform modp = 47 + // let sep = >::challenge_scalars(step3, 1, "c"); + + // // "joint" → domain label + // // "\0A32pk" → absorb 32 bytes (1 group point) + // // "\0R" → ratchet + // // "\0A64resp" → absorb 64 bytes (2 Fr) + // // "\0S47c" → squeeze 47 bytes (1 Fr uniform) + // assert_eq!(sep.as_bytes(), b"joint\0A32pk\0R\0A64resp\0S47c"); + // } + + // #[test] + // fn test_field_domain_separator_for_custom_fp() { + // #[derive(MontConfig)] + // #[modulus = "18446744069414584321"] + // #[generator = "7"] + // pub struct FConfig64; + // pub type Field64 = Fp64>; + + // pub type Field64_2 = Fp2; + // pub struct F2Config64; + // impl Fp2Config for F2Config64 { + // type Fp = Field64; + + // const NONRESIDUE: Self::Fp = MontFp!("7"); + + // const FROBENIUS_COEFF_FP2_C1: &'static [Self::Fp] = &[ + // // Fq(7)**(((q^0) - 1) / 2) + // MontFp!("1"), + // // Fq(7)**(((q^1) - 1) / 2) + // MontFp!("18446744069414584320"), + // ]; + // } + + // // First absorb 3 Field64 elements: + // // - Fp64 has MODULUS_BIT_SIZE = 64 + // // - bytes_modp(64) = 8 + // // - 3 scalars × 8 bytes = 24 bytes + // // → \0A24foo + // let sep = >::add_scalars( + // DomainSeparator::new("test-fp"), + // 3, + // "foo", + // ); + + // // Then squeeze 1 Field64_2 element: + // // - Fp2 has extension_degree = 2 (since it's two Field64 elements) + // // - bytes_uniform_modp(64) = 24 + // // - 2 × 24 = 48 bytes + // // → \0S48bar + // let sep = + // >::challenge_scalars(sep, 1, "bar"); + + // // Final byte encoding is: + // // - "test-fp" domain label + // // - \0A24foo → absorb 24 bytes labeled "foo" + // // - \0S48bar → squeeze 48 bytes labeled "bar" + // let expected = b"test-fp\0A24foo\0S48bar"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_add_scalars_babybear() { + // // Test absorption of scalars from the base field BabyBear. + // // - BabyBear has extension degree = 1 + // // - Field size: 2^31 - 2^27 + 1 → 31 bits → bytes_modp(31) = 4 + // // - 2 scalars * 1 * 4 = 8 bytes absorbed + // // - "A" prefix indicates absorption in the domain separator + // let sep = >::add_scalars( + // DomainSeparator::new("babybear"), + // 2, + // "foo", + // ); + + // let expected = b"babybear\0A8foo"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_challenge_scalars_babybear() { + // // Test squeezing of scalars from the base field BabyBear. + // // - BabyBear has extension degree = 1 + // // - bytes_uniform_modp(31) = 5 + // // - 3 scalars * 1 * 5 = 15 bytes squeezed + // // - "S" prefix indicates squeezing in the domain separator + // let sep = >::challenge_scalars( + // DomainSeparator::new("bb"), + // 3, + // "bar", + // ); + + // let expected = b"bb\0S57bar"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_add_scalars_quadratic_ext_field() { + // // Test absorption of scalars from a quadratic extension field (BabyBear2 = Fp2 over BabyBear). + // // - Extension degree = 2 + // // - Base field bits = 31 → bytes_modp(31) = 4 + // // - 2 scalars * 2 * 4 = 16 bytes absorbed + // let sep = >::add_scalars( + // DomainSeparator::new("ext"), + // 2, + // "a", + // ); + + // let expected = b"ext\0A16a"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_challenge_scalars_quadratic_ext_field() { + // // Test squeezing of scalars from a quadratic extension field (BabyBear2 = Fp2 over BabyBear). + // // - Extension degree = 2 + // // - bytes_uniform_modp(31) = 19 + // // - 1 scalar * 2 * 19 = 38 bytes squeezed + // let sep = >::challenge_scalars( + // DomainSeparator::new("ext2"), + // 1, + // "b", + // ); + + // let expected = b"ext2\0S38b"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_add_scalars_quartic_ext_field() { + // // Test absorption of scalars from a quartic extension field (BabyBear4 = Fp4 over BabyBear). + // // - Extension degree = 4 + // // - Base field bits = 31 → bytes_modp(31) = 4 + // // - 2 scalars * 4 * 4 = 32 bytes absorbed + // let sep = >::add_scalars( + // DomainSeparator::new("ext"), + // 2, + // "a", + // ); + // let expected = b"ext\0A32a"; + // assert_eq!(sep.as_bytes(), expected); + // } + + // #[test] + // fn test_challenge_scalars_quartic_ext_field() { + // // Test squeezing of scalars from a quartic extension field (BabyBear4 = Fp4 over BabyBear). + // // - Extension degree = 4 + // // - bytes_uniform_modp(31) = 19 + // // - 1 scalar * 4 * 19 = 76 bytes squeezed + // let sep = >::challenge_scalars( + // DomainSeparator::new("ext2"), + // 1, + // "b", + // ); + + // let expected = b"ext2\0S76b"; + // assert_eq!(sep.as_bytes(), expected); + // } } diff --git a/spongefish/src/codecs/arkworks_algebra/mod.rs b/spongefish/src/codecs/arkworks_algebra/mod.rs index 43431d8..e12ad06 100644 --- a/spongefish/src/codecs/arkworks_algebra/mod.rs +++ b/spongefish/src/codecs/arkworks_algebra/mod.rs @@ -128,12 +128,11 @@ mod deserialize; mod prover_messages; /// Tests for arkworks. -#[cfg(test)] -mod tests; - +// #[cfg(test)] +// mod tests; pub use crate::{ - duplex_sponge::Unit, traits::*, DomainSeparator, DuplexSpongeInterface, - HashStateWithInstructions, ProofError, ProofResult, ProverState, VerifierState, + duplex_sponge::Unit, traits::*, DuplexSpongeInterface, ProofError, ProofResult, ProverState, + VerifierState, }; super::traits::field_traits!(ark_ff::Field); diff --git a/spongefish/src/codecs/arkworks_algebra/prover_messages.rs b/spongefish/src/codecs/arkworks_algebra/prover_messages.rs index 126989a..07afd74 100644 --- a/spongefish/src/codecs/arkworks_algebra/prover_messages.rs +++ b/spongefish/src/codecs/arkworks_algebra/prover_messages.rs @@ -5,17 +5,16 @@ use rand::{CryptoRng, RngCore}; use super::{CommonFieldToUnit, CommonGroupToUnit, FieldToUnitSerialize, GroupToUnitSerialize}; use crate::{ - BytesToUnitDeserialize, BytesToUnitSerialize, CommonUnitToBytes, DomainSeparatorMismatch, - DuplexSpongeInterface, ProofResult, ProverState, Unit, UnitTranscript, VerifierState, + BytesToUnitDeserialize, BytesToUnitSerialize, CommonUnitToBytes, DuplexSpongeInterface, + ProofResult, ProverState, Unit, UnitTranscript, VerifierState, }; impl FieldToUnitSerialize for ProverState { - fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> { + fn add_scalars(&mut self, input: &[F]) { let serialized = self.public_scalars(input); - self.narg_string.extend(serialized?); - Ok(()) + self.narg_string.extend(serialized); } } @@ -26,12 +25,13 @@ impl< const N: usize, > FieldToUnitSerialize> for ProverState, R> { - fn add_scalars(&mut self, input: &[Fp]) -> ProofResult<()> { - self.public_units(input)?; + fn add_scalars(&mut self, input: &[Fp]) { + self.public_units(input); for i in input { - i.serialize_compressed(&mut self.narg_string)?; + // Serialization should be infallible. + i.serialize_compressed(&mut self.narg_string) + .expect("Serialization failed"); } - Ok(()) } } @@ -42,10 +42,9 @@ where R: RngCore + CryptoRng, Self: CommonGroupToUnit>, { - fn add_points(&mut self, input: &[G]) -> ProofResult<()> { + fn add_points(&mut self, input: &[G]) { let serialized = self.public_points(input); - self.narg_string.extend(serialized?); - Ok(()) + self.narg_string.extend(serialized); } } @@ -57,12 +56,12 @@ where R: RngCore + CryptoRng, Self: CommonGroupToUnit + FieldToUnitSerialize, { - fn add_points(&mut self, input: &[G]) -> ProofResult<()> { - self.public_points(input).map(|_| ())?; + fn add_points(&mut self, input: &[G]) { + self.public_points(input); for i in input { - i.serialize_compressed(&mut self.narg_string)?; + i.serialize_compressed(&mut self.narg_string) + .expect("Serialization failed"); } - Ok(()) } } @@ -72,10 +71,9 @@ where C: FpConfig, R: RngCore + CryptoRng, { - fn add_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> { - self.public_bytes(input)?; + fn add_bytes(&mut self, input: &[u8]) { + self.public_bytes(input); self.narg_string.extend(input); - Ok(()) } } @@ -84,13 +82,15 @@ where H: DuplexSpongeInterface>, C: FpConfig, { - fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), DomainSeparatorMismatch> { + fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), std::io::Error> { u8::read(&mut self.narg_string, input)?; - self.public_bytes(input) + self.public_bytes(input); + Ok(()) } } #[cfg(test)] +#[cfg(feature = "disable")] mod tests { use ark_bls12_381::Fr; use ark_curve25519::EdwardsProjective; @@ -99,10 +99,8 @@ mod tests { use super::*; use crate::{ - codecs::arkworks_algebra::{ - FieldDomainSeparator, FieldToUnitSerialize, GroupDomainSeparator, - }, - ByteDomainSeparator, DefaultHash, DomainSeparator, + codecs::arkworks_algebra::{FieldPattern, FieldToUnitSerialize, GroupPattern}, + ByteDomainSeparator, DefaultHash, }; /// Curve used for tests diff --git a/spongefish/src/codecs/arkworks_algebra/tests.rs b/spongefish/src/codecs/arkworks_algebra/tests.rs index cc00cc3..7988035 100644 --- a/spongefish/src/codecs/arkworks_algebra/tests.rs +++ b/spongefish/src/codecs/arkworks_algebra/tests.rs @@ -2,7 +2,7 @@ use ark_ff::Field; use crate::{ ByteDomainSeparator, BytesToUnitDeserialize, BytesToUnitSerialize, DefaultHash, - DomainSeparator, DuplexSpongeInterface, ProofResult, Unit, UnitToBytes, UnitTranscript, + DuplexSpongeInterface, ProofResult, Unit, UnitToBytes, UnitTranscript, }; /// Test that the algebraic hashes do use the IV generated from the domain separator. diff --git a/spongefish/src/codecs/arkworks_algebra/verifier_messages.rs b/spongefish/src/codecs/arkworks_algebra/verifier_messages.rs index 4ff67ca..89b6723 100644 --- a/spongefish/src/codecs/arkworks_algebra/verifier_messages.rs +++ b/spongefish/src/codecs/arkworks_algebra/verifier_messages.rs @@ -7,8 +7,8 @@ use rand::{CryptoRng, RngCore}; use super::{CommonFieldToUnit, CommonGroupToUnit, UnitToField}; use crate::{ - codecs::bytes_uniform_modp, CommonUnitToBytes, DomainSeparatorMismatch, DuplexSpongeInterface, - ProofError, ProofResult, ProverState, Unit, UnitToBytes, UnitTranscript, VerifierState, + codecs::bytes_uniform_modp, CommonUnitToBytes, DuplexSpongeInterface, ProofError, ProofResult, + ProverState, Unit, UnitToBytes, UnitTranscript, VerifierState, }; // Implementation of basic traits for bridging arkworks and spongefish @@ -46,13 +46,15 @@ where { type Repr = Vec; - fn public_points(&mut self, input: &[G]) -> ProofResult { + fn public_points(&mut self, input: &[G]) -> Self::Repr { let mut buf = Vec::new(); for i in input { - i.serialize_compressed(&mut buf)?; + // Serialization should be infallible + i.serialize_compressed(&mut buf) + .expect("Serialization failed."); } - self.public_bytes(&buf)?; - Ok(buf) + self.public_bytes(&buf); + buf } } @@ -63,13 +65,15 @@ where { type Repr = Vec; - fn public_scalars(&mut self, input: &[F]) -> ProofResult { + fn public_scalars(&mut self, input: &[F]) -> Self::Repr { let mut buf = Vec::new(); for i in input { - i.serialize_compressed(&mut buf)?; + // Writing to buffer should be infallible + i.serialize_compressed(&mut buf) + .expect("Serialization failed."); } - self.public_bytes(&buf)?; - Ok(buf) + self.public_bytes(&buf); + buf } } @@ -78,19 +82,19 @@ where F: Field, T: UnitTranscript, { - fn fill_challenge_scalars(&mut self, output: &mut [F]) -> ProofResult<()> { + fn fill_challenge_scalars(&mut self, output: &mut [F]) { let base_field_size = bytes_uniform_modp(F::BasePrimeField::MODULUS_BIT_SIZE); let mut buf = vec![0u8; F::extension_degree() as usize * base_field_size]; for o in output.iter_mut() { - self.fill_challenge_bytes(&mut buf)?; + self.fill_challenge_bytes(&mut buf); *o = F::from_base_prime_field_elems( buf.chunks(base_field_size) .map(F::BasePrimeField::from_be_bytes_mod_order), ) .expect("Could not convert"); } - Ok(()) + () } } @@ -99,9 +103,8 @@ where C: FpConfig, H: DuplexSpongeInterface>, { - fn fill_challenge_scalars(&mut self, output: &mut [Fp]) -> ProofResult<()> { - self.fill_challenge_units(output) - .map_err(ProofError::InvalidDomainSeparator) + fn fill_challenge_scalars(&mut self, output: &mut [Fp]) { + self.fill_challenge_units(output); } } @@ -111,9 +114,8 @@ where H: DuplexSpongeInterface>, R: CryptoRng + RngCore, { - fn fill_challenge_scalars(&mut self, output: &mut [Fp]) -> ProofResult<()> { - self.fill_challenge_units(output) - .map_err(ProofError::InvalidDomainSeparator) + fn fill_challenge_scalars(&mut self, output: &mut [Fp]) { + self.fill_challenge_units(output); } } @@ -128,13 +130,13 @@ where { type Repr = (); - fn public_scalars(&mut self, input: &[F]) -> ProofResult { + fn public_scalars(&mut self, input: &[F]) -> Self::Repr { let flattened: Vec<_> = input .iter() .flat_map(Field::to_base_prime_field_elements) .collect(); - self.public_units(&flattened)?; - Ok(()) + self.public_units(&flattened); + () } } @@ -165,13 +167,13 @@ where { type Repr = (); - fn public_scalars(&mut self, input: &[F]) -> ProofResult { + fn public_scalars(&mut self, input: &[F]) -> Self::Repr { let flattened: Vec<_> = input .iter() .flat_map(Field::to_base_prime_field_elements) .collect(); - self.public_units(&flattened)?; - Ok(()) + self.public_units(&flattened); + () } } @@ -184,12 +186,12 @@ where { type Repr = (); - fn public_points(&mut self, input: &[G]) -> ProofResult { + fn public_points(&mut self, input: &[G]) -> Self::Repr { for point in input { let (x, y) = point.into_affine().xy().unwrap(); - self.public_units(&[x, y])?; + self.public_units(&[x, y]); } - Ok(()) + () } } @@ -201,12 +203,12 @@ where { type Repr = (); - fn public_points(&mut self, input: &[G]) -> ProofResult { + fn public_points(&mut self, input: &[G]) -> Self::Repr { for point in input { let (x, y) = point.into_affine().xy().unwrap(); - self.public_units(&[x, y])?; + self.public_units(&[x, y]); } - Ok(()) + () } } @@ -217,11 +219,10 @@ where C: FpConfig, H: DuplexSpongeInterface>, { - fn public_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> { + fn public_bytes(&mut self, input: &[u8]) { for &byte in input { - self.public_units(&[Fp::from(byte)])?; + self.public_units(&[Fp::from(byte)]); } - Ok(()) } } @@ -231,11 +232,10 @@ where H: DuplexSpongeInterface>, R: CryptoRng + rand::RngCore, { - fn public_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> { + fn public_bytes(&mut self, input: &[u8]) { for &byte in input { - self.public_units(&[Fp::from(byte)])?; + self.public_units(&[Fp::from(byte)]); } - Ok(()) } } @@ -245,21 +245,19 @@ where H: DuplexSpongeInterface>, R: CryptoRng + RngCore, { - fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), DomainSeparatorMismatch> { - if output.is_empty() { - Ok(()) - } else { + fn fill_challenge_bytes(&mut self, output: &mut [u8]) { + if !output.is_empty() { let len_good = usize::min( crate::codecs::random_bytes_in_random_modp(Fp::::MODULUS), output.len(), ); let mut tmp = [Fp::from(0); 1]; - self.fill_challenge_units(&mut tmp)?; + self.fill_challenge_units(&mut tmp); let buf = tmp[0].into_bigint().to_bytes_le(); output[..len_good].copy_from_slice(&buf[..len_good]); // recursively fill the rest of the buffer - self.fill_challenge_bytes(&mut output[len_good..]) + self.fill_challenge_bytes(&mut output[len_good..]); } } } @@ -270,26 +268,25 @@ where C: FpConfig, H: DuplexSpongeInterface>, { - fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), DomainSeparatorMismatch> { - if output.is_empty() { - Ok(()) - } else { + fn fill_challenge_bytes(&mut self, output: &mut [u8]) { + if !output.is_empty() { let len_good = usize::min( crate::codecs::random_bytes_in_random_modp(Fp::::MODULUS), output.len(), ); let mut tmp = [Fp::from(0); 1]; - self.fill_challenge_units(&mut tmp)?; + self.fill_challenge_units(&mut tmp); let buf = tmp[0].into_bigint().to_bytes_le(); output[..len_good].copy_from_slice(&buf[..len_good]); // recursively fill the rest of the buffer - self.fill_challenge_bytes(&mut output[len_good..]) + self.fill_challenge_bytes(&mut output[len_good..]); } } } #[cfg(test)] +#[cfg(feature = "disable")] mod tests { use ark_curve25519::EdwardsProjective as Curve; use ark_ec::PrimeGroup; @@ -297,8 +294,8 @@ mod tests { use super::*; use crate::{ - codecs::arkworks_algebra::{FieldDomainSeparator, GroupDomainSeparator}, - DefaultHash, DomainSeparator, + codecs::arkworks_algebra::{FieldPattern, GroupPattern}, + DefaultHash, }; /// Configuration for the BabyBear field (modulus = 2^31 - 2^27 + 1, generator = 21). diff --git a/spongefish/src/codecs/bytes.rs b/spongefish/src/codecs/bytes.rs new file mode 100644 index 0000000..0dc7ba6 --- /dev/null +++ b/spongefish/src/codecs/bytes.rs @@ -0,0 +1,32 @@ +use crate::{ + codecs::unit::Pattern as _, + pattern::{Label, Length, Pattern as _, PatternState}, +}; + +/// Traits for patterns that handle byte arrays in a transcript. +pub trait Pattern { + fn public_bytes(&mut self, label: Label, size: usize); + fn message_bytes(&mut self, label: Label, size: usize); + fn challenge_bytes(&mut self, label: Label, size: usize); +} + +/// Implementation where `Unit = u8` +impl Pattern for PatternState { + fn public_bytes(&mut self, label: Label, size: usize) { + self.begin_public::(label, Length::Fixed(size)); + self.public_units("units", size); + self.end_public::(label, Length::Fixed(size)) + } + + fn message_bytes(&mut self, label: Label, size: usize) { + self.begin_message::(label, Length::Fixed(size)); + self.message_units("units", size); + self.end_message::(label, Length::Fixed(size)) + } + + fn challenge_bytes(&mut self, label: Label, size: usize) { + self.begin_challenge::(label, Length::Fixed(size)); + self.challenge_units("units", size); + self.end_challenge::(label, Length::Fixed(size)) + } +} diff --git a/spongefish/src/codecs/mod.rs b/spongefish/src/codecs/mod.rs index 525e2b0..ebf07ae 100644 --- a/spongefish/src/codecs/mod.rs +++ b/spongefish/src/codecs/mod.rs @@ -1,5 +1,8 @@ //! Bindings to some popular libraries using zero-knowledge. +pub mod bytes; +pub mod unit; + /// Extension traits macros, for both arkworks and group. #[cfg(any(feature = "arkworks-algebra", feature = "zkcrypto-group"))] mod traits; @@ -54,6 +57,6 @@ pub(super) const fn bytes_modp(modulus_bits: u32) -> usize { (modulus_bits as usize).div_ceil(8) } -/// Unit-tests for inter-operability among libraries. -#[cfg(all(test, feature = "arkworks-algebra", feature = "zkcrypto-group"))] -mod tests; +// /// Unit-tests for inter-operability among libraries. +// #[cfg(all(test, feature = "arkworks-algebra", feature = "zkcrypto-group"))] +// mod tests; diff --git a/spongefish/src/codecs/traits.rs b/spongefish/src/codecs/traits.rs index 3236b5d..08273ba 100644 --- a/spongefish/src/codecs/traits.rs +++ b/spongefish/src/codecs/traits.rs @@ -1,11 +1,9 @@ macro_rules! field_traits { ($Field:path) => { /// Absorb and squeeze field elements to the domain separator. - pub trait FieldDomainSeparator { - #[must_use] - fn add_scalars(self, count: usize, label: &str) -> Self; - #[must_use] - fn challenge_scalars(self, count: usize, label: &str) -> Self; + pub trait FieldPattern { + fn message_scalars(&mut self, label: $crate::pattern::Label, count: usize); + fn challenge_scalars(&mut self, label: $crate::pattern::Label, count: usize); } /// Interpret verifier messages as uniformly distributed field elements. @@ -13,24 +11,24 @@ macro_rules! field_traits { /// The implementation of this trait **MUST** ensure that the field elements /// are uniformly distributed and valid. pub trait UnitToField { - fn fill_challenge_scalars(&mut self, output: &mut [F]) -> $crate::ProofResult<()>; + fn fill_challenge_scalars(&mut self, output: &mut [F]); - fn challenge_scalars(&mut self) -> crate::ProofResult<[F; N]> { + fn challenge_scalars(&mut self) -> [F; N] { let mut output = [F::default(); N]; - self.fill_challenge_scalars(&mut output)?; - Ok(output) + self.fill_challenge_scalars(&mut output); + output } } /// Add field elements as shared public information. pub trait CommonFieldToUnit { type Repr; - fn public_scalars(&mut self, input: &[F]) -> crate::ProofResult; + fn public_scalars(&mut self, input: &[F]) -> Self::Repr; } /// Add field elements to the protocol transcript. pub trait FieldToUnitSerialize: CommonFieldToUnit { - fn add_scalars(&mut self, input: &[F]) -> crate::ProofResult<()>; + fn add_scalars(&mut self, input: &[F]); } /// Deserialize field elements from the protocol transcript. @@ -49,18 +47,16 @@ macro_rules! field_traits { }; } -#[macro_export] macro_rules! group_traits { ($Group:path, Scalar: $Field:path) => { /// Send group elements in the domain separator. - pub trait GroupDomainSeparator { - #[must_use] - fn add_points(self, count: usize, label: &str) -> Self; + pub trait GroupPattern { + fn message_points(&mut self, label: $crate::pattern::Label, count: usize); } /// Adds a new prover message consisting of an EC element. pub trait GroupToUnitSerialize: CommonGroupToUnit { - fn add_points(&mut self, input: &[G]) -> $crate::ProofResult<()>; + fn add_points(&mut self, input: &[G]); } /// Receive (and deserialize) group elements from the domain separator. @@ -87,7 +83,7 @@ macro_rules! group_traits { type Repr; /// Incorporate group elements into the proof without adding them to the final protocol transcript. - fn public_points(&mut self, input: &[G]) -> $crate::ProofResult; + fn public_points(&mut self, input: &[G]) -> Self::Repr; } }; } diff --git a/spongefish/src/codecs/unit.rs b/spongefish/src/codecs/unit.rs new file mode 100644 index 0000000..7de10b6 --- /dev/null +++ b/spongefish/src/codecs/unit.rs @@ -0,0 +1,15 @@ +use crate::{pattern::Label, Unit}; + +pub trait Pattern { + type Unit: Unit; + + fn ratchet(&mut self); + fn public_unit(&mut self, label: Label); + fn public_units(&mut self, label: Label, size: usize); + fn message_unit(&mut self, label: Label); + fn message_units(&mut self, label: Label, size: usize); + fn challenge_unit(&mut self, label: Label); + fn challenge_units(&mut self, label: Label, size: usize); + fn hint_bytes(&mut self, label: Label, size: usize); + fn hint_bytes_dynamic(&mut self, label: Label); +} diff --git a/spongefish/src/codecs/zkcrypto_group/domain_separator.rs b/spongefish/src/codecs/zkcrypto_group/domain_separator.rs index 48aafd1..cd44a51 100644 --- a/spongefish/src/codecs/zkcrypto_group/domain_separator.rs +++ b/spongefish/src/codecs/zkcrypto_group/domain_separator.rs @@ -1,33 +1,39 @@ use group::{ff::PrimeField, Group, GroupEncoding}; -use super::{FieldDomainSeparator, GroupDomainSeparator}; +use super::{FieldPattern, GroupPattern}; use crate::{ - codecs::{bytes_modp, bytes_uniform_modp}, - ByteDomainSeparator, DomainSeparator, DuplexSpongeInterface, + codecs::{bytes, bytes_modp, bytes_uniform_modp}, + pattern::{self, Label, Length}, }; -impl FieldDomainSeparator for DomainSeparator +impl FieldPattern for P where + P: pattern::Pattern + bytes::Pattern, F: PrimeField, - H: DuplexSpongeInterface, { - fn add_scalars(self, count: usize, label: &str) -> Self { - self.add_bytes(count * bytes_modp(F::NUM_BITS), label) + fn message_scalars(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); + self.message_bytes("bytes", count * bytes_modp(F::NUM_BITS)); + self.end_message::(label, Length::Fixed(count)); } - fn challenge_scalars(self, count: usize, label: &str) -> Self { - self.challenge_bytes(count * bytes_uniform_modp(F::NUM_BITS), label) + fn challenge_scalars(&mut self, label: Label, count: usize) { + self.begin_challenge::(label, Length::Fixed(count)); + self.challenge_bytes("bytes", count * bytes_uniform_modp(F::NUM_BITS)); + self.end_challenge::(label, Length::Fixed(count)); } } -impl GroupDomainSeparator for DomainSeparator +impl GroupPattern for P where + P: pattern::Pattern + bytes::Pattern, G: Group + GroupEncoding, G::Repr: AsRef<[u8]>, - H: DuplexSpongeInterface, { - fn add_points(self, count: usize, label: &str) -> Self { + fn message_points(&mut self, label: Label, count: usize) { + self.begin_message::(label, Length::Fixed(count)); let n = G::Repr::default().as_ref().len(); - self.add_bytes(count * n, label) + self.message_bytes("bytes", count * n); + self.end_message::(label, Length::Fixed(count)); } } diff --git a/spongefish/src/codecs/zkcrypto_group/prover_messages.rs b/spongefish/src/codecs/zkcrypto_group/prover_messages.rs index 982b88c..c3d8dcb 100644 --- a/spongefish/src/codecs/zkcrypto_group/prover_messages.rs +++ b/spongefish/src/codecs/zkcrypto_group/prover_messages.rs @@ -2,9 +2,7 @@ use group::{ff::PrimeField, Group, GroupEncoding}; use rand::{CryptoRng, RngCore}; use super::{CommonFieldToUnit, CommonGroupToUnit, FieldToUnitSerialize, GroupToUnitSerialize}; -use crate::{ - BytesToUnitSerialize, CommonUnitToBytes, DuplexSpongeInterface, ProofResult, ProverState, -}; +use crate::{BytesToUnitSerialize, CommonUnitToBytes, DuplexSpongeInterface, ProverState}; impl FieldToUnitSerialize for ProverState where @@ -12,10 +10,10 @@ where H: DuplexSpongeInterface, R: RngCore + CryptoRng, { - fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> { - let serialized = self.public_scalars(input); - self.narg_string.extend(serialized?); - Ok(()) + fn add_scalars(&mut self, input: &[F]) { + let mut buf = Vec::new(); + input.iter().for_each(|i| buf.extend(i.to_repr().as_ref())); + self.add_bytes(&buf); } } @@ -27,13 +25,13 @@ where R: RngCore + CryptoRng, { type Repr = Vec; - fn public_points(&mut self, input: &[G]) -> crate::ProofResult { + fn public_points(&mut self, input: &[G]) -> Self::Repr { let mut buf = Vec::new(); for p in input { buf.extend_from_slice(::to_bytes(p).as_ref()); } - self.add_bytes(&buf)?; - Ok(buf) + self.public_bytes(&buf); + buf } } @@ -44,10 +42,12 @@ where H: DuplexSpongeInterface, R: RngCore + CryptoRng, { - fn add_points(&mut self, input: &[G]) -> crate::ProofResult<()> { - let serialized = self.public_points(input); - self.narg_string.extend(serialized?); - Ok(()) + fn add_points(&mut self, input: &[G]) { + let mut buf = Vec::new(); + for p in input { + buf.extend_from_slice(::to_bytes(p).as_ref()); + } + self.add_bytes(&buf); } } @@ -58,10 +58,10 @@ where { type Repr = Vec; - fn public_scalars(&mut self, input: &[F]) -> ProofResult { + fn public_scalars(&mut self, input: &[F]) -> Self::Repr { let mut buf = Vec::new(); input.iter().for_each(|i| buf.extend(i.to_repr().as_ref())); - self.public_bytes(&buf)?; - Ok(buf) + self.public_bytes(&buf); + buf } } diff --git a/spongefish/src/codecs/zkcrypto_group/verifier_messages.rs b/spongefish/src/codecs/zkcrypto_group/verifier_messages.rs index 9cc00c3..28848f6 100644 --- a/spongefish/src/codecs/zkcrypto_group/verifier_messages.rs +++ b/spongefish/src/codecs/zkcrypto_group/verifier_messages.rs @@ -1,7 +1,7 @@ use group::ff::PrimeField; use super::UnitToField; -use crate::{codecs::bytes_uniform_modp, ProofResult, UnitToBytes}; +use crate::{codecs::bytes_uniform_modp, UnitToBytes}; /// Convert a byte array to a field element. /// @@ -20,14 +20,12 @@ where F: PrimeField, T: UnitToBytes, { - fn fill_challenge_scalars(&mut self, output: &mut [F]) -> ProofResult<()> { + fn fill_challenge_scalars(&mut self, output: &mut [F]) { let mut buf = vec![0; bytes_uniform_modp(F::NUM_BITS)]; for o in output { - self.fill_challenge_bytes(&mut buf)?; + self.fill_challenge_bytes(&mut buf); *o = from_bytes_mod_order(&buf); } - - Ok(()) } } diff --git a/spongefish/src/domain_separator.rs b/spongefish/src/domain_separator.rs deleted file mode 100644 index a9d0ceb..0000000 --- a/spongefish/src/domain_separator.rs +++ /dev/null @@ -1,553 +0,0 @@ -// XXX. before, absorb and squeeze were accepting arguments of type -// use ::core::num::NonZeroUsize; -// which was a pain to use -// (plain integers don't cast to NonZeroUsize automatically) - -use std::{collections::VecDeque, marker::PhantomData}; - -use super::{ - duplex_sponge::{DuplexSpongeInterface, Unit}, - errors::DomainSeparatorMismatch, -}; -use crate::ByteDomainSeparator; - -/// This is the separator between operations in the domain separator -/// and as such is the only forbidden character in labels. -const SEP_BYTE: &str = "\0"; - -/// The domain separator of an interactive protocol. -/// -/// An domain separator is a string that specifies the protocol in a simple, -/// non-ambiguous, human-readable format. A typical example is the following: -/// -/// ```text -/// domain-separator A32generator A32public-key R A32commitment S32challenge A32response -/// ``` -/// The domain-separator is a user-specified string uniquely identifying the end-user application (to avoid cross-protocol attacks). -/// The letter `A` indicates the absorption of a public input (an `ABSORB`), while the letter `S` indicates the squeezing (a `SQUEEZE`) of a challenge. -/// The letter `R` indicates a ratcheting operation: ratcheting means invoking the hash function even on an incomplete block. -/// It provides forward secrecy and allows it to start from a clean rate. -/// After the operation type, is the number of elements in base 10 that are being absorbed/squeezed. -/// Then, follows the label associated with the element being absorbed/squeezed. This often comes from the underlying description of the protocol. The label cannot start with a digit or contain the NULL byte. -/// -/// ## Guarantees -/// -/// The struct [`DomainSeparator`] guarantees the creation of a valid domain separator string, whose lengths are coherent with the types described in the protocol. No information about the types themselves is stored in an domain separator. -/// This means that [`ProverState`][`crate::ProverState`] or [`VerifierState`][`crate::VerifierState`] instances can generate successfully a protocol transcript respecting the length constraint but not the types. See [issue #6](https://github.com/arkworks-rs/spongefish/issues/6) for a discussion on the topic. -#[derive(Clone)] -pub struct DomainSeparator -where - U: Unit, - H: DuplexSpongeInterface, -{ - /// Encoded domain separator string representing the sequence of sponge operations. - io: String, - /// Marker for the sponge hash function and unit type used. - _hash: PhantomData<(H, U)>, -} - -/// Sponge operations. -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub enum Op { - /// Indicates absorption of `usize` lanes. - /// - /// In a tag, absorb is indicated with 'A'. - Absorb(usize), - /// Indicates processing of out-of-band message - /// from prover to verifier. - /// - /// This is useful for e.g. adding merkle proofs to the proof. - Hint, - /// Indicates squeezing of `usize` lanes. - /// - /// In a tag, squeeze is indicated with 'S'. - Squeeze(usize), - /// Indicates a ratchet operation. - /// - /// For sponge functions, we squeeze sizeof(capacity) lanes - /// and initialize a new state filling the capacity. - /// This allows for a more efficient preprocessing, and for removal of - /// private information stored in the rate. - Ratchet, -} - -impl Op { - /// Create a new OP from the portion of a tag. - fn new(id: char, count: Option) -> Result { - match (id, count) { - ('A', Some(c)) if c > 0 => Ok(Self::Absorb(c)), - ('H', None | Some(0)) => Ok(Self::Hint), - ('R', None | Some(0)) => Ok(Self::Ratchet), - ('S', Some(c)) if c > 0 => Ok(Self::Squeeze(c)), - _ => Err("Invalid tag".into()), - } - } -} - -impl, U: Unit> DomainSeparator { - #[must_use] - pub const fn from_string(io: String) -> Self { - Self { - io, - _hash: PhantomData, - } - } - - /// Create a new DomainSeparator with the domain separator. - #[must_use] - pub fn new(session_identifier: &str) -> Self { - assert!( - !session_identifier.contains(SEP_BYTE), - "Domain separator cannot contain the separator BYTE." - ); - Self::from_string(session_identifier.to_string()) - } - - /// Absorb `count` native elements. - #[must_use] - pub fn absorb(self, count: usize, label: &str) -> Self { - assert!(count > 0, "Count must be positive."); - assert!( - !label.contains(SEP_BYTE), - "Label cannot contain the separator BYTE." - ); - assert!( - label - .chars() - .next() - .is_none_or(|char| !char.is_ascii_digit()), - "Label cannot start with a digit." - ); - - Self::from_string(self.io + SEP_BYTE + &format!("A{count}") + label) - } - - /// Hint `count` native elements. - #[must_use] - pub fn hint(self, label: &str) -> Self { - assert!( - !label.contains(SEP_BYTE), - "Label cannot contain the separator BYTE." - ); - - Self::from_string(self.io + SEP_BYTE + "H" + label) - } - - /// Squeeze `count` native elements. - #[must_use] - pub fn squeeze(self, count: usize, label: &str) -> Self { - assert!(count > 0, "Count must be positive."); - assert!( - !label.contains(SEP_BYTE), - "Label cannot contain the separator BYTE." - ); - assert!( - label - .chars() - .next() - .is_none_or(|char| !char.is_ascii_digit()), - "Label cannot start with a digit." - ); - - Self::from_string(self.io + SEP_BYTE + &format!("S{count}") + label) - } - - /// Ratchet the state. - #[must_use] - pub fn ratchet(self) -> Self { - Self::from_string(self.io + SEP_BYTE + "R") - } - - /// Return the domain separator as bytes. - #[must_use] - pub fn as_bytes(&self) -> &[u8] { - self.io.as_bytes() - } - - /// Parse the givern domain separator into a sequence of [`Op`]'s. - pub(crate) fn finalize(&self) -> VecDeque { - // Guaranteed to succeed as instances are all valid domain_separators - Self::parse_domsep(self.io.as_bytes()) - .expect("Internal error. Please submit issue to m@orru.net") - } - - fn parse_domsep(domain_separator: &[u8]) -> Result, DomainSeparatorMismatch> { - let mut stack = VecDeque::new(); - - // skip the domain separator - for part in domain_separator - .split(|&b| b == SEP_BYTE.as_bytes()[0]) - .skip(1) - { - let next_id = part[0] as char; - let next_length = part[1..] - .iter() - .take_while(|x| x.is_ascii_digit()) - .fold(0, |acc, x| acc * 10 + (x - b'0') as usize); - - // check that next_length != 0 is performed internally on Op::new - let next_op = Op::new(next_id, Some(next_length))?; - stack.push_back(next_op); - } - - // consecutive calls are merged into one - match stack.pop_front() { - None => Ok(stack), - Some(x) => Ok(Self::simplify_stack([x].into(), stack)), - } - } - - fn simplify_stack(mut dst: VecDeque, mut stack: VecDeque) -> VecDeque { - while let Some(next) = stack.pop_front() { - match (dst.pop_back(), next) { - (Some(Op::Squeeze(a)), Op::Squeeze(b)) => dst.push_back(Op::Squeeze(a + b)), - (Some(Op::Absorb(a)), Op::Absorb(b)) => dst.push_back(Op::Absorb(a + b)), - (Some(prev), next) => { - dst.push_back(prev); - dst.push_back(next); - } - (None, next) => dst.push_back(next), - } - } - dst - } - - /// Create an [`crate::ProverState`] instance from the domain separator. - #[must_use] - pub fn to_prover_state(&self) -> crate::ProverState { - self.into() - } - - /// Create a [`crate::VerifierState`] instance from the domain separator and the protocol transcript (bytes). - #[must_use] - pub fn to_verifier_state<'a>(&self, transcript: &'a [u8]) -> crate::VerifierState<'a, H, U> { - crate::VerifierState::new(self, transcript) - } -} - -impl> core::fmt::Debug for DomainSeparator { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - // Ensure that the state isn't accidentally logged - write!(f, "DomainSeparator({:?})", self.io) - } -} - -impl ByteDomainSeparator for DomainSeparator { - #[inline] - fn add_bytes(self, count: usize, label: &str) -> Self { - self.absorb(count, label) - } - - fn hint(self, label: &str) -> Self { - self.hint(label) - } - - #[inline] - fn challenge_bytes(self, count: usize, label: &str) -> Self { - self.squeeze(count, label) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::DefaultHash; - - pub type H = DefaultHash; - - #[test] - fn test_op_new_invalid_cases() { - assert!(Op::new('A', Some(0)).is_err()); // absorb with zero - assert!(Op::new('H', Some(1)).is_err()); // hint with size - assert!(Op::new('S', Some(0)).is_err()); // squeeze with zero - assert!(Op::new('X', Some(1)).is_err()); // invalid op char - assert!(Op::new('R', Some(5)).is_err()); // R doesn't support > 0 - assert!(Op::new('R', Some(0)).is_ok()); // ratchet with 0 - assert!(Op::new('R', None).is_ok()); // ratchet with None - } - - #[test] - fn test_domain_separator_new_and_bytes() { - let ds = DomainSeparator::::new("session"); - assert_eq!(ds.as_bytes(), b"session"); - } - - #[test] - #[should_panic] - fn test_new_with_separator_byte_panics() { - // This should panic because "\0" is forbidden in the session identifier. - let _ = DomainSeparator::::new("invalid\0session"); - } - - #[test] - fn test_domain_separator_absorb_and_squeeze() { - let ds = DomainSeparator::::new("proto") - .absorb(2, "input") - .squeeze(1, "challenge"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(2), Op::Squeeze(1)]); - } - - #[test] - fn test_domain_separator_ratcheting() { - let ds = DomainSeparator::::new("session").ratchet(); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Ratchet]); - } - - #[test] - fn test_absorb_return_value_format() { - let ds = DomainSeparator::::new("proto").absorb(3, "input"); - let expected_str = "proto\0A3input"; // initial + SEP + absorb op + label - assert_eq!(ds.as_bytes(), expected_str.as_bytes()); - } - - #[test] - #[should_panic] - fn test_absorb_zero_panics() { - let _ = DomainSeparator::::new("x").absorb(0, "label"); - } - - #[test] - #[should_panic] - fn test_label_with_separator_byte_panics() { - let _ = DomainSeparator::::new("x").absorb(1, "bad\0label"); - } - - #[test] - #[should_panic] - fn test_label_starts_with_digit_panics() { - let _ = DomainSeparator::::new("x").absorb(1, "1label"); - } - - #[test] - fn test_merge_consecutive_absorbs_and_squeezes() { - let ds = DomainSeparator::::new("merge") - .absorb(1, "a") - .absorb(2, "b") - .squeeze(3, "c") - .squeeze(1, "d"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(3), Op::Squeeze(4)]); - } - - #[test] - fn test_parse_domsep_multiple_ops() { - let tag = "main\0A1x\0A2y\0S3z\0R\0S2w"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!( - ops, - vec![Op::Absorb(3), Op::Squeeze(3), Op::Ratchet, Op::Squeeze(2)] - ); - } - - #[test] - fn test_byte_domain_separator_trait_impl() { - let ds = DomainSeparator::::new("x") - .add_bytes(1, "a") - .challenge_bytes(2, "b"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(1), Op::Squeeze(2)]); - } - - #[test] - fn test_empty_operations() { - let ds = DomainSeparator::::new("tag"); - let ops = ds.finalize(); - assert!(ops.is_empty()); - } - - #[test] - fn test_consecutive_ratchets_preserved() { - let ds = DomainSeparator::::new("r").ratchet().ratchet(); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Ratchet, Op::Ratchet]); - } - - #[test] - fn test_unicode_labels() { - let ds = DomainSeparator::::new("emoji") - .absorb(1, "🦀") - .squeeze(1, "🎯"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(1), Op::Squeeze(1)]); - } - - #[test] - fn test_large_counts_and_labels() { - let label = "x".repeat(100); - let ds = DomainSeparator::::new("big") - .absorb(12345, &label) - .squeeze(54321, &label); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(12345), Op::Squeeze(54321)]); - } - - #[test] - fn test_malformed_tag_parsing_fails() { - // Missing count - let broken = "proto\0Ax"; - let ds = DomainSeparator::::from_string(broken.to_string()); - let res = DomainSeparator::::parse_domsep(ds.as_bytes()); - assert!(res.is_err()); - } - - #[test] - fn test_simplify_stack_keeps_unlike_ops() { - let tag = "test\0A2x\0S3y\0A1z"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(2), Op::Squeeze(3), Op::Absorb(1)]); - } - - #[test] - fn test_round_trip_operations() { - let ds1 = DomainSeparator::::new("foo") - .absorb(2, "a") - .squeeze(3, "b") - .ratchet(); - let ops1 = ds1.finalize(); - - let tag = std::str::from_utf8(ds1.as_bytes()).unwrap(); - let ds2 = DomainSeparator::::from_string(tag.to_string()); - let ops2 = ds2.finalize(); - - assert_eq!(ops1, ops2); - } - - #[test] - fn test_squeeze_returns_correct_string() { - let ds = DomainSeparator::::new("proto").squeeze(4, "challenge"); - let expected_str = "proto\0S4challenge"; - assert_eq!(ds.as_bytes(), expected_str.as_bytes()); - } - - #[test] - #[should_panic] - fn test_squeeze_zero_count_panics() { - let _ = DomainSeparator::::new("proto").squeeze(0, "label"); - } - - #[test] - #[should_panic] - fn test_squeeze_label_with_null_byte_panics() { - let _ = DomainSeparator::::new("proto").squeeze(2, "bad\0label"); - } - - #[test] - #[should_panic] - fn test_squeeze_label_starts_with_digit_panics() { - let _ = DomainSeparator::::new("proto").squeeze(2, "1invalid"); - } - - #[test] - fn test_multiple_squeeze_chaining() { - let ds = DomainSeparator::::new("proto") - .squeeze(1, "first") - .squeeze(2, "second"); - let expected_str = "proto\0S1first\0S2second"; - assert_eq!(ds.as_bytes(), expected_str.as_bytes()); - } - - #[test] - fn test_ratchet_returns_correct_self() { - let ds = DomainSeparator::::new("proto"); - let ratcheted = ds.ratchet(); - let expected_str = "proto\0R"; - assert_eq!(ratcheted.as_bytes(), expected_str.as_bytes()); - } - - #[test] - fn test_finalize_mixed_ops_order_preserved() { - let tag = "zkp\0A1a\0S1b\0A2c\0S3d\0R\0A4e\0S1f"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!( - ops, - vec![ - Op::Absorb(1), - Op::Squeeze(1), - Op::Absorb(2), - Op::Squeeze(3), - Op::Ratchet, - Op::Absorb(4), - Op::Squeeze(1), - ] - ); - } - - #[test] - fn test_finalize_large_values_and_merge() { - let tag = "main\0A5a\0A10b\0S8c\0S2d"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(15), Op::Squeeze(10)]); - } - - #[test] - fn test_finalize_merge_and_breaks() { - let tag = "example\0A2x\0A1y\0R\0A3z\0S4u\0S1v"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!( - ops, - vec![Op::Absorb(3), Op::Ratchet, Op::Absorb(3), Op::Squeeze(5),] - ); - } - - #[test] - fn test_finalize_only_ratchets() { - let tag = "onlyratchets\0R\0R\0R"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Ratchet, Op::Ratchet, Op::Ratchet]); - } - - #[test] - fn test_finalize_complex_merge_boundaries() { - let tag = "demo\0A1a\0A1b\0S2c\0S2d\0A3e\0S1f\0Hd"; - let ds = DomainSeparator::::from_string(tag.to_string()); - let ops = ds.finalize(); - assert_eq!( - ops, - vec![ - Op::Absorb(2), // A1a + A1b - Op::Squeeze(4), // S2c + S2d - Op::Absorb(3), // A3e - Op::Squeeze(1), // S1f - Op::Hint, // Hd - ] - ); - } - - #[test] - fn test_hint_is_parsed_correctly() { - let ds = DomainSeparator::::new("hint_test").hint("my_hint"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Hint]); - } - - #[test] - fn test_hint_format_is_correct_in_bytes() { - let ds = DomainSeparator::::new("proto").hint("my_hint"); - let expected = b"proto\0Hmy_hint"; - assert_eq!(ds.as_bytes(), expected); - } - - #[test] - #[should_panic] - fn test_hint_label_with_null_byte_panics() { - let _ = DomainSeparator::::new("x").hint("bad\0hint"); - } - - #[test] - fn test_hint_combined_with_absorb_and_squeeze() { - let ds = DomainSeparator::::new("combo") - .absorb(1, "x") - .hint("meta") - .squeeze(2, "y"); - let ops = ds.finalize(); - assert_eq!(ops, vec![Op::Absorb(1), Op::Hint, Op::Squeeze(2)]); - } -} diff --git a/spongefish/src/duplex_sponge/mod.rs b/spongefish/src/duplex_sponge/mod.rs index de12788..f423f15 100644 --- a/spongefish/src/duplex_sponge/mod.rs +++ b/spongefish/src/duplex_sponge/mod.rs @@ -60,7 +60,7 @@ pub trait Permutation: Zeroize + Default + Clone + AsRef<[Self::U]> + AsMut<[Sel } /// A cryptographic sponge. -#[derive(Clone, Default, Zeroize, ZeroizeOnDrop)] +#[derive(Clone, PartialEq, Eq, Default, Zeroize, ZeroizeOnDrop)] pub struct DuplexSponge { permutation: C, absorb_pos: usize, diff --git a/spongefish/src/errors.rs b/spongefish/src/errors.rs index 0784384..3554064 100644 --- a/spongefish/src/errors.rs +++ b/spongefish/src/errors.rs @@ -17,19 +17,13 @@ /// An error to signal that the verification equation has failed. Destined for end users. /// /// A [`core::Result::Result`] wrapper called [`ProofResult`] (having error fixed to [`ProofError`]) is also provided. -use std::{borrow::Borrow, error::Error, fmt::Display}; - -/// Signals a domain separator is inconsistent with the description provided. -#[derive(Debug, Clone)] -pub struct DomainSeparatorMismatch(String); +use std::{error::Error, fmt::Display}; /// An error happened when creating or verifying a proof. #[derive(Debug, Clone)] pub enum ProofError { /// Signals the verification equation has failed. InvalidProof, - /// The domain separator specified mismatches the protocol execution. - InvalidDomainSeparator(DomainSeparatorMismatch), /// Serialization/Deserialization led to errors. SerializationError, } @@ -37,45 +31,19 @@ pub enum ProofError { /// The result type when trying to prove or verify a proof using Fiat-Shamir. pub type ProofResult = Result; -impl Display for DomainSeparatorMismatch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } -} - impl Display for ProofError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::SerializationError => write!(f, "Serialization Error"), - Self::InvalidDomainSeparator(e) => e.fmt(f), Self::InvalidProof => write!(f, "Invalid proof"), } } } -impl Error for DomainSeparatorMismatch {} impl Error for ProofError {} -impl From<&str> for DomainSeparatorMismatch { - fn from(s: &str) -> Self { - s.to_string().into() - } -} - -impl From for DomainSeparatorMismatch { - fn from(s: String) -> Self { - Self(s) - } -} - -impl> From for ProofError { - fn from(value: B) -> Self { - Self::InvalidDomainSeparator(value.borrow().clone()) - } -} - -impl From for DomainSeparatorMismatch { - fn from(value: std::io::Error) -> Self { - Self(value.to_string()) +impl From for ProofError { + fn from(_value: std::io::Error) -> Self { + Self::SerializationError } } diff --git a/spongefish/src/keccak.rs b/spongefish/src/keccak.rs index 21da588..49eb3d9 100644 --- a/spongefish/src/keccak.rs +++ b/spongefish/src/keccak.rs @@ -2,55 +2,60 @@ //! Despite internally we use the same permutation function, //! we build a duplex sponge in overwrite mode //! on the top of it using the `DuplexSponge` trait. +use std::fmt::Debug; + +use zerocopy::IntoBytes; use zeroize::{Zeroize, ZeroizeOnDrop}; use crate::duplex_sponge::{DuplexSponge, Permutation}; /// A duplex sponge based on the permutation [`keccak::f1600`] /// using [`DuplexSponge`]. -pub type Keccak = DuplexSponge; - -fn transmute_state(st: &mut AlignedKeccakF1600) -> &mut [u64; 25] { - unsafe { &mut *std::ptr::from_mut::(st).cast::<[u64; 25]>() } -} - -/// This is a wrapper around 200-byte buffer that's always 8-byte aligned -/// to make pointers to it safely convertible to pointers to [u64; 25] -/// (since u64 words must be 8-byte aligned) -#[derive(Clone, Zeroize, ZeroizeOnDrop)] -#[repr(align(8))] -pub struct AlignedKeccakF1600([u8; 200]); - -impl Permutation for AlignedKeccakF1600 { +/// +/// **Warning**: This function is not SHA3. +/// Despite internally we use the same permutation function, +/// we build a duplex sponge in overwrite mode +/// on the top of it using the `DuplexSponge` trait. +pub type Keccak = DuplexSponge; + +/// Keccak permutation internal state: 25 64-bit words, +/// or equivalently 200 bytes in little-endian order. +#[derive(Clone, PartialEq, Eq, Default, Zeroize, ZeroizeOnDrop)] +pub struct KeccakF1600([u64; 25]); + +impl Permutation for KeccakF1600 { type U = u8; const N: usize = 136 + 64; const R: usize = 136; - fn new(tag: [u8; 32]) -> Self { + fn new(iv: [u8; 32]) -> Self { let mut state = Self::default(); - state.0[Self::R..Self::R + 32].copy_from_slice(&tag); + state.as_mut()[Self::R..Self::R + 32].copy_from_slice(&iv); state } fn permute(&mut self) { - keccak::f1600(transmute_state(self)); + keccak::f1600(&mut self.0); } } -impl Default for AlignedKeccakF1600 { - fn default() -> Self { - Self([0u8; Self::N]) +impl AsRef<[u8]> for KeccakF1600 { + fn as_ref(&self) -> &[u8] { + self.0.as_bytes() } } -impl AsRef<[u8]> for AlignedKeccakF1600 { - fn as_ref(&self) -> &[u8] { - &self.0 +impl AsMut<[u8]> for KeccakF1600 { + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut_bytes() } } -impl AsMut<[u8]> for AlignedKeccakF1600 { - fn as_mut(&mut self) -> &mut [u8] { - &mut self.0 +/// Censored version of Debug +impl Debug for KeccakF1600 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AlignedKeccakF1600") + .field(&"") + .finish() } } diff --git a/spongefish/src/lib.rs b/spongefish/src/lib.rs index 0189a29..aef3914 100644 --- a/spongefish/src/lib.rs +++ b/spongefish/src/lib.rs @@ -138,24 +138,21 @@ pub mod keccak; /// APIs for common zkp libraries. pub mod codecs; -/// domain separator -mod domain_separator; -/// Prover's internal state and transcript generation. -mod prover; -/// SAFE API. -mod sho; + +pub mod pattern; /// Unit-tests. #[cfg(test)] mod tests; +/// Prover's internal state and transcript generation. +mod prover; + /// Traits for byte support. pub mod traits; -pub use domain_separator::DomainSeparator; pub use duplex_sponge::{legacy::DigestBridge, DuplexSpongeInterface, Unit}; -pub use errors::{DomainSeparatorMismatch, ProofError, ProofResult}; +pub use errors::{ProofError, ProofResult}; pub use prover::ProverState; -pub use sho::HashStateWithInstructions; pub use traits::*; pub use verifier::VerifierState; diff --git a/spongefish/src/pattern/interaction.rs b/spongefish/src/pattern/interaction.rs new file mode 100644 index 0000000..8d5f05f --- /dev/null +++ b/spongefish/src/pattern/interaction.rs @@ -0,0 +1,175 @@ +use core::{any::type_name, fmt::Display}; + +/// A single abstract prover-verifier interaction. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub struct Interaction { + /// Hierarchical nesting of the interactions. + hierarchy: Hierarchy, + /// The kind of interaction. + kind: Kind, + /// A label identifying the purpose of the value. + label: Label, + /// The Rust name of the type of the value. + /// + /// We use [`core::any::type_name`] to verify value types intead of [`core::any::TypeID`] since + /// the latter only supports types with a `'static` lifetime. The downside of `type_name` is + /// that it is slightly less precise in that it can create more type collisions. But this is + /// acceptable here as it only serves as an additional check and as debug information. + type_name: &'static str, + /// Length of the value. + length: Length, +} + +/// Labels for interactions. +pub type Label = &'static str; + +/// Kinds of prover-verifier interactions +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub enum Kind { + /// A protocol containing mixed interactions. + Protocol, + /// A public message prover and verifier agree on. + Public, + /// A message send in-band from prover to verifier. + Message, + /// A hint send out-of-band from prover to verifier. + Hint, + /// A challenge issued by the verifier. + Challenge, +} + +/// Kinds of prover-verifier interactions +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub enum Hierarchy { + /// A single interaction. + Atomic, + /// Start of a sub-protocol. + Begin, + /// End of a sub-protocol. + End, +} + +/// Length of values involved in interactions. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub enum Length { + /// No length information. + None, + /// A single value. + Scalar, + /// A fixed number of values. + Fixed(usize), + /// A dynamic number of values. + Dynamic, +} + +impl Interaction { + #[must_use] + pub fn new(hierarchy: Hierarchy, kind: Kind, label: Label, length: Length) -> Self { + let type_name = type_name::(); + Self { + hierarchy, + kind, + label, + type_name, + length, + } + } + + #[must_use] + pub const fn hierarchy(&self) -> Hierarchy { + self.hierarchy + } + + #[must_use] + pub const fn kind(&self) -> Kind { + self.kind + } + + /// Returns `true` if this is a `Hierarchy::End` that closes the provided + /// `Hierarchy::Begin`. + #[must_use] + pub(super) fn closes(&self, other: &Self) -> bool { + self.hierarchy == Hierarchy::End + && other.hierarchy == Hierarchy::Begin + && self.kind == other.kind + && self.label == other.label + && self.type_name == other.type_name + && self.length == other.length + } +} + +impl Display for Interaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if f.alternate() { + // Domain separator mode: stable unambiguous format. + write!(f, "{} {}", self.hierarchy, self.kind)?; + // Length prefixed strings for labels to disambiguate + write!(f, " {} {}", self.label.len(), self.label)?; + write!(f, " {}", self.length) + // Leave out type names for domain separators. + } else { + write!( + f, + "{} {} {} {} {}", + self.hierarchy, self.kind, self.label, self.length, self.type_name, + ) + } + } +} + +impl Display for Hierarchy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Atomic => write!(f, "Atomic"), + Self::Begin => write!(f, "Begin"), + Self::End => write!(f, "End"), + } + } +} + +impl Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Protocol => write!(f, "Protocol"), + Self::Public => write!(f, "Public"), + Self::Message => write!(f, "Message"), + Self::Hint => write!(f, "Hint"), + Self::Challenge => write!(f, "Challenge"), + } + } +} + +impl Display for Length { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::None => write!(f, "None"), + Self::Scalar => write!(f, "Scalar"), + Self::Fixed(size) => write!(f, "Fixed({size})"), + Self::Dynamic => write!(f, "Dynamic"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sizes() { + dbg!(size_of::()); + assert!(size_of::() < 80); + } + + #[test] + fn test_domain_separator() { + let interaction = Interaction::new::>( + Hierarchy::Atomic, + Kind::Message, + "test-message", + Length::Scalar, + ); + let result = format!("{interaction:#}"); + let expected = "Atomic Message 12 test-message Scalar"; + assert_eq!(result, expected); + } +} diff --git a/spongefish/src/pattern/interaction_pattern.rs b/spongefish/src/pattern/interaction_pattern.rs new file mode 100644 index 0000000..591644b --- /dev/null +++ b/spongefish/src/pattern/interaction_pattern.rs @@ -0,0 +1,190 @@ +use core::fmt::Display; + +use thiserror::Error; + +use super::{interaction::Hierarchy, Interaction, Kind}; + +/// Abstract transcript containing prover-verifier interactions +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug, Default)] +pub struct InteractionPattern { + interactions: Vec, +} + +/// Errors when validating a transcript. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Error)] +pub enum TranscriptError { + #[error("Missing Begin for {end} at {position}")] + MissingBegin { position: usize, end: Interaction }, + #[error( + "Invalid kind {interaction} at {interaction_position} for {begin} at {begin_position}" + )] + InvalidKind { + begin_position: usize, + begin: Interaction, + interaction_position: usize, + interaction: Interaction, + }, + #[error("Mismatch {begin} at {begin_position} for {end} at {end_position}")] + MismatchedBeginEnd { + begin_position: usize, + begin: Interaction, + end_position: usize, + end: Interaction, + }, + #[error("Missing End for {begin} at {position}")] + MissingEnd { position: usize, begin: Interaction }, +} + +impl InteractionPattern { + pub fn new(interactions: Vec) -> Result { + let result = Self { interactions }; + result.validate()?; + Ok(result) + } + + #[must_use] + #[allow(clippy::missing_const_for_fn)] // False positive + pub fn interactions(&self) -> &[Interaction] { + &self.interactions + } + + /// Generate a unique identifier for the protocol. + /// + /// It is created by taking the SHA3 hash of a stable unambiguous + /// string representation of the transcript interactions. + // TODO: A more neutral implementation would use ASN.1 DER. + #[must_use] + pub fn domain_separator(&self) -> [u8; 32] { + use sha3::{Digest, Sha3_256}; + let mut hasher = Sha3_256::new(); + // Use Display in `alternate` mode for stable unambiguous representation. + hasher.update(format!("{self:#}").as_bytes()); + let result = hasher.finalize(); + result.into() + } + + /// Validate the transcript. + /// + /// A valid transcript has: + /// + /// - Matching [`InteractionHierachy::Begin`] and [`InteractionHierachy::End`] interactions + /// creating a nested hierarchy. + /// - Nested interactions are the same [`InteractionKind`] as the last [`InteractionHierachy::Begin`] interaction, except for [`InteractionKind::Protocol`] which can contain any [`InteractionKind`]. + fn validate(&self) -> Result<(), TranscriptError> { + let mut stack = Vec::new(); + for (position, interaction) in self.interactions.iter().enumerate() { + match interaction.hierarchy() { + Hierarchy::Begin => stack.push((position, interaction)), + Hierarchy::End => { + let Some((position, begin)) = stack.pop() else { + dbg!(); + return Err(TranscriptError::MissingBegin { + position, + end: interaction.clone(), + }); + }; + if !interaction.closes(begin) { + return Err(TranscriptError::MismatchedBeginEnd { + begin_position: position, + begin: begin.clone(), + end_position: self.interactions.len(), + end: interaction.clone(), + }); + } + } + Hierarchy::Atomic => { + let Some((begin_position, begin)) = stack.last().copied() else { + continue; + }; + if begin.kind() != Kind::Protocol && begin.kind() != interaction.kind() { + return Err(TranscriptError::InvalidKind { + begin_position, + begin: begin.clone(), + interaction_position: position, + interaction: interaction.clone(), + }); + } + } + } + } + if let Some((position, begin)) = stack.pop() { + return Err(TranscriptError::MissingEnd { + position, + begin: begin.clone(), + }); + } + Ok(()) + } +} + +/// Creates a human readable representation of the transcript. +/// +/// When called in alternate mode `{:#}` it will be a stable format suitable as domain separator. +impl Display for InteractionPattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Write the total interactions up front so no prefix string can be a valid domain separator. + let length = self.interactions.len(); + let width = length.saturating_sub(1).to_string().len(); + writeln!(f, "Spongefish Transcript ({length} interactions)")?; + let mut indentation = 0; + for (position, interaction) in self.interactions.iter().enumerate() { + write!(f, "{position:0>width$} ")?; + if interaction.hierarchy() == Hierarchy::End { + indentation -= 1; + } + for _ in 0..indentation { + write!(f, " ")?; + } + if f.alternate() { + writeln!(f, "{interaction:#}")?; + } else { + writeln!(f, "{interaction}")?; + } + if interaction.hierarchy() == Hierarchy::Begin { + indentation += 1; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pattern::Length; + + #[test] + fn test_size() { + dbg!(size_of::()); + assert!(size_of::() < 170); + } + + #[test] + fn test_domain_separator() { + let transcript = InteractionPattern::new(vec![ + Interaction::new::(Hierarchy::Begin, Kind::Protocol, "test", Length::None), + Interaction::new::>( + Hierarchy::Atomic, + Kind::Message, + "test-message", + Length::Scalar, + ), + Interaction::new::(Hierarchy::End, Kind::Protocol, "test", Length::None), + ]) + .unwrap(); + + let result = format!("{transcript:#}"); + let expected = r"Spongefish Transcript (3 interactions) +0 Begin Protocol 4 test None +1 Atomic Message 12 test-message Scalar +2 End Protocol 4 test None +"; + assert_eq!(result, expected); + + let result = transcript.domain_separator(); + assert_eq!( + hex::encode(result), + "33daf542c95b80a2b01be277d9d0f9b6d5bee823c5c3a0dcca71e614a5a783e3" + ); + } +} diff --git a/spongefish/src/pattern/mod.rs b/spongefish/src/pattern/mod.rs new file mode 100644 index 0000000..cf5b054 --- /dev/null +++ b/spongefish/src/pattern/mod.rs @@ -0,0 +1,317 @@ +//! Abstract interaction patterns for interactive protocols. + +mod interaction; +mod interaction_pattern; +mod pattern_player; +mod pattern_state; + +pub use self::{ + interaction::{Hierarchy, Interaction, Kind, Label, Length}, + interaction_pattern::{InteractionPattern, TranscriptError}, + pattern_player::PatternPlayer, + pattern_state::PatternState, +}; + +/// Trait for objects that implement hierarchy operations. +/// +/// It does not offer any [`Kind::Atomic`] operations, these need to be implemented specifically. +pub trait Pattern { + /// End a transcript without finalizing it. + /// + /// # Panics + /// + /// Panics only if the interaction is already finalized or aborted. + fn abort(&mut self); + + /// Begin of a group of interactions. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin(&mut self, label: Label, kind: Kind, length: Length); + + /// End of a group of interactions. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end(&mut self, label: Label, kind: Kind, length: Length); + + /// Begin of a subprotocol. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin_protocol(&mut self, label: Label) { + self.begin::(label, Kind::Protocol, Length::None); + } + + /// End of a subprotocol. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end_protocol(&mut self, label: Label) { + self.end::(label, Kind::Protocol, Length::None); + } + + /// Begin of a public message interaction. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin_public(&mut self, label: Label, length: Length) { + self.begin::(label, Kind::Public, length); + } + + /// End of a public message interaction. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end_public(&mut self, label: Label, length: Length) { + self.end::(label, Kind::Public, length); + } + + /// Begin of a message interaction. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin_message(&mut self, label: Label, length: Length) { + self.begin::(label, Kind::Message, length); + } + + /// End of a message interaction. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end_message(&mut self, label: Label, length: Length) { + self.end::(label, Kind::Message, length); + } + + /// Begin of a hint interaction. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin_hint(&mut self, label: Label, length: Length) { + self.begin::(label, Kind::Hint, length); + } + + /// End of a hint interaction.. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end_hint(&mut self, label: Label, length: Length) { + self.end::(label, Kind::Hint, length); + } + + /// Begin of a challenge interaction.. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn begin_challenge(&mut self, label: Label, length: Length) { + self.begin::(label, Kind::Challenge, length); + } + + /// End of a challenge interaction.. + /// + /// # Panics + /// + /// Panics if the interaction violates interaction pattern consistency rules. + fn end_challenge(&mut self, label: Label, length: Length) { + self.end::(label, Kind::Challenge, length); + } +} + +/// Aliases offered for convenience. +pub use Pattern as Common; +/// Aliases offered for convenience. +pub use Pattern as Verifier; +/// Aliases offered for convenience. +pub use Pattern as Prover; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_record_playback() { + // Record a new pattern + let mut pattern = PatternState::::new(); + pattern.begin_protocol::<()>("Example protocol"); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + pattern.end_protocol::<()>("Example protocol"); + let pattern = pattern.finalize(); + + // Play it back exactly + let mut playback = PatternPlayer::new(pattern.into()); + playback.begin_protocol::<()>("Example protocol"); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + playback.end_protocol::<()>("Example protocol"); + playback.finalize(); + } + + #[test] + #[should_panic(expected = "Dropped unfinalized transcript.")] + fn panics_if_playback_not_finalized() { + let mut pattern = PatternState::::new(); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + let pattern = pattern.finalize(); + + let mut playback = PatternPlayer::new(pattern.into()); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + } + + #[test] + #[should_panic( + expected = "Mismatched begin and end: Begin Protocol Example protocol None (), End Protocol Invalid example protocol None ()" + )] + fn panics_if_record_begin_end_mismatch() { + let mut pattern = PatternState::::new(); + pattern.begin_protocol::<()>("Example protocol"); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + pattern.end_protocol::<()>("Invalid example protocol"); + let _pattern = pattern.finalize(); + } + #[test] + #[should_panic( + expected = "Error validating interaction pattern: Missing End for Begin Protocol Example protocol None () at 0" + )] + fn panics_if_record_unmatched_begin() { + let mut pattern = PatternState::::new(); + pattern.begin_protocol::<()>("Example protocol"); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + let _pattern = pattern.finalize(); + } + + #[test] + #[should_panic( + expected = "Received interaction Atomic Challenge nonce Scalar f64, but expected Atomic Challenge nonce Scalar u64" + )] + fn panics_if_type_mismatch() { + let mut pattern = PatternState::::new(); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + let pattern = pattern.finalize(); + + let mut playback = PatternPlayer::new(pattern.into()); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + playback.finalize(); + } + + #[test] + #[should_panic( + expected = "Received interaction Atomic Public nonce Scalar f64, but expected Atomic Message nonce Scalar u64" + )] + fn panics_if_kind_mismatch() { + let mut pattern = PatternState::::new(); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Message, + "nonce", + Length::Scalar, + )); + let pattern = pattern.finalize(); + + let mut playback = PatternPlayer::new(pattern.into()); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Public, + "nonce", + Length::Scalar, + )); + playback.finalize(); + } + + #[test] + #[should_panic( + expected = "Received interaction Atomic Challenge invalid Scalar f64, but expected Atomic Challenge nonce Scalar u64" + )] + fn panics_if_label_mismatch() { + let mut pattern = PatternState::::new(); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + let pattern = pattern.finalize(); + + let mut playback = PatternPlayer::new(pattern.into()); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "invalid", + Length::Scalar, + )); + playback.finalize(); + } + + #[test] + #[should_panic( + expected = "Received interaction Atomic Challenge nonce Fixed(1) f64, but expected Atomic Challenge nonce Scalar u64" + )] + fn panics_if_length_mismatch() { + let mut pattern = PatternState::::new(); + pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Scalar, + )); + let pattern = pattern.finalize(); + + let mut playback = PatternPlayer::new(pattern.into()); + playback.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "nonce", + Length::Fixed(1), + )); + playback.finalize(); + } +} diff --git a/spongefish/src/pattern/pattern_player.rs b/spongefish/src/pattern/pattern_player.rs new file mode 100644 index 0000000..7cea592 --- /dev/null +++ b/spongefish/src/pattern/pattern_player.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use super::{Interaction, InteractionPattern, Kind, Label, Length}; +use crate::pattern::Hierarchy; + +/// Play back an interaction pattern and make sure all interactions match up. +/// +/// # Panics +/// +/// Panics on [`Drop`] if there are unfinished interactions. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct PatternPlayer { + /// Shared reference to the transcript. + pattern: Arc, + /// Current position in the interaction pattern. + position: usize, + /// Whether the transcript playback has been finalized. + finalized: bool, +} + +impl PatternPlayer { + #[must_use] + pub const fn new(pattern: Arc) -> Self { + Self { + pattern, + position: 0, + finalized: false, + } + } + + /// Finalize the sequence of interactions. Returns an error if there + /// are unfinished interactions. + /// + /// # Panics + /// + /// Panics if the transcript is already finalized or if there are expected interactions left. + pub fn finalize(mut self) { + assert!(self.position <= self.pattern.interactions().len()); + assert!(!self.finalized, "Transcript is already finalized."); + assert!( + self.position >= self.pattern.interactions().len(), + "Transcript not finished, expecting {}", + self.pattern.interactions()[self.position] + ); + self.finalized = true; + } + + /// Play the next interaction in the pattern. + /// + /// # Panics + /// + /// Panics if the transcript is already finalized or if the interaction does not match the expected one. + pub fn interact(&mut self, interaction: Interaction) { + assert!(!self.finalized, "Transcript is already finalized."); + let Some(expected) = self.pattern.interactions().get(self.position) else { + self.finalized = true; + panic!("Received interaction, but no more expected interactions: {interaction}"); + }; + if expected != &interaction { + self.finalized = true; + panic!("Received interaction {interaction}, but expected {expected}"); + } + self.position += 1; + } +} + +impl Drop for PatternPlayer { + fn drop(&mut self) { + assert!(self.finalized, "Dropped unfinalized transcript."); + } +} + +impl super::Pattern for PatternPlayer { + fn abort(&mut self) { + assert!(!self.finalized, "Transcript is already finalized."); + self.finalized = true; + } + + fn begin(&mut self, label: Label, kind: Kind, length: Length) { + self.interact(Interaction::new::(Hierarchy::Begin, kind, label, length)); + } + + fn end(&mut self, label: Label, kind: Kind, length: Length) { + self.interact(Interaction::new::(Hierarchy::End, kind, label, length)); + } +} diff --git a/spongefish/src/pattern/pattern_state.rs b/spongefish/src/pattern/pattern_state.rs new file mode 100644 index 0000000..92ef6c0 --- /dev/null +++ b/spongefish/src/pattern/pattern_state.rs @@ -0,0 +1,206 @@ +use std::marker::PhantomData; + +use super::{Hierarchy, Interaction, InteractionPattern, Kind, Label, Length}; +use crate::{codecs::unit, Unit}; + +/// Records an interaction pattern. +/// +/// # Panics +/// +/// Panics on [`Drop`] if there are unfinished interactions. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Default)] +pub struct PatternState +where + U: Unit, +{ + /// Recorded interactions. + interactions: Vec, + /// Whether the transcript playback has been finalized. + finalized: bool, + _unit: PhantomData, +} + +impl PatternState +where + U: Unit, +{ + #[must_use] + pub const fn new() -> Self { + Self { + interactions: Vec::new(), + finalized: false, + _unit: PhantomData, + } + } + + #[must_use] + pub fn finalize(self) -> InteractionPattern { + assert!(!self.finalized, "Transcript is already finalized."); + match InteractionPattern::new(self.interactions) { + Ok(transcript) => transcript, + Err(e) => panic!("Error validating interaction pattern: {e}"), + } + } + + /// Add a new interaction to the pattern. + /// + /// # Panics + /// + /// Panics if + /// - the interaction does not match the parent kind and + /// the parent kind is not [`Kind::Protocol`], + /// - the it is an [`Hierarchy::End`], but there is either no + /// [`Hierarchy::Begin`] or it does not match the end.being + pub fn interact(&mut self, interaction: Interaction) { + assert!(!self.finalized, "Transcript is already finalized."); + if let Some(begin) = self.last_open_begin() { + // Check if the new interaction is of a permissible kind. + assert!( + begin.kind() == Kind::Protocol || begin.kind() == interaction.kind(), + "Invalid interaction kind: expected {}, got {}", + begin.kind(), + interaction.kind() + ); + // Check if it is a matching End to the current Begin + assert!( + interaction.hierarchy() != Hierarchy::End || interaction.closes(begin), + "Mismatched begin and end: {begin}, {interaction}" + ); + } else { + // No unclosed Begin interaction. Make sure this is not an end. + assert!( + interaction.hierarchy() != Hierarchy::End, + "Missing begin for {interaction}" + ); + } + + // All good, append + self.interactions.push(interaction); + } + + /// Return the last unclosed [`Hierachy::Begin`] interaction. + fn last_open_begin(&self) -> Option<&Interaction> { + // Reverse search to find matching begin + let mut stack = 0; + for interaction in self.interactions.iter().rev() { + match interaction.hierarchy() { + Hierarchy::End => stack += 1, + Hierarchy::Begin => { + if stack == 0 { + return Some(interaction); + } + stack -= 1; + } + _ => {} + } + } + None + } +} + +impl super::Pattern for PatternState +where + U: Unit, +{ + fn abort(&mut self) { + assert!(!self.finalized, "Transcript is already finalized."); + self.finalized = true; + } + + fn begin(&mut self, label: Label, kind: Kind, length: Length) { + self.interact(Interaction::new::(Hierarchy::Begin, kind, label, length)); + } + + fn end(&mut self, label: Label, kind: Kind, length: Length) { + self.interact(Interaction::new::(Hierarchy::End, kind, label, length)); + } +} + +// TODO: We will turn this into `unit::Pattern` later. +impl unit::Pattern for PatternState +where + U: Unit, +{ + type Unit = U; + + fn ratchet(&mut self) { + self.interact(Interaction::new::<()>( + Hierarchy::Atomic, + Kind::Protocol, + "ratchet", + Length::None, + )); + } + + fn public_unit(&mut self, label: Label) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Public, + label, + Length::Scalar, + )); + } + + fn public_units(&mut self, label: Label, size: usize) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Public, + label, + Length::Fixed(size), + )); + } + + fn message_unit(&mut self, label: Label) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Message, + label, + Length::Scalar, + )); + } + + fn message_units(&mut self, label: Label, size: usize) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Message, + label, + Length::Fixed(size), + )); + } + + fn challenge_unit(&mut self, label: Label) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + label, + Length::Scalar, + )); + } + + fn challenge_units(&mut self, label: Label, size: usize) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + label, + Length::Fixed(size), + )); + } + + fn hint_bytes(&mut self, label: Label, size: usize) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Hint, + label, + Length::Fixed(size), + )); + } + + fn hint_bytes_dynamic(&mut self, label: Label) { + self.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Hint, + label, + Length::Dynamic, + )); + } +} diff --git a/spongefish/src/prover.rs b/spongefish/src/prover.rs index aa45fdd..b17ab93 100644 --- a/spongefish/src/prover.rs +++ b/spongefish/src/prover.rs @@ -1,12 +1,13 @@ +use std::{marker::PhantomData, sync::Arc}; + use rand::{CryptoRng, RngCore}; +use zeroize::Zeroize; -use super::{ - duplex_sponge::DuplexSpongeInterface, keccak::Keccak, DefaultHash, DefaultRng, - DomainSeparatorMismatch, -}; +use super::{duplex_sponge::DuplexSpongeInterface, keccak::Keccak, DefaultHash, DefaultRng}; use crate::{ - duplex_sponge::Unit, BytesToUnitSerialize, DomainSeparator, HashStateWithInstructions, - UnitTranscript, + duplex_sponge::Unit, + pattern::{Hierarchy, Interaction, InteractionPattern, Kind, Length, Pattern, PatternPlayer}, + BytesToUnitSerialize, UnitTranscript, }; /// [`ProverState`] is the prover state of an interactive proof (IP) system. @@ -29,12 +30,16 @@ where H: DuplexSpongeInterface, R: RngCore + CryptoRng, { + /// The interaction pattern being followed. + pub(crate) pattern: PatternPlayer, /// The randomness state of the prover. pub(crate) rng: ProverPrivateRng, /// The public coins for the protocol - pub(crate) hash_state: HashStateWithInstructions, + pub(crate) duplex_sponge: H, /// The encoded data. pub(crate) narg_string: Vec, + /// Unit type + pub(crate) _unit_type: PhantomData, } /// A cryptographically-secure random number generator that is bound to the protocol transcript. @@ -87,39 +92,61 @@ where H: DuplexSpongeInterface, R: RngCore + CryptoRng, { - pub fn new(domain_separator: &DomainSeparator, csrng: R) -> Self { - let hash_state = HashStateWithInstructions::new(domain_separator); + pub fn new(pattern: Arc, csrng: R) -> Self { + let iv = pattern.domain_separator(); let mut duplex_sponge = Keccak::default(); - duplex_sponge.absorb_unchecked(domain_separator.as_bytes()); + duplex_sponge.absorb_unchecked(&iv); let rng = ProverPrivateRng { ds: duplex_sponge, csrng, }; Self { + pattern: PatternPlayer::new(pattern), rng, - hash_state, + duplex_sponge: H::new(iv), narg_string: Vec::new(), + _unit_type: PhantomData, } } - pub fn hint_bytes(&mut self, hint: &[u8]) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.hint()?; + /// Abort the proof without completing. + pub fn abort(mut self) { + self.pattern.abort(); + self.duplex_sponge.zeroize(); + self.rng.ds.zeroize(); + self.narg_string.zeroize(); + } + + /// Finish the proof and return the proof bytes. + pub fn finalize(mut self) -> Vec { + self.pattern.finalize(); + self.duplex_sponge.zeroize(); + self.rng.ds.zeroize(); + self.narg_string + } + + pub fn hint_bytes(&mut self, hint: &[u8]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Hint, + "hint_bytes", + Length::Dynamic, + )); let len = u32::try_from(hint.len()).expect("Hint size out of bounds"); self.narg_string.extend_from_slice(&len.to_le_bytes()); self.narg_string.extend_from_slice(hint); - Ok(()) } } -impl From<&DomainSeparator> for ProverState +impl From<&InteractionPattern> for ProverState where U: Unit, H: DuplexSpongeInterface, { - fn from(domain_separator: &DomainSeparator) -> Self { - Self::new(domain_separator, DefaultRng::default()) + fn from(pattern: &InteractionPattern) -> Self { + Self::new(Arc::new(pattern.clone()), DefaultRng::default()) } } @@ -142,19 +169,29 @@ where /// let result = prover_state.add_units(b"1tbsp every 10 liters"); /// assert!(result.is_err()) /// ``` - pub fn add_units(&mut self, input: &[U]) -> Result<(), DomainSeparatorMismatch> { + pub fn add_units(&mut self, input: &[U]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Message, + "units", + Length::Fixed(input.len()), + )); + self.duplex_sponge.absorb_unchecked(input); let old_len = self.narg_string.len(); - self.hash_state.absorb(input)?; // write never fails on Vec U::write(input, &mut self.narg_string).unwrap(); self.rng.ds.absorb_unchecked(&self.narg_string[old_len..]); - - Ok(()) } /// Ratchet the verifier's state. - pub fn ratchet(&mut self) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.ratchet() + pub fn ratchet(&mut self) { + self.pattern.interact(Interaction::new::<()>( + Hierarchy::Atomic, + Kind::Protocol, + "ratchet", + Length::None, + )); + self.duplex_sponge.ratchet_unchecked(); } /// Return a reference to the random number generator associated to the protocol transcript. @@ -212,16 +249,30 @@ where /// assert!(prover_state.public_bytes(&[0u8; 20]).is_ok()); /// assert_eq!(prover_state.narg_string(), b""); /// ``` - fn public_units(&mut self, input: &[U]) -> Result<(), DomainSeparatorMismatch> { - let len = self.narg_string.len(); - self.add_units(input)?; - self.narg_string.truncate(len); - Ok(()) + fn public_units(&mut self, input: &[U]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Public, + "public_units", + Length::Fixed(input.len()), + )); + self.duplex_sponge.absorb_unchecked(input); + let old_len = self.narg_string.len(); + // write never fails on Vec + U::write(input, &mut self.narg_string).unwrap(); + self.rng.ds.absorb_unchecked(&self.narg_string[old_len..]); + self.narg_string.truncate(old_len); } /// Fill a slice with uniformly-distributed challenges from the verifier. - fn fill_challenge_units(&mut self, output: &mut [U]) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.squeeze(output) + fn fill_challenge_units(&mut self, output: &mut [U]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "fill_challenge_units", + Length::Fixed(output.len()), + )); + self.duplex_sponge.squeeze_unchecked(output); } } @@ -234,7 +285,7 @@ where R: RngCore + CryptoRng, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.hash_state.fmt(f) + self.pattern.fmt(f) } } @@ -243,203 +294,235 @@ where H: DuplexSpongeInterface, R: RngCore + CryptoRng, { - fn add_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> { - self.add_units(input) + fn add_bytes(&mut self, input: &[u8]) { + self.pattern + .begin_message::("bytes", Length::Fixed(input.len())); + self.add_units(input); + self.pattern + .end_message::("bytes", Length::Fixed(input.len())); } } #[cfg(test)] mod tests { use super::*; + use crate::{ + codecs::{bytes::Pattern as _, unit::Pattern as _}, + pattern::PatternState, + }; #[test] fn test_prover_state_add_units_and_rng_differs() { - let domsep = DomainSeparator::::new("test").absorb(4, "data"); - let mut pstate = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.message_bytes("bytes", 4); + let pattern = pattern.finalize(); + + let mut pstate: ProverState = ProverState::from(&pattern); - pstate.add_bytes(&[1, 2, 3, 4]).unwrap(); + pstate.add_bytes(&[1, 2, 3, 4]); let mut buf = [0u8; 8]; pstate.rng().fill_bytes(&mut buf); assert_ne!(buf, [0; 8]); + let _proof = pstate.finalize(); } #[test] fn test_prover_state_public_units_does_not_affect_narg() { - let domsep = DomainSeparator::::new("test").absorb(4, "data"); - let mut pstate = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.public_units("public_units", 4); + let pattern = pattern.finalize(); + let mut pstate: ProverState = ProverState::from(&pattern); - pstate.public_units(&[1, 2, 3, 4]).unwrap(); + pstate.public_units(&[1, 2, 3, 4]); assert_eq!(pstate.narg_string(), b""); + let _proof = pstate.finalize(); } #[test] fn test_prover_state_ratcheting_changes_rng_output() { - let domsep = DomainSeparator::::new("test").ratchet(); - let mut pstate = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.ratchet(); + let pattern = pattern.finalize(); + let mut pstate: ProverState = ProverState::from(&pattern); let mut buf1 = [0u8; 4]; pstate.rng().fill_bytes(&mut buf1); - - pstate.ratchet().unwrap(); - + pstate.ratchet(); let mut buf2 = [0u8; 4]; pstate.rng().fill_bytes(&mut buf2); + // TODO: This test is broken. You'd expect these to be different even without the ratchet. assert_ne!(buf1, buf2); + let _proof = pstate.finalize(); } #[test] fn test_add_units_appends_to_narg_string() { - let domsep = DomainSeparator::::new("test").absorb(3, "msg"); - let mut pstate = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 3); + let pattern = pattern.finalize(); + let mut pstate: ProverState = ProverState::from(&pattern); + let input = [42, 43, 44]; - assert!(pstate.add_units(&input).is_ok()); - assert_eq!(pstate.narg_string(), &input); + pstate.add_units(&input); + let proof = pstate.finalize(); + assert_eq!(proof, &input); } #[test] - fn test_add_units_too_many_elements_should_error() { - let domsep = DomainSeparator::::new("test").absorb(2, "short"); - let mut pstate = ProverState::from(&domsep); - - let result = pstate.add_units(&[1, 2, 3]); - assert!(result.is_err()); + #[should_panic( + expected = "Received interaction Atomic Message units Fixed(3) u8, but expected Atomic Message units Fixed(2) u8" + )] + fn test_add_units_too_many_elements_should_panic() { + let mut pattern = PatternState::::new(); + pattern.message_units("units", 2); + let pattern = pattern.finalize(); + + let mut pstate: ProverState = ProverState::from(&pattern); + pstate.add_units(&[1, 2, 3]); } #[test] fn test_ratchet_works_when_expected() { - let domsep = DomainSeparator::::new("test").ratchet(); - let mut pstate = ProverState::from(&domsep); - assert!(pstate.ratchet().is_ok()); - } + let mut pattern = PatternState::::new(); + pattern.ratchet(); + let pattern = pattern.finalize(); - #[test] - fn test_ratchet_fails_when_not_expected() { - let domsep = DomainSeparator::::new("test").absorb(1, "bad"); - let mut pstate = ProverState::from(&domsep); - assert!(pstate.ratchet().is_err()); + let mut pstate: ProverState = ProverState::from(&pattern); + pstate.ratchet(); + let _proof = pstate.finalize(); } #[test] - fn test_public_units_does_not_update_transcript() { - let domsep = DomainSeparator::::new("test").absorb(2, "p"); - let mut pstate = ProverState::from(&domsep); - let _ = pstate.public_units(&[0xaa, 0xbb]); + #[should_panic( + expected = "Received interaction Atomic Protocol ratchet None (), but expected Atomic Message units Fixed(4) u8" + )] + fn test_ratchet_fails_when_not_expected() { + let mut pattern = PatternState::::new(); + pattern.message_units("units", 4); + let pattern = pattern.finalize(); - assert_eq!(pstate.narg_string(), b""); + let mut pstate: ProverState = ProverState::from(&pattern); + pstate.ratchet(); + let _proof = pstate.finalize(); } #[test] fn test_fill_challenge_units() { - let domsep = DomainSeparator::::new("test").squeeze(8, "ch"); - let mut pstate = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.challenge_units("fill_challenge_units", 8); + let pattern = pattern.finalize(); + let mut pstate: ProverState = ProverState::from(&pattern); let mut out = [0u8; 8]; - let _ = pstate.fill_challenge_units(&mut out); - assert_eq!(out, [77, 249, 17, 180, 176, 109, 121, 62]); + pstate.fill_challenge_units(&mut out); + assert_eq!(out, [62, 110, 82, 217, 159, 135, 60, 9]); + let _proof = pstate.finalize(); } #[test] fn test_rng_entropy_changes_with_transcript() { - let domsep = DomainSeparator::::new("t").absorb(3, "init"); - let mut p1 = ProverState::from(&domsep); - let mut p2 = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.message_bytes("bytes", 3); + let pattern = pattern.finalize(); + + let mut p1: ProverState = ProverState::from(&pattern); + let mut p2: ProverState = ProverState::from(&pattern); let mut a = [0u8; 16]; let mut b = [0u8; 16]; p1.rng().fill_bytes(&mut a); - p2.add_units(&[1, 2, 3]).unwrap(); + p2.add_bytes(&[1, 2, 3]); p2.rng().fill_bytes(&mut b); assert_ne!(a, b); + p1.abort(); + p2.abort(); } #[test] fn test_add_units_multiple_accumulates() { - let domsep = DomainSeparator::::new("t") - .absorb(2, "a") - .absorb(3, "b"); - let mut p = ProverState::from(&domsep); - - p.add_units(&[10, 11]).unwrap(); - p.add_units(&[20, 21, 22]).unwrap(); - - assert_eq!(p.narg_string(), &[10, 11, 20, 21, 22]); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 2); + pattern.message_units("units", 3); + let pattern = pattern.finalize(); + + let mut p: ProverState = ProverState::from(&pattern); + p.add_units(&[10, 11]); + p.add_units(&[20, 21, 22]); + assert_eq!(p.finalize(), &[10, 11, 20, 21, 22]); } #[test] fn test_narg_string_round_trip_check() { - let domsep = DomainSeparator::::new("t").absorb(5, "data"); - let mut p = ProverState::from(&domsep); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 5); + let pattern = pattern.finalize(); + let mut p: ProverState = ProverState::from(&pattern); let msg = b"zkp42"; - p.add_units(msg).unwrap(); - - let encoded = p.narg_string(); - assert_eq!(encoded, msg); + p.add_units(msg); + assert_eq!(p.finalize(), msg); } #[test] fn test_hint_bytes_appends_hint_length_and_data() { - let domsep: DomainSeparator = - DomainSeparator::new("hint_test").hint("proof_hint"); - let mut prover = domsep.to_prover_state(); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); + let mut prover: ProverState = ProverState::from(&pattern); let hint = b"abc123"; - prover.hint_bytes(hint).unwrap(); - - // Explanation: - // - `hint` is "abc123", which has 6 bytes. - // - The protocol encodes this as a 4-byte *little-endian* length prefix: 6 = 0x00000006 → [6, 0, 0, 0] - // - Then it appends the hint bytes: b"abc123" - // - So the full expected value is: + prover.hint_bytes(hint); let expected = [6, 0, 0, 0, b'a', b'b', b'c', b'1', b'2', b'3']; - - assert_eq!(prover.narg_string(), &expected); + assert_eq!(prover.finalize(), &expected); } #[test] fn test_hint_bytes_empty_hint_is_encoded_correctly() { - let domsep: DomainSeparator = DomainSeparator::new("empty_hint").hint("empty"); - let mut prover = domsep.to_prover_state(); - - prover.hint_bytes(b"").unwrap(); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); - // Length = 0 encoded as 4 zero bytes - assert_eq!(prover.narg_string(), &[0, 0, 0, 0]); + let mut prover: ProverState = ProverState::from(&pattern); + prover.hint_bytes(b""); + assert_eq!(prover.finalize(), &[0, 0, 0, 0]); } #[test] + #[should_panic( + expected = "Received interaction, but no more expected interactions: Atomic Hint hint_bytes Dynamic u8" + )] fn test_hint_bytes_fails_if_hint_op_missing() { - let domsep: DomainSeparator = DomainSeparator::new("no_hint"); - let mut prover = domsep.to_prover_state(); - - // DomainSeparator contains no hint operation - let result = prover.hint_bytes(b"some_hint"); - assert!( - result.is_err(), - "Should error if no hint op in domain separator" - ); + let pattern = PatternState::::new().finalize(); + + let mut prover: ProverState = ProverState::from(&pattern); + // indicate a hint without a matching hint_bytes interaction + prover.hint_bytes(b"some_hint"); } #[test] fn test_hint_bytes_is_deterministic() { - let domsep: DomainSeparator = DomainSeparator::new("det_hint").hint("same"); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); let hint = b"zkproof_hint"; - let mut prover1 = domsep.to_prover_state(); - let mut prover2 = domsep.to_prover_state(); + let mut prover1: ProverState = ProverState::from(&pattern); + let mut prover2: ProverState = ProverState::from(&pattern); - prover1.hint_bytes(hint).unwrap(); - prover2.hint_bytes(hint).unwrap(); + prover1.hint_bytes(hint); + prover2.hint_bytes(hint); assert_eq!( prover1.narg_string(), prover2.narg_string(), "Encoding should be deterministic" ); + let _proof1 = prover1.finalize(); + let _proof2 = prover2.finalize(); } } diff --git a/spongefish/src/sho.rs b/spongefish/src/sho.rs deleted file mode 100644 index f694546..0000000 --- a/spongefish/src/sho.rs +++ /dev/null @@ -1,412 +0,0 @@ -use core::{fmt, marker::PhantomData}; -use std::collections::vec_deque::VecDeque; - -use super::{ - domain_separator::{DomainSeparator, Op}, - duplex_sponge::{DuplexSpongeInterface, Unit}, - errors::DomainSeparatorMismatch, - keccak::Keccak, -}; - -/// A stateful hash object that interfaces with duplex interfaces. -#[derive(Clone)] -pub struct HashStateWithInstructions -where - U: Unit, - H: DuplexSpongeInterface, -{ - /// The internal duplex sponge used for absorbing and squeezing data. - ds: H, - /// A stack of expected sponge operations. - stack: VecDeque, - /// Marker to associate the unit type `U` without storing a value. - _unit: PhantomData, -} - -impl> HashStateWithInstructions { - /// Initialise a stateful hash object, - /// setting up the state of the sponge function and parsing the tag string. - #[must_use] - pub fn new(domain_separator: &DomainSeparator) -> Self { - let stack = domain_separator.finalize(); - let tag = Self::generate_tag(domain_separator.as_bytes()); - Self::unchecked_load_with_stack(tag, stack) - } - - /// Finish the block and compress the state. - pub fn ratchet(&mut self) -> Result<(), DomainSeparatorMismatch> { - match self.stack.pop_front() { - Some(Op::Ratchet) => { - self.ds.ratchet_unchecked(); - Ok(()) - } - Some(op) => Err(format!("Expected Ratchet, got {op:?}").into()), - None => Err("Expected Ratchet, but stack is empty".into()), - } - } - - /// Ratchet and return the sponge state. - pub fn preprocess(self) -> Result<&'static [U], DomainSeparatorMismatch> { - unimplemented!() - // self.ratchet()?; - // Ok(self.sponge.tag().clone()) - } - - /// Perform secure absorption of the elements in `input`. - /// - /// Absorb calls can be batched together, or provided separately for streaming-friendly protocols. - pub fn absorb(&mut self, input: &[U]) -> Result<(), DomainSeparatorMismatch> { - match self.stack.pop_front() { - Some(Op::Absorb(length)) if length >= input.len() => { - if length > input.len() { - self.stack.push_front(Op::Absorb(length - input.len())); - } - self.ds.absorb_unchecked(input); - Ok(()) - } - None => { - self.stack.clear(); - Err(format!( - "Invalid tag. Stack empty, got {:?}", - Op::Absorb(input.len()) - ) - .into()) - } - Some(op) => { - self.stack.clear(); - Err(format!( - "Invalid tag. Got {:?}, expected {:?}", - Op::Absorb(input.len()), - op - ) - .into()) - } - } - } - - /// Send or receive a hint from the proof stream. - pub fn hint(&mut self) -> Result<(), DomainSeparatorMismatch> { - match self.stack.pop_front() { - Some(Op::Hint) => Ok(()), - Some(op) => Err(format!("Invalid tag. Got Op::Hint, expected {op:?}",).into()), - None => Err(format!("Invalid tag. Stack empty, got {:?}", Op::Hint).into()), - } - } - - /// Perform a secure squeeze operation, filling the output buffer with uniformly random bytes. - /// - /// For byte-oriented sponges, this operation is equivalent to the squeeze operation. - /// However, for algebraic hashes, this operation is non-trivial. - /// This function provides no guarantee of streaming-friendliness. - pub fn squeeze(&mut self, output: &mut [U]) -> Result<(), DomainSeparatorMismatch> { - match self.stack.pop_front() { - Some(Op::Squeeze(length)) if output.len() <= length => { - self.ds.squeeze_unchecked(output); - if length != output.len() { - self.stack.push_front(Op::Squeeze(length - output.len())); - } - Ok(()) - } - None => { - self.stack.clear(); - Err(format!( - "Invalid tag. Stack empty, got {:?}", - Op::Squeeze(output.len()) - ) - .into()) - } - Some(op) => { - self.stack.clear(); - Err(format!( - "Invalid tag. Got {:?}, expected {:?}. The stack remaining is: {:?}", - Op::Squeeze(output.len()), - op, - self.stack - ) - .into()) - } - } - } - - fn generate_tag(iop_bytes: &[u8]) -> [u8; 32] { - let mut keccak = Keccak::default(); - keccak.absorb_unchecked(iop_bytes); - let mut tag = [0u8; 32]; - keccak.squeeze_unchecked(&mut tag); - tag - } - - fn unchecked_load_with_stack(tag: [u8; 32], stack: VecDeque) -> Self { - Self { - ds: H::new(tag), - stack, - _unit: PhantomData, - } - } - - #[cfg(test)] - pub const fn ds(&self) -> &H { - &self.ds - } -} - -impl> Drop for HashStateWithInstructions { - /// Destroy the sponge state. - fn drop(&mut self) { - // it's a bit violent to panic here, - // because any other issue in the protocol transcript causing `Safe` to get out of scope - // (like another panic) will pollute the traceback. - // debug_assert!(self.stack.is_empty()); - if !self.stack.is_empty() { - eprintln!("Unfinished operations:\n {:?}", self.stack); - } - // XXX. is the compiler going to optimize this out? - self.ds.zeroize(); - } -} - -impl> fmt::Debug for HashStateWithInstructions { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Ensure that the state isn't accidentally logged, - // but provide the remaining domain separator for debugging. - write!( - f, - "Sponge in duplex mode with committed verifier operations: {:?}", - self.stack - ) - } -} - -impl, B: core::borrow::Borrow>> From - for HashStateWithInstructions -{ - fn from(value: B) -> Self { - Self::new(value.borrow()) - } -} - -#[cfg(test)] -#[allow(clippy::bool_assert_comparison)] -mod tests { - use std::{cell::RefCell, rc::Rc}; - - use super::*; - - #[derive(Default, Clone)] - pub struct DummySponge { - pub absorbed: Rc>>, - pub squeezed: Rc>>, - pub ratcheted: Rc>, - } - - impl zeroize::Zeroize for DummySponge { - fn zeroize(&mut self) { - self.absorbed.borrow_mut().clear(); - self.squeezed.borrow_mut().clear(); - *self.ratcheted.borrow_mut() = false; - } - } - - impl DummySponge { - fn new_inner() -> Self { - Self { - absorbed: Rc::new(RefCell::new(Vec::new())), - squeezed: Rc::new(RefCell::new(Vec::new())), - ratcheted: Rc::new(RefCell::new(false)), - } - } - } - - impl DuplexSpongeInterface for DummySponge { - fn new(_iv: [u8; 32]) -> Self { - Self::new_inner() - } - - fn absorb_unchecked(&mut self, input: &[u8]) -> &mut Self { - self.absorbed.borrow_mut().extend_from_slice(input); - self - } - - fn squeeze_unchecked(&mut self, output: &mut [u8]) -> &mut Self { - for (i, byte) in output.iter_mut().enumerate() { - *byte = i as u8; // Dummy output - } - self.squeezed.borrow_mut().extend_from_slice(output); - self - } - - fn ratchet_unchecked(&mut self) -> &mut Self { - *self.ratcheted.borrow_mut() = true; - self - } - } - - #[test] - fn test_absorb_works_and_modifies_stack() { - let domsep = DomainSeparator::::new("test").absorb(2, "x"); - let mut state = HashStateWithInstructions::::new(&domsep); - - assert_eq!(state.stack.len(), 1); - - let result = state.absorb(&[1, 2]); - assert!(result.is_ok()); - - assert_eq!(state.stack.len(), 0); - let inner = state.ds.absorbed.borrow(); - assert_eq!(&*inner, &[1, 2]); - } - - #[test] - fn test_absorb_too_much_returns_error() { - let domsep = DomainSeparator::::new("test").absorb(2, "x"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let result = state.absorb(&[1, 2, 3]); - assert!(result.is_err()); - } - - #[test] - fn test_squeeze_works() { - let domsep = DomainSeparator::::new("test").squeeze(3, "y"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let mut out = [0u8; 3]; - let result = state.squeeze(&mut out); - assert!(result.is_ok()); - assert_eq!(out, [0, 1, 2]); - } - - #[test] - fn test_squeeze_with_leftover_updates_stack() { - let domsep = DomainSeparator::::new("test").squeeze(4, "z"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let mut out = [0u8; 2]; - let result = state.squeeze(&mut out); - assert!(result.is_ok()); - - assert_eq!(state.stack.front(), Some(&Op::Squeeze(2))); - } - - #[test] - fn test_ratchet_correct_op() { - let domsep = DomainSeparator::::new("test").ratchet(); - let mut state = HashStateWithInstructions::::new(&domsep); - - let result = state.ratchet(); - assert!(result.is_ok()); - assert_eq!(*state.ds.ratcheted.borrow(), true); - } - - #[test] - fn test_ratchet_wrong_op_returns_error() { - let domsep = DomainSeparator::::new("test").absorb(1, "oops"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let result = state.ratchet(); - assert!(result.is_err()); - assert!(state.stack.is_empty()); - } - - #[test] - fn test_multiple_absorbs_deplete_stack_properly() { - let domsep = DomainSeparator::::new("test").absorb(5, "a"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let res1 = state.absorb(&[1, 2]); - assert!(res1.is_ok()); - assert_eq!(state.stack.front(), Some(&Op::Absorb(3))); - - let res2 = state.absorb(&[3, 4, 5]); - assert!(res2.is_ok()); - assert!(state.stack.is_empty()); - - assert_eq!(&*state.ds.absorbed.borrow(), &[1, 2, 3, 4, 5]); - } - - #[test] - fn test_multiple_squeeze_deplete_stack_properly() { - let domsep = DomainSeparator::::new("test").squeeze(5, "z"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let mut out1 = [0u8; 2]; - assert!(state.squeeze(&mut out1).is_ok()); - assert_eq!(state.stack.front(), Some(&Op::Squeeze(3))); - - let mut out2 = [0u8; 3]; - assert!(state.squeeze(&mut out2).is_ok()); - assert!(state.stack.is_empty()); - assert_eq!(&*state.ds.squeezed.borrow(), &[0, 1, 0, 1, 2]); - } - - #[test] - fn test_absorb_then_wrong_squeeze_clears_stack() { - let domsep = DomainSeparator::::new("test").absorb(3, "in"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let mut out = [0u8; 1]; - let result = state.squeeze(&mut out); - assert!(result.is_err()); - assert!(state.stack.is_empty()); - } - - #[test] - fn test_absorb_exact_then_too_much() { - let domsep = DomainSeparator::::new("test").absorb(2, "x"); - let mut state = HashStateWithInstructions::::new(&domsep); - - assert!(state.absorb(&[10, 20]).is_ok()); - assert!(state.absorb(&[30]).is_err()); // no ops left - assert!(state.stack.is_empty()); - } - - #[test] - fn test_from_impl_constructs_hash_state() { - let domsep = DomainSeparator::::new("from").absorb(1, "in"); - let state = HashStateWithInstructions::::from(&domsep); - - assert_eq!(state.stack.len(), 1); - assert_eq!(state.stack.front(), Some(&Op::Absorb(1))); - } - - #[test] - fn test_generate_tag_is_deterministic() { - let ds1 = DomainSeparator::::new("session1").absorb(1, "x"); - let ds2 = DomainSeparator::::new("session1").absorb(1, "x"); - - let tag1 = HashStateWithInstructions::::new(&ds1); - let tag2 = HashStateWithInstructions::::new(&ds2); - - assert_eq!(&*tag1.ds.absorbed.borrow(), &*tag2.ds.absorbed.borrow()); - } - - #[test] - fn test_hint_works_and_removes_stack_entry() { - let domsep = DomainSeparator::::new("test").hint("hint"); - let mut state = HashStateWithInstructions::::new(&domsep); - - assert_eq!(state.stack.len(), 1); - let result = state.hint(); - assert!(result.is_ok()); - assert!(state.stack.is_empty()); - } - - #[test] - fn test_hint_wrong_op_errors_and_clears_stack() { - let domsep = DomainSeparator::::new("test").absorb(1, "x"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let result = state.hint(); // Should expect Op::Hint, but see Op::Absorb - assert!(result.is_err()); - assert!(state.stack.is_empty()); - } - - #[test] - fn test_hint_on_empty_stack_errors() { - let domsep = DomainSeparator::::new("test"); - let mut state = HashStateWithInstructions::::new(&domsep); - - let result = state.hint(); // Stack is empty - assert!(result.is_err()); - } -} diff --git a/spongefish/src/tests.rs b/spongefish/src/tests.rs index 9238874..720441e 100644 --- a/spongefish/src/tests.rs +++ b/spongefish/src/tests.rs @@ -1,28 +1,25 @@ +use std::sync::Arc; + use rand::RngCore; use crate::{ - duplex_sponge::legacy::DigestBridge, keccak::Keccak, BytesToUnitDeserialize, - BytesToUnitSerialize, CommonUnitToBytes, DomainSeparator, DuplexSpongeInterface, - HashStateWithInstructions, ProverState, UnitToBytes, + codecs::unit::Pattern as UnitPattern, + duplex_sponge::legacy::DigestBridge, + keccak::Keccak, + pattern::{Length, Pattern, PatternState}, + traits::{BytesToUnitSerialize, UnitToBytes}, + DuplexSpongeInterface, ProverState, UnitTranscript, VerifierState, }; type Sha2 = DigestBridge; type Blake2b512 = DigestBridge; type Blake2s256 = DigestBridge; -/// How should a protocol without actual IO be handled? -#[test] -fn test_domain_separator() { - // test that the byte separator is always added - let domain_separator = DomainSeparator::::new("example.com"); - assert!(domain_separator.as_bytes().starts_with(b"example.com")); -} - /// Test ProverState's rng is not doing completely stupid things. #[test] fn test_prover_rng_basic() { - let domain_separator = DomainSeparator::::new("example.com"); - let mut prover_state = domain_separator.to_prover_state(); + let pattern = PatternState::::new().finalize(); + let mut prover_state: ProverState = ProverState::from(&pattern); let rng = prover_state.rng(); let mut random_bytes = [0u8; 32]; @@ -33,191 +30,226 @@ fn test_prover_rng_basic() { assert_ne!(random_u32, 0); assert_ne!(random_u64, 0); assert!(random_bytes.iter().any(|&x| x != random_bytes[0])); + let _proof = prover_state.finalize(); } /// Test adding of public bytes and non-public elements to the transcript. #[test] -fn test_prover_bytewriter() { - let domain_separator = DomainSeparator::::new("example.com").absorb(1, "🥕"); - let mut prover_state = domain_separator.to_prover_state(); - assert!(prover_state.add_bytes(&[0u8]).is_ok()); - assert!(prover_state.add_bytes(&[1u8]).is_err()); - assert_eq!( - prover_state.narg_string(), - b"\0", - "Protocol Transcript survives errors" - ); +fn test_prover_bytewriter_correct() { + // Expect exactly one add_bytes call. + let mut pattern = PatternState::::new(); + pattern.begin_message::("bytes", Length::Fixed(1)); + pattern.message_units("units", 1); + pattern.end_message::("bytes", Length::Fixed(1)); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_bytes(&[0u8]); + let proof = prover_state.finalize(); + assert_eq!(hex::encode(proof), "00"); +} - let mut prover_state = domain_separator.to_prover_state(); - assert!(prover_state.public_bytes(&[0u8]).is_ok()); - assert_eq!(prover_state.narg_string(), b""); +#[test] +#[should_panic( + expected = "Received interaction, but no more expected interactions: Begin Message bytes Fixed(1) u8" +)] +fn test_prover_bytewriter_invalid() { + // Expect exactly one add_bytes call. + let mut pattern = PatternState::::new(); + pattern.begin_message::("bytes", Length::Fixed(1)); + pattern.message_units("units", 1); + pattern.end_message::("bytes", Length::Fixed(1)); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_bytes(&[0u8]); + prover_state.add_bytes(&[1u8]); +} + +#[test] +#[should_panic( + expected = "Received interaction, but no more expected interactions: Atomic Public public_units Fixed(1) u8" +)] +fn test_prover_public_units_invalid() { + // Expect exactly one add_bytes call. + let mut pattern = PatternState::::new(); + pattern.public_units("public_units", 1); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.public_units(&[0u8]); + prover_state.public_units(&[1u8]); } -/// A protocol flow that does not match the DomainSeparator should fail. +/// A protocol flow whose pattern does not match should panic. #[test] +#[should_panic( + expected = "Received interaction Atomic Challenge fill_challenge_units Fixed(16) u8, but expected Atomic Message units Fixed(3) u8" +)] fn test_invalid_domsep_sequence() { - let domain_separator = DomainSeparator::new("example.com") - .absorb(3, "") - .squeeze(1, ""); - let mut verifier_state = HashStateWithInstructions::::new(&domain_separator); - assert!(verifier_state.squeeze(&mut [0u8; 16]).is_err()); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 3); + pattern.challenge_units("challenge_units", 1); + let pattern = pattern.finalize(); + + let mut verifier_state: VerifierState = VerifierState::new(Arc::new(pattern), &[]); + // This should panic due to pattern mismatch. + verifier_state.fill_challenge_bytes(&mut [0u8; 16]); } -// Hiding for now. Should it panic ? -// /// A protocol whose domain separator is not finished should panic. -// #[test] -// #[should_panic] -// fn test_unfinished_domsep() { -// let iop = DomainSeparator::new("example.com").absorb(3, "").squeeze(1, ""); -// let _verifier_challenges = VerifierState::::new(&iop); -// } +/// A protocol whose domain separator is not finished should panic. +#[test] +#[should_panic(expected = "Dropped unfinalized transcript.")] +fn test_unfinished_domsep() { + let mut pattern = PatternState::::new(); + pattern.message_units("elt", 3); + pattern.challenge_units("another_elt", 16); + let pattern = pattern.finalize(); + + let mut _verifier: VerifierState = VerifierState::new(pattern.into(), b""); +} -/// Challenges from the same transcript should be equal. +/// The domain separator tag should be deterministic. #[test] fn test_deterministic() { - let domain_separator = DomainSeparator::new("example.com") - .absorb(3, "elt") - .squeeze(16, "another_elt"); - let mut first_sponge = HashStateWithInstructions::::new(&domain_separator); - let mut second_sponge = HashStateWithInstructions::::new(&domain_separator); - - let mut first = [0u8; 16]; - let mut second = [0u8; 16]; - - first_sponge.absorb(b"123").unwrap(); - second_sponge.absorb(b"123").unwrap(); - - first_sponge.squeeze(&mut first).unwrap(); - second_sponge.squeeze(&mut second).unwrap(); - assert_eq!(first, second); + let mut pattern = PatternState::::new(); + pattern.message_units("elt", 3); + pattern.challenge_units("another_elt", 16); + let pattern = pattern.finalize(); + + let iv1 = pattern.domain_separator(); + let iv2 = pattern.domain_separator(); + assert_eq!(iv1, iv2); } -/// Basic scatistical test to check that the squeezed output looks random. +/// Basic check that the domain separator tag has some non-zero byte. #[test] fn test_statistics() { - let domain_separator = DomainSeparator::new("example.com") - .absorb(4, "statement") - .ratchet() - .squeeze(2048, "gee"); - let mut verifier_state = HashStateWithInstructions::::new(&domain_separator); - verifier_state.absorb(b"seed").unwrap(); - verifier_state.ratchet().unwrap(); - let mut output = [0u8; 2048]; - verifier_state.squeeze(&mut output).unwrap(); - - let frequencies = (0u8..=255) - .map(|i| output.iter().filter(|&&x| x == i).count()) - .collect::>(); - // each element should appear roughly 8 times on average. Checking we're not too far from that. - assert!(frequencies.iter().all(|&x| x < 32 && x > 0)); + let pattern = PatternState::::new().finalize(); + let iv = pattern.domain_separator(); + assert!(iv.iter().any(|&b| b != 0)); } #[test] fn test_transcript_readwrite() { - let domain_separator = DomainSeparator::::new("domain separator") - .absorb(10, "hello") - .squeeze(10, "world"); - - let mut prover_state = domain_separator.to_prover_state(); - prover_state - .add_units(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - .unwrap(); - let prover_challenges = prover_state.challenge_bytes::<10>().unwrap(); - let transcript = prover_state.narg_string(); - - let mut verifier_state = domain_separator.to_verifier_state(transcript); + // Pattern for prover and verifier sequence: add_units, fill_challenge_units, two fill_next_units, then fill_challenge_units + let mut pattern = PatternState::::new(); + pattern.message_units("units", 10); + pattern.challenge_units("fill_challenge_units", 10); + pattern.message_units("units", 5); + pattern.message_units("units", 5); + pattern.challenge_units("fill_challenge_units", 10); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_units(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + assert_eq!( + hex::encode(prover_state.challenge_bytes::<10>()), + "0ccd176155e008b158ad" + ); + prover_state.add_units(&[10, 11, 12, 13, 14]); + prover_state.add_units(&[15, 16, 17, 18, 19]); + assert_eq!( + hex::encode(prover_state.challenge_bytes::<10>()), + "0f691da125269385ceea" + ); + let proof = prover_state.finalize(); + assert_eq!( + hex::encode(&proof), + "000102030405060708090a0b0c0d0e0f10111213" + ); + + let mut verifier_state: VerifierState = VerifierState::new(Arc::new(pattern), &proof); + let mut input = [0u8; 10]; + verifier_state.fill_next_units(&mut input).unwrap(); + assert_eq!(input, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + assert_eq!( + hex::encode(verifier_state.challenge_bytes::<10>()), + "0ccd176155e008b158ad" + ); let mut input = [0u8; 5]; verifier_state.fill_next_units(&mut input).unwrap(); - assert_eq!(input, [0, 1, 2, 3, 4]); + assert_eq!(input, [10, 11, 12, 13, 14]); verifier_state.fill_next_units(&mut input).unwrap(); - assert_eq!(input, [5, 6, 7, 8, 9]); - let verifier_challenges = verifier_state.challenge_bytes::<10>().unwrap(); - assert_eq!(verifier_challenges, prover_challenges); + assert_eq!(input, [15, 16, 17, 18, 19]); + assert_eq!( + hex::encode(verifier_state.challenge_bytes::<10>()), + "0f691da125269385ceea" + ); + verifier_state.finalize(); } /// An IO that is not fully finished should fail. +/// An IO that is not fully finished should panic. #[test] #[should_panic] fn test_incomplete_domsep() { - let domain_separator = DomainSeparator::::new("domain separator") - .absorb(10, "hello") - .squeeze(1, "nop"); - - let mut prover_state = domain_separator.to_prover_state(); - prover_state - .add_units(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - .unwrap(); - prover_state.fill_challenge_bytes(&mut [0u8; 10]).unwrap(); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 10); + pattern.challenge_units("fill_challenge_units", 1); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_units(&[0u8; 10]); + // This should panic due to pattern mismatch length + prover_state.fill_challenge_bytes(&mut [0u8; 10]); } /// The user should respect the domain separator even with empty length. +/// The user should respect the pattern even with empty operations. #[test] fn test_prover_empty_absorb() { - let domain_separator = DomainSeparator::::new("domain separator") - .absorb(1, "in") - .squeeze(1, "something"); - - assert!(domain_separator - .to_prover_state() - .fill_challenge_bytes(&mut [0u8; 1]) - .is_err()); - assert!(domain_separator - .to_verifier_state(b"") - .next_bytes::<1>() - .is_err()); + // Pattern expects one add_units and one challenge + let mut pattern = PatternState::::new(); + pattern.message_units("units", 0); + pattern.challenge_units("fill_challenge_units", 0); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_units(b""); + let _challenge = prover_state.challenge_bytes::<0>(); + let proof = prover_state.finalize(); + assert!(proof.is_empty()); + + let mut vstate: VerifierState = VerifierState::new(Arc::new(pattern), &proof); + let mut out = [0_u8; 0]; + vstate.fill_next_units(&mut out).unwrap(); + let _challenge = vstate.challenge_bytes::<0>(); + vstate.finalize(); } -/// Absorbs and squeeze over byte-Units should be streamable. -fn test_streaming_absorb_and_squeeze() +/// Absorbs and squeeze over byte-Units +fn test_absorb_and_squeeze() where ProverState: BytesToUnitSerialize + UnitToBytes, { let bytes = b"yellow submarine"; - let domain_separator = DomainSeparator::::new("domain separator") - .absorb(16, "some bytes") - .squeeze(16, "control challenge") - .absorb(1, "level 2: use this as a prng stream") - .squeeze(1024, "that's a long challenge"); - - let mut prover_state = domain_separator.to_prover_state(); - prover_state.add_bytes(bytes).unwrap(); - let control_chal = prover_state.challenge_bytes::<16>().unwrap(); - let control_transcript = prover_state.narg_string(); - - let mut stream_prover_state = domain_separator.to_prover_state(); - stream_prover_state.add_bytes(&bytes[..10]).unwrap(); - stream_prover_state.add_bytes(&bytes[10..]).unwrap(); - let first_chal = stream_prover_state.challenge_bytes::<8>().unwrap(); - let second_chal = stream_prover_state.challenge_bytes::<8>().unwrap(); - let transcript = stream_prover_state.narg_string(); - - assert_eq!(transcript, control_transcript); - assert_eq!(&first_chal[..], &control_chal[..8]); - assert_eq!(&second_chal[..], &control_chal[8..]); - - prover_state.add_bytes(&[0x42]).unwrap(); - stream_prover_state.add_bytes(&[0x42]).unwrap(); - - let control_chal = prover_state.challenge_bytes::<1024>().unwrap(); - for control_chunk in control_chal.chunks(16) { - let chunk = stream_prover_state.challenge_bytes::<16>().unwrap(); - assert_eq!(control_chunk, &chunk[..]); - } + let mut pattern = PatternState::::new(); + pattern.begin_message::("bytes", Length::Fixed(16)); + pattern.message_units("units", 16); + pattern.end_message::("bytes", Length::Fixed(16)); + pattern.challenge_units("fill_challenge_units", 16); + let pattern = pattern.finalize(); + + let mut prover_state: ProverState = ProverState::from(&pattern); + prover_state.add_bytes(bytes); + let _challenge = prover_state.challenge_bytes::<16>(); + let _proof = prover_state.finalize(); } #[test] -fn test_streaming_sha2() { - test_streaming_absorb_and_squeeze::(); +fn test_sha2() { + test_absorb_and_squeeze::(); } #[test] -fn test_streaming_blake2() { - test_streaming_absorb_and_squeeze::(); - test_streaming_absorb_and_squeeze::(); +fn test_blake2() { + test_absorb_and_squeeze::(); + test_absorb_and_squeeze::(); } #[test] -fn test_streaming_keccak() { - test_streaming_absorb_and_squeeze::(); +fn test_keccak() { + test_absorb_and_squeeze::(); } diff --git a/spongefish/src/traits.rs b/spongefish/src/traits.rs index 90ac2d7..b521cb3 100644 --- a/spongefish/src/traits.rs +++ b/spongefish/src/traits.rs @@ -1,4 +1,4 @@ -use crate::{errors::DomainSeparatorMismatch, Unit}; +use crate::Unit; /// Absorbing and squeezing native elements from the sponge. /// @@ -6,9 +6,9 @@ use crate::{errors::DomainSeparatorMismatch, Unit}; /// Implementors of this trait are expected to make sure that the unit type `U` matches /// the one used by the internal sponge. pub trait UnitTranscript { - fn public_units(&mut self, input: &[U]) -> Result<(), DomainSeparatorMismatch>; + fn public_units(&mut self, input: &[U]); - fn fill_challenge_units(&mut self, output: &mut [U]) -> Result<(), DomainSeparatorMismatch>; + fn fill_challenge_units(&mut self, output: &mut [U]); } /// Absorbing bytes from the sponge, without reading or writing them into the protocol transcript. @@ -19,7 +19,7 @@ pub trait UnitTranscript { /// For instance, in the case of algebraic sponges operating over a field $\mathbb{F}_p$, we do not expect /// the implementation to cache field elements filling $\ceil{\log_2(p)}$ bytes. pub trait CommonUnitToBytes { - fn public_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch>; + fn public_bytes(&mut self, input: &[u8]); } /// Squeezing bytes from the sponge. @@ -30,12 +30,12 @@ pub trait CommonUnitToBytes { /// - `u8` implementations are assumed to be streaming-friendly, that is: `implementor.fill_challenge_bytes(&mut out[..1]); implementor.fill_challenge_bytes(&mut out[1..]);` is expected to be equivalent to `implementor.fill_challenge_bytes(&mut out);`. /// - $\mathbb{F}_p$ implementations are expected to provide no such guarantee. In addition, we expect the implementation to return bytes that are uniformly distributed. In particular, note that the most significant bytes of a $\mod p$ element are not uniformly distributed. The number of bytes good to be used can be discovered playing with [our scripts](https://github.com/arkworks-rs/spongefish/blob/main/spongefish/scripts/useful_bits_modp.py). pub trait UnitToBytes { - fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), DomainSeparatorMismatch>; + fn fill_challenge_bytes(&mut self, output: &mut [u8]); - fn challenge_bytes(&mut self) -> Result<[u8; N], DomainSeparatorMismatch> { + fn challenge_bytes(&mut self) -> [u8; N] { let mut output = [0u8; N]; - self.fill_challenge_bytes(&mut output)?; - Ok(output) + self.fill_challenge_bytes(&mut output); + output } } @@ -46,9 +46,9 @@ pub trait UnitToBytes { pub trait ByteTranscript: CommonUnitToBytes + UnitToBytes {} pub trait BytesToUnitDeserialize { - fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), DomainSeparatorMismatch>; + fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), std::io::Error>; - fn next_bytes(&mut self) -> Result<[u8; N], DomainSeparatorMismatch> { + fn next_bytes(&mut self) -> Result<[u8; N], std::io::Error> { let mut input = [0u8; N]; self.fill_next_bytes(&mut input)?; Ok(input) @@ -56,7 +56,7 @@ pub trait BytesToUnitDeserialize { } pub trait BytesToUnitSerialize { - fn add_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch>; + fn add_bytes(&mut self, input: &[u8]); } /// Methods for adding bytes to the [`DomainSeparator`](crate::DomainSeparator), properly counting group elements. @@ -71,14 +71,14 @@ pub trait ByteDomainSeparator { impl> CommonUnitToBytes for T { #[inline] - fn public_bytes(&mut self, input: &[u8]) -> Result<(), DomainSeparatorMismatch> { - self.public_units(input) + fn public_bytes(&mut self, input: &[u8]) { + self.public_units(input); } } impl> UnitToBytes for T { #[inline] - fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), DomainSeparatorMismatch> { - self.fill_challenge_units(output) + fn fill_challenge_bytes(&mut self, output: &mut [u8]) { + self.fill_challenge_units(output); } } diff --git a/spongefish/src/verifier.rs b/spongefish/src/verifier.rs index ecb433e..0df6b12 100644 --- a/spongefish/src/verifier.rs +++ b/spongefish/src/verifier.rs @@ -1,8 +1,8 @@ +use std::{marker::PhantomData, sync::Arc}; + use crate::{ - domain_separator::DomainSeparator, duplex_sponge::{DuplexSpongeInterface, Unit}, - errors::DomainSeparatorMismatch, - sho::HashStateWithInstructions, + pattern::{Hierarchy, Interaction, InteractionPattern, Kind, Length, Pattern, PatternPlayer}, traits::{BytesToUnitDeserialize, UnitTranscript}, DefaultHash, }; @@ -17,8 +17,10 @@ where H: DuplexSpongeInterface, U: Unit, { - pub(crate) hash_state: HashStateWithInstructions, + pub(crate) pattern: PatternPlayer, + pub(crate) duplex_sponge: H, pub(crate) narg_string: &'a [u8], + pub(crate) _unit_type: PhantomData, } impl<'a, U: Unit, H: DuplexSpongeInterface> VerifierState<'a, H, U> { @@ -39,29 +41,45 @@ impl<'a, U: Unit, H: DuplexSpongeInterface> VerifierState<'a, H, U> { /// assert_ne!(challenge.unwrap(), [0; 32]); /// ``` #[must_use] - pub fn new(domain_separator: &DomainSeparator, narg_string: &'a [u8]) -> Self { - let hash_state = HashStateWithInstructions::new(domain_separator); + pub fn new(pattern: Arc, narg_string: &'a [u8]) -> Self { + let iv = pattern.domain_separator(); Self { - hash_state, + pattern: PatternPlayer::new(pattern), + duplex_sponge: H::new(iv), narg_string, + _unit_type: PhantomData, } } /// Read `input.len()` elements from the NARG string. #[inline] - pub fn fill_next_units(&mut self, input: &mut [U]) -> Result<(), DomainSeparatorMismatch> { + pub fn fill_next_units(&mut self, input: &mut [U]) -> Result<(), std::io::Error> { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Message, + "units", + Length::Fixed(input.len()), + )); U::read(&mut self.narg_string, input)?; - self.hash_state.absorb(input)?; + self.duplex_sponge.absorb_unchecked(input); Ok(()) } /// Read a hint from the NARG string. Returns the number of units read. - pub fn hint_bytes(&mut self) -> Result<&'a [u8], DomainSeparatorMismatch> { - self.hash_state.hint()?; + pub fn hint_bytes(&mut self) -> Result<&'a [u8], std::io::Error> { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Hint, + "hint_bytes", + Length::Dynamic, + )); // Ensure at least 4 bytes are available for the length prefix if self.narg_string.len() < 4 { - return Err("Insufficient transcript remaining for hint".into()); + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Insufficient transcript remaining for hint", + )); } // Read 4-byte little-endian length prefix @@ -70,11 +88,13 @@ impl<'a, U: Unit, H: DuplexSpongeInterface> VerifierState<'a, H, U> { // Ensure the rest of the slice has `len` bytes if rest.len() < len { - return Err(format!( - "Insufficient transcript remaining, got {}, need {len}", - rest.len() - ) - .into()); + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!( + "Insufficient transcript remaining, got {}, need {len}", + rest.len() + ), + )); } // Split the hint and advance the transcript @@ -86,52 +106,84 @@ impl<'a, U: Unit, H: DuplexSpongeInterface> VerifierState<'a, H, U> { /// Signals the end of the statement. #[inline] - pub fn ratchet(&mut self) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.ratchet() + pub fn ratchet(&mut self) { + self.pattern.interact(Interaction::new::<()>( + Hierarchy::Atomic, + Kind::Protocol, + "ratchet", + Length::None, + )); + self.duplex_sponge.ratchet_unchecked(); } - /// Signals the end of the statement and returns the (compressed) sponge state. - #[inline] - pub fn preprocess(self) -> Result<&'static [U], DomainSeparatorMismatch> { - self.hash_state.preprocess() + /// Abort the verifier session without completing playback. + /// + /// Any remaining expected interactions are discarded. + pub fn abort(mut self) { + self.pattern.abort(); + } + + /// Finalize the verifier session, asserting all interactions were consumed. + pub fn finalize(self) { + self.pattern.finalize(); } } impl, U: Unit> UnitTranscript for VerifierState<'_, H, U> { /// Add native elements to the sponge without writing them to the NARG string. #[inline] - fn public_units(&mut self, input: &[U]) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.absorb(input) + fn public_units(&mut self, input: &[U]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Public, + "public_units", + Length::Fixed(input.len()), + )); + self.duplex_sponge.absorb_unchecked(input); } /// Fill `input` with units sampled uniformly at random. #[inline] - fn fill_challenge_units(&mut self, input: &mut [U]) -> Result<(), DomainSeparatorMismatch> { - self.hash_state.squeeze(input) + fn fill_challenge_units(&mut self, input: &mut [U]) { + self.pattern.interact(Interaction::new::( + Hierarchy::Atomic, + Kind::Challenge, + "fill_challenge_units", + Length::Fixed(input.len()), + )); + self.duplex_sponge.squeeze_unchecked(input); } } impl, U: Unit> core::fmt::Debug for VerifierState<'_, H, U> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_tuple("VerifierState") - .field(&self.hash_state) - .finish() + f.debug_tuple("VerifierState").field(&self.pattern).finish() } } impl> BytesToUnitDeserialize for VerifierState<'_, H, u8> { /// Read the next `input.len()` bytes from the NARG string and return them. #[inline] - fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), DomainSeparatorMismatch> { - self.fill_next_units(input) + fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), std::io::Error> { + self.pattern + .begin_message::("bytes", Length::Fixed(input.len())); + self.fill_next_units(input)?; + self.pattern + .end_message::("bytes", Length::Fixed(input.len())); + Ok(()) } } #[cfg(test)] mod tests { - use std::{cell::RefCell, rc::Rc}; + use std::{cell::RefCell, rc::Rc, sync::Arc}; use super::*; + use crate::{ + codecs::{bytes::Pattern as _, unit::Pattern}, + pattern::PatternState, + ProverState, + }; #[derive(Default, Clone)] pub struct DummySponge { @@ -184,152 +236,176 @@ mod tests { #[test] fn test_new_verifier_state_constructs_correctly() { - let ds = DomainSeparator::::new("test"); + let pattern = PatternState::::new().finalize(); let transcript = b"abc"; - let vs = VerifierState::::new(&ds, transcript); + let vs = VerifierState::::new(Arc::new(pattern), transcript); assert_eq!(vs.narg_string, b"abc"); + vs.finalize(); } #[test] fn test_fill_next_units_reads_and_absorbs() { - let ds = DomainSeparator::::new("x").absorb(3, "input"); - let mut vs = VerifierState::::new(&ds, b"abc"); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 3); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), b"abc"); let mut buf = [0u8; 3]; - let res = vs.fill_next_units(&mut buf); - assert!(res.is_ok()); + assert!(vs.fill_next_units(&mut buf).is_ok()); assert_eq!(buf, *b"abc"); - assert_eq!(*vs.hash_state.ds().absorbed.borrow(), b"abc"); + assert_eq!(*vs.duplex_sponge.absorbed.borrow(), b"abc"); + vs.finalize(); } #[test] fn test_fill_next_units_with_insufficient_data_errors() { - let ds = DomainSeparator::::new("x").absorb(4, "fail"); - let mut vs = VerifierState::::new(&ds, b"xy"); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 4); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), b"xy"); let mut buf = [0u8; 4]; - let res = vs.fill_next_units(&mut buf); - assert!(res.is_err()); + assert!(vs.fill_next_units(&mut buf).is_err()); + vs.abort(); } #[test] fn test_ratcheting_success() { - let ds = DomainSeparator::::new("x").ratchet(); - let mut vs = VerifierState::::new(&ds, &[]); - assert!(vs.ratchet().is_ok()); - assert!(*vs.hash_state.ds().ratcheted.borrow()); + let mut pattern = PatternState::::new(); + pattern.ratchet(); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), &[]); + vs.ratchet(); + assert!(*vs.duplex_sponge.ratcheted.borrow()); + vs.finalize(); } #[test] + #[should_panic( + expected = "Received interaction Atomic Protocol ratchet None (), but expected Atomic Message units Fixed(1) u8" + )] fn test_ratcheting_wrong_op_errors() { - let ds = DomainSeparator::::new("x").absorb(1, "wrong"); - let mut vs = VerifierState::::new(&ds, b"z"); - assert!(vs.ratchet().is_err()); + let mut pattern = PatternState::::new(); + pattern.message_units("units", 1); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), &[]); + vs.ratchet(); } #[test] fn test_unit_transcript_public_units() { - let ds = DomainSeparator::::new("x").absorb(2, "public"); - let mut vs = VerifierState::::new(&ds, b".."); - assert!(vs.public_units(&[1, 2]).is_ok()); - assert_eq!(*vs.hash_state.ds().absorbed.borrow(), &[1, 2]); + let mut pattern = PatternState::::new(); + pattern.public_units("public_units", 2); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), b".."); + vs.public_units(&[1, 2]); + assert_eq!(*vs.duplex_sponge.absorbed.borrow(), &[1, 2]); + vs.finalize(); } #[test] fn test_unit_transcript_fill_challenge_units() { - let ds = DomainSeparator::::new("x").squeeze(4, "c"); - let mut vs = VerifierState::::new(&ds, b"abcd"); + let mut pattern = PatternState::::new(); + pattern.challenge_units("fill_challenge_units", 4); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), b"abcd"); let mut out = [0u8; 4]; - assert!(vs.fill_challenge_units(&mut out).is_ok()); + vs.fill_challenge_units(&mut out); assert_eq!(out, [0, 1, 2, 3]); + vs.finalize(); } #[test] fn test_fill_next_bytes_impl() { - let ds = DomainSeparator::::new("x").absorb(3, "byte"); - let mut vs = VerifierState::::new(&ds, b"xyz"); + let mut pattern = PatternState::::new(); + pattern.message_bytes("bytes", 3); + let pattern = pattern.finalize(); + + let mut vs = VerifierState::::new(Arc::new(pattern), b"xyz"); let mut out = [0u8; 3]; assert!(vs.fill_next_bytes(&mut out).is_ok()); assert_eq!(out, *b"xyz"); + vs.finalize(); } #[test] fn test_hint_bytes_verifier_valid_hint() { - // Domain separator commits to a hint - let domsep: DomainSeparator = DomainSeparator::new("valid").hint("hint"); - - let mut prover = domsep.to_prover_state(); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); let hint = b"abc123"; - prover.hint_bytes(hint).unwrap(); - - let narg = prover.narg_string(); + let mut prover: ProverState = ProverState::from(&pattern); + prover.hint_bytes(hint); + let narg = prover.finalize(); + assert_eq!(hex::encode(&narg), "06000000616263313233"); - let mut verifier = domsep.to_verifier_state(narg); - let result = verifier.hint_bytes().unwrap(); + let mut vs: VerifierState = VerifierState::new(Arc::new(pattern.clone()), &narg); + let result = vs.hint_bytes().unwrap(); assert_eq!(result, hint); + vs.finalize(); } #[test] fn test_hint_bytes_verifier_empty_hint() { - // Commit to a hint instruction - let domsep: DomainSeparator = DomainSeparator::new("empty").hint("hint"); - - let mut prover = domsep.to_prover_state(); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); let hint = b""; - prover.hint_bytes(hint).unwrap(); - - let narg = prover.narg_string(); + let mut prover: ProverState = ProverState::from(&pattern); + prover.hint_bytes(hint); + let narg = prover.finalize(); - let mut verifier = domsep.to_verifier_state(narg); - let result = verifier.hint_bytes().unwrap(); + let mut vs: VerifierState = VerifierState::new(Arc::new(pattern.clone()), &narg); + let result = vs.hint_bytes().unwrap(); assert_eq!(result, b""); + vs.finalize(); } #[test] + #[should_panic( + expected = "Received interaction, but no more expected interactions: Atomic Hint hint_bytes Dynamic u8" + )] fn test_hint_bytes_verifier_no_hint_op() { - // No hint instruction in domain separator - let domsep: DomainSeparator = DomainSeparator::new("nohint"); + let pattern = PatternState::::new().finalize(); // Manually construct a hint buffer (length = 6, followed by bytes) - let mut narg = vec![6, 0, 0, 0]; // length prefix for 6 - narg.extend_from_slice(b"abc123"); + let narg = hex::decode("06000000616263313233").unwrap(); - let mut verifier = domsep.to_verifier_state(&narg); - - assert!(verifier.hint_bytes().is_err()); + let mut vs: VerifierState = VerifierState::new(Arc::new(pattern), &narg); + vs.hint_bytes().unwrap(); } #[test] fn test_hint_bytes_verifier_length_prefix_too_short() { - // Valid hint domain separator - let domsep: DomainSeparator = DomainSeparator::new("short").hint("hint"); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); // Provide only 3 bytes, which is not enough for a u32 length let narg = &[1, 2, 3]; // less than 4 bytes - let mut verifier = domsep.to_verifier_state(narg); - - let err = verifier.hint_bytes().unwrap_err(); - assert!( - format!("{err}").contains("Insufficient transcript remaining for hint"), - "Expected error for short prefix, got: {err}" - ); + let mut vs: VerifierState = VerifierState::new(Arc::new(pattern), narg); + let err = vs.hint_bytes().unwrap_err(); + assert!(format!("{err}").contains("Insufficient transcript remaining for hint")); + vs.abort(); } #[test] fn test_hint_bytes_verifier_declared_hint_too_long() { - // Valid hint domain separator - let domsep: DomainSeparator = DomainSeparator::new("loverflow").hint("hint"); - - // Prefix says "5 bytes", but we only supply 2 - let narg = &[5, 0, 0, 0, b'a', b'b']; - - let mut verifier = domsep.to_verifier_state(narg); - - let err = verifier.hint_bytes().unwrap_err(); - assert!( - format!("{err}").contains("Insufficient transcript remaining"), - "Expected error for hint length > actual NARG bytes, got: {err}" - ); + let mut pattern = PatternState::::new(); + pattern.hint_bytes_dynamic("hint_bytes"); + let pattern = pattern.finalize(); + + let narg = [5u8, 0, 0, 0, b'a', b'b']; + let mut vs: VerifierState = VerifierState::new(Arc::new(pattern), &narg); + let err = vs.hint_bytes().unwrap_err(); + assert!(format!("{err}").contains("Insufficient transcript remaining")); + vs.abort(); } }