Skip to content

Commit 463c0b2

Browse files
committed
improvements, restructuring, comments and rename attr
1 parent 2e786ba commit 463c0b2

File tree

1 file changed

+131
-51
lines changed
  • postgres-from-row-derive/src

1 file changed

+131
-51
lines changed

postgres-from-row-derive/src/lib.rs

Lines changed: 131 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use darling::ast::{self, Style};
2-
use darling::{FromDeriveInput, FromField, ToTokens};
1+
use darling::{ast::Data, Error, FromDeriveInput, FromField, ToTokens};
32
use proc_macro::TokenStream;
3+
use proc_macro2::TokenStream as TokenStream2;
44
use quote::quote;
5-
use syn::{parse_macro_input, DeriveInput, Ident};
5+
use syn::{parse_macro_input, DeriveInput, Ident, Result};
66

77
#[proc_macro_derive(FromRowTokioPostgres, attributes(from_row))]
88
pub fn derive_from_row_tokio_postgres(input: TokenStream) -> TokenStream {
@@ -14,6 +14,7 @@ pub fn derive_from_row_postgres(input: TokenStream) -> TokenStream {
1414
derive_from_row(input, quote::format_ident!("postgres"))
1515
}
1616

17+
/// Calls the fallible entry point and writes any errors to the tokenstream.
1718
fn derive_from_row(input: TokenStream, module: Ident) -> TokenStream {
1819
let derive_input = parse_macro_input!(input as DeriveInput);
1920
match try_derive_from_row(&derive_input, module) {
@@ -22,11 +23,16 @@ fn derive_from_row(input: TokenStream, module: Ident) -> TokenStream {
2223
}
2324
}
2425

25-
fn try_derive_from_row(input: &DeriveInput, module: Ident) -> Result<TokenStream, darling::Error> {
26+
/// Fallible entry point for generating a `FromRow` implementation
27+
fn try_derive_from_row(
28+
input: &DeriveInput,
29+
module: Ident,
30+
) -> std::result::Result<TokenStream, Error> {
2631
let from_row_derive = DeriveFromRow::from_derive_input(input)?;
27-
from_row_derive.generate(module)
32+
Ok(from_row_derive.generate(module)?)
2833
}
2934

35+
/// Main struct for deriving `FromRow` for a struct.
3036
#[derive(Debug, FromDeriveInput)]
3137
#[darling(
3238
attributes(from_row),
@@ -36,59 +42,60 @@ fn try_derive_from_row(input: &DeriveInput, module: Ident) -> Result<TokenStream
3642
struct DeriveFromRow {
3743
ident: syn::Ident,
3844
generics: syn::Generics,
39-
data: ast::Data<(), FromRowField>,
45+
data: Data<(), FromRowField>,
4046
}
4147

4248
impl DeriveFromRow {
43-
fn generate(self, module: Ident) -> Result<TokenStream, darling::Error> {
44-
let ident = &self.ident;
49+
/// Validates all fields
50+
fn validate(&self) -> Result<()> {
51+
for field in self.fields() {
52+
field.validate()?;
53+
}
4554

46-
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
55+
Ok(())
56+
}
4757

48-
let fields = self
49-
.data
50-
.take_struct()
51-
.ok_or_else(|| darling::Error::unsupported_shape("enum").with_span(&self.ident))?;
58+
/// Generates any additional where clause predicates needed for the fields in this struct.
59+
fn predicates(&self, module: &Ident) -> Result<Vec<TokenStream2>> {
60+
let mut predicates = Vec::new();
5261

53-
let fields = match fields.style {
54-
Style::Unit => {
55-
return Err(darling::Error::unsupported_shape("unit struct").with_span(&self.ident))
56-
}
57-
Style::Tuple => {
58-
return Err(darling::Error::unsupported_shape("tuple struct").with_span(&self.ident))
59-
}
60-
Style::Struct => fields.fields,
61-
};
62+
for field in self.fields() {
63+
field.add_predicates(&module, &mut predicates)?;
64+
}
6265

63-
let from_row_fields = fields
66+
Ok(predicates)
67+
}
68+
69+
/// Provides a slice of this struct's fields.
70+
fn fields(&self) -> &[FromRowField] {
71+
match &self.data {
72+
Data::Struct(fields) => &fields.fields,
73+
_ => panic!("invalid shape"),
74+
}
75+
}
76+
77+
/// Generate the `FromRow` implementation.
78+
fn generate(self, module: Ident) -> Result<TokenStream> {
79+
self.validate()?;
80+
81+
let ident = &self.ident;
82+
83+
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
84+
let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
85+
let predicates = self.predicates(&module)?;
86+
87+
let from_row_fields = self
88+
.fields()
6489
.iter()
6590
.map(|f| f.generate_from_row(&module))
6691
.collect::<syn::Result<Vec<_>>>()?;
6792

68-
let try_from_row_fields = fields
93+
let try_from_row_fields = self
94+
.fields()
6995
.iter()
7096
.map(|f| f.generate_try_from_row(&module))
7197
.collect::<syn::Result<Vec<_>>>()?;
7298

73-
let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
74-
let mut predicates = Vec::new();
75-
76-
for field in fields.iter() {
77-
let target_ty = &field.target_ty()?;
78-
let ty = &field.ty;
79-
predicates.push(if field.flatten {
80-
quote! (#target_ty: postgres_from_row::FromRow)
81-
} else {
82-
quote! (#target_ty: for<'a> #module::types::FromSql<'a>)
83-
});
84-
85-
if field.from.is_some() {
86-
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
87-
} else if field.try_from.is_some() {
88-
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
89-
}
90-
}
91-
9299
Ok(quote! {
93100
impl #impl_generics postgres_from_row::FromRow for #ident #ty_generics where #(#original_predicates),* #(#predicates),* {
94101

@@ -109,19 +116,52 @@ impl DeriveFromRow {
109116
}
110117
}
111118

119+
/// A single field inside of a struct that derives `FromRow`
112120
#[derive(Debug, FromField)]
113121
#[darling(attributes(from_row), forward_attrs(allow, doc, cfg))]
114122
struct FromRowField {
123+
/// The identifier of this field.
115124
ident: Option<syn::Ident>,
125+
/// The type specified in this field.
116126
ty: syn::Type,
127+
/// Wether to flatten this field. Flattening means calling the `FromRow` implementation
128+
/// of `self.ty` instead of extracting it directly from the row.
117129
#[darling(default)]
118130
flatten: bool,
131+
/// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
132+
/// call `TryFrom::try_from` to convert it the `self.ty`.
119133
try_from: Option<String>,
134+
/// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
135+
/// call `From::from` to convert it the `self.ty`.
120136
from: Option<String>,
137+
/// Override the name of the actual sql column instead of using `self.ident`.
138+
/// Is not compatible with `flatten` since no column is needed there.
139+
rename: Option<String>,
121140
}
122141

123142
impl FromRowField {
124-
fn target_ty(&self) -> syn::Result<proc_macro2::TokenStream> {
143+
/// Checks wether this field has a valid combination of attributes
144+
fn validate(&self) -> Result<()> {
145+
if self.from.is_some() && self.try_from.is_some() {
146+
return Err(Error::custom(
147+
r#"can't combine `#[from_row(from = "..")]` with `#[from_row(try_from = "..")]`"#,
148+
)
149+
.into());
150+
}
151+
152+
if self.rename.is_some() && self.flatten {
153+
return Err(Error::custom(
154+
r#"can't combine `#[from_row(flatten)]` with `#[from_row(rename = "..")]`"#,
155+
)
156+
.into());
157+
}
158+
159+
Ok(())
160+
}
161+
162+
/// Returns a tokenstream of the type that should be returned from either
163+
/// `FromRow` (when using `flatten`) or `FromSql`.
164+
fn target_ty(&self) -> Result<TokenStream2> {
125165
if let Some(from) = &self.from {
126166
Ok(from.parse()?)
127167
} else if let Some(try_from) = &self.try_from {
@@ -131,17 +171,56 @@ impl FromRowField {
131171
}
132172
}
133173

134-
fn generate_from_row(&self, module: &Ident) -> syn::Result<proc_macro2::TokenStream> {
174+
/// Returns the name that maps to the actuall sql column
175+
/// By default this is the same as the rust field name but can be overwritten by `#[from_row(rename = "..")]`.
176+
fn column_name(&self) -> String {
177+
self.rename
178+
.as_ref()
179+
.map(Clone::clone)
180+
.unwrap_or_else(|| self.ident.as_ref().unwrap().to_string())
181+
}
182+
183+
/// Pushes the needed where clause predicates for this field.
184+
///
185+
/// By default this is `T: for<'a> postgres::types::FromSql<'a>`,
186+
/// when using `flatten` it's: `T: postgres_from_row::FromRow`
187+
/// and when using either `from` or `try_from` attributes it additionally pushes this bound:
188+
/// `T: std::convert::From<R>`, where `T` is the type specified in the struct and `R` is the
189+
/// type specified in the `[try]_from` attribute.
190+
fn add_predicates(&self, module: &Ident, predicates: &mut Vec<TokenStream2>) -> Result<()> {
191+
let target_ty = &self.target_ty()?;
192+
let ty = &self.ty;
193+
194+
predicates.push(if self.flatten {
195+
quote! (#target_ty: postgres_from_row::FromRow)
196+
} else {
197+
quote! (#target_ty: for<'a> #module::types::FromSql<'a>)
198+
});
199+
200+
if self.from.is_some() {
201+
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
202+
} else if self.try_from.is_some() {
203+
let try_from = quote!(std::convert::TryFrom<#target_ty>);
204+
205+
predicates.push(quote!(#ty: #try_from));
206+
predicates.push(quote!(#module::Error: std::convert::From<<#ty as #try_from>::Error>));
207+
predicates.push(quote!(<#ty as #try_from>::Error: std::fmt::Debug));
208+
}
209+
210+
Ok(())
211+
}
212+
213+
/// Generate the line needed to retrievee this field from a row when calling `from_row`.
214+
fn generate_from_row(&self, module: &Ident) -> Result<TokenStream2> {
135215
let ident = self.ident.as_ref().unwrap();
136-
let str_ident = ident.to_string();
216+
let column_name = self.column_name();
137217
let field_ty = &self.ty;
138-
139218
let target_ty = self.target_ty()?;
140219

141220
let mut base = if self.flatten {
142221
quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row))
143222
} else {
144-
quote!(#module::Row::get::<&str, #target_ty>(row, #str_ident))
223+
quote!(#module::Row::get::<&str, #target_ty>(row, #column_name))
145224
};
146225

147226
if self.from.is_some() {
@@ -153,16 +232,17 @@ impl FromRowField {
153232
Ok(quote!(#ident: #base))
154233
}
155234

156-
fn generate_try_from_row(&self, module: &Ident) -> syn::Result<proc_macro2::TokenStream> {
235+
/// Generate the line needed to retrieve this field from a row when calling `try_from_row`.
236+
fn generate_try_from_row(&self, module: &Ident) -> Result<TokenStream2> {
157237
let ident = self.ident.as_ref().unwrap();
158-
let str_ident = ident.to_string();
238+
let column_name = self.column_name();
159239
let field_ty = &self.ty;
160240
let target_ty = self.target_ty()?;
161241

162242
let mut base = if self.flatten {
163243
quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?)
164244
} else {
165-
quote!(#module::Row::try_get::<&str, #target_ty>(row, #str_ident)?)
245+
quote!(#module::Row::try_get::<&str, #target_ty>(row, #column_name)?)
166246
};
167247

168248
if self.from.is_some() {

0 commit comments

Comments
 (0)