Skip to content

Commit 9334c72

Browse files
authored
Make TryFrom and FromStr infallible if there's a default (#476)
1 parent 0ccbbf8 commit 9334c72

5 files changed

Lines changed: 182 additions & 88 deletions

File tree

strum_macros/src/lib.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
3434

3535
/// Converts strings to enum variants based on their name.
3636
///
37-
/// auto-derives `std::str::FromStr` on the enum (for Rust 1.34 and above, `std::convert::TryFrom<&str>`
38-
/// will be derived as well). Each variant of the enum will match on it's own name.
39-
/// This can be overridden using `serialize="DifferentName"` or `to_string="DifferentName"`
37+
/// auto-derives `std::str::FromStr` on the enum. Each variant of the enum will match on its own
38+
/// name. This can be overridden using `serialize="DifferentName"` or `to_string="DifferentName"`
4039
/// on the attribute as shown below.
41-
/// Multiple deserializations can be added to the same variant. If the variant contains additional data,
42-
/// they will be set to their default values upon deserialization.
40+
/// Multiple deserializations can be added to the same variant. If the variant contains additional
41+
/// data, they will be set to their default values upon deserialization.
4342
///
44-
/// The `default` attribute can be applied to a tuple variant with a single data parameter. When a match isn't
45-
/// found, the given variant will be returned and the input string will be captured in the parameter.
43+
/// The `default` attribute can be applied to a tuple variant with a single data parameter. When a
44+
/// match isn't found, the given variant will be returned and the input string will be captured in
45+
/// the parameter.
4646
///
4747
/// Note that the implementation of `FromStr` by default only matches on the name of the
4848
/// variant. There is an option to match on different case conversions through the
@@ -57,15 +57,20 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
5757
/// rather than just assume it will be faster. With SIMD + pipelining, linear string search (aka memcmp)
5858
/// can be very fast for enums with a surprisingly large number of enum variants.
5959
///
60-
/// The default error type is `strum::ParseError`. This can be overriden by applying both the
61-
/// `parse_err_ty` and `parse_err_fn` attributes at the type level. `parse_err_fn` should be a
60+
/// # Infallible Parsing
61+
///
62+
/// If the enum has a `#[strum(default)]` variant and no `parse_err_ty` is set, parsing is
63+
/// infallible: `From<&str>` is derived instead of `TryFrom<&str>`, which allows calling
64+
/// `MyEnum::from("string")` directly.
65+
///
66+
/// # Custom Error Types
67+
///
68+
/// The default error type is `strum::ParseError`. This can be overridden by applying both the
69+
/// `parse_err_ty` and `parse_err_fn` attributes at the type level. `parse_err_fn` should be a
6270
/// function that accepts an `&str` and returns the type `parse_err_ty`. See [this test
6371
/// case](https://github.com/Peternator7/strum/blob/9db3c4dc9b6f585aeb9f5f15f9cc18b6cf4fd780/strum_tests/tests/from_str.rs#L233)
64-
/// for an example.
65-
///
66-
/// If the enum has a default variant (annotated with `#[strum(default)]`), then parsing is
67-
/// infallible. In that case, `parse_err_fn` need not exist (it will never be called) and
68-
/// `parse_err_ty` can be safely set to [`std::convert::Infallible`].
72+
/// for an example. When `parse_err_ty` is set, `TryFrom<&str>` is always derived, even if the
73+
/// enum has a `#[strum(default)]` variant.
6974
///
7075
/// # Example how to use `EnumString`
7176
/// ```

strum_macros/src/macros/strings/from_string.rs

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use proc_macro2::TokenStream;
22
use quote::quote;
3-
use syn::{parse_quote, Data, DeriveInput, Fields, Path};
3+
use syn::{Data, DeriveInput, Fields};
44

55
use crate::helpers::{
66
missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
@@ -18,26 +18,14 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
1818
let type_properties = ast.get_type_properties()?;
1919
let strum_module_path = type_properties.crate_module_path();
2020

21+
// It's an error to provide an err_fn but not an err_ty.
22+
if type_properties.parse_err_fn.is_some() && type_properties.parse_err_ty.is_none() {
23+
return Err(missing_parse_err_attr_error());
24+
}
25+
2126
let mut default_kw = None;
22-
let (default_err_ty, mut default_match_arm) = match (
23-
type_properties.parse_err_ty,
24-
type_properties.parse_err_fn,
25-
) {
26-
(None, None) => (
27-
quote! { #strum_module_path::ParseError },
28-
quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) },
29-
),
30-
(Some(ty), Some(f)) => {
31-
let ty_path: Path = parse_quote!(#ty);
32-
let fn_path: Path = parse_quote!(#f);
33-
34-
(
35-
quote! { #ty_path },
36-
quote! { ::core::result::Result::Err(#fn_path(s)) },
37-
)
38-
}
39-
_ => return Err(missing_parse_err_attr_error()),
40-
};
27+
let mut default_match_arm = None;
28+
4129
let mut phf_exact_match_arms = Vec::new();
4230
let mut standard_match_arms = Vec::new();
4331
for variant in variants {
@@ -57,15 +45,13 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
5745

5846
match &variant.fields {
5947
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
60-
default_match_arm = quote! {
61-
::core::result::Result::Ok(#name::#ident(s.into()))
62-
};
48+
default_match_arm = Some(quote! {
49+
#name::#ident(s.into())
50+
});
6351
}
6452
Fields::Named(ref f) if f.named.len() == 1 => {
6553
let field_name = f.named.last().unwrap().ident.as_ref().unwrap();
66-
default_match_arm = quote! {
67-
::core::result::Result::Ok(#name::#ident { #field_name : s.into() } )
68-
};
54+
default_match_arm = Some(quote! { #name::#ident { #field_name : s.into() } });
6955
}
7056
_ => {
7157
return Err(syn::Error::new_spanned(
@@ -133,85 +119,109 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
133119
phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
134120
standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
135121
}
122+
} else if !is_ascii_case_insensitive {
123+
standard_match_arms.push(quote! { #serialization => #name::#ident #params, });
136124
} else {
137-
standard_match_arms.push(if !is_ascii_case_insensitive {
138-
quote! { #serialization => #name::#ident #params, }
139-
} else {
140-
quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
141-
});
125+
standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
142126
}
143127
}
144128
}
145129

146-
let phf_body = if phf_exact_match_arms.is_empty() {
147-
quote!()
130+
// Determine the error type on FromStr and TryFrom based on what the user
131+
// has configured and whether there is a default variant.
132+
let is_infallible = default_match_arm.is_some();
133+
let has_custom_err_ty = type_properties.parse_err_ty.is_some();
134+
let err_ty = if let Some(ty) = type_properties.parse_err_ty {
135+
quote! { #ty }
136+
} else if is_infallible {
137+
quote! { ::core::convert::Infallible }
138+
} else {
139+
quote! { #strum_module_path::ParseError }
140+
};
141+
142+
// Determine the default match arm behavior based on whether the user provided a "default"
143+
// or if the user provided a custom error function.
144+
let default_match_arm = if let Some(default_match_arm) = default_match_arm {
145+
default_match_arm
146+
} else if let Some(f) = type_properties.parse_err_fn {
147+
quote! { return ::core::result::Result::Err(#f(s)) }
148+
} else if has_custom_err_ty {
149+
// The user defined a custom error type, but not a custom error function. This is an error
150+
// if the method isn't infallible.
151+
return Err(missing_parse_err_attr_error());
152+
} else {
153+
quote! { return ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }
154+
};
155+
156+
let mut match_expression = if standard_match_arms.is_empty() {
157+
default_match_arm
148158
} else {
149159
quote! {
160+
match s {
161+
#(#standard_match_arms)*
162+
_ => #default_match_arm,
163+
}
164+
}
165+
};
166+
167+
if !phf_exact_match_arms.is_empty() {
168+
match_expression = quote! {
150169
use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
151170
static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
152171
#(#phf_exact_match_arms)*
153172
};
173+
154174
if let Some(value) = PHF.get(s).cloned() {
155-
return ::core::result::Result::Ok(value);
175+
value
176+
} else {
177+
#match_expression
156178
}
157179
}
158-
};
180+
}
159181

160-
let standard_match_body = if standard_match_arms.is_empty() {
161-
default_match_arm
182+
let from_impl = if is_infallible && !has_custom_err_ty {
183+
quote! {
184+
#[allow(clippy::use_self)]
185+
#[automatically_derived]
186+
impl #impl_generics ::core::convert::From<&str> for #name #ty_generics #where_clause {
187+
#[inline]
188+
fn from(s: &str) -> #name #ty_generics {
189+
#match_expression
190+
}
191+
}
192+
}
162193
} else {
163194
quote! {
164-
::core::result::Result::Ok(match s {
165-
#(#standard_match_arms)*
166-
_ => return #default_match_arm,
167-
})
195+
#[allow(clippy::use_self)]
196+
#[automatically_derived]
197+
impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
198+
type Error = #err_ty;
199+
200+
#[inline]
201+
fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
202+
Ok({
203+
#match_expression
204+
})
205+
}
206+
}
168207
}
169208
};
170209

171210
let from_str = quote! {
172211
#[allow(clippy::use_self)]
173212
#[automatically_derived]
174213
impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
175-
type Err = #default_err_ty;
214+
type Err = #err_ty;
176215

177216
#[inline]
178217
fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
179-
#phf_body
180-
#standard_match_body
218+
<Self as ::core::convert::TryFrom<&str>>::try_from(s)
181219
}
182220
}
183221
};
184-
let try_from_str = try_from_str(
185-
name,
186-
&impl_generics,
187-
&ty_generics,
188-
where_clause,
189-
&default_err_ty,
190-
);
191222

192223
Ok(quote! {
193224
#from_str
194-
#try_from_str
225+
#from_impl
195226
})
196227
}
197-
198-
fn try_from_str(
199-
name: &proc_macro2::Ident,
200-
impl_generics: &syn::ImplGenerics,
201-
ty_generics: &syn::TypeGenerics,
202-
where_clause: Option<&syn::WhereClause>,
203-
default_err_ty: &TokenStream,
204-
) -> TokenStream {
205-
quote! {
206-
#[allow(clippy::use_self)]
207-
#[automatically_derived]
208-
impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
209-
type Error = #default_err_ty;
210-
211-
#[inline]
212-
fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
213-
::core::str::FromStr::from_str(s)
214-
}
215-
}
216-
}
217-
}

strum_tests/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub enum Color {
1212
Yellow,
1313
#[strum(disabled)]
1414
Green(String),
15+
#[strum(default)]
16+
Purple(String),
1517
}
1618

1719
/// A bunch of errors

strum_tests/tests/from_str.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::str::FromStr;
1+
#![allow(clippy::infallible_try_from)]
2+
3+
use std::{convert::Infallible, str::FromStr};
24
use strum::EnumString;
35

46
mod core {} // ensure macros call `::core`
@@ -70,6 +72,16 @@ fn color_default() {
7072
assert_from_str(Color::Green(String::from("not found")), "not found");
7173
}
7274

75+
#[test]
76+
#[allow(clippy::unnecessary_fallible_conversions)]
77+
fn color2_infallible() {
78+
let r: Result<Color2, Infallible> = Color2::from_str("infallible");
79+
assert!(r.is_ok());
80+
let r: Result<Color2, Infallible> = Color2::try_from("infallible");
81+
assert!(r.is_ok());
82+
let _ = Color2::from("infallible");
83+
}
84+
7385
#[test]
7486
fn color2_default() {
7587
assert_from_str(
@@ -300,3 +312,26 @@ fn case_custom_infallible_parsing_with_default() {
300312
r
301313
);
302314
}
315+
316+
enum Never {}
317+
318+
#[derive(Debug, EnumString, Eq, PartialEq)]
319+
#[strum(
320+
parse_err_ty = Never
321+
)]
322+
enum CustomErrorTyWithNoErrorFn {
323+
#[strum(serialize = "foo")]
324+
Foo,
325+
#[strum(serialize = "bar")]
326+
Bar,
327+
#[strum(default)]
328+
Unknown(String),
329+
}
330+
331+
#[test]
332+
fn case_custom_infallible_parsing_with_default_no_err_fn() {
333+
let r: Result<CustomErrorTyWithNoErrorFn, Never> =
334+
"yellow".parse::<CustomErrorTyWithNoErrorFn>();
335+
336+
assert!(r.is_ok());
337+
}

strum_tests/tests/phf.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,48 @@ fn from_str_with_phf() {
1414
assert_eq!("bLuE".parse::<Color>().unwrap(), Color::Blue);
1515
}
1616

17+
#[cfg(feature = "test_phf")]
18+
#[test]
19+
fn from_str_with_phf_infallible() {
20+
#[derive(Debug, PartialEq, Eq, Clone, strum::EnumString)]
21+
#[strum(use_phf)]
22+
enum Color {
23+
Red,
24+
Blue,
25+
#[strum(default)]
26+
Unknown(String),
27+
}
28+
29+
// Known variants still parse correctly
30+
let c: Color = Color::from("Red");
31+
assert_eq!(c, Color::Red);
32+
let c: Color = Color::from("Blue");
33+
assert_eq!(c, Color::Blue);
34+
35+
// Unknown input falls through to the default variant
36+
let c: Color = Color::from("notacolor");
37+
assert_eq!(c, Color::Unknown("notacolor".to_string()));
38+
}
39+
40+
#[cfg(feature = "test_phf")]
41+
#[test]
42+
fn from_str_with_phf_infallible_case_insensitive() {
43+
#[derive(Debug, PartialEq, Eq, Clone, strum::EnumString)]
44+
#[strum(use_phf)]
45+
enum Color {
46+
#[strum(ascii_case_insensitive)]
47+
Blue,
48+
Red,
49+
#[strum(default)]
50+
Unknown(String),
51+
}
52+
53+
let c: Color = Color::from("bLuE");
54+
assert_eq!(c, Color::Blue);
55+
let c: Color = Color::from("notacolor");
56+
assert_eq!(c, Color::Unknown("notacolor".to_string()));
57+
}
58+
1759
#[cfg(feature = "test_phf")]
1860
#[test]
1961
fn from_str_with_phf_big() {

0 commit comments

Comments
 (0)