Skip to content

Commit 1c9eeec

Browse files
etareductionGoodGrief1488DamianCzajkowski
authored
Add a setting that sets default=None to all optional fields so they can be missing from response and still parse (#385)
* Add a setting that sets default=None to all optional fields so they can be missing from response and still parse * Unit tests for default_optional_fields_to_none setting --------- Co-authored-by: GoodGrief1488 <55449535+GoodGrief1488@users.noreply.github.com> Co-authored-by: DamianCzajkowski <43958031+DamianCzajkowski@users.noreply.github.com>
1 parent efa3cbc commit 1c9eeec

File tree

5 files changed

+159
-2
lines changed

5 files changed

+159
-2
lines changed

ariadne_codegen/client_generators/fragments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
convert_to_snake_case: bool = True,
2121
custom_scalars: Optional[dict[str, ScalarData]] = None,
2222
plugin_manager: Optional[PluginManager] = None,
23+
default_optional_fields_to_none: bool = False,
2324
include_typename: bool = True,
2425
) -> None:
2526
self.schema = schema
@@ -29,6 +30,7 @@ def __init__(
2930
self.convert_to_snake_case = convert_to_snake_case
3031
self.custom_scalars = custom_scalars
3132
self.plugin_manager = plugin_manager
33+
self.default_optional_fields_to_none = default_optional_fields_to_none
3234
self.include_typename = include_typename
3335

3436
self._fragments_names = set(self.fragments_definitions.keys())
@@ -54,6 +56,7 @@ def generate(self, exclude_names: Optional[set[str]] = None) -> ast.Module:
5456
convert_to_snake_case=self.convert_to_snake_case,
5557
custom_scalars=self.custom_scalars,
5658
plugin_manager=self.plugin_manager,
59+
default_optional_fields_to_none=self.default_optional_fields_to_none,
5760
include_typename=self.include_typename,
5861
)
5962
imports.extend(generator.get_imports())

ariadne_codegen/client_generators/package.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
custom_scalars: Optional[dict[str, ScalarData]] = None,
8585
plugin_manager: Optional[PluginManager] = None,
8686
enable_custom_operations: bool = False,
87+
default_optional_fields_to_none: bool = False,
8788
include_typename: bool = True,
8889
) -> None:
8990
self.package_path = Path(target_path) / package_name
@@ -134,6 +135,7 @@ def __init__(
134135
)
135136
self.custom_scalars = custom_scalars if custom_scalars else {}
136137
self.plugin_manager = plugin_manager
138+
self.default_optional_fields_to_none = default_optional_fields_to_none
137139
self.include_typename = include_typename
138140

139141
self._result_types_files: dict[str, ast.Module] = {}
@@ -201,6 +203,7 @@ def add_operation(self, definition: OperationDefinitionNode):
201203
convert_to_snake_case=self.convert_to_snake_case,
202204
custom_scalars=self.custom_scalars,
203205
plugin_manager=self.plugin_manager,
206+
default_optional_fields_to_none=self.default_optional_fields_to_none,
204207
include_typename=self.include_typename,
205208
)
206209
self._unpacked_fragments = self._unpacked_fragments.union(
@@ -457,6 +460,7 @@ def get_package_generator(
457460
convert_to_snake_case=settings.convert_to_snake_case,
458461
custom_scalars=settings.scalars,
459462
plugin_manager=plugin_manager,
463+
default_optional_fields_to_none=settings.default_optional_fields_to_none,
460464
include_typename=settings.include_typename,
461465
)
462466
custom_fields_generator = CustomFieldsGenerator(
@@ -537,5 +541,6 @@ def get_package_generator(
537541
custom_scalars=settings.scalars,
538542
plugin_manager=plugin_manager,
539543
enable_custom_operations=settings.enable_custom_operations,
544+
default_optional_fields_to_none=settings.default_optional_fields_to_none,
540545
include_typename=settings.include_typename,
541546
)

ariadne_codegen/client_generators/result_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
convert_to_snake_case: bool = True,
8585
custom_scalars: Optional[dict[str, ScalarData]] = None,
8686
plugin_manager: Optional[PluginManager] = None,
87+
default_optional_fields_to_none: bool = False,
8788
include_typename: bool = True,
8889
) -> None:
8990
self.schema = schema
@@ -99,6 +100,7 @@ def __init__(
99100
self.custom_scalars = custom_scalars if custom_scalars else {}
100101
self.convert_to_snake_case = convert_to_snake_case
101102
self.plugin_manager = plugin_manager
103+
self.default_optional_fields_to_none = default_optional_fields_to_none
102104
self.include_typename = include_typename
103105

104106
self._imports: list[ast.ImportFrom] = [
@@ -447,6 +449,14 @@ def _process_field_implementation(
447449
keywords[DEFAULT_KEYWORD] = generate_constant(
448450
field_implementation.value.value
449451
)
452+
elif (
453+
self.default_optional_fields_to_none
454+
and field_implementation.value is None
455+
and isinstance(field_implementation.annotation, ast.Subscript)
456+
and isinstance(field_implementation.annotation.value, ast.Name)
457+
and field_implementation.annotation.value.id == OPTIONAL
458+
):
459+
keywords[DEFAULT_KEYWORD] = generate_constant(None)
450460

451461
if keywords:
452462
field_implementation.value = generate_pydantic_field(keywords)

ariadne_codegen/settings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ class ClientSettings(BaseSettings):
7070
include_all_enums: bool = True
7171
async_client: bool = True
7272
opentelemetry_client: bool = False
73-
files_to_include: list[str] = field(default_factory=list)
74-
scalars: dict[str, ScalarData] = field(default_factory=dict)
73+
files_to_include: List[str] = field(default_factory=list)
74+
scalars: Dict[str, ScalarData] = field(default_factory=dict)
75+
default_optional_fields_to_none: bool = False
7576
include_typename: bool = True
7677

7778
def __post_init__(self):
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import ast
2+
from typing import cast
3+
4+
from graphql import (
5+
OperationDefinitionNode,
6+
build_ast_schema,
7+
parse,
8+
)
9+
10+
from ariadne_codegen.client_generators.constants import (
11+
ALIAS_KEYWORD,
12+
DEFAULT_KEYWORD,
13+
FIELD_CLASS,
14+
OPTIONAL,
15+
)
16+
from ariadne_codegen.client_generators.result_types import ResultTypesGenerator
17+
18+
from ...utils import compare_ast, format_graphql_str, get_class_def
19+
from .schema import SCHEMA_STR
20+
21+
22+
def test_default_optional_fields_true():
23+
query_str = format_graphql_str(
24+
"""
25+
query CustomQuery {
26+
query1 {
27+
... on CustomType {
28+
field1
29+
field2
30+
}
31+
}
32+
}
33+
"""
34+
)
35+
expected_results = [
36+
ast.AnnAssign(
37+
target=ast.Name(id="field_1"),
38+
annotation=ast.Name(id='"CustomQueryQuery1Field1"'),
39+
value=ast.Call(
40+
func=ast.Name(id=FIELD_CLASS),
41+
args=[],
42+
keywords=[
43+
ast.keyword(
44+
arg=ALIAS_KEYWORD,
45+
value=ast.Constant(value="field1"),
46+
)
47+
],
48+
),
49+
simple=1,
50+
),
51+
ast.AnnAssign(
52+
target=ast.Name(id="field_2"),
53+
annotation=ast.Subscript(
54+
value=ast.Name(id=OPTIONAL),
55+
slice=ast.Name(id='"CustomQueryQuery1Field2"'),
56+
),
57+
value=ast.Call(
58+
func=ast.Name(id=FIELD_CLASS),
59+
args=[],
60+
keywords=[
61+
ast.keyword(arg=ALIAS_KEYWORD, value=ast.Constant(value="field2")),
62+
ast.keyword(arg=DEFAULT_KEYWORD, value=ast.Constant(value=None)),
63+
],
64+
),
65+
simple=1,
66+
),
67+
]
68+
generator = ResultTypesGenerator(
69+
schema=build_ast_schema(parse(SCHEMA_STR)),
70+
operation_definition=cast(
71+
OperationDefinitionNode, parse(query_str).definitions[0]
72+
),
73+
enums_module_name="enums",
74+
default_optional_fields_to_none=True,
75+
)
76+
result = generator.generate()
77+
classdef = get_class_def(result, 1)
78+
assert compare_ast(classdef.body[0], expected_results[0])
79+
assert compare_ast(classdef.body[1], expected_results[1])
80+
81+
82+
def test_default_optional_fields_false():
83+
query_str = format_graphql_str(
84+
"""
85+
query CustomQuery {
86+
query1 {
87+
... on CustomType {
88+
field1
89+
field2
90+
}
91+
}
92+
}
93+
"""
94+
)
95+
expected_results = [
96+
ast.AnnAssign(
97+
target=ast.Name(id="field_1"),
98+
annotation=ast.Name(id='"CustomQueryQuery1Field1"'),
99+
value=ast.Call(
100+
func=ast.Name(id=FIELD_CLASS),
101+
args=[],
102+
keywords=[
103+
ast.keyword(
104+
arg=ALIAS_KEYWORD,
105+
value=ast.Constant(value="field1"),
106+
)
107+
],
108+
),
109+
simple=1,
110+
),
111+
ast.AnnAssign(
112+
target=ast.Name(id="field_2"),
113+
annotation=ast.Subscript(
114+
value=ast.Name(id=OPTIONAL),
115+
slice=ast.Name(id='"CustomQueryQuery1Field2"'),
116+
),
117+
value=ast.Call(
118+
func=ast.Name(id=FIELD_CLASS),
119+
args=[],
120+
keywords=[
121+
ast.keyword(arg=ALIAS_KEYWORD, value=ast.Constant(value="field2"))
122+
],
123+
),
124+
simple=1,
125+
),
126+
]
127+
generator = ResultTypesGenerator(
128+
schema=build_ast_schema(parse(SCHEMA_STR)),
129+
operation_definition=cast(
130+
OperationDefinitionNode, parse(query_str).definitions[0]
131+
),
132+
enums_module_name="enums",
133+
default_optional_fields_to_none=False,
134+
)
135+
result = generator.generate()
136+
classdef = get_class_def(result, 1)
137+
assert compare_ast(classdef.body[0], expected_results[0])
138+
assert compare_ast(classdef.body[1], expected_results[1])

0 commit comments

Comments
 (0)