Skip to content

Commit 92285e4

Browse files
committed
Fix handling of Enums in Literal types
1 parent 3ee4b40 commit 92285e4

2 files changed

Lines changed: 51 additions & 4 deletions

File tree

src/cattr/converters.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,17 @@ def _structure_call(obj, cl):
405405

406406
@staticmethod
407407
def _structure_literal(val, type):
408-
if val not in type.__args__:
409-
raise Exception(f"{val} not in literal {type}")
410-
return val
408+
vals = set(type.__args__)
409+
enums = {x for x in vals if isinstance(x, Enum)}
410+
literal_vals = vals.difference(enums)
411+
if val not in literal_vals:
412+
enum_vals = {x.value: x for x in enums}
413+
if val not in enum_vals:
414+
raise Exception(f"{val} not in literal {type}")
415+
else:
416+
return enum_vals[val]
417+
else:
418+
return val
411419

412420
# Attrs classes.
413421

tests/test_structure_attrs.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Loading of attrs classes."""
2+
from enum import Enum
23
from ipaddress import IPv4Address, IPv6Address, ip_address
34
from typing import Union
45
from unittest.mock import Mock
@@ -164,6 +165,27 @@ class ClassWithLiteral:
164165
) == ClassWithLiteral(4)
165166

166167

168+
@pytest.mark.skipif(is_py37, reason="Not supported on 3.7")
169+
@pytest.mark.parametrize("converter_cls", [Converter, GenConverter])
170+
def test_structure_literal_enum(converter_cls):
171+
"""Structuring a class with a literal field works."""
172+
from typing import Literal
173+
174+
converter = converter_cls()
175+
176+
class Foo(Enum):
177+
FOO = 1
178+
BAR = 2
179+
180+
@define
181+
class ClassWithLiteral:
182+
literal_field: Literal[Foo.FOO] = Foo.FOO
183+
184+
assert converter.structure(
185+
{"literal_field": 1}, ClassWithLiteral
186+
) == ClassWithLiteral(Foo.FOO)
187+
188+
167189
@pytest.mark.skipif(is_py37, reason="Not supported on 3.7")
168190
@pytest.mark.parametrize("converter_cls", [Converter, GenConverter])
169191
def test_structure_literal_multiple(converter_cls):
@@ -172,16 +194,33 @@ def test_structure_literal_multiple(converter_cls):
172194

173195
converter = converter_cls()
174196

197+
class Foo(Enum):
198+
FOO = 1
199+
FOOFOO = 2
200+
201+
class Bar(Enum):
202+
BAR = 8
203+
BARBAR = 9
204+
175205
@define
176206
class ClassWithLiteral:
177-
literal_field: Literal[4, 5] = 4
207+
literal_field: Literal[4, 5, Literal[6], Foo.FOO, Bar.BARBAR] = 4
178208

179209
assert converter.structure(
180210
{"literal_field": 4}, ClassWithLiteral
181211
) == ClassWithLiteral(4)
182212
assert converter.structure(
183213
{"literal_field": 5}, ClassWithLiteral
184214
) == ClassWithLiteral(5)
215+
assert converter.structure(
216+
{"literal_field": 6}, ClassWithLiteral
217+
) == ClassWithLiteral(6)
218+
assert converter.structure(
219+
{"literal_field": 1}, ClassWithLiteral
220+
) == ClassWithLiteral(Foo.FOO)
221+
assert converter.structure(
222+
{"literal_field": 9}, ClassWithLiteral
223+
) == ClassWithLiteral(Bar.BARBAR)
185224

186225

187226
@pytest.mark.skipif(is_py37, reason="Not supported on 3.7")

0 commit comments

Comments
 (0)