fix(macros): where clause

This commit is contained in:
2025-08-11 14:20:13 +02:00
parent 7aa66ea38c
commit 5f8dbdc764
3 changed files with 140 additions and 51 deletions

View File

@@ -17,7 +17,7 @@ pub fn where_clause(input: TokenStream) -> TokenStream {
}
fn crate_name() -> proc_macro2::TokenStream {
match proc_macro_crate::crate_name("sqlx_utils") {
match proc_macro_crate::crate_name("sqlx-utils") {
Err(_) => quote! {::sqlx_utils},
Ok(proc_macro_crate::FoundCrate::Itself) => quote! {crate},
Ok(proc_macro_crate::FoundCrate::Name(e)) => {

View File

@@ -22,6 +22,7 @@ struct WhereClauseField {
struct WhereClauseArgs {
ident: Ident,
data: Data<(), WhereClauseField>,
driver: Type,
}
pub fn where_clause(input: TokenStream) -> TokenStream {
@@ -39,71 +40,92 @@ pub fn where_clause(input: TokenStream) -> TokenStream {
let crate_name = crate_name();
let push_clauses = args.data.take_struct().expect("Should never be `None`").into_iter().map(|e| {
let ident = e.ident.expect("Should never be `None`");
let driver = args.driver;
let col_name = e.rename.unwrap_or_else(|| ident.to_string());
let push_clauses = args
.data
.take_struct()
.expect("Should never be `None`")
.into_iter()
.map(|e| {
let ident = e.ident.expect("Should never be `None`");
let op = e.op.unwrap_or_else(|| {
let mut segments: Punctuated<PathSegment, Token![::]> = Punctuated::new();
segments.push_value(PathSegment::from(Ident::new("sqlx_utils", Span::call_site())));
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
segments.push_value(PathSegment::from(Ident::new("builder", Span::call_site())));
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
segments.push_value(PathSegment::from(Ident::new("expr", Span::call_site())));
segments.push_punct(PathSep { spans: [Span::call_site(); 2] });
segments.push_value(PathSegment::from(Ident::new("Equals", Span::call_site())));
let col_name = e.rename.unwrap_or_else(|| ident.to_string());
Type::Path(TypePath {
qself: None,
path: Path {
leading_colon: Some(PathSep {
spans: [Span::call_site(); 2],
}),
segments,
},
})
});
let op = e.op.unwrap_or_else(|| {
let mut segments: Punctuated<PathSegment, Token![::]> = Punctuated::new();
segments.push_value(PathSegment::from(Ident::new(
"sqlx_utils",
Span::call_site(),
)));
segments.push_punct(PathSep {
spans: [Span::call_site(); 2],
});
segments.push_value(PathSegment::from(Ident::new("builder", Span::call_site())));
segments.push_punct(PathSep {
spans: [Span::call_site(); 2],
});
segments.push_value(PathSegment::from(Ident::new("expr", Span::call_site())));
segments.push_punct(PathSep {
spans: [Span::call_site(); 2],
});
segments.push_value(PathSegment::from(Ident::new("Equals", Span::call_site())));
let is_option = if let Type::Path(ty) = &e.ty {
ty.path.segments.first().is_some_and(|e| e.ident.eq("Option"))
} else {
false
};
Type::Path(TypePath {
qself: None,
path: Path {
leading_colon: Some(PathSep {
spans: [Span::call_site(); 2],
}),
segments,
},
})
});
let push_quote = quote! {
let expr = #crate_name::builder::BracketsExpr::new(
#crate_name::builder::BinaryExpr::new(#col_name, #crate_name::builder::expr::ensure_where(#op), ident.clone())
);
list.push(expr);
};
let is_option = if let Type::Path(ty) = &e.ty {
ty.path
.segments
.first()
.is_some_and(|e| e.ident.eq("Option"))
} else {
false
};
if is_option {
quote! {
if let Some(ident) = &self.#ident {
#push_quote
let push_quote = quote! {
let expr = #crate_name::builder::BinaryExpr::new(
#crate_name::builder::ColumnIdent::new(#col_name),
#crate_name::builder::expr::ensure_where(#op),
#crate_name::builder::QueryVariable::new(ident.clone())
);
list.push(::std::sync::Arc::new(expr));
};
if is_option {
quote! {
if let Some(ident) = &self.#ident {
#push_quote
}
}
} else {
quote! {
{
let ident = &self.#ident;
#push_quote;
}
}
}
} else {
quote! {
{
let ident = &self.#ident;
#push_quote;
}
}
}
});
});
quote! {
impl<DB> #crate_name::builder::PushToBuilder<DB> for #struct_name where DB: ::sqlx::Database {
fn push_to(&self, builder: &mut ::sqlx::QueryBuilder<'_, DB>) {
let mut list = #crate_name::builder::ExprList::<DB, #crate_name::builder::expr::And>::new(::std::vec::Vec::new());
impl #crate_name::builder::PushToBuilder<#driver> for #struct_name {
fn push_to(&self, builder: &mut ::sqlx::QueryBuilder<'_, #driver>) {
let mut list = #crate_name::builder::ExprList::<#driver, #crate_name::builder::expr::And>::new(::std::vec::Vec::new());
#(
#push_clauses
)*
#crate_name::builder::PushToBuilder::push_to(list, builder);
#crate_name::builder::PushToBuilder::push_to(&list, builder);
}
}
}
@@ -118,6 +140,7 @@ mod test {
fn derive() {
let input = r#"
#[derive(WhereClause)]
#[sqlxu(driver = "::sqlx::Any")]
struct NagelTest {
#[sqlxu(rename = "fusgesicht")]
name: String

View File

@@ -63,6 +63,20 @@ where
marker: PhantomData<DB>,
}
impl<T, DB> QueryVariable<T, DB>
where
T: Type<DB> + for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Clone + 'static,
DB: Database,
{
#[must_use]
pub const fn new(inner: T) -> Self {
Self {
inner,
marker: PhantomData,
}
}
}
impl<T, DB> PushToBuilder<DB> for QueryVariable<T, DB>
where
T: Type<DB> + for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Clone + 'static,
@@ -73,6 +87,26 @@ where
}
}
pub struct ColumnIdent(&'static str);
impl ColumnIdent {
#[must_use]
pub const fn new(inner: &'static str) -> Self {
Self(inner)
}
}
impl<DB> PushToBuilder<DB> for ColumnIdent
where
DB: Database,
{
fn push_to(&self, builder: &mut QueryBuilder<'_, DB>) {
builder.push("\"");
builder.push(self.0);
builder.push("\"");
}
}
pub struct BracketsExpr<P, DB>
where
P: PushToBuilder<DB>,
@@ -221,3 +255,35 @@ where
}
}
}
#[cfg(test)]
#[allow(clippy::print_stdout)]
mod test {
use super::PushToBuilder;
use super::expr::{Equals, NotEquals};
use sqlx::{Any, QueryBuilder};
use sqlx_utils_macros::WhereClause;
#[test]
fn where_derive() {
#[derive(WhereClause)]
#[sqlxu(driver = "Any")]
struct WhereTest {
#[sqlxu(op = "Equals")]
name: String,
#[sqlxu(op = "NotEquals")]
id: i64,
}
let mut builder: QueryBuilder<'_, Any> = QueryBuilder::new("");
let s = WhereTest {
name: "Flip".to_string(),
id: 69,
};
s.push_to(&mut builder);
println!("{}", builder.sql());
}
}