diff --git a/internal/src/init.rs b/internal/src/init.rs new file mode 100644 index 0000000..b5c04d0 --- /dev/null +++ b/internal/src/init.rs @@ -0,0 +1,331 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use std::iter::Peekable; + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use proc_macro::Punct; +use proc_macro::{Delimiter, Ident, Spacing, TokenStream, TokenTree}; + +pub fn expand(input: TokenStream) -> TokenStream { + let mut tokens = input.into_iter().peekable(); + let attrs = parse_attrs(&mut tokens); + let default_error = attrs + .iter() + .filter_map(|attr| match attr { + Attr::Pin => None, + Attr::DefaultError(err) => Some(err), + }) + .next_back(); + let closure = parse_closure(&mut tokens, default_error); + let mut statements = vec![]; + let tail = match tokens.peek() { + Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => { + tokens.next(); + match tokens.next() { + None => {} + Some(rest) => panic!("unexpected token after initializer body: {rest:?}"), + } + loop { + let mut statement: TokenStream = (&mut tokens) + .take_while(|t| !matches!(t, TokenTree::Punct(p) if p.as_char() == ';')) + .collect(); + match tokens.peek() { + None => break parse_initializer_tail(statement), + Some(TokenTree::Punct(p)) if p.as_char() == ';' => { + statement.extend([tokens.next().unwrap()]) + } + Some(_) => unreachable!(), + } + statements.push(statement); + } + } + Some(_) => { + if closure.is_some() { + panic!("expected initializer body when using closure") + } + parse_initializer_tail(&mut tokens) + } + None => panic!("missing initializer body"), + }; + let Tail { path, fields } = tail; + let ty = closure + .as_ref() + .and_then(|c| c.ty.as_ref().cloned()) + .unwrap_or_else(|| quote!(_)); + let err = closure + .as_ref() + .and_then(|c| c.err.as_ref().cloned()) + .or(default_error.cloned()) + .unwrap_or_else(|| quote!(::core::convert::Infallible)); + quote! { + ::pin_init::__init_internal!( + statements(#(#statements)*), + ty_hint(#ty), + err(#err), + struct_path(#(#path)*), + fields(#fields), + ) + } +} + +struct Tail { + path: Vec, + fields: TokenStream, +} + +fn parse_initializer_tail(tokens: impl IntoIterator) -> Tail { + let mut tokens: Vec = tokens.into_iter().collect(); + if tokens.is_empty() { + panic!("incomplete initializer body") + } + let last = tokens.remove(tokens.len() - 1); + match last { + TokenTree::Group(g) if g.delimiter() == Delimiter::Brace => Tail { + path: tokens, + fields: g.stream(), + }, + _ => panic!("expected `{{}}` as the last token in the initializer body, found {last:?}"), + } +} + +enum Attr { + Pin, + DefaultError(TokenStream), +} + +fn parse_attr(meta: TokenStream) -> Attr { + let mut tokens = meta.into_iter(); + match tokens.next() { + Some(TokenTree::Ident(name)) => { + if name == "pin" { + match tokens.next() { + None => {} + Some(next) => panic!("unexpected token in `#[pin]` attribute: {next:?}"), + } + Attr::Pin + } else if name == "default_error" { + Attr::DefaultError(tokens.collect()) + } else { + panic!("unexpected attribute name: `{name}`") + } + } + Some(rest) => panic!("unexpected token in attribute: {rest:?}"), + None => panic!("expected name inside of attribute"), + } +} + +fn parse_attrs(tokens: &mut Peekable>) -> Vec { + let mut attrs = vec![]; + loop { + match tokens.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == '#' => { + tokens.next(); + match tokens.next() { + Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { + attrs.push(parse_attr(g.stream())); + } + next => { + panic!("expected `[...]` after `#`, but found {next:?}") + } + } + } + Some(_) => break attrs, + None => panic!("missing initializer body"), + } + } +} + +enum Arg { + Untyped(Ident), + Typed { name: Ident, ty: TokenStream }, +} + +struct ClosureSig { + /// arguments in the `|arg0, arg1|` + args: Vec, + /// returned type `-> MyType` or `-> Result` + ty: Option, + /// returned error type `-> Result<, Err>` + err: Option, +} + +fn parse_closure( + tokens: &mut Peekable>, + default_error: Option<&TokenStream>, +) -> Option { + match tokens.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == '|' => { + tokens.next(); + } + Some(_) => return None, + None => panic!("missing initializer body"), + } + let mut args = vec![]; + loop { + match tokens.next() { + Some(TokenTree::Ident(name)) => { + match tokens.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == ':' => { + tokens.next(); + args.push(Arg::Typed { + name: name.clone(), + ty: parse_ty_until_punct(tokens, |p| matches!(p.as_char(), ',' | '|')), + }); + } + _ => args.push(Arg::Untyped(name.clone())), + } + match tokens.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == ',' => { + tokens.next(); + } + Some(TokenTree::Punct(p)) if p.as_char() == '|' => break, + Some(rest) => { + panic!("expected comma after argument in initializer closure signature: {rest:?}") + } + _ => {} + } + } + Some(TokenTree::Punct(p)) if p.as_char() == '|' => break, + Some(rest) => panic!("unexpected token in initializer closure signature: {rest:?}"), + None => panic!("incomplete initializer body"), + } + } + // check for an `->` indicating a return type + match tokens.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == '-' && p.spacing() == Spacing::Joint => { + tokens.next(); + match tokens.next() { + Some(TokenTree::Punct(p)) if p.as_char() == '>' => {} + Some(rest) => panic!("expected arrow `->` in initializer closure signature, found `-` and then {rest:?}"), + None => panic!("incomplete initializer body"), + } + } + _ => { + return Some(ClosureSig { + args, + ty: None, + err: None, + }) + } + } + // we support several different constructs here before the opening `{`: + // * just having your own type here (then the error will be assumed to be `Infallible`), + // * you can have `Result`, + // * `Result` combined with a `#[default_error(Err)]` attribute, + match tokens.peek() { + Some(TokenTree::Ident(res)) if res == "Result" => { + tokens.next(); + match tokens.next() { + Some(TokenTree::Punct(p)) if p.as_char() == '<' => {} + _ => panic!("expected `<` after `Result` in initializer return type"), + } + let ty = parse_ty_until_punct(tokens, |p| p.as_char() == ','); + match tokens.next() { + Some(TokenTree::Punct(p)) if p.as_char() == ',' => {} + _ => { + panic!("expected `,` after first type in `Result<` in initializer return type") + } + } + let mut err = parse_ty_until_punct(tokens, |p| matches!(p.as_char(), ',' | '>')); + match tokens.next() { + Some(TokenTree::Punct(p)) if p.as_char() == ',' => match tokens.next() { + Some(TokenTree::Punct(p)) if p.as_char() == '>' => {} + _ => { + panic!("expected `>` after second type in `Result<` in initializer return type") + } + }, + Some(TokenTree::Punct(p)) if p.as_char() == '>' => {} + _ => { + panic!("expected `,` or `>` after second type in `Result<` in initializer return type") + } + } + let mut err_inspect = err.into_iter().peekable(); + if matches!(&err_inspect.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '_') { + assert_eq!( + err_inspect.count(), + 1, + "expected type to only be `_` if it starts with `_`" + ); + err = default_error + .expect( + "need a `#[default_error()]` attribute to be able to use `_` in errors.", + ) + .clone(); + } else { + err = err_inspect.collect(); + } + Some(ClosureSig { + args, + ty: Some(ty), + err: Some(err), + }) + } + None => panic!("incomplete initializer body"), + _ => Some(ClosureSig { + args, + ty: Some(parse_ty_until_brace(tokens)), + err: None, + }), + } +} + +fn parse_ty_until_punct( + tokens: &mut Peekable>, + mut punct: impl FnMut(&Punct) -> bool, +) -> TokenStream { + let mut nesting = 0u64; + let mut res = TokenStream::new(); + loop { + match tokens.peek() { + Some(TokenTree::Punct(p)) => { + if nesting == 0 && punct(p) { + return res; + } + match p.as_char() { + '<' => nesting += 1, + '>' => { + nesting = nesting + .checked_sub(1) + .expect("nestings of `<`/`>` became negative"); + } + _ => {} + } + } + Some(_) => {} + None => panic!("incomplete initializer body"), + } + let Some(tok) = tokens.next() else { + unreachable!() + }; + res.extend([tok]); + } +} + +fn parse_ty_until_brace(tokens: &mut Peekable>) -> TokenStream { + let mut nesting = 0u64; + let mut res = TokenStream::new(); + loop { + match tokens.peek() { + Some(TokenTree::Punct(p)) => match p.as_char() { + '<' => nesting += 1, + '>' => { + nesting = nesting + .checked_sub(1) + .expect("nestings of `<`/`>` became negative"); + } + _ => {} + }, + Some(TokenTree::Group(g)) if nesting == 0 && g.delimiter() == Delimiter::Brace => { + return res; + } + Some(_) => {} + None => panic!("incomplete initializer body"), + } + let Some(tok) = tokens.next() else { + unreachable!() + }; + res.extend([tok]); + } +} diff --git a/internal/src/lib.rs b/internal/src/lib.rs index 297b012..10d3b8a 100644 --- a/internal/src/lib.rs +++ b/internal/src/lib.rs @@ -29,6 +29,7 @@ mod quote; extern crate quote; mod helpers; +mod init; mod pin_data; mod pinned_drop; mod zeroable; @@ -52,3 +53,8 @@ pub fn derive_zeroable(input: TokenStream) -> TokenStream { pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream { zeroable::maybe_derive(input.into()).into() } + +#[proc_macro] +pub fn init(input: TokenStream) -> TokenStream { + init::expand(input.into()).into() +} diff --git a/src/macros.rs b/src/macros.rs index 9ced630..752e77d 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -1202,6 +1202,21 @@ macro_rules! __init_internal { // have been initialized. Therefore we can now dismiss the guards by forgetting them. $(::core::mem::forget($guards);)* }; + (init_slot($($use_data:ident)?): + @data($data:ident), + @slot($slot:ident), + @guards($($guards:ident,)*), + // arbitrary code block + @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), + ) => { + { $($code)* } + $crate::__init_internal!(init_slot($($use_data)?): + @data($data), + @slot($slot), + @guards($($guards,)*), + @munch_fields($($rest)*), + ); + }; (init_slot($use_data:ident): // `use_data` is present, so we use the `data` to init fields. @data($data:ident), @slot($slot:ident), @@ -1351,6 +1366,20 @@ macro_rules! __init_internal { ); } }; + (make_initializer: + @slot($slot:ident), + @type_name($t:path), + @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), + @acc($($acc:tt)*), + ) => { + // code blocks are ignored for the initializer check + $crate::__init_internal!(make_initializer: + @slot($slot), + @type_name($t), + @munch_fields($($rest)*), + @acc($($acc)*), + ); + }; (make_initializer: @slot($slot:ident), @type_name($t:path), diff --git a/tests/ui/compile-fail/init/wrong_generics2.stderr b/tests/ui/compile-fail/init/wrong_generics2.stderr index d1b1e7a..cc41892 100644 --- a/tests/ui/compile-fail/init/wrong_generics2.stderr +++ b/tests/ui/compile-fail/init/wrong_generics2.stderr @@ -12,7 +12,7 @@ help: you might have forgotten to add the struct literal inside the block --> src/macros.rs | ~ ::core::ptr::write($slot, $t { SomeStruct { - |9 $($acc)* + |4 $($acc)* ~ } }); | diff --git a/tests/underscore.rs b/tests/underscore.rs new file mode 100644 index 0000000..583b3b3 --- /dev/null +++ b/tests/underscore.rs @@ -0,0 +1,31 @@ +use pin_init::{try_init, Init}; + +pub struct Foo { + x: u64, +} + +fn foo() -> bool { + false +} + +fn bar() -> bool { + true +} + +impl Foo { + pub fn new() -> impl Init { + try_init!(Self { + _: { + if foo() { + return Err(()); + } + }, + x: 0, + _: { + if bar() { + return Err(()); + } + } + }? ()) + } +}