Skip to content

Configurable macros namespace #3944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,25 @@ jobs:
cargo test
-p sqlx-macros-core
--all-features

- name: Test sqlx-macros custom namespace
run: >
cargo test
-p sqlx-macros
--all-features
env:
SQLX_NAMESPACE: "external"
RUSTFLAGS: "--cfg sqlx_macros_namespace"

# Should fail compilation since the env var is not set
# and the testing `sqlx_macros_namespace` cfg is enabled.
- name: Test sqlx-macros without namespace
run: >
cargo test
-p sqlx-macros
--all-features || true
env:
RUSTFLAGS: "--cfg sqlx_macros_namespace"

# Note: use `--lib` to not run integration tests that require a DB
- name: Test sqlx
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 37 additions & 33 deletions sqlx-macros-core/src/derives/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@ use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Stmt, TypeParamBound, Variant,
FieldsUnnamed, Ident, Stmt, TypeParamBound, Variant,
};

pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
pub fn expand_derive_decode(input: &DeriveInput, crate_name: &Ident) -> syn::Result<TokenStream> {
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
expand_derive_decode_transparent(input, unnamed.first().unwrap())
expand_derive_decode_transparent(input, unnamed.first().unwrap(), crate_name)
}
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
Some(_) => expand_derive_decode_weak_enum(input, variants),
None => expand_derive_decode_strong_enum(input, variants),
Some(_) => expand_derive_decode_weak_enum(input, variants, crate_name),
None => expand_derive_decode_strong_enum(input, variants, crate_name),
},
Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) => expand_derive_decode_struct(input, named),
}) => expand_derive_decode_struct(input, named, crate_name),
Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
Data::Struct(DataStruct {
fields: Fields::Unnamed(..),
Expand All @@ -50,6 +50,7 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<TokenStream> {
fn expand_derive_decode_transparent(
input: &DeriveInput,
field: &Field,
crate_name: &Ident,
) -> syn::Result<TokenStream> {
check_transparent_attributes(input, field)?;

Expand All @@ -64,26 +65,26 @@ fn expand_derive_decode_transparent(
let mut generics = generics.clone();
generics
.params
.insert(0, parse_quote!(DB: ::sqlx::Database));
.insert(0, parse_quote!(DB: ::#crate_name::Database));
generics.params.insert(0, parse_quote!('r));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>));
.push(parse_quote!(#ty: ::#crate_name::decode::Decode<'r, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();

let tts = quote!(
#[automatically_derived]
impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
impl #impl_generics ::#crate_name::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
fn decode(
value: <DB as ::sqlx::database::Database>::ValueRef<'r>,
value: <DB as ::#crate_name::database::Database>::ValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
>,
> {
<#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
<#ty as ::#crate_name::decode::Decode<'r, DB>>::decode(value).map(Self)
}
}
);
Expand All @@ -94,6 +95,7 @@ fn expand_derive_decode_transparent(
fn expand_derive_decode_weak_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
crate_name: &Ident,
) -> syn::Result<TokenStream> {
let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap();
Expand All @@ -113,23 +115,23 @@ fn expand_derive_decode_weak_enum(

Ok(quote!(
#[automatically_derived]
impl<'r, DB: ::sqlx::Database> ::sqlx::decode::Decode<'r, DB> for #ident
impl<'r, DB: ::#crate_name::Database> ::#crate_name::decode::Decode<'r, DB> for #ident
where
#repr: ::sqlx::decode::Decode<'r, DB>,
#repr: ::#crate_name::decode::Decode<'r, DB>,
{
fn decode(
value: <DB as ::sqlx::database::Database>::ValueRef<'r>,
value: <DB as ::#crate_name::database::Database>::ValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync,
>,
> {
let value = <#repr as ::sqlx::decode::Decode<'r, DB>>::decode(value)?;
let value = <#repr as ::#crate_name::decode::Decode<'r, DB>>::decode(value)?;

match value {
#(#arms)*
_ => ::std::result::Result::Err(::std::boxed::Box::new(::sqlx::Error::Decode(
_ => ::std::result::Result::Err(::std::boxed::Box::new(::#crate_name::Error::Decode(
::std::format!("invalid value {:?} for enum {}", value, #ident_s).into(),
)))
}
Expand All @@ -141,6 +143,7 @@ fn expand_derive_decode_weak_enum(
fn expand_derive_decode_strong_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
crate_name: &Ident,
) -> syn::Result<TokenStream> {
let cattr = check_strong_enum_attributes(input, variants)?;

Expand Down Expand Up @@ -176,9 +179,9 @@ fn expand_derive_decode_strong_enum(
if cfg!(feature = "mysql") {
tts.extend(quote!(
#[automatically_derived]
impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::mysql::MySql> for #ident {
impl<'r> ::#crate_name::decode::Decode<'r, ::#crate_name::mysql::MySql> for #ident {
fn decode(
value: ::sqlx::mysql::MySqlValueRef<'r>,
value: ::#crate_name::mysql::MySqlValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
Expand All @@ -188,9 +191,9 @@ fn expand_derive_decode_strong_enum(
+ ::std::marker::Sync,
>,
> {
let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode<
let value = <&'r ::std::primitive::str as ::#crate_name::decode::Decode<
'r,
::sqlx::mysql::MySql,
::#crate_name::mysql::MySql,
>>::decode(value)?;

#values
Expand All @@ -202,9 +205,9 @@ fn expand_derive_decode_strong_enum(
if cfg!(feature = "postgres") {
tts.extend(quote!(
#[automatically_derived]
impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::postgres::Postgres> for #ident {
impl<'r> ::#crate_name::decode::Decode<'r, ::#crate_name::postgres::Postgres> for #ident {
fn decode(
value: ::sqlx::postgres::PgValueRef<'r>,
value: ::#crate_name::postgres::PgValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
Expand All @@ -214,9 +217,9 @@ fn expand_derive_decode_strong_enum(
+ ::std::marker::Sync,
>,
> {
let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode<
let value = <&'r ::std::primitive::str as ::#crate_name::decode::Decode<
'r,
::sqlx::postgres::Postgres,
::#crate_name::postgres::Postgres,
>>::decode(value)?;

#values
Expand All @@ -228,9 +231,9 @@ fn expand_derive_decode_strong_enum(
if cfg!(feature = "_sqlite") {
tts.extend(quote!(
#[automatically_derived]
impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::sqlite::Sqlite> for #ident {
impl<'r> ::#crate_name::decode::Decode<'r, ::#crate_name::sqlite::Sqlite> for #ident {
fn decode(
value: ::sqlx::sqlite::SqliteValueRef<'r>,
value: ::#crate_name::sqlite::SqliteValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
Expand All @@ -240,9 +243,9 @@ fn expand_derive_decode_strong_enum(
+ ::std::marker::Sync,
>,
> {
let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode<
let value = <&'r ::std::primitive::str as ::#crate_name::decode::Decode<
'r,
::sqlx::sqlite::Sqlite,
::#crate_name::sqlite::Sqlite,
>>::decode(value)?;

#values
Expand All @@ -257,6 +260,7 @@ fn expand_derive_decode_strong_enum(
fn expand_derive_decode_struct(
input: &DeriveInput,
fields: &Punctuated<Field, Comma>,
crate_name: &Ident,
) -> syn::Result<TokenStream> {
check_struct_attributes(input, fields)?;

Expand All @@ -272,8 +276,8 @@ fn expand_derive_decode_struct(
// add db type for impl generics & where clause
for type_param in &mut generics.type_params_mut() {
type_param.bounds.extend::<[TypeParamBound; 2]>([
parse_quote!(for<'decode> ::sqlx::decode::Decode<'decode, ::sqlx::Postgres>),
parse_quote!(::sqlx::types::Type<::sqlx::Postgres>),
parse_quote!(for<'decode> ::#crate_name::decode::Decode<'decode, ::#crate_name::Postgres>),
parse_quote!(::#crate_name::types::Type<::#crate_name::Postgres>),
]);
}

Expand All @@ -294,11 +298,11 @@ fn expand_derive_decode_struct(

tts.extend(quote!(
#[automatically_derived]
impl #impl_generics ::sqlx::decode::Decode<'r, ::sqlx::Postgres> for #ident #ty_generics
impl #impl_generics ::#crate_name::decode::Decode<'r, ::#crate_name::Postgres> for #ident #ty_generics
#where_clause
{
fn decode(
value: ::sqlx::postgres::PgValueRef<'r>,
value: ::#crate_name::postgres::PgValueRef<'r>,
) -> ::std::result::Result<
Self,
::std::boxed::Box<
Expand All @@ -308,7 +312,7 @@ fn expand_derive_decode_struct(
+ ::std::marker::Sync,
>,
> {
let mut decoder = ::sqlx::postgres::types::PgRecordDecoder::new(value)?;
let mut decoder = ::#crate_name::postgres::types::PgRecordDecoder::new(value)?;

#(#reads)*

Expand Down
Loading
Loading