fix(macros): where clause
This commit is contained in:
@@ -17,7 +17,7 @@ pub fn where_clause(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
|
||||
fn crate_name() -> proc_macro2::TokenStream {
|
||||
match proc_macro_crate::crate_name("sqlx_utils") {
|
||||
match proc_macro_crate::crate_name("sqlx-utils") {
|
||||
Err(_) => quote! {::sqlx_utils},
|
||||
Ok(proc_macro_crate::FoundCrate::Itself) => quote! {crate},
|
||||
Ok(proc_macro_crate::FoundCrate::Name(e)) => {
|
||||
|
||||
@@ -22,6 +22,7 @@ struct WhereClauseField {
|
||||
struct WhereClauseArgs {
|
||||
ident: Ident,
|
||||
data: Data<(), WhereClauseField>,
|
||||
driver: Type,
|
||||
}
|
||||
|
||||
pub fn where_clause(input: TokenStream) -> TokenStream {
|
||||
@@ -39,71 +40,92 @@ pub fn where_clause(input: TokenStream) -> TokenStream {
|
||||
|
||||
let crate_name = crate_name();
|
||||
|
||||
let push_clauses = args.data.take_struct().expect("Should never be `None`").into_iter().map(|e| {
|
||||
let ident = e.ident.expect("Should never be `None`");
|
||||
let driver = args.driver;
|
||||
|
||||
let col_name = e.rename.unwrap_or_else(|| ident.to_string());
|
||||
let push_clauses = args
|
||||
.data
|
||||
.take_struct()
|
||||
.expect("Should never be `None`")
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let ident = e.ident.expect("Should never be `None`");
|
||||
|
||||
let op = e.op.unwrap_or_else(|| {
|
||||
let mut segments: Punctuated<PathSegment, Token![::]> = Punctuated::new();
|
||||
segments.push_value(PathSegment::from(Ident::new("sqlx_utils", Span::call_site())));
|
||||
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
|
||||
segments.push_value(PathSegment::from(Ident::new("builder", Span::call_site())));
|
||||
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
|
||||
segments.push_value(PathSegment::from(Ident::new("expr", Span::call_site())));
|
||||
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
|
||||
segments.push_value(PathSegment::from(Ident::new("Equals", Span::call_site())));
|
||||
let col_name = e.rename.unwrap_or_else(|| ident.to_string());
|
||||
|
||||
Type::Path(TypePath {
|
||||
qself: None,
|
||||
path: Path {
|
||||
leading_colon: Some(PathSep {
|
||||
spans: [Span::call_site(); 2],
|
||||
}),
|
||||
segments,
|
||||
},
|
||||
})
|
||||
});
|
||||
let op = e.op.unwrap_or_else(|| {
|
||||
let mut segments: Punctuated<PathSegment, Token![::]> = Punctuated::new();
|
||||
segments.push_value(PathSegment::from(Ident::new(
|
||||
"sqlx_utils",
|
||||
Span::call_site(),
|
||||
)));
|
||||
segments.push_punct(PathSep {
|
||||
spans: [Span::call_site(); 2],
|
||||
});
|
||||
segments.push_value(PathSegment::from(Ident::new("builder", Span::call_site())));
|
||||
segments.push_punct(PathSep {
|
||||
spans: [Span::call_site(); 2],
|
||||
});
|
||||
segments.push_value(PathSegment::from(Ident::new("expr", Span::call_site())));
|
||||
segments.push_punct(PathSep {
|
||||
spans: [Span::call_site(); 2],
|
||||
});
|
||||
segments.push_value(PathSegment::from(Ident::new("Equals", Span::call_site())));
|
||||
|
||||
let is_option = if let Type::Path(ty) = &e.ty {
|
||||
ty.path.segments.first().is_some_and(|e| e.ident.eq("Option"))
|
||||
} else {
|
||||
false
|
||||
};
|
||||
Type::Path(TypePath {
|
||||
qself: None,
|
||||
path: Path {
|
||||
leading_colon: Some(PathSep {
|
||||
spans: [Span::call_site(); 2],
|
||||
}),
|
||||
segments,
|
||||
},
|
||||
})
|
||||
});
|
||||
|
||||
let push_quote = quote! {
|
||||
let expr = #crate_name::builder::BracketsExpr::new(
|
||||
#crate_name::builder::BinaryExpr::new(#col_name, #crate_name::builder::expr::ensure_where(#op), ident.clone())
|
||||
);
|
||||
list.push(expr);
|
||||
};
|
||||
let is_option = if let Type::Path(ty) = &e.ty {
|
||||
ty.path
|
||||
.segments
|
||||
.first()
|
||||
.is_some_and(|e| e.ident.eq("Option"))
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if is_option {
|
||||
quote! {
|
||||
if let Some(ident) = &self.#ident {
|
||||
#push_quote
|
||||
let push_quote = quote! {
|
||||
let expr = #crate_name::builder::BinaryExpr::new(
|
||||
#crate_name::builder::ColumnIdent::new(#col_name),
|
||||
#crate_name::builder::expr::ensure_where(#op),
|
||||
#crate_name::builder::QueryVariable::new(ident.clone())
|
||||
);
|
||||
list.push(::std::sync::Arc::new(expr));
|
||||
};
|
||||
|
||||
if is_option {
|
||||
quote! {
|
||||
if let Some(ident) = &self.#ident {
|
||||
#push_quote
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
{
|
||||
let ident = &self.#ident;
|
||||
#push_quote;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
{
|
||||
let ident = &self.#ident;
|
||||
#push_quote;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
quote! {
|
||||
impl<DB> #crate_name::builder::PushToBuilder<DB> for #struct_name where DB: ::sqlx::Database {
|
||||
fn push_to(&self, builder: &mut ::sqlx::QueryBuilder<'_, DB>) {
|
||||
let mut list = #crate_name::builder::ExprList::<DB, #crate_name::builder::expr::And>::new(::std::vec::Vec::new());
|
||||
impl #crate_name::builder::PushToBuilder<#driver> for #struct_name {
|
||||
fn push_to(&self, builder: &mut ::sqlx::QueryBuilder<'_, #driver>) {
|
||||
let mut list = #crate_name::builder::ExprList::<#driver, #crate_name::builder::expr::And>::new(::std::vec::Vec::new());
|
||||
|
||||
#(
|
||||
#push_clauses
|
||||
)*
|
||||
|
||||
#crate_name::builder::PushToBuilder::push_to(list, builder);
|
||||
#crate_name::builder::PushToBuilder::push_to(&list, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,6 +140,7 @@ mod test {
|
||||
fn derive() {
|
||||
let input = r#"
|
||||
#[derive(WhereClause)]
|
||||
#[sqlxu(driver = "::sqlx::Any")]
|
||||
struct NagelTest {
|
||||
#[sqlxu(rename = "fusgesicht")]
|
||||
name: String
|
||||
|
||||
@@ -63,6 +63,20 @@ where
|
||||
marker: PhantomData<DB>,
|
||||
}
|
||||
|
||||
impl<T, DB> QueryVariable<T, DB>
|
||||
where
|
||||
T: Type<DB> + for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Clone + 'static,
|
||||
DB: Database,
|
||||
{
|
||||
#[must_use]
|
||||
pub const fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, DB> PushToBuilder<DB> for QueryVariable<T, DB>
|
||||
where
|
||||
T: Type<DB> + for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Clone + 'static,
|
||||
@@ -73,6 +87,26 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ColumnIdent(&'static str);
|
||||
|
||||
impl ColumnIdent {
|
||||
#[must_use]
|
||||
pub const fn new(inner: &'static str) -> Self {
|
||||
Self(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB> PushToBuilder<DB> for ColumnIdent
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
fn push_to(&self, builder: &mut QueryBuilder<'_, DB>) {
|
||||
builder.push("\"");
|
||||
builder.push(self.0);
|
||||
builder.push("\"");
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BracketsExpr<P, DB>
|
||||
where
|
||||
P: PushToBuilder<DB>,
|
||||
@@ -221,3 +255,35 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::print_stdout)]
|
||||
mod test {
|
||||
use super::PushToBuilder;
|
||||
use super::expr::{Equals, NotEquals};
|
||||
use sqlx::{Any, QueryBuilder};
|
||||
use sqlx_utils_macros::WhereClause;
|
||||
|
||||
#[test]
|
||||
fn where_derive() {
|
||||
#[derive(WhereClause)]
|
||||
#[sqlxu(driver = "Any")]
|
||||
struct WhereTest {
|
||||
#[sqlxu(op = "Equals")]
|
||||
name: String,
|
||||
#[sqlxu(op = "NotEquals")]
|
||||
id: i64,
|
||||
}
|
||||
|
||||
let mut builder: QueryBuilder<'_, Any> = QueryBuilder::new("");
|
||||
|
||||
let s = WhereTest {
|
||||
name: "Flip".to_string(),
|
||||
id: 69,
|
||||
};
|
||||
|
||||
s.push_to(&mut builder);
|
||||
|
||||
println!("{}", builder.sql());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user