11use proc_macro2:: TokenStream ;
22use quote:: quote;
3- use syn:: { parse_quote , Data , DeriveInput , Fields , Path } ;
3+ use syn:: { Data , DeriveInput , Fields } ;
44
55use crate :: helpers:: {
66 missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties ,
@@ -18,26 +18,14 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
1818 let type_properties = ast. get_type_properties ( ) ?;
1919 let strum_module_path = type_properties. crate_module_path ( ) ;
2020
21+ // It's an error to provide an err_fn but not an err_ty.
22+ if type_properties. parse_err_fn . is_some ( ) && type_properties. parse_err_ty . is_none ( ) {
23+ return Err ( missing_parse_err_attr_error ( ) ) ;
24+ }
25+
2126 let mut default_kw = None ;
22- let ( default_err_ty, mut default_match_arm) = match (
23- type_properties. parse_err_ty ,
24- type_properties. parse_err_fn ,
25- ) {
26- ( None , None ) => (
27- quote ! { #strum_module_path:: ParseError } ,
28- quote ! { :: core:: result:: Result :: Err ( #strum_module_path:: ParseError :: VariantNotFound ) } ,
29- ) ,
30- ( Some ( ty) , Some ( f) ) => {
31- let ty_path: Path = parse_quote ! ( #ty) ;
32- let fn_path: Path = parse_quote ! ( #f) ;
33-
34- (
35- quote ! { #ty_path } ,
36- quote ! { :: core:: result:: Result :: Err ( #fn_path( s) ) } ,
37- )
38- }
39- _ => return Err ( missing_parse_err_attr_error ( ) ) ,
40- } ;
27+ let mut default_match_arm = None ;
28+
4129 let mut phf_exact_match_arms = Vec :: new ( ) ;
4230 let mut standard_match_arms = Vec :: new ( ) ;
4331 for variant in variants {
@@ -57,15 +45,13 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
5745
5846 match & variant. fields {
5947 Fields :: Unnamed ( fields) if fields. unnamed . len ( ) == 1 => {
60- default_match_arm = quote ! {
61- :: core :: result :: Result :: Ok ( #name:: #ident( s. into( ) ) )
62- } ;
48+ default_match_arm = Some ( quote ! {
49+ #name:: #ident( s. into( ) )
50+ } ) ;
6351 }
6452 Fields :: Named ( ref f) if f. named . len ( ) == 1 => {
6553 let field_name = f. named . last ( ) . unwrap ( ) . ident . as_ref ( ) . unwrap ( ) ;
66- default_match_arm = quote ! {
67- :: core:: result:: Result :: Ok ( #name:: #ident { #field_name : s. into( ) } )
68- } ;
54+ default_match_arm = Some ( quote ! { #name:: #ident { #field_name : s. into( ) } } ) ;
6955 }
7056 _ => {
7157 return Err ( syn:: Error :: new_spanned (
@@ -133,85 +119,109 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
133119 phf_exact_match_arms. push ( quote ! { #upper => #name:: #ident #params, } ) ;
134120 standard_match_arms. push ( quote ! { s if s. eq_ignore_ascii_case( #serialization) => #name:: #ident #params, } ) ;
135121 }
122+ } else if !is_ascii_case_insensitive {
123+ standard_match_arms. push ( quote ! { #serialization => #name:: #ident #params, } ) ;
136124 } else {
137- standard_match_arms. push ( if !is_ascii_case_insensitive {
138- quote ! { #serialization => #name:: #ident #params, }
139- } else {
140- quote ! { s if s. eq_ignore_ascii_case( #serialization) => #name:: #ident #params, }
141- } ) ;
125+ standard_match_arms. push ( quote ! { s if s. eq_ignore_ascii_case( #serialization) => #name:: #ident #params, } ) ;
142126 }
143127 }
144128 }
145129
146- let phf_body = if phf_exact_match_arms. is_empty ( ) {
147- quote ! ( )
130+ // Determine the error type on FromStr and TryFrom based on what the user
131+ // has configured and whether there is a default variant.
132+ let is_infallible = default_match_arm. is_some ( ) ;
133+ let has_custom_err_ty = type_properties. parse_err_ty . is_some ( ) ;
134+ let err_ty = if let Some ( ty) = type_properties. parse_err_ty {
135+ quote ! { #ty }
136+ } else if is_infallible {
137+ quote ! { :: core:: convert:: Infallible }
138+ } else {
139+ quote ! { #strum_module_path:: ParseError }
140+ } ;
141+
142+ // Determine the default match arm behavior based on whether the user provided a "default"
143+ // or if the user provided a custom error function.
144+ let default_match_arm = if let Some ( default_match_arm) = default_match_arm {
145+ default_match_arm
146+ } else if let Some ( f) = type_properties. parse_err_fn {
147+ quote ! { return :: core:: result:: Result :: Err ( #f( s) ) }
148+ } else if has_custom_err_ty {
149+ // The user defined a custom error type, but not a custom error function. This is an error
150+ // if the method isn't infallible.
151+ return Err ( missing_parse_err_attr_error ( ) ) ;
152+ } else {
153+ quote ! { return :: core:: result:: Result :: Err ( #strum_module_path:: ParseError :: VariantNotFound ) }
154+ } ;
155+
156+ let mut match_expression = if standard_match_arms. is_empty ( ) {
157+ default_match_arm
148158 } else {
149159 quote ! {
160+ match s {
161+ #( #standard_match_arms) *
162+ _ => #default_match_arm,
163+ }
164+ }
165+ } ;
166+
167+ if !phf_exact_match_arms. is_empty ( ) {
168+ match_expression = quote ! {
150169 use #strum_module_path:: _private_phf_reexport_for_macro_if_phf_feature as phf;
151170 static PHF : phf:: Map <& ' static str , #name> = phf:: phf_map! {
152171 #( #phf_exact_match_arms) *
153172 } ;
173+
154174 if let Some ( value) = PHF . get( s) . cloned( ) {
155- return :: core:: result:: Result :: Ok ( value) ;
175+ value
176+ } else {
177+ #match_expression
156178 }
157179 }
158- } ;
180+ }
159181
160- let standard_match_body = if standard_match_arms. is_empty ( ) {
161- default_match_arm
182+ let from_impl = if is_infallible && !has_custom_err_ty {
183+ quote ! {
184+ #[ allow( clippy:: use_self) ]
185+ #[ automatically_derived]
186+ impl #impl_generics :: core:: convert:: From <& str > for #name #ty_generics #where_clause {
187+ #[ inline]
188+ fn from( s: & str ) -> #name #ty_generics {
189+ #match_expression
190+ }
191+ }
192+ }
162193 } else {
163194 quote ! {
164- :: core:: result:: Result :: Ok ( match s {
165- #( #standard_match_arms) *
166- _ => return #default_match_arm,
167- } )
195+ #[ allow( clippy:: use_self) ]
196+ #[ automatically_derived]
197+ impl #impl_generics :: core:: convert:: TryFrom <& str > for #name #ty_generics #where_clause {
198+ type Error = #err_ty;
199+
200+ #[ inline]
201+ fn try_from( s: & str ) -> :: core:: result:: Result < #name #ty_generics , <Self as :: core:: convert:: TryFrom <& str >>:: Error > {
202+ Ok ( {
203+ #match_expression
204+ } )
205+ }
206+ }
168207 }
169208 } ;
170209
171210 let from_str = quote ! {
172211 #[ allow( clippy:: use_self) ]
173212 #[ automatically_derived]
174213 impl #impl_generics :: core:: str :: FromStr for #name #ty_generics #where_clause {
175- type Err = #default_err_ty ;
214+ type Err = #err_ty ;
176215
177216 #[ inline]
178217 fn from_str( s: & str ) -> :: core:: result:: Result < #name #ty_generics , <Self as :: core:: str :: FromStr >:: Err > {
179- #phf_body
180- #standard_match_body
218+ <Self as :: core:: convert:: TryFrom <& str >>:: try_from( s)
181219 }
182220 }
183221 } ;
184- let try_from_str = try_from_str (
185- name,
186- & impl_generics,
187- & ty_generics,
188- where_clause,
189- & default_err_ty,
190- ) ;
191222
192223 Ok ( quote ! {
193224 #from_str
194- #try_from_str
225+ #from_impl
195226 } )
196227}
197-
198- fn try_from_str (
199- name : & proc_macro2:: Ident ,
200- impl_generics : & syn:: ImplGenerics ,
201- ty_generics : & syn:: TypeGenerics ,
202- where_clause : Option < & syn:: WhereClause > ,
203- default_err_ty : & TokenStream ,
204- ) -> TokenStream {
205- quote ! {
206- #[ allow( clippy:: use_self) ]
207- #[ automatically_derived]
208- impl #impl_generics :: core:: convert:: TryFrom <& str > for #name #ty_generics #where_clause {
209- type Error = #default_err_ty;
210-
211- #[ inline]
212- fn try_from( s: & str ) -> :: core:: result:: Result < #name #ty_generics , <Self as :: core:: convert:: TryFrom <& str >>:: Error > {
213- :: core:: str :: FromStr :: from_str( s)
214- }
215- }
216- }
217- }
0 commit comments