diff --git a/README.md b/README.md index a6ef559e8..4d0e63fda 100644 --- a/README.md +++ b/README.md @@ -457,6 +457,41 @@ pub enum Gender { } ``` +### Skipping Fields of Existing Types +You may skip encoding or decoding fields in existing Rust types using the `skip` +attribute. This will prevent the field from being encoded or decoded as part of +the Protobuf message but otherwise allow the field to be part of the Rust type. + +```rust,ignore +use prost; +use prost::Message; + +#[derive(Clone, PartialEq, Message)] +struct Person { + #[prost(string, tag = "1")] + pub id: String, + #[prost(skip)] + pub temp_data: String, // This field will be skipped +} +``` + +If the skipped field type does not implement `Default`, you must provide a +default value for the field using the `default` attribute. + +```rust,ignore +use prost; +use prost::Message; +use std::collections::HashMap; + +#[derive(Clone, PartialEq, Message)] +struct Person { + #[prost(string, tag = "1")] + pub id: String, + #[prost(skip, default = "HashMap::new")] + pub temp_data: HashMap, // This field will be skipped +} +``` + ## Nix The prost project maintains flakes support for local development. Once you have diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index d3922b1b4..23f093e72 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -3,6 +3,7 @@ mod map; mod message; mod oneof; mod scalar; +mod skip; use std::fmt; use std::slice; @@ -26,6 +27,8 @@ pub enum Field { Oneof(oneof::Field), /// A group field. Group(group::Field), + /// An ignored field. + Skip(skip::Field), } impl Field { @@ -36,9 +39,9 @@ impl Field { pub fn new(attrs: Vec, inferred_tag: Option) -> Result, Error> { let attrs = prost_attrs(attrs)?; - // TODO: check for ignore attribute. - - let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { + let field = if let Some(field) = skip::Field::new(&attrs)? { + Field::Skip(field) + } else if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { Field::Scalar(field) } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? { Field::Message(field) @@ -62,8 +65,6 @@ impl Field { pub fn new_oneof(attrs: Vec) -> Result, Error> { let attrs = prost_attrs(attrs)?; - // TODO: check for ignore attribute. - let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? { Field::Scalar(field) } else if let Some(field) = message::Field::new_oneof(&attrs)? { @@ -81,6 +82,7 @@ impl Field { pub fn tags(&self) -> Vec { match *self { + Field::Skip(_) => vec![], Field::Scalar(ref scalar) => vec![scalar.tag], Field::Message(ref message) => vec![message.tag], Field::Map(ref map) => vec![map.tag], @@ -92,6 +94,7 @@ impl Field { /// Returns a statement which encodes the field. pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { match *self { + Field::Skip(_) => TokenStream::default(), Field::Scalar(ref scalar) => scalar.encode(prost_path, ident), Field::Message(ref message) => message.encode(prost_path, ident), Field::Map(ref map) => map.encode(prost_path, ident), @@ -104,6 +107,7 @@ impl Field { /// value into the field. pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { match *self { + Field::Skip(_) => TokenStream::default(), Field::Scalar(ref scalar) => scalar.merge(prost_path, ident), Field::Message(ref message) => message.merge(prost_path, ident), Field::Map(ref map) => map.merge(prost_path, ident), @@ -115,6 +119,7 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the field. pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { match *self { + Field::Skip(_) => quote!(0), Field::Scalar(ref scalar) => scalar.encoded_len(prost_path, ident), Field::Map(ref map) => map.encoded_len(prost_path, ident), Field::Message(ref msg) => msg.encoded_len(prost_path, ident), @@ -126,6 +131,7 @@ impl Field { /// Returns a statement which clears the field. pub fn clear(&self, ident: TokenStream) -> TokenStream { match *self { + Field::Skip(ref skip) => skip.clear(ident), Field::Scalar(ref scalar) => scalar.clear(ident), Field::Message(ref message) => message.clear(ident), Field::Map(ref map) => map.clear(ident), diff --git a/prost-derive/src/field/skip.rs b/prost-derive/src/field/skip.rs new file mode 100644 index 000000000..0b94b14f8 --- /dev/null +++ b/prost-derive/src/field/skip.rs @@ -0,0 +1,75 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Expr, ExprLit, Lit, Meta, MetaNameValue, Path}; + +use crate::field::{set_bool, set_option, word_attr}; + +#[derive(Clone)] +pub struct Field { + pub default_fn: Option, +} + +impl Field { + pub fn new(attrs: &[Meta]) -> Result, Error> { + let mut skip = false; + let mut default_fn = None; + let mut default_lit = None; + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("skip", attr) { + set_bool(&mut skip, "duplicate skip attribute")?; + } else if let Meta::NameValue(MetaNameValue { path, value, .. }) = attr { + if path.is_ident("default") { + match value { + // There has to be a better way... + Expr::Lit(ExprLit { + lit: Lit::Str(lit), .. + }) => set_option(&mut default_lit, lit, "duplicate default attributes")?, + _ => bail!("default attribute value must be a string literal"), + }; + } else { + unknown_attrs.push(attr); + } + } else { + unknown_attrs.push(attr); + } + } + + if !skip { + return Ok(None); + } + + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for skipped field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); + } + + if let Some(lit) = default_lit { + let fn_path: Path = syn::parse_str(&lit.value()) + .map_err(|_| anyhow::anyhow!("invalid path for default function"))?; + if default_fn.is_some() { + bail!("duplicate default attribute for skipped field"); + } + default_fn = Some(fn_path); + } + + Ok(Some(Field { default_fn })) + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + let default = self.default_value(); + quote!( #ident = #default; ) + } + + pub fn default_value(&self) -> TokenStream { + if let Some(ref path) = self.default_fn { + quote! { #path() } + } else { + quote! { ::core::default::Default::default() } + } + } +} diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 2804ddfbe..c6181094b 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -84,12 +84,20 @@ fn try_message(input: TokenStream) -> Result { // We want Debug to be in declaration order let unsorted_fields = fields.clone(); + // Filter out ignored fields + fields.retain(|(_, field)| !matches!(field, Field::Skip(..))); + // Sort the fields by tag number so that fields will be encoded in tag order. // TODO: This encodes oneof fields in the position of their lowest tag, // regardless of the currently occupied variant, is that consequential? + let all_fields = unsorted_fields.clone(); + let mut active_fields = all_fields.clone(); + // Filter out skipped fields for encoding/decoding/length + active_fields.retain(|(_, field)| !matches!(field, Field::Skip(_))); + // Sort the active fields by tag number so that fields will be encoded in tag order. // See: https://protobuf.dev/programming-guides/encoding/#order - fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); - let fields = fields; + active_fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); + let fields = active_fields; if let Some(duplicate_tag) = fields .iter() @@ -128,7 +136,7 @@ fn try_message(input: TokenStream) -> Result { } }); - let struct_name = if fields.is_empty() { + let struct_name = if all_fields.is_empty() { quote!() } else { quote!( @@ -136,21 +144,28 @@ fn try_message(input: TokenStream) -> Result { ) }; - let clear = fields + let clear = all_fields .iter() .map(|(field_ident, field)| field.clear(quote!(self.#field_ident))); + // For Default implementation, use all_fields (including skipped) let default = if is_struct { - let default = fields.iter().map(|(field_ident, field)| { - let value = field.default(&prost_path); + let default = all_fields.iter().map(|(field_ident, field)| { + let value = match field { + Field::Skip(skip_field) => skip_field.default_value(), + _ => field.default(&prost_path), + }; quote!(#field_ident: #value,) }); quote! {#ident { #(#default)* }} } else { - let default = fields.iter().map(|(_, field)| { - let value = field.default(&prost_path); + let default = all_fields.iter().map(|(_, field)| { + let value = match field { + Field::Skip(skip_field) => skip_field.default_value(), + _ => field.default(&prost_path), + }; quote!(#value,) }); quote! {#ident ( diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 8321187a7..a07eee2f7 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -107,6 +107,8 @@ pub mod oneof_attributes { #[cfg(test)] mod proto3_presence; +#[cfg(test)] +mod skipped_fields; use core::fmt::Debug; diff --git a/tests/src/skipped_fields.rs b/tests/src/skipped_fields.rs new file mode 100644 index 000000000..f2967cf58 --- /dev/null +++ b/tests/src/skipped_fields.rs @@ -0,0 +1,46 @@ +//! Tests for skipping fields when using prost-derive. + +use crate::alloc::string::ToString; +use crate::check_serialize_equivalent; +use alloc::collections::BTreeMap; +use alloc::string::String; +use prost::Message; + +/// A struct with the same data as another, but with a skipped field, should be equal when encoded. +#[test] +fn skipped_field_serial_equality() { + #[derive(Clone, PartialEq, prost::Message)] + struct TypeWithoutSkippedField { + #[prost(string, tag = "1")] + value: String, + } + + fn create_hashmap() -> BTreeMap { + let mut map = BTreeMap::new(); + map.insert("key".to_string(), "value".to_string()); + map + } + + #[derive(Clone, PartialEq, prost::Message)] + struct TypeWithSkippedField { + #[prost(string, tag = "1")] + value: String, + #[prost(skip, default = "create_hashmap")] + pub temp_data: BTreeMap, // This field will be skipped + } + + let a = TypeWithoutSkippedField { + value: "hello".to_string(), + }; + let b = TypeWithSkippedField { + value: "hello".to_string(), + temp_data: create_hashmap(), + }; + + // Encoded forms should be equal + check_serialize_equivalent(&a, &b); + + // Decoded forms should be equal, with the skipped field initialized using the default attribute + let decoded = TypeWithSkippedField::decode(a.encode_to_vec().as_slice()).unwrap(); + assert_eq!(b, decoded); +}