Skip to content

Commit 0be6b1f

Browse files
fix(_patch_enum_parse_value): Fix patching enum values
1 parent e3e7424 commit 0be6b1f

2 files changed

Lines changed: 160 additions & 2 deletions

File tree

ariadne/enums_default_values.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from enum import Enum
23
from typing import Any
34

45
from graphql import (
@@ -152,9 +153,13 @@ def make_patched_parse_value(
152153
original: Callable[[str], Any],
153154
) -> Callable[[Any], Any]:
154155
def patched_parse_value(input_value: Any) -> Any:
155-
# Check if already a valid Python enum value
156+
# If input is an Enum instance or a non-string (e.g. int), return as-is.
157+
# If input is a string, parse to convert to the enum object.
158+
# For StrEnum, strings matching member names should be converted to
159+
# enum objects.
156160
if input_value in enum_type._value_lookup:
157-
return input_value
161+
if isinstance(input_value, Enum) or not isinstance(input_value, str):
162+
return input_value
158163
return original(input_value)
159164

160165
return patched_parse_value

tests/test_graphql_enum_fixes.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,156 @@ def test_invalid_default_field_input_arg_object_enum_list_value_fails_validation
586586
)
587587

588588
assert "Undefined enum value" in str(exc_info.value)
589+
590+
591+
class StrEnumMatchingValue(str, Enum):
592+
"""StrEnum where value equals name (e.g., ACTIVE = "ACTIVE")."""
593+
594+
ACTIVE = "ACTIVE"
595+
INACTIVE = "INACTIVE"
596+
597+
598+
class StrEnumDifferentValue(str, Enum):
599+
"""StrEnum where value differs from name (e.g., USER = "u")."""
600+
601+
USER = "u"
602+
ADMIN = "a"
603+
604+
605+
class IntEnumType(int, Enum):
606+
"""IntEnum with integer values."""
607+
608+
LOW = 0
609+
HIGH = 1
610+
611+
612+
class PlainEnumType(Enum):
613+
"""Plain Enum without type mixin."""
614+
615+
FOO = "foo_value"
616+
BAR = "bar_value"
617+
618+
619+
@pytest.mark.parametrize(
620+
"enum_class,member_name,expected_member",
621+
[
622+
(StrEnumMatchingValue, "ACTIVE", StrEnumMatchingValue.ACTIVE),
623+
(StrEnumDifferentValue, "ADMIN", StrEnumDifferentValue.ADMIN),
624+
(IntEnumType, "HIGH", IntEnumType.HIGH),
625+
(PlainEnumType, "FOO", PlainEnumType.FOO),
626+
],
627+
ids=[
628+
"str_enum_matching_value",
629+
"str_enum_different_value",
630+
"int_enum",
631+
"plain_enum",
632+
],
633+
)
634+
def test_enum_inline_literal_returns_enum_object(
635+
enum_class, member_name, expected_member
636+
):
637+
"""Enum inline literal is parsed to correct enum object."""
638+
enum_name = enum_class.__name__
639+
members = " ".join(m.name for m in enum_class)
640+
captured = {}
641+
642+
schema = make_executable_schema(
643+
f"""
644+
enum {enum_name} {{ {members} }}
645+
type Query {{ getValue(arg: {enum_name}!): String! }}
646+
""",
647+
enum_class,
648+
)
649+
650+
def resolver(*_, arg):
651+
captured["value"] = arg
652+
return "ok"
653+
654+
set_resolver(schema, "getValue", resolver)
655+
656+
result = graphql_sync(schema, f"{{ getValue(arg: {member_name}) }}")
657+
658+
assert not result.errors
659+
assert captured["value"] is expected_member
660+
661+
662+
@pytest.mark.parametrize(
663+
"enum_class,member_name,expected_member",
664+
[
665+
(StrEnumMatchingValue, "ACTIVE", StrEnumMatchingValue.ACTIVE),
666+
(StrEnumDifferentValue, "ADMIN", StrEnumDifferentValue.ADMIN),
667+
(IntEnumType, "HIGH", IntEnumType.HIGH),
668+
(PlainEnumType, "BAR", PlainEnumType.BAR),
669+
],
670+
ids=[
671+
"str_enum_matching_value",
672+
"str_enum_different_value",
673+
"int_enum",
674+
"plain_enum",
675+
],
676+
)
677+
def test_enum_from_variable_returns_enum_object(
678+
enum_class, member_name, expected_member
679+
):
680+
"""Enum from variable is parsed to correct enum object."""
681+
enum_name = enum_class.__name__
682+
members = " ".join(m.name for m in enum_class)
683+
captured = {}
684+
685+
schema = make_executable_schema(
686+
f"""
687+
enum {enum_name} {{ {members} }}
688+
type Query {{ getValue(arg: {enum_name}!): String! }}
689+
""",
690+
enum_class,
691+
)
692+
693+
def resolver(*_, arg):
694+
captured["value"] = arg
695+
return "ok"
696+
697+
set_resolver(schema, "getValue", resolver)
698+
699+
result = graphql_sync(
700+
schema,
701+
f"query($v: {enum_name}!) {{ getValue(arg: $v) }}",
702+
variable_values={"v": member_name},
703+
)
704+
705+
assert not result.errors
706+
assert captured["value"] is expected_member
707+
708+
709+
@pytest.mark.parametrize(
710+
"enum_class,member_name,expected_member",
711+
[
712+
(StrEnumMatchingValue, "ACTIVE", StrEnumMatchingValue.ACTIVE),
713+
(StrEnumDifferentValue, "USER", StrEnumDifferentValue.USER),
714+
(IntEnumType, "LOW", IntEnumType.LOW),
715+
(PlainEnumType, "FOO", PlainEnumType.FOO),
716+
],
717+
ids=[
718+
"str_enum_matching_value",
719+
"str_enum_different_value",
720+
"int_enum",
721+
"plain_enum",
722+
],
723+
)
724+
def test_parse_value_returns_enum_object(enum_class, member_name, expected_member):
725+
"""parse_value returns enum object, not the raw string."""
726+
enum_name = enum_class.__name__
727+
members = " ".join(m.name for m in enum_class)
728+
729+
schema = make_executable_schema(
730+
f"""
731+
enum {enum_name} {{ {members} }}
732+
type Query {{ value: {enum_name} }}
733+
""",
734+
enum_class,
735+
)
736+
737+
enum_type = schema.type_map[enum_name]
738+
parsed = enum_type.parse_value(member_name)
739+
740+
assert isinstance(parsed, enum_class)
741+
assert parsed is expected_member

0 commit comments

Comments
 (0)