diff --git a/Cargo.lock b/Cargo.lock index 4c8e69187..5d91e18e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -677,6 +677,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lz4_flex" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" + [[package]] name = "matchers" version = "0.1.0" @@ -1167,6 +1173,8 @@ dependencies = [ "starknet-crypto", "starknet-ff", "std-shims", + "stwo-compact-binary", + "stwo-compact-binary-derive", "test-log", "thiserror", "tracing", @@ -1195,6 +1203,24 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "stwo-compact-binary" +version = "0.1.1" +dependencies = [ + "lz4_flex", + "starknet-ff", + "std-shims", + "unsigned-varint", +] + +[[package]] +name = "stwo-compact-binary-derive" +version = "0.1.0" +dependencies = [ + "quote", + "syn 2.0.104", +] + [[package]] name = "stwo-constraint-framework" version = "0.1.1" @@ -1389,6 +1415,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unsigned-varint" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb066959b24b5196ae73cb057f45598450d2c5f71460e98c49b738086eff9c06" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index 848e64ae5..13d16fd4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,8 @@ members = [ "crates/stwo", "crates/air-utils", "crates/air-utils-derive", + "crates/compact-binary", + "crates/compact-binary-derive", "crates/constraint-framework", "crates/examples", "crates/std-shims", @@ -32,6 +34,7 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } serde = { version = "1.0", default-features = false, features = ["derive"] } hashbrown = ">=0.15.2" std-shims = { path = "crates/std-shims", default-features = false } +unsigned-varint = "0.8" [profile.bench] codegen-units = 1 diff --git a/crates/compact-binary-derive/Cargo.toml b/crates/compact-binary-derive/Cargo.toml new file mode 100644 index 000000000..11b278cf8 --- /dev/null +++ b/crates/compact-binary-derive/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "stwo-compact-binary-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0.37" +syn = "2.0.90" diff --git a/crates/compact-binary-derive/src/lib.rs b/crates/compact-binary-derive/src/lib.rs new file mode 100644 index 000000000..629e7ec2c --- /dev/null +++ b/crates/compact-binary-derive/src/lib.rs @@ -0,0 +1,161 @@ +use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, Type}; + +/// Proc macro to automatically derive `CompactBinary` trait for structs. +#[proc_macro_derive(CompactBinary, attributes(zipped))] +pub fn derive_compact_binary(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree. + let input = parse_macro_input!(input as DeriveInput); + + let struct_name = input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + // Extract the fields of the struct. + let fields = match input.data { + Data::Struct(ref data_struct) => match &data_struct.fields { + Fields::Named(ref fields_named) => &fields_named.named, + Fields::Unnamed(_) | Fields::Unit => { + return syn::Error::new_spanned( + struct_name, + "CompactBinary can only be derived for structs with named fields.", + ) + .to_compile_error() + .into(); + } + }, + _ => { + return syn::Error::new_spanned( + struct_name, + "CompactBinary can only be derived for structs.", + ) + .to_compile_error() + .into(); + } + }; + + // Check if MerkleHasher is present in the where clause or generics + let h_is_merklehasher = input + .generics + .where_clause + .as_ref() + .map(|wc| { + wc.predicates.iter().any(|pred| { + pred.to_token_stream() + .to_string() + .contains("H: MerkleHasher") + }) + }) + .unwrap_or(false) + || input.generics.params.iter().any(|param| { + if let syn::GenericParam::Type(ty) = param { + ty.bounds + .iter() + .any(|b| b.to_token_stream().to_string().contains("MerkleHasher")) + } else { + false + } + }); + + // Check if any field requires H bounds + let needs_h_bounds = fields.iter().any(|f| { + if let Type::Path(type_path) = &f.ty { + if let Some(seg) = type_path.path.segments.last() { + if let syn::PathArguments::AngleBracketed(ref args) = seg.arguments { + return args.args.iter().any(|arg| { + if let syn::GenericArgument::Type(Type::Path(type_path)) = arg { + type_path + .path + .segments + .last() + .is_some_and(|s| s.ident == "H") + } else { + false + } + }); + } + } + } + false + }); + + let mut where_clause = where_clause.cloned(); + // If MerkleHasher is present and H bounds are needed, add the necessary bounds. + if h_is_merklehasher && needs_h_bounds { + let pred: syn::WherePredicate = + parse_quote! { H::Hash: stwo::core::compact_binary::CompactBinary }; + if let Some(ref mut wc) = where_clause { + wc.predicates.push(pred); + } else { + where_clause = Some::( + parse_quote! { where H::Hash: stwo::core::compact_binary::CompactBinary }, + ); + } + } + + // Generate code to serialize each field in the order they appear. + let compact_serialize_body = fields.iter().enumerate().map(|(i, f)| { + let field_name = &f.ident; + let field_type = &f.ty; + let is_zipped = f.attrs.iter().any(|attr| attr.path().is_ident("zipped")); + match is_zipped { + true => { + quote! { + usize::compact_serialize(&#i, output)?; + let #field_name = stwo::core::compact_binary::ZippedCompactBinary(&self.#field_name); + stwo::core::compact_binary::ZippedCompactBinary::<&#field_type>::compact_serialize(&#field_name, output)?; + } + } + false => { + quote! { + usize::compact_serialize(&#i, output)?; + stwo::core::compact_binary::CompactBinary::compact_serialize(&self.#field_name, output)?; + } + } + } + }); + + // Generate code to deserialize each field in the order they appear. + let compact_deserialize_let_bindings = fields.iter().enumerate().map(|(i, f)| { + let field_name = &f.ident; + let field_type = &f.ty; + let is_zipped = f.attrs.iter().any(|attr| attr.path().is_ident("zipped")); + match is_zipped { + true => { + quote! { + let input = stwo::core::compact_binary::strip_expected_tag(input, #i)?; + let (input, #field_name) = stwo::core::compact_binary::ZippedCompactBinary::<&#field_type>::compact_deserialize(input)?; + } + } + false => { + quote! { + let input = stwo::core::compact_binary::strip_expected_tag(input, #i)?; + let (input, #field_name) = stwo::core::compact_binary::CompactBinary::compact_deserialize(input)?; + } + } + } + }); + let compact_deserialize_struct_fields = fields.iter().map(|f| { + let field_name = &f.ident; + quote! { #field_name } + }); + + // Implement `CompactBinary` for the type. + let expanded = quote! { + impl #impl_generics stwo::core::compact_binary::CompactBinary for #struct_name #ty_generics #where_clause { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), stwo::core::compact_binary::CompactSerializeError> { + u32::compact_serialize(&0, output)?; + #(#compact_serialize_body)* + Ok(()) + } + + fn compact_deserialize<'a>(mut input: &'a [u8]) -> Result<(&'a [u8], Self), stwo::core::compact_binary::CompactDeserializeError> { + let input = stwo::core::compact_binary::strip_expected_version(input, 0)?; + #(#compact_deserialize_let_bindings)* + Ok((input, Self { #(#compact_deserialize_struct_fields),* })) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/crates/compact-binary/Cargo.toml b/crates/compact-binary/Cargo.toml new file mode 100644 index 000000000..5a94f8d5d --- /dev/null +++ b/crates/compact-binary/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "stwo-compact-binary" +version.workspace = true +edition.workspace = true + +[features] +default = ["std"] +std = [ + "std-shims/std", +] + +[dependencies] +starknet-ff = { version = "0.3.7", default-features = false, features = [ + "alloc", + "serde", +] } +unsigned-varint.workspace = true +lz4_flex = { version = "0.11", default-features = false } +std-shims.workspace = true + +[lib] +bench = false diff --git a/crates/compact-binary/README.md b/crates/compact-binary/README.md new file mode 100644 index 000000000..1dcc448d9 --- /dev/null +++ b/crates/compact-binary/README.md @@ -0,0 +1,61 @@ +# Compact binary format specs + +## Needs + +Stwo proofs can be serialized in two formats: + +- `json`: each field of each struct is serialized as base json format. +- `cairo-serde`: the proof is first converted into a `Vec`, that is then serialized in a serde json format. + +This crate implements a third proof serialization format, `compact-binary`, where the proof is serialized as a `Vec` in a compact way. + +## Format description + +- Integers (`u32`, `u64` and `usize`) should be handled as VarInts +- Relevant fields should be compactified if possible or compressed. +- Structured data should have: + - versions numbers, to be able to update the structure + - tags for each field, to be able to add new fields easily + +## Versioning description + +If we want to add or change a field of a struct `StructA`, while still being able to deserialize previous versions of this struct, we should: + +- Update `compact_serialize()` to serialize a new version number, and serialize the new struct +- Update `compact_deserialize()` to: + - Get the version of the deserialized struct + - Match on it and dispatch to the deserialization logic corresponding to this version + +## Implementation + +The current implementation consists of the following elements: + +- A `CompactBinary` trait in `crates/compact-binary/src/lib.rs`, along with helper functions and implementations for base structures. +- A `#[derive(CompactBinary)]` proc macro to implement the trait for structures composed of fields implementing it. Note that the proc macro is only expected to produce a `0` version, if a given structure is to be updated it's implementation should be done manually, while keeping back-compatibility of all previous serialization versions for this structure. See `crates/compact-binary-derive/src/lib.rs` +Note that the proc macro supports the `#[zipped]` attribute to specify that a given field should be zipped (compressed with LZ4 compression). +- Error handling through `CompactDeserializeError` enum and `CompactSerializeError` struct +- Implementations of the `CompactBinary` the trait for structures in `stwo` crate used for CairoProofs. + +In the [stwo-cairo repository](https://github.com/starkware-libs/stwo-cairo): + +- A `CompactBinary` implementation for `CairoProof` +- Argument handling for proof and verification in the CLI (added `--proof-format compact-binary`). See `cairo-prove/src/main.rs` and `cairo-prove/src/args.rs`. + +## Tests and benchmarks + +We execute the example proof in the [stwo-cairo repository](https://github.com/starkware-libs/stwo-cairo): +```bash +cairo-prove/target/release/cairo-prove prove cairo-prove/example/target/dev/example.executable.json ./example_proof.compact_bin --arguments 10000 --proof-format compact-binary +cairo-prove/target/release/cairo-prove verify ./example_proof.compact_bin --proof-format compact-binary +``` + +**Note that we've adapted the serialize_proof_to_file() function to use serde_json without a JSON prettier to have more accurate results** + +For this example proof, here are the results: + +| File | Format | Size on disk (bytes) | Gain | +|------------------------------------|-------------------|---------------------:|---------:| +| example_proof.base_json | json | 2 528 114 | -- | +| example_proof.cairo_serde | cairo-serde | 2 448 494 | - 3.1 % | +| example_proof.compact_bin_unzipped | compact-binary | 834 606 | - 67.0 % | +| example_proof.compact_bin_zipped | compact-binary | 582 932 | - 76.9 % | diff --git a/crates/compact-binary/src/lib.rs b/crates/compact-binary/src/lib.rs new file mode 100644 index 000000000..61d25111d --- /dev/null +++ b/crates/compact-binary/src/lib.rs @@ -0,0 +1,276 @@ +#![cfg_attr(not(feature = "std"), no_std)] +use core::array; + +use lz4_flex::{compress_prepend_size, decompress_size_prepended}; +use starknet_ff::FieldElement; +use std_shims::Vec; +use unsigned_varint::encode::{u32_buffer, u64_buffer, usize_buffer}; +use unsigned_varint::{decode, encode}; + +/// Trait for types that can be serialized and deserialized in a compact binary format. +/// +/// ## Format guidelines +/// - Integers (`u32`, `u64`, and `usize`) should be handled as VarInts. +/// - Relevant `FieldElement` fields should be compactified if possible, or compressed +/// - Structured data should have: +/// - version numbers, to be able to update the structure +/// - tags for each field, to be able to add new fields +/// +/// ## Struct Versioning +/// If we want to add or change a field of a struct `StructA`, while still being able to deserialize +/// previous versions of this struct, we should: +/// - Update `compact_serialize()` to serialize a new version number, and serialize the new struct +/// - Update `compact_deserialize()` to: +/// - Get the version of the deserialized struct +/// - Match on it and dispatch to the deserialization logic corresponding to this version +/// +/// ## Derive proc macro +/// The `#[derive(CompactBinary)]` proc macro can be used to implement the trait for structures +/// composed of fields implementing it. Note that the proc macro is only expected to produce a `0` +/// version, if a given structure is to be updated it's implementation should be done manually, +/// while keeping back-compatibility of all previous serialization versions for this structure. +/// The `#[zipped]` attribute can be used to specify that a given field should be zipped. +pub trait CompactBinary { + /// Serializes the object into a compact binary format. + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError>; + + /// Deserializes the object from a compact binary format. + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> + where + Self: Sized; +} + +/// Error enum for CompactBinary deserialization. +#[derive(Debug)] +pub enum CompactDeserializeError { + /// UnexpectedVersion(got, expected), + UnexpectedVersion(u32, u32), + /// UnexpectedTag(got, expected), + UnexpectedTag(usize, usize), + /// Generic decode error, e.g. when the input is malformed. + DecodeError, +} + +/// Error struct for CompactBinary serialization. +#[derive(Debug)] +pub struct CompactSerializeError; + +/// Helper function to convert a byte slice into an array of a specific size from a closure if +/// possible. +pub fn buf_to_array_ctr V, V, const N: usize>( + buf: &[u8], + ctr: F, +) -> Option<(&[u8], V)> { + Some((&buf[N..], ctr(&buf.get(..N)?.try_into().ok()?))) +} + +/// Helper function to deserialize a struct's version and check it against an expected value. +pub fn strip_expected_version( + input: &[u8], + expected_version: u32, +) -> Result<&[u8], CompactDeserializeError> { + let (input, version) = u32::compact_deserialize(input)?; + if version != expected_version { + return Err(CompactDeserializeError::UnexpectedVersion( + version, + expected_version, + )); + } + Ok(input) +} + +/// Helper function to deserialize a field's tag and check it against an expected value. +pub fn strip_expected_tag( + input: &[u8], + expected_tag: usize, +) -> Result<&[u8], CompactDeserializeError> { + let (input, tag) = usize::compact_deserialize(input)?; + if tag != expected_tag { + return Err(CompactDeserializeError::UnexpectedTag(tag, expected_tag)); + } + Ok(input) +} + +/// A wrapper type for zipping and unzipping data during serialization and deserialization. +pub struct ZippedCompactBinary(pub T); + +impl ZippedCompactBinary<&T> { + pub fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let mut unzipped_data = Vec::new(); + T::compact_serialize(self.0, &mut unzipped_data)?; + let zipped_data = zip_bytes(&unzipped_data); + usize::compact_serialize(&zipped_data.len(), output)?; + output.extend_from_slice(&zipped_data); + Ok(()) + } + + pub fn compact_deserialize(input: &[u8]) -> Result<(&[u8], T), CompactDeserializeError> { + let (input, len) = usize::compact_deserialize(input)?; + let (zipped_data, input) = input.split_at(len); + let unzipped_data = unzip_bytes(zipped_data)?; + let (_, data) = T::compact_deserialize(&unzipped_data)?; + Ok((input, data)) + } +} + +/// Helper function for compressing bytes with LZ4 compression. +fn zip_bytes(input: &[u8]) -> Vec { + compress_prepend_size(input) +} + +/// Helper function for decompressing bytes with LZ4 decompression. +fn unzip_bytes(input: &[u8]) -> Result, CompactDeserializeError> { + decompress_size_prepended(input).map_err(|_| CompactDeserializeError::DecodeError) +} + +impl CompactBinary for u32 { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(encode::u32(*self, &mut u32_buffer())); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (value, input) = + decode::u32(input).map_err(|_| CompactDeserializeError::DecodeError)?; + Ok((input, value)) + } +} + +impl CompactBinary for u64 { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(encode::u64(*self, &mut u64_buffer())); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (value, input) = + decode::u64(input).map_err(|_| CompactDeserializeError::DecodeError)?; + Ok((input, value)) + } +} + +impl CompactBinary for usize { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(encode::usize(*self, &mut usize_buffer())); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (value, input) = + decode::usize(input).map_err(|_| CompactDeserializeError::DecodeError)?; + Ok((input, value)) + } +} + +impl CompactBinary for Option { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + if let Some(value) = self { + output.push(b'1'); + value.compact_serialize(output)?; + } else { + output.push(b'0'); + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (first, input) = if input.is_empty() { + Err(CompactDeserializeError::DecodeError) + } else { + Ok((input[0], &input[1..])) + }?; + if first == b'1' { + let (input, value) = T::compact_deserialize(input)?; + Ok((input, Some(value))) + } else { + Ok((input, None)) + } + } +} + +impl CompactBinary for [T; N] { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + for v in self { + v.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let mut input = input; + let mut values = Vec::with_capacity(N); + for _ in 0..N { + let (updated_input, value) = T::compact_deserialize(input)?; + input = updated_input; + values.push(value); + } + Ok((input, array::from_fn(|i| values[i].clone()))) + } +} + +impl CompactBinary for Vec { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + self.len().compact_serialize(output)?; + for v in self { + v.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (mut input, len) = usize::compact_deserialize(input)?; + let mut values = Vec::with_capacity(len); + for _ in 0..len { + let (updated_input, value) = T::compact_deserialize(input)?; + input = updated_input; + values.push(value); + } + Ok((input, values)) + } +} + +impl CompactBinary for (T0, T1) { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let (v0, v1) = self; + v0.compact_serialize(output)?; + v1.compact_serialize(output)?; + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, v0) = T0::compact_deserialize(input)?; + let (input, v1) = T1::compact_deserialize(input)?; + Ok((input, (v0, v1))) + } +} + +impl CompactBinary for (T0, T1, T2) { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let (v0, v1, v2) = self; + v0.compact_serialize(output)?; + v1.compact_serialize(output)?; + v2.compact_serialize(output)?; + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, v0) = T0::compact_deserialize(input)?; + let (input, v1) = T1::compact_deserialize(input)?; + let (input, v2) = T2::compact_deserialize(input)?; + Ok((input, (v0, v1, v2))) + } +} + +impl CompactBinary for FieldElement { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(&self.to_bytes_be()); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, field_elem) = buf_to_array_ctr(input, FieldElement::from_bytes_be) + .ok_or(CompactDeserializeError::DecodeError)?; + let field_elem = field_elem.map_err(|_| CompactDeserializeError::DecodeError)?; + Ok((input, field_elem)) + } +} diff --git a/crates/stwo/Cargo.toml b/crates/stwo/Cargo.toml index 23b7bc83f..f07cf4b41 100644 --- a/crates/stwo/Cargo.toml +++ b/crates/stwo/Cargo.toml @@ -8,6 +8,7 @@ default = ["std"] std = [ "blake2/std", "blake3/std", + "stwo-compact-binary/std", "hex/std", "itertools/use_std", "indexmap/std", @@ -51,6 +52,8 @@ serde.workspace = true tracing-subscriber.workspace = true hashbrown.workspace = true std-shims.workspace = true +stwo-compact-binary = { path = "../compact-binary", default-features = false } +stwo-compact-binary-derive = { path = "../compact-binary-derive" } [dev-dependencies] aligned = "0.4.2" diff --git a/crates/stwo/src/core/compact_binary/mod.rs b/crates/stwo/src/core/compact_binary/mod.rs new file mode 100644 index 000000000..43500253f --- /dev/null +++ b/crates/stwo/src/core/compact_binary/mod.rs @@ -0,0 +1,384 @@ +use std_shims::{vec, Vec}; +pub use stwo_compact_binary::{ + buf_to_array_ctr, strip_expected_tag, strip_expected_version, CompactBinary, + CompactDeserializeError, CompactSerializeError, ZippedCompactBinary, +}; +pub use stwo_compact_binary_derive::CompactBinary; + +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::{BaseField, P}; +use crate::core::fields::qm31::SecureField; +use crate::core::fri::{FriConfig, FriLayerProof, FriProof}; +use crate::core::pcs::quotients::CommitmentSchemeProof; +use crate::core::pcs::{PcsConfig, TreeVec}; +use crate::core::poly::line::LinePoly; +use crate::core::proof::StarkProof; +use crate::core::vcs::blake2_hash::Blake2sHash; +use crate::core::vcs::verifier::MerkleDecommitment; +use crate::core::vcs::MerkleHasher; +use crate::core::ColumnVec; + +#[cfg(test)] +mod tests; + +impl CompactBinary for BaseField { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(&self.0.to_be_bytes()); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, u32_value) = buf_to_array_ctr(input, |v| u32::from_be_bytes(*v)) + .ok_or(CompactDeserializeError::DecodeError)?; + + if u32_value > P { + Err(CompactDeserializeError::DecodeError) + } else { + Ok((input, BaseField::from_u32_unchecked(u32_value))) + } + } +} + +impl CompactBinary for CM31 { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + self.0.compact_serialize(output)?; + self.1.compact_serialize(output)?; + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, u32_value_0) = BaseField::compact_deserialize(input)?; + let (input, u32_value_1) = BaseField::compact_deserialize(input)?; + Ok((input, CM31::from_m31(u32_value_0, u32_value_1))) + } +} + +impl CompactBinary for SecureField { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + self.0.compact_serialize(output)?; + self.1.compact_serialize(output)?; + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, m31_value_0) = BaseField::compact_deserialize(input)?; + let (input, m31_value_1) = BaseField::compact_deserialize(input)?; + let (input, m31_value_2) = BaseField::compact_deserialize(input)?; + let (input, m31_value_3) = BaseField::compact_deserialize(input)?; + Ok(( + input, + SecureField::from_m31(m31_value_0, m31_value_1, m31_value_2, m31_value_3), + )) + } +} + +impl CompactBinary for MerkleDecommitment +where + H::Hash: CompactBinary, +{ + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + hash_witness, + column_witness, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = + vec![(0, hash_witness), (1, column_witness)]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, hash_witness) = Vec::::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, column_witness) = Vec::::compact_deserialize(input)?; + Ok(( + input, + MerkleDecommitment { + hash_witness, + column_witness, + }, + )) + } +} + +impl CompactBinary for LinePoly { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let coeffs = self.clone().into_ordered_coefficients(); + coeffs.len().compact_serialize(output)?; + for coeff in &coeffs { + coeff.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (mut input, len) = usize::compact_deserialize(input)?; + let mut coeffs = Vec::with_capacity(len); + for _ in 0..len { + let (updated_input, coeff) = SecureField::compact_deserialize(input)?; + input = updated_input; + coeffs.push(coeff); + } + Ok((input, LinePoly::from_ordered_coefficients(coeffs))) + } +} + +impl CompactBinary for FriLayerProof +where + H::Hash: CompactBinary, +{ + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + fri_witness, + decommitment, + commitment, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = + vec![(0, fri_witness), (1, decommitment), (2, commitment)]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, fri_witness) = Vec::::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, decommitment) = MerkleDecommitment::compact_deserialize(input)?; + let input = strip_expected_tag(input, 2)?; + let (input, commitment) = H::Hash::compact_deserialize(input)?; + Ok(( + input, + FriLayerProof { + fri_witness, + decommitment, + commitment, + }, + )) + } +} + +impl CompactBinary for FriProof +where + H::Hash: CompactBinary, +{ + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + first_layer, + inner_layers, + last_layer_poly, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = + vec![(0, first_layer), (1, inner_layers), (2, last_layer_poly)]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, first_layer) = FriLayerProof::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, inner_layers) = Vec::>::compact_deserialize(input)?; + let input = strip_expected_tag(input, 2)?; + let (input, last_layer_poly) = LinePoly::compact_deserialize(input)?; + Ok(( + input, + FriProof { + first_layer, + inner_layers, + last_layer_poly, + }, + )) + } +} + +impl CompactBinary for FriConfig { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + log_blowup_factor, + log_last_layer_degree_bound, + n_queries, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = vec![ + (0, log_blowup_factor), + (1, log_last_layer_degree_bound), + (2, n_queries), + ]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, log_blowup_factor) = u32::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, log_last_layer_degree_bound) = u32::compact_deserialize(input)?; + let input = strip_expected_tag(input, 2)?; + let (input, n_queries) = usize::compact_deserialize(input)?; + Ok(( + input, + FriConfig { + log_blowup_factor, + log_last_layer_degree_bound, + n_queries, + }, + )) + } +} + +impl CompactBinary for PcsConfig { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + pow_bits, + fri_config, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = vec![(0, pow_bits), (1, fri_config)]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, pow_bits) = u32::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, fri_config) = FriConfig::compact_deserialize(input)?; + Ok(( + input, + PcsConfig { + pow_bits, + fri_config, + }, + )) + } +} + +impl CompactBinary for CommitmentSchemeProof +where + H::Hash: CompactBinary, +{ + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self { + config, + commitments, + sampled_values, + decommitments, + queried_values, + proof_of_work, + fri_proof, + } = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = vec![ + (0, config), + (1, &commitments.0), + (2, &sampled_values.0), + (3, &decommitments.0), + (4, &queried_values.0), + (5, proof_of_work), + (6, fri_proof), + ]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, config) = PcsConfig::compact_deserialize(input)?; + let input = strip_expected_tag(input, 1)?; + let (input, commitments) = Vec::::compact_deserialize(input)?; + let input = strip_expected_tag(input, 2)?; + let (input, sampled_values) = + Vec::>>::compact_deserialize(input)?; + let input = strip_expected_tag(input, 3)?; + let (input, decommitments) = Vec::>::compact_deserialize(input)?; + let input = strip_expected_tag(input, 4)?; + let (input, queried_values) = Vec::>::compact_deserialize(input)?; + let input = strip_expected_tag(input, 5)?; + let (input, proof_of_work) = u64::compact_deserialize(input)?; + let input = strip_expected_tag(input, 6)?; + let (input, fri_proof) = FriProof::compact_deserialize(input)?; + Ok(( + input, + CommitmentSchemeProof { + config, + commitments: TreeVec::new(commitments), + sampled_values: TreeVec::new(sampled_values), + decommitments: TreeVec::new(decommitments), + queried_values: TreeVec::new(queried_values), + proof_of_work, + fri_proof, + }, + )) + } +} + +impl CompactBinary for StarkProof +where + H::Hash: CompactBinary, +{ + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + let Self(commitment_scheme_proof) = self; + let version = 0; + let to_serialize: Vec<(usize, &dyn CompactBinary)> = vec![(0, commitment_scheme_proof)]; + u32::compact_serialize(&version, output)?; + for (tag, value) in to_serialize { + usize::compact_serialize(&tag, output)?; + value.compact_serialize(output)?; + } + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let input = strip_expected_version(input, 0)?; + let input = strip_expected_tag(input, 0)?; + let (input, commitment_scheme_proof) = CommitmentSchemeProof::compact_deserialize(input)?; + Ok((input, StarkProof(commitment_scheme_proof))) + } +} + +impl CompactBinary for Blake2sHash { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + output.extend_from_slice(&self.0); + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, hash) = buf_to_array_ctr(input, |v| Blake2sHash(*v)) + .ok_or(CompactDeserializeError::DecodeError)?; + Ok((input, hash)) + } +} diff --git a/crates/stwo/src/core/compact_binary/tests.rs b/crates/stwo/src/core/compact_binary/tests.rs new file mode 100644 index 000000000..e05f85b8b --- /dev/null +++ b/crates/stwo/src/core/compact_binary/tests.rs @@ -0,0 +1,250 @@ +use num_traits::One; +use stwo_compact_binary::CompactBinary; + +use super::*; +use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; + +#[test] +fn test_base_field_serialization() { + let m1 = BaseField::from_u32_unchecked(42); + let m2 = BaseField::from_u32_unchecked(100); + let m3 = BaseField::from_u32_unchecked(0); + let m4 = BaseField::from_u32_unchecked(31); + let field = SecureField::from_m31(m1, m2, m3, m4); + let mut output = Vec::new(); + field.compact_serialize(&mut output).unwrap(); + let (remaining, deserialized) = SecureField::compact_deserialize(&output).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(deserialized, field); +} + +#[test] +fn test_pcs_config_serialization() { + let pcs_config = PcsConfig { + pow_bits: 5, + fri_config: FriConfig::new(0, 1, 3), + }; + + let mut output = Vec::new(); + pcs_config.compact_serialize(&mut output).unwrap(); + let (remaining, deserialized) = PcsConfig::compact_deserialize(&output).unwrap(); + assert!(remaining.is_empty()); + assert_eq_pcs_config(&pcs_config, &deserialized); +} + +#[test] +fn test_proof_serialization() { + let stark_proof: StarkProof = StarkProof(CommitmentSchemeProof { + config: PcsConfig::default(), + commitments: TreeVec::new(vec![Blake2sHash([0; 32])]), + sampled_values: TreeVec::new(vec![]), + decommitments: TreeVec::new(vec![MerkleDecommitment { + hash_witness: vec![Blake2sHash([0; 32])], + column_witness: vec![BaseField::one()], + }]), + queried_values: TreeVec::new(vec![vec![BaseField::one()]]), + proof_of_work: 42, + fri_proof: FriProof { + first_layer: FriLayerProof { + fri_witness: vec![SecureField::one()], + decommitment: MerkleDecommitment { + hash_witness: vec![Blake2sHash([0; 32])], + column_witness: vec![BaseField::one()], + }, + commitment: Blake2sHash([0; 32]), + }, + inner_layers: vec![], + last_layer_poly: LinePoly::from_ordered_coefficients(vec![SecureField::one()]), + }, + }); + + let mut output = Vec::new(); + stark_proof.compact_serialize(&mut output).unwrap(); + let (remaining, deserialized) = + StarkProof::::compact_deserialize(&output).unwrap(); + assert!(remaining.is_empty()); + + assert_eq_pcs_config(&deserialized.0.config, &stark_proof.0.config); + + assert_eq!(deserialized.0.commitments.0, stark_proof.0.commitments.0); + assert_eq!( + deserialized.0.sampled_values.0, + stark_proof.0.sampled_values.0 + ); + assert_eq!( + deserialized.0.decommitments.0, + stark_proof.0.decommitments.0 + ); + assert_eq!( + deserialized.0.queried_values.0, + stark_proof.0.queried_values.0 + ); + assert_eq!(deserialized.0.proof_of_work, stark_proof.0.proof_of_work); + + assert_eq_fri_proof(&deserialized.0.fri_proof, &stark_proof.0.fri_proof); +} + +fn assert_eq_pcs_config(pcs_config1: &PcsConfig, pcs_config2: &PcsConfig) { + assert_eq!(pcs_config1.pow_bits, pcs_config2.pow_bits); + assert_eq!( + pcs_config1.fri_config.log_blowup_factor, + pcs_config2.fri_config.log_blowup_factor + ); + assert_eq!( + pcs_config1.fri_config.log_last_layer_degree_bound, + pcs_config2.fri_config.log_last_layer_degree_bound + ); + assert_eq!( + pcs_config1.fri_config.n_queries, + pcs_config2.fri_config.n_queries + ); +} + +fn assert_eq_fri_proof( + fri_proof1: &FriProof, + fri_proof2: &FriProof, +) { + assert_eq_fri_layer(&fri_proof1.first_layer, &fri_proof2.first_layer); + + assert_eq!(fri_proof1.inner_layers.len(), fri_proof2.inner_layers.len()); + for (layer1, layer2) in fri_proof1 + .inner_layers + .iter() + .zip(fri_proof2.inner_layers.iter()) + { + assert_eq_fri_layer(layer1, layer2); + } + assert_eq!(fri_proof1.last_layer_poly, fri_proof2.last_layer_poly); +} + +fn assert_eq_fri_layer( + fri_layer1: &FriLayerProof, + fri_layer2: &FriLayerProof, +) { + assert_eq!(fri_layer1.fri_witness, fri_layer2.fri_witness); + assert_eq!(fri_layer1.decommitment, fri_layer2.decommitment); + assert_eq!(fri_layer1.commitment, fri_layer2.commitment); +} + +// The tests in this module are for the `CompactBinary` derive macro: +// - a base test +// - a test with a zipped field +// - a test with a generic type `H: MerkleHasher` that requires +// the bound `H:Hash: CompactBinary` to be implemented +mod tests_derive { + use std_shims::Vec; + + use crate::core::compact_binary::{ + CompactBinary, CompactDeserializeError, CompactSerializeError, + }; + use crate::core::fields::m31::BaseField; + use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; + use crate::core::vcs::MerkleHasher; + use crate::{self as stwo}; + + // We first define three structs `TestStruct`, `TestStructZipped` and `TestStructGeneric` to + // test the derive macro + #[derive(CompactBinary, Debug, PartialEq, Eq)] + struct TestStruct { + base: [BaseField; 64], + } + + #[derive(CompactBinary, Debug, PartialEq, Eq)] + struct TestStructZipped { + #[zipped] + base: [BaseField; 64], + } + + #[derive(CompactBinary, Debug, PartialEq, Eq)] + struct TestStructGeneric { + hashed_base: HashedBaseField, + } + + #[derive(Debug, PartialEq, Eq)] + struct HashedBaseField { + base: BaseField, + hash: H::Hash, + } + + impl HashedBaseField { + fn new(base: BaseField) -> Self { + let hash = H::hash_node(None, &[base]); + Self { base, hash } + } + } + + impl CompactBinary for HashedBaseField + where + H::Hash: CompactBinary, + { + fn compact_serialize(&self, output: &mut Vec) -> Result<(), CompactSerializeError> { + self.base.compact_serialize(output)?; + self.hash.compact_serialize(output)?; + Ok(()) + } + + fn compact_deserialize(input: &[u8]) -> Result<(&[u8], Self), CompactDeserializeError> { + let (input, base) = BaseField::compact_deserialize(input)?; + let (input, hash) = H::Hash::compact_deserialize(input)?; + Ok((input, HashedBaseField { base, hash })) + } + } + + #[test] + fn test_proc_macro() { + let test_instance = TestStruct { + base: [BaseField::from_u32_unchecked(1654); 64], + }; + + let mut output = Vec::new(); + test_instance.compact_serialize(&mut output).unwrap(); + let (remaining, deserialized) = TestStruct::compact_deserialize(&output).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(deserialized, test_instance); + } + + #[test] + fn test_proc_macro_zipped_field() { + let test_instance_unzipped = TestStruct { + base: [BaseField::from_u32_unchecked(1654); 64], + }; + let mut output_unzipped = Vec::new(); + test_instance_unzipped + .compact_serialize(&mut output_unzipped) + .unwrap(); + + let test_instance_zipped = TestStructZipped { + base: [BaseField::from_u32_unchecked(1654); 64], + }; + let mut output_zipped = Vec::new(); + test_instance_zipped + .compact_serialize(&mut output_zipped) + .unwrap(); + + assert!( + output_zipped.len() < output_unzipped.len(), + "Zipped output should be smaller (on redundant data)" + ); + + let (remaining, deserialized) = + TestStructZipped::compact_deserialize(&output_zipped).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(deserialized, test_instance_zipped); + } + + #[test] + fn test_proc_macro_generic() { + let test_instance_generic: TestStructGeneric = TestStructGeneric { + hashed_base: HashedBaseField::new(BaseField::from_u32_unchecked(1)), + }; + + let mut output = Vec::new(); + test_instance_generic + .compact_serialize(&mut output) + .unwrap(); + let (remaining, deserialized) = + TestStructGeneric::::compact_deserialize(&output).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(deserialized, test_instance_generic); + } +} diff --git a/crates/stwo/src/core/mod.rs b/crates/stwo/src/core/mod.rs index bd99c01d5..fbe8d2935 100644 --- a/crates/stwo/src/core/mod.rs +++ b/crates/stwo/src/core/mod.rs @@ -4,6 +4,7 @@ use std_shims::Vec; pub mod air; pub mod channel; pub mod circle; +pub mod compact_binary; pub mod constraints; pub mod fft; pub mod fields; diff --git a/ensure-verifier-no_std/Cargo.lock b/ensure-verifier-no_std/Cargo.lock index 35f36eb10..e49daec88 100644 --- a/ensure-verifier-no_std/Cargo.lock +++ b/ensure-verifier-no_std/Cargo.lock @@ -402,6 +402,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lz4_flex" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" + [[package]] name = "memory_units" version = "0.4.0" @@ -647,11 +653,31 @@ dependencies = [ "starknet-crypto", "starknet-ff", "std-shims", + "stwo-compact-binary", + "stwo-compact-binary-derive", "thiserror", "tracing", "tracing-subscriber", ] +[[package]] +name = "stwo-compact-binary" +version = "0.1.1" +dependencies = [ + "lz4_flex", + "starknet-ff", + "std-shims", + "unsigned-varint", +] + +[[package]] +name = "stwo-compact-binary-derive" +version = "0.1.0" +dependencies = [ + "quote", + "syn 2.0.104", +] + [[package]] name = "stwo-constraint-framework" version = "0.1.1" @@ -750,6 +776,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unsigned-varint" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb066959b24b5196ae73cb057f45598450d2c5f71460e98c49b738086eff9c06" + [[package]] name = "version_check" version = "0.9.5"