Skip to content

Commit ba6b6e4

Browse files
committed
backport: Enable passing values configuration to GraphQLEnumType as a thunk
Replicates graphql/graphql-js@6a1614c
1 parent 6687245 commit ba6b6e4

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

src/graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@
342342
GraphQLArgumentMap,
343343
GraphQLEnumValue,
344344
GraphQLEnumValueMap,
345+
GraphQLEnumValuesDefinition,
345346
GraphQLField,
346347
GraphQLFieldMap,
347348
GraphQLFieldResolver,
@@ -564,6 +565,7 @@
564565
"GraphQLArgumentMap",
565566
"GraphQLEnumValue",
566567
"GraphQLEnumValueMap",
568+
"GraphQLEnumValuesDefinition",
567569
"GraphQLField",
568570
"GraphQLFieldMap",
569571
"GraphQLFieldResolver",

src/graphql/type/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
GraphQLArgumentMap,
9393
GraphQLEnumValue,
9494
GraphQLEnumValueMap,
95+
GraphQLEnumValuesDefinition,
9596
GraphQLField,
9697
GraphQLFieldMap,
9798
GraphQLInputField,
@@ -245,6 +246,7 @@
245246
"GraphQLArgumentMap",
246247
"GraphQLEnumValue",
247248
"GraphQLEnumValueMap",
249+
"GraphQLEnumValuesDefinition",
248250
"GraphQLField",
249251
"GraphQLFieldMap",
250252
"GraphQLInputField",

src/graphql/type/definition.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
"GraphQLEnumValue",
117117
"GraphQLEnumValueKwargs",
118118
"GraphQLEnumValueMap",
119+
"GraphQLEnumValuesDefinition",
119120
"GraphQLField",
120121
"GraphQLFieldKwargs",
121122
"GraphQLFieldMap",
@@ -1106,6 +1107,8 @@ def assert_union_type(type_: Any) -> GraphQLUnionType:
11061107

11071108
GraphQLEnumValueMap = Dict[str, "GraphQLEnumValue"]
11081109

1110+
GraphQLEnumValuesDefinition = Union[GraphQLEnumValueMap, Mapping[str, Any], Type[Enum]]
1111+
11091112

11101113
class GraphQLEnumTypeKwargs(GraphQLNamedTypeKwargs, total=False):
11111114
values: GraphQLEnumValueMap
@@ -1153,7 +1156,7 @@ class RGBEnum(enum.Enum):
11531156
def __init__(
11541157
self,
11551158
name: str,
1156-
values: Union[GraphQLEnumValueMap, Mapping[str, Any], Type[Enum]],
1159+
values: Thunk[GraphQLEnumValuesDefinition],
11571160
names_as_values: Optional[bool] = False,
11581161
description: Optional[str] = None,
11591162
extensions: Optional[Dict[str, Any]] = None,
@@ -1167,6 +1170,8 @@ def __init__(
11671170
ast_node=ast_node,
11681171
extension_ast_nodes=extension_ast_nodes,
11691172
)
1173+
if not isinstance(values, type):
1174+
values = resolve_thunk(values) # type: ignore
11701175
try: # check for enum
11711176
values = cast(Enum, values).__members__ # type: ignore
11721177
except AttributeError:
@@ -1175,7 +1180,7 @@ def __init__(
11751180
):
11761181
try:
11771182
# noinspection PyTypeChecker
1178-
values = dict(values)
1183+
values = dict(values) # type: ignore
11791184
except (TypeError, ValueError) as error:
11801185
raise TypeError(
11811186
f"{name} values must be an Enum or a mapping"

tests/type/test_enum.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class Complex2:
4141

4242
ColorType2 = GraphQLEnumType("Color", ColorTypeEnumValues)
4343

44+
ThunkValuesEnum = GraphQLEnumType("ThunkValues", lambda: {"A": "a", "B": "b"})
45+
4446
QueryType = GraphQLObjectType(
4547
"Query",
4648
{
@@ -83,6 +85,13 @@ class Complex2:
8385
else Complex2() if args.get("provideBadValue") else args.get("fromEnum")
8486
),
8587
),
88+
"thunkValuesString": GraphQLField(
89+
GraphQLString,
90+
args={
91+
"fromEnum": GraphQLArgument(ThunkValuesEnum),
92+
},
93+
resolve=lambda _source, _info, fromEnum: fromEnum,
94+
),
8695
},
8796
)
8897

@@ -345,5 +354,10 @@ def may_be_internally_represented_with_complex_values():
345354
],
346355
)
347356

357+
def may_have_values_specified_via_a_callable():
358+
result = execute_query("{ thunkValuesString(fromEnum: B) }")
359+
360+
assert result == ({"thunkValuesString": "b"}, None)
361+
348362
def can_be_introspected_without_error():
349363
introspection_from_schema(schema)

0 commit comments

Comments
 (0)