Skip to content

Commit 70f22ac

Browse files
committed
Pick discriminants automatically when unspecified
Closes #21
1 parent 88cd25e commit 70f22ac

File tree

6 files changed

+104
-51
lines changed

6 files changed

+104
-51
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ enumflags2 = "^0.6"
2020
## Features
2121

2222
- [x] Uses enums to represent individual flags—a set of flags is a separate type from a single flag.
23+
- [x] Automatically chooses a free bit when you don't specify.
2324
- [x] Detects incorrect BitFlags at compile time.
2425
- [x] Has a similar API compared to the popular [bitflags](https://crates.io/crates/bitflags) crate.
2526
- [x] Does not expose the generated types explicity. The user interacts exclusively with `struct BitFlags<Enum>;`.
@@ -37,7 +38,7 @@ use enumflags2::{bitflags, make_bitflags, BitFlags};
3738
enum Test {
3839
A = 0b0001,
3940
B = 0b0010,
40-
C = 0b0100,
41+
C, // unspecified variants pick unused bits automatically
4142
D = 0b1000,
4243
}
4344

enumflags_derive/src/lib.rs

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,28 @@ use syn::{
99
parse::{Parse, ParseStream},
1010
parse_macro_input,
1111
spanned::Spanned,
12-
Ident, Item, ItemEnum, Token,
12+
Expr, Ident, Item, ItemEnum, Token, Variant,
1313
};
1414

15-
#[derive(Debug)]
16-
struct Flag {
15+
struct Flag<'a> {
1716
name: Ident,
1817
span: Span,
19-
value: FlagValue,
18+
value: FlagValue<'a>,
2019
}
2120

22-
#[derive(Debug)]
23-
enum FlagValue {
21+
enum FlagValue<'a> {
2422
Literal(u128),
2523
Deferred,
26-
Inferred,
24+
Inferred(&'a mut Variant),
25+
}
26+
27+
impl FlagValue<'_> {
28+
fn is_inferred(&self) -> bool {
29+
match self {
30+
FlagValue::Inferred(_) => true,
31+
_ => false,
32+
}
33+
}
2734
}
2835

2936
struct Parameters {
@@ -54,9 +61,9 @@ pub fn bitflags_internal(
5461
input: proc_macro::TokenStream,
5562
) -> proc_macro::TokenStream {
5663
let Parameters { default } = parse_macro_input!(attr as Parameters);
57-
let ast = parse_macro_input!(input as Item);
64+
let mut ast = parse_macro_input!(input as Item);
5865
let output = match ast {
59-
Item::Enum(ref item_enum) => gen_enumflags(item_enum, default),
66+
Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
6067
_ => Err(syn::Error::new_spanned(
6168
&ast,
6269
"#[bitflags] requires an enum",
@@ -76,7 +83,6 @@ pub fn bitflags_internal(
7683

7784
/// Try to evaluate the expression given.
7885
fn fold_expr(expr: &syn::Expr) -> Option<u128> {
79-
use syn::Expr;
8086
match expr {
8187
Expr::Lit(ref expr_lit) => match expr_lit.lit {
8288
syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
@@ -98,8 +104,8 @@ fn fold_expr(expr: &syn::Expr) -> Option<u128> {
98104
}
99105

100106
fn collect_flags<'a>(
101-
variants: impl Iterator<Item = &'a syn::Variant>,
102-
) -> Result<Vec<Flag>, syn::Error> {
107+
variants: impl Iterator<Item = &'a mut Variant>,
108+
) -> Result<Vec<Flag<'a>>, syn::Error> {
103109
variants
104110
.map(|variant| {
105111
// MSRV: Would this be cleaner with `matches!`?
@@ -113,25 +119,51 @@ fn collect_flags<'a>(
113119
}
114120
}
115121

122+
let name = variant.ident.clone();
123+
let span = variant.span();
116124
let value = if let Some(ref expr) = variant.discriminant {
117125
if let Some(n) = fold_expr(&expr.1) {
118126
FlagValue::Literal(n)
119127
} else {
120128
FlagValue::Deferred
121129
}
122130
} else {
123-
FlagValue::Inferred
131+
FlagValue::Inferred(variant)
124132
};
125133

126-
Ok(Flag {
127-
name: variant.ident.clone(),
128-
span: variant.span(),
129-
value,
130-
})
134+
Ok(Flag { name, span, value })
131135
})
132136
.collect()
133137
}
134138

139+
fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
140+
let tokens = if previous_variants.is_empty() {
141+
quote!(1)
142+
} else {
143+
quote!(::enumflags2::_internal::next_bit(
144+
#(#type_name::#previous_variants as u128)|*
145+
) as #repr)
146+
};
147+
148+
syn::parse2(tokens).expect("couldn't parse inferred value")
149+
}
150+
151+
fn infer_values<'a>(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
152+
let mut previous_variants: Vec<Ident> = flags.iter()
153+
.filter(|flag| !flag.value.is_inferred())
154+
.map(|flag| flag.name.clone()).collect();
155+
156+
for flag in flags {
157+
match flag.value {
158+
FlagValue::Inferred(ref mut variant) => {
159+
variant.discriminant = Some((<Token![=]>::default(), inferred_value(type_name, &previous_variants, repr)));
160+
previous_variants.push(flag.name.clone());
161+
}
162+
_ => {}
163+
}
164+
}
165+
}
166+
135167
/// Given a list of attributes, find the `repr`, if any, and return the integer
136168
/// type specified.
137169
fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
@@ -210,10 +242,7 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
210242
Ok(None)
211243
}
212244
}
213-
Inferred => Err(syn::Error::new(
214-
flag.span,
215-
"Please add an explicit discriminant",
216-
)),
245+
Inferred(_) => Ok(None),
217246
Deferred => {
218247
let variant_name = &flag.name;
219248
// MSRV: Use an unnamed constant (`const _: ...`).
@@ -235,33 +264,34 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
235264
}
236265
}
237266

238-
fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
267+
fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
239268
let ident = &ast.ident;
240269

241270
let span = Span::call_site();
242-
// for quote! interpolation
243-
let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
244-
let repeated_name = vec![&ident; ast.variants.len()];
245271

246-
let ty = extract_repr(&ast.attrs)?
272+
let repr = extract_repr(&ast.attrs)?
247273
.ok_or_else(|| syn::Error::new_spanned(&ident,
248274
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
249-
let bits = type_bits(&ty)?;
275+
let bits = type_bits(&repr)?;
250276

251-
let variants = collect_flags(ast.variants.iter())?;
277+
let mut variants = collect_flags(ast.variants.iter_mut())?;
252278
let deferred = variants
253279
.iter()
254280
.flat_map(|variant| check_flag(ident, variant, bits).transpose())
255281
.collect::<Result<Vec<_>, _>>()?;
256282

283+
infer_values(&mut variants, ident, &repr);
284+
257285
if (bits as usize) < variants.len() {
258286
return Err(syn::Error::new_spanned(
259-
&ty,
287+
&repr,
260288
format!("Not enough bits for {} flags", variants.len()),
261289
));
262290
}
263291

264292
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
293+
let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
294+
let repeated_name = vec![&ident; ast.variants.len()];
265295

266296
Ok(quote_spanned! {
267297
span =>
@@ -303,15 +333,15 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
303333
}
304334

305335
impl ::enumflags2::_internal::RawBitFlags for #ident {
306-
type Numeric = #ty;
336+
type Numeric = #repr;
307337

308338
const EMPTY: Self::Numeric = 0;
309339

310340
const DEFAULT: Self::Numeric =
311-
0 #(| (#repeated_name::#default as #ty))*;
341+
0 #(| (#repeated_name::#default as #repr))*;
312342

313343
const ALL_BITS: Self::Numeric =
314-
0 #(| (#repeated_name::#variant_names as #ty))*;
344+
0 #(| (#repeated_name::#variant_names as #repr))*;
315345

316346
const FLAG_LIST: &'static [Self] =
317347
&[#(#repeated_name::#variant_names),*];
@@ -320,7 +350,7 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
320350
concat!("BitFlags<", stringify!(#ident), ">");
321351

322352
fn bits(self) -> Self::Numeric {
323-
self as #ty
353+
self as #repr
324354
}
325355
}
326356

src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
//! enum Test {
1414
//! A = 0b0001,
1515
//! B = 0b0010,
16-
//! C = 0b0100,
16+
//! C, // unspecified variants pick unused bits automatically
1717
//! D = 0b1000,
1818
//! }
1919
//!
@@ -259,6 +259,11 @@ pub mod _internal {
259259
impl AssertionHelper for [(); 0] {
260260
type Status = AssertionFailed;
261261
}
262+
263+
pub const fn next_bit(x: u128) -> u128 {
264+
// trailing_ones is beyond our MSRV
265+
1 << (!x).trailing_zeros()
266+
}
262267
}
263268

264269
// Internal debug formatting implementations

test_suite/common.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ enum Test1 {
1919
E = 1 << 34,
2020
}
2121

22-
#[enumflags2::bitflags(default = B | C)]
22+
#[bitflags(default = B | C)]
2323
#[derive(Copy, Clone, Debug)]
2424
#[repr(u8)]
2525
enum Default6 {
@@ -129,3 +129,34 @@ fn module() {
129129
}
130130
}
131131
}
132+
133+
#[test]
134+
fn inferred_values() {
135+
#[bitflags]
136+
#[derive(Copy, Clone, Debug)]
137+
#[repr(u8)]
138+
enum Inferred {
139+
Infer2,
140+
SpecifiedA = 1,
141+
Infer8,
142+
SpecifiedB = 4,
143+
}
144+
145+
assert_eq!(Inferred::Infer2 as u8, 2);
146+
assert_eq!(Inferred::Infer8 as u8, 8);
147+
148+
#[bitflags]
149+
#[derive(Copy, Clone, Debug)]
150+
#[repr(u8)]
151+
enum OnlyInferred {
152+
Infer1,
153+
Infer2,
154+
Infer4,
155+
Infer8,
156+
}
157+
158+
assert_eq!(OnlyInferred::Infer1 as u8, 1);
159+
assert_eq!(OnlyInferred::Infer2 as u8, 2);
160+
assert_eq!(OnlyInferred::Infer4 as u8, 4);
161+
assert_eq!(OnlyInferred::Infer8 as u8, 8);
162+
}

test_suite/ui/missing_disciminant.rs

Lines changed: 0 additions & 9 deletions
This file was deleted.

test_suite/ui/missing_disciminant.stderr

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)