@@ -9,21 +9,28 @@ use syn::{
9
9
parse:: { Parse , ParseStream } ,
10
10
parse_macro_input,
11
11
spanned:: Spanned ,
12
- Ident , Item , ItemEnum , Token ,
12
+ Expr , Ident , Item , ItemEnum , Token , Variant ,
13
13
} ;
14
14
15
- #[ derive( Debug ) ]
16
- struct Flag {
15
+ struct Flag < ' a > {
17
16
name : Ident ,
18
17
span : Span ,
19
- value : FlagValue ,
18
+ value : FlagValue < ' a > ,
20
19
}
21
20
22
- #[ derive( Debug ) ]
23
- enum FlagValue {
21
+ enum FlagValue < ' a > {
24
22
Literal ( u128 ) ,
25
23
Deferred ,
26
- Inferred ,
24
+ Inferred ( & ' a mut Variant ) ,
25
+ }
26
+
27
+ impl FlagValue < ' _ > {
28
+ fn is_inferred ( & self ) -> bool {
29
+ match self {
30
+ FlagValue :: Inferred ( _) => true ,
31
+ _ => false ,
32
+ }
33
+ }
27
34
}
28
35
29
36
struct Parameters {
@@ -54,9 +61,9 @@ pub fn bitflags_internal(
54
61
input : proc_macro:: TokenStream ,
55
62
) -> proc_macro:: TokenStream {
56
63
let Parameters { default } = parse_macro_input ! ( attr as Parameters ) ;
57
- let ast = parse_macro_input ! ( input as Item ) ;
64
+ let mut ast = parse_macro_input ! ( input as Item ) ;
58
65
let output = match ast {
59
- Item :: Enum ( ref item_enum) => gen_enumflags ( item_enum, default) ,
66
+ Item :: Enum ( ref mut item_enum) => gen_enumflags ( item_enum, default) ,
60
67
_ => Err ( syn:: Error :: new_spanned (
61
68
& ast,
62
69
"#[bitflags] requires an enum" ,
@@ -76,7 +83,6 @@ pub fn bitflags_internal(
76
83
77
84
/// Try to evaluate the expression given.
78
85
fn fold_expr ( expr : & syn:: Expr ) -> Option < u128 > {
79
- use syn:: Expr ;
80
86
match expr {
81
87
Expr :: Lit ( ref expr_lit) => match expr_lit. lit {
82
88
syn:: Lit :: Int ( ref lit_int) => lit_int. base10_parse ( ) . ok ( ) ,
@@ -98,8 +104,8 @@ fn fold_expr(expr: &syn::Expr) -> Option<u128> {
98
104
}
99
105
100
106
fn collect_flags < ' a > (
101
- variants : impl Iterator < Item = & ' a syn :: Variant > ,
102
- ) -> Result < Vec < Flag > , syn:: Error > {
107
+ variants : impl Iterator < Item = & ' a mut Variant > ,
108
+ ) -> Result < Vec < Flag < ' a > > , syn:: Error > {
103
109
variants
104
110
. map ( |variant| {
105
111
// MSRV: Would this be cleaner with `matches!`?
@@ -113,25 +119,51 @@ fn collect_flags<'a>(
113
119
}
114
120
}
115
121
122
+ let name = variant. ident . clone ( ) ;
123
+ let span = variant. span ( ) ;
116
124
let value = if let Some ( ref expr) = variant. discriminant {
117
125
if let Some ( n) = fold_expr ( & expr. 1 ) {
118
126
FlagValue :: Literal ( n)
119
127
} else {
120
128
FlagValue :: Deferred
121
129
}
122
130
} else {
123
- FlagValue :: Inferred
131
+ FlagValue :: Inferred ( variant )
124
132
} ;
125
133
126
- Ok ( Flag {
127
- name : variant. ident . clone ( ) ,
128
- span : variant. span ( ) ,
129
- value,
130
- } )
134
+ Ok ( Flag { name, span, value } )
131
135
} )
132
136
. collect ( )
133
137
}
134
138
139
+ fn inferred_value ( type_name : & Ident , previous_variants : & [ Ident ] , repr : & Ident ) -> Expr {
140
+ let tokens = if previous_variants. is_empty ( ) {
141
+ quote ! ( 1 )
142
+ } else {
143
+ quote ! ( :: enumflags2:: _internal:: next_bit(
144
+ #( #type_name:: #previous_variants as u128 ) |*
145
+ ) as #repr)
146
+ } ;
147
+
148
+ syn:: parse2 ( tokens) . expect ( "couldn't parse inferred value" )
149
+ }
150
+
151
+ fn infer_values < ' a > ( flags : & mut [ Flag ] , type_name : & Ident , repr : & Ident ) {
152
+ let mut previous_variants: Vec < Ident > = flags. iter ( )
153
+ . filter ( |flag| !flag. value . is_inferred ( ) )
154
+ . map ( |flag| flag. name . clone ( ) ) . collect ( ) ;
155
+
156
+ for flag in flags {
157
+ match flag. value {
158
+ FlagValue :: Inferred ( ref mut variant) => {
159
+ variant. discriminant = Some ( ( <Token ! [ =] >:: default ( ) , inferred_value ( type_name, & previous_variants, repr) ) ) ;
160
+ previous_variants. push ( flag. name . clone ( ) ) ;
161
+ }
162
+ _ => { }
163
+ }
164
+ }
165
+ }
166
+
135
167
/// Given a list of attributes, find the `repr`, if any, and return the integer
136
168
/// type specified.
137
169
fn extract_repr ( attrs : & [ syn:: Attribute ] ) -> Result < Option < Ident > , syn:: Error > {
@@ -210,10 +242,7 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
210
242
Ok ( None )
211
243
}
212
244
}
213
- Inferred => Err ( syn:: Error :: new (
214
- flag. span ,
215
- "Please add an explicit discriminant" ,
216
- ) ) ,
245
+ Inferred ( _) => Ok ( None ) ,
217
246
Deferred => {
218
247
let variant_name = & flag. name ;
219
248
// MSRV: Use an unnamed constant (`const _: ...`).
@@ -235,33 +264,34 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
235
264
}
236
265
}
237
266
238
- fn gen_enumflags ( ast : & ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
267
+ fn gen_enumflags ( ast : & mut ItemEnum , default : Vec < Ident > ) -> Result < TokenStream , syn:: Error > {
239
268
let ident = & ast. ident ;
240
269
241
270
let span = Span :: call_site ( ) ;
242
- // for quote! interpolation
243
- let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
244
- let repeated_name = vec ! [ & ident; ast. variants. len( ) ] ;
245
271
246
- let ty = extract_repr ( & ast. attrs ) ?
272
+ let repr = extract_repr ( & ast. attrs ) ?
247
273
. ok_or_else ( || syn:: Error :: new_spanned ( & ident,
248
274
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ) ) ?;
249
- let bits = type_bits ( & ty ) ?;
275
+ let bits = type_bits ( & repr ) ?;
250
276
251
- let variants = collect_flags ( ast. variants . iter ( ) ) ?;
277
+ let mut variants = collect_flags ( ast. variants . iter_mut ( ) ) ?;
252
278
let deferred = variants
253
279
. iter ( )
254
280
. flat_map ( |variant| check_flag ( ident, variant, bits) . transpose ( ) )
255
281
. collect :: < Result < Vec < _ > , _ > > ( ) ?;
256
282
283
+ infer_values ( & mut variants, ident, & repr) ;
284
+
257
285
if ( bits as usize ) < variants. len ( ) {
258
286
return Err ( syn:: Error :: new_spanned (
259
- & ty ,
287
+ & repr ,
260
288
format ! ( "Not enough bits for {} flags" , variants. len( ) ) ,
261
289
) ) ;
262
290
}
263
291
264
292
let std_path = quote_spanned ! ( span => :: enumflags2:: _internal:: core) ;
293
+ let variant_names = ast. variants . iter ( ) . map ( |v| & v. ident ) . collect :: < Vec < _ > > ( ) ;
294
+ let repeated_name = vec ! [ & ident; ast. variants. len( ) ] ;
265
295
266
296
Ok ( quote_spanned ! {
267
297
span =>
@@ -303,15 +333,15 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
303
333
}
304
334
305
335
impl :: enumflags2:: _internal:: RawBitFlags for #ident {
306
- type Numeric = #ty ;
336
+ type Numeric = #repr ;
307
337
308
338
const EMPTY : Self :: Numeric = 0 ;
309
339
310
340
const DEFAULT : Self :: Numeric =
311
- 0 #( | ( #repeated_name:: #default as #ty ) ) * ;
341
+ 0 #( | ( #repeated_name:: #default as #repr ) ) * ;
312
342
313
343
const ALL_BITS : Self :: Numeric =
314
- 0 #( | ( #repeated_name:: #variant_names as #ty ) ) * ;
344
+ 0 #( | ( #repeated_name:: #variant_names as #repr ) ) * ;
315
345
316
346
const FLAG_LIST : & ' static [ Self ] =
317
347
& [ #( #repeated_name:: #variant_names) , * ] ;
@@ -320,7 +350,7 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
320
350
concat!( "BitFlags<" , stringify!( #ident) , ">" ) ;
321
351
322
352
fn bits( self ) -> Self :: Numeric {
323
- self as #ty
353
+ self as #repr
324
354
}
325
355
}
326
356
0 commit comments