Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,41 @@ pub enum Gender {
}
```

### Skipping Fields of Existing Types
You may skip encoding or decoding fields in existing Rust types using the `skip`
attribute. This will prevent the field from being encoded or decoded as part of
the Protobuf message but otherwise allow the field to be part of the Rust type.

```rust,ignore
use prost;
use prost::Message;

#[derive(Clone, PartialEq, Message)]
struct Person {
#[prost(string, tag = "1")]
pub id: String,
#[prost(skip)]
pub temp_data: String, // This field will be skipped
}
```

If the skipped field type does not implement `Default`, you must provide a
default value for the field using the `default` attribute.

```rust,ignore
use prost;
use prost::Message;
use std::collections::HashMap;

#[derive(Clone, PartialEq, Message)]
struct Person {
#[prost(string, tag = "1")]
pub id: String,
#[prost(skip, default = "HashMap::new")]
pub temp_data: HashMap<String, String>, // This field will be skipped
}
```

## Nix

The prost project maintains flakes support for local development. Once you have
Expand Down
16 changes: 11 additions & 5 deletions prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod map;
mod message;
mod oneof;
mod scalar;
mod skip;

use std::fmt;
use std::slice;
Expand All @@ -26,6 +27,8 @@ pub enum Field {
Oneof(oneof::Field),
/// A group field.
Group(group::Field),
/// An ignored field.
Skip(skip::Field),
}

impl Field {
Expand All @@ -36,9 +39,9 @@ impl Field {
pub fn new(attrs: Vec<Attribute>, inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
let attrs = prost_attrs(attrs)?;

// TODO: check for ignore attribute.

let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? {
let field = if let Some(field) = skip::Field::new(&attrs)? {
Field::Skip(field)
} else if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? {
Field::Scalar(field)
} else if let Some(field) = message::Field::new(&attrs, inferred_tag)? {
Field::Message(field)
Expand All @@ -62,8 +65,6 @@ impl Field {
pub fn new_oneof(attrs: Vec<Attribute>) -> Result<Option<Field>, Error> {
let attrs = prost_attrs(attrs)?;

// TODO: check for ignore attribute.

let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? {
Field::Scalar(field)
} else if let Some(field) = message::Field::new_oneof(&attrs)? {
Expand All @@ -81,6 +82,7 @@ impl Field {

pub fn tags(&self) -> Vec<u32> {
match *self {
Field::Skip(_) => vec![],
Field::Scalar(ref scalar) => vec![scalar.tag],
Field::Message(ref message) => vec![message.tag],
Field::Map(ref map) => vec![map.tag],
Expand All @@ -92,6 +94,7 @@ impl Field {
/// Returns a statement which encodes the field.
pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
match *self {
Field::Skip(_) => TokenStream::default(),
Field::Scalar(ref scalar) => scalar.encode(prost_path, ident),
Field::Message(ref message) => message.encode(prost_path, ident),
Field::Map(ref map) => map.encode(prost_path, ident),
Expand All @@ -104,6 +107,7 @@ impl Field {
/// value into the field.
pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
match *self {
Field::Skip(_) => TokenStream::default(),
Field::Scalar(ref scalar) => scalar.merge(prost_path, ident),
Field::Message(ref message) => message.merge(prost_path, ident),
Field::Map(ref map) => map.merge(prost_path, ident),
Expand All @@ -115,6 +119,7 @@ impl Field {
/// Returns an expression which evaluates to the encoded length of the field.
pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
match *self {
Field::Skip(_) => quote!(0),
Field::Scalar(ref scalar) => scalar.encoded_len(prost_path, ident),
Field::Map(ref map) => map.encoded_len(prost_path, ident),
Field::Message(ref msg) => msg.encoded_len(prost_path, ident),
Expand All @@ -126,6 +131,7 @@ impl Field {
/// Returns a statement which clears the field.
pub fn clear(&self, ident: TokenStream) -> TokenStream {
match *self {
Field::Skip(ref skip) => skip.clear(ident),
Field::Scalar(ref scalar) => scalar.clear(ident),
Field::Message(ref message) => message.clear(ident),
Field::Map(ref map) => map.clear(ident),
Expand Down
75 changes: 75 additions & 0 deletions prost-derive/src/field/skip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Expr, ExprLit, Lit, Meta, MetaNameValue, Path};

use crate::field::{set_bool, set_option, word_attr};

#[derive(Clone)]
pub struct Field {
pub default_fn: Option<Path>,
}

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
let mut skip = false;
let mut default_fn = None;
let mut default_lit = None;
let mut unknown_attrs = Vec::new();

for attr in attrs {
if word_attr("skip", attr) {
set_bool(&mut skip, "duplicate skip attribute")?;
} else if let Meta::NameValue(MetaNameValue { path, value, .. }) = attr {
if path.is_ident("default") {
match value {
// There has to be a better way...
Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) => set_option(&mut default_lit, lit, "duplicate default attributes")?,
_ => bail!("default attribute value must be a string literal"),
};
} else {
unknown_attrs.push(attr);
}
} else {
unknown_attrs.push(attr);
}
}

if !skip {
return Ok(None);
}

if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for skipped field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

if let Some(lit) = default_lit {
let fn_path: Path = syn::parse_str(&lit.value())
.map_err(|_| anyhow::anyhow!("invalid path for default function"))?;
if default_fn.is_some() {
bail!("duplicate default attribute for skipped field");
}
default_fn = Some(fn_path);
}

Ok(Some(Field { default_fn }))
}

pub fn clear(&self, ident: TokenStream) -> TokenStream {
let default = self.default_value();
quote!( #ident = #default; )
}

pub fn default_value(&self) -> TokenStream {
if let Some(ref path) = self.default_fn {
quote! { #path() }
} else {
quote! { ::core::default::Default::default() }
}
}
}
31 changes: 23 additions & 8 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,20 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
// We want Debug to be in declaration order
let unsorted_fields = fields.clone();

// Filter out ignored fields
fields.retain(|(_, field)| !matches!(field, Field::Skip(..)));

// Sort the fields by tag number so that fields will be encoded in tag order.
// TODO: This encodes oneof fields in the position of their lowest tag,
// regardless of the currently occupied variant, is that consequential?
let all_fields = unsorted_fields.clone();
let mut active_fields = all_fields.clone();
// Filter out skipped fields for encoding/decoding/length
active_fields.retain(|(_, field)| !matches!(field, Field::Skip(_)));
// Sort the active fields by tag number so that fields will be encoded in tag order.
// See: https://protobuf.dev/programming-guides/encoding/#order
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
let fields = fields;
active_fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
let fields = active_fields;

if let Some(duplicate_tag) = fields
.iter()
Expand Down Expand Up @@ -128,29 +136,36 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
}
});

let struct_name = if fields.is_empty() {
let struct_name = if all_fields.is_empty() {
quote!()
} else {
quote!(
const STRUCT_NAME: &'static str = stringify!(#ident);
)
};

let clear = fields
let clear = all_fields
.iter()
.map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));

// For Default implementation, use all_fields (including skipped)
let default = if is_struct {
let default = fields.iter().map(|(field_ident, field)| {
let value = field.default(&prost_path);
let default = all_fields.iter().map(|(field_ident, field)| {
let value = match field {
Field::Skip(skip_field) => skip_field.default_value(),
_ => field.default(&prost_path),
};
quote!(#field_ident: #value,)
});
quote! {#ident {
#(#default)*
}}
} else {
let default = fields.iter().map(|(_, field)| {
let value = field.default(&prost_path);
let default = all_fields.iter().map(|(_, field)| {
let value = match field {
Field::Skip(skip_field) => skip_field.default_value(),
_ => field.default(&prost_path),
};
quote!(#value,)
});
quote! {#ident (
Expand Down
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ pub mod oneof_attributes {

#[cfg(test)]
mod proto3_presence;
#[cfg(test)]
mod skipped_fields;

use core::fmt::Debug;

Expand Down
46 changes: 46 additions & 0 deletions tests/src/skipped_fields.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//! Tests for skipping fields when using prost-derive.

use crate::alloc::string::ToString;
use crate::check_serialize_equivalent;
use alloc::collections::BTreeMap;
use alloc::string::String;
use prost::Message;

/// A struct with the same data as another, but with a skipped field, should be equal when encoded.
#[test]
fn skipped_field_serial_equality() {
#[derive(Clone, PartialEq, prost::Message)]
struct TypeWithoutSkippedField {
#[prost(string, tag = "1")]
value: String,
}

fn create_hashmap() -> BTreeMap<String, String> {
let mut map = BTreeMap::new();
map.insert("key".to_string(), "value".to_string());
map
}

#[derive(Clone, PartialEq, prost::Message)]
struct TypeWithSkippedField {
#[prost(string, tag = "1")]
value: String,
#[prost(skip, default = "create_hashmap")]
pub temp_data: BTreeMap<String, String>, // This field will be skipped
}

let a = TypeWithoutSkippedField {
value: "hello".to_string(),
};
let b = TypeWithSkippedField {
value: "hello".to_string(),
temp_data: create_hashmap(),
};

// Encoded forms should be equal
check_serialize_equivalent(&a, &b);

// Decoded forms should be equal, with the skipped field initialized using the default attribute
let decoded = TypeWithSkippedField::decode(a.encode_to_vec().as_slice()).unwrap();
assert_eq!(b, decoded);
}
Loading