Skip to content

Commit 3411dda

Browse files
committed
Allow inputs to have configurable base class (#371)
The configuration in the pyproject.toml now takes the additional setting ignore_extra_fields. ignore_extra_fields is true by default as BaseModel ignores extra fields by default. If it is set to false, extra='forbid' will be appended to the BaseModel in the client. If additional fields are then provided, an error is thrown. Additionally three unit tests have been added, which test: - If extra='forbid' is properly appended if no extra is provided - If extra='ignore' is not overriden, in case it exists - If other classes in the provided code are left untouched
1 parent af021b3 commit 3411dda

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

ariadne_codegen/client_generators/package.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from ..exceptions import ParsingError
1414
from ..plugins.manager import PluginManager
1515
from ..settings import ClientSettings, CommentsStrategy
16-
from ..utils import ast_to_str, process_name, str_to_pascal_case
16+
from ..utils import (
17+
add_extra_to_base_model,
18+
ast_to_str,
19+
process_name,
20+
str_to_pascal_case,
21+
)
1722
from .arguments import ArgumentsGenerator
1823
from .client import ClientGenerator
1924
from .comments import get_comment
@@ -85,6 +90,7 @@ def __init__(
8590
plugin_manager: Optional[PluginManager] = None,
8691
enable_custom_operations: bool = False,
8792
include_typename: bool = True,
93+
ignore_extra_fields: bool = True,
8894
) -> None:
8995
self.package_path = Path(target_path) / package_name
9096

@@ -135,6 +141,7 @@ def __init__(
135141
self.custom_scalars = custom_scalars if custom_scalars else {}
136142
self.plugin_manager = plugin_manager
137143
self.include_typename = include_typename
144+
self.ignore_extra_fields = ignore_extra_fields
138145

139146
self._result_types_files: Dict[str, ast.Module] = {}
140147
self._generated_files: List[str] = []
@@ -355,6 +362,8 @@ def _copy_files(self):
355362
]
356363
for source_path in files_to_copy:
357364
code = self._add_comments_to_code(source_path.read_text(encoding="utf-8"))
365+
if not self.ignore_extra_fields and source_path.name == "base_model.py":
366+
code = add_extra_to_base_model(code)
358367
if self.plugin_manager:
359368
code = self.plugin_manager.copy_code(code)
360369
target_path = self.package_path / source_path.name
@@ -538,4 +547,5 @@ def get_package_generator(
538547
plugin_manager=plugin_manager,
539548
enable_custom_operations=settings.enable_custom_operations,
540549
include_typename=settings.include_typename,
550+
ignore_extra_fields=settings.ignore_extra_fields,
541551
)

ariadne_codegen/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class ClientSettings(BaseSettings):
7474
files_to_include: List[str] = field(default_factory=list)
7575
scalars: Dict[str, ScalarData] = field(default_factory=dict)
7676
include_typename: bool = True
77+
ignore_extra_fields: bool = True
7778

7879
def __post_init__(self):
7980
if not self.queries_path and not self.enable_custom_operations:

ariadne_codegen/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,29 @@ def process_name(
138138
if set(name) == {"_"} and not processed_name:
139139
return "underscore_named_field_"
140140
return processed_name
141+
142+
143+
def add_extra_to_base_model(code: str) -> str:
144+
"Adds `extra='forbid'` to the ConfigDict in BaseModel if not already present."
145+
tree = ast.parse(code)
146+
for node in tree.body:
147+
if not isinstance(node, ast.ClassDef):
148+
continue
149+
if node.name != "BaseModel":
150+
continue
151+
for statement in node.body:
152+
if not isinstance(statement, ast.Assign):
153+
continue
154+
call = statement.value
155+
if not isinstance(call, ast.Call):
156+
continue
157+
if not isinstance(call.func, ast.Name):
158+
continue
159+
if call.func.id != "ConfigDict":
160+
continue
161+
if not any(kw.arg == "extra" for kw in call.keywords):
162+
call.keywords.append(
163+
ast.keyword(arg="extra", value=ast.Constant("forbid"))
164+
)
165+
ast.fix_missing_locations(tree)
166+
return ast.unparse(tree)

tests/test_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from ariadne_codegen.utils import (
7+
add_extra_to_base_model,
78
ast_to_str,
89
convert_to_multiline_string,
910
format_multiline_strings,
@@ -201,3 +202,42 @@ def test_process_name_returns_name_returned_from_plugin_for_name_with_only_under
201202
)
202203
== "name_from_plugin"
203204
)
205+
206+
207+
def test_adds_extra_to_base_model_if_missing():
208+
code = dedent("""
209+
class BaseModel:
210+
Config = ConfigDict()
211+
""")
212+
expected = dedent("""
213+
class BaseModel:
214+
Config = ConfigDict(extra='forbid')
215+
""")
216+
result = add_extra_to_base_model(code)
217+
assert dedent(result).strip() == expected.strip()
218+
219+
220+
def test_adds_extra_to_base_model_does_not_overwrite_existing_extra():
221+
code = dedent("""
222+
class BaseModel:
223+
Config = ConfigDict(extra='ignore')
224+
""")
225+
expected = dedent("""
226+
class BaseModel:
227+
Config = ConfigDict(extra='ignore')
228+
""")
229+
result = add_extra_to_base_model(code)
230+
assert dedent(result).strip() == expected.strip()
231+
232+
233+
def test_adds_extra_to_base_model_leaves_other_classes_untouched():
234+
code = dedent("""
235+
class NotBaseModel:
236+
Config = ConfigDict()
237+
""")
238+
expected = dedent("""
239+
class NotBaseModel:
240+
Config = ConfigDict()
241+
""")
242+
result = add_extra_to_base_model(code)
243+
assert dedent(result).strip() == expected.strip()

0 commit comments

Comments
 (0)