diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 2dcd603c..5f3ab67b 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -19,7 +19,7 @@ get_graphql_schema_from_path, get_graphql_schema_from_url, ) -from .settings import Strategy +from .settings import Strategy, get_validation_rule @click.command() # type: ignore @@ -64,7 +64,7 @@ def client(config_dict): fragments = [] queries = [] if settings.queries_path: - definitions = get_graphql_queries(settings.queries_path, schema) + definitions = get_graphql_queries(settings.queries_path, schema, [get_validation_rule(e) for e in settings.skip_validation_rules]) queries = filter_operations_definitions(definitions) fragments = filter_fragments_definitions(definitions) diff --git a/ariadne_codegen/schema.py b/ariadne_codegen/schema.py index 1ce36268..8ee34789 100644 --- a/ariadne_codegen/schema.py +++ b/ariadne_codegen/schema.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Dict, Generator, List, Optional, Tuple, cast +from typing_extensions import Any, Sequence import httpx from graphql import ( @@ -14,6 +15,7 @@ IntrospectionQuery, NoUnusedFragmentsRule, OperationDefinitionNode, + UniqueFragmentNamesRule, build_ast_schema, build_client_schema, get_introspection_query, @@ -45,7 +47,7 @@ def filter_fragments_definitions( def get_graphql_queries( - queries_path: str, schema: GraphQLSchema + queries_path: str, schema: GraphQLSchema, skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,) ) -> Tuple[DefinitionNode, ...]: """Get graphql queries definitions build from provided path.""" queries_str = load_graphql_files_from_path(Path(queries_path)) @@ -53,7 +55,7 @@ def get_graphql_queries( validation_errors = validate( schema=schema, document_ast=queries_ast, - rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule], + rules=[r for r in specified_rules if r not in skip_rules], ) if validation_errors: raise InvalidOperationForSchema( diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 808397ba..48929c27 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -6,6 +6,8 @@ from textwrap import dedent from typing import Dict, List +from graphql.validation import UniqueFragmentNamesRule, NoUnusedFragmentsRule + from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_NAME, DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME, @@ -25,6 +27,18 @@ class CommentsStrategy(str, enum.Enum): STABLE = "stable" TIMESTAMP = "timestamp" +class ValidationRuleSkips(str, enum.Enum): + UniqueFragmentNames = "UniqueFragmentNames" + NoUnusedFragments = "NoUnusedFragments" + +def get_validation_rule(rule: ValidationRuleSkips): + if rule == ValidationRuleSkips.UniqueFragmentNames: + return UniqueFragmentNamesRule + elif rule == ValidationRuleSkips.NoUnusedFragments: + return NoUnusedFragmentsRule + else: + raise ValueError(f"Unknown validation rule: {rule}") + class Strategy(str, enum.Enum): CLIENT = "client" @@ -70,6 +84,7 @@ class ClientSettings(BaseSettings): include_all_enums: bool = True async_client: bool = True opentelemetry_client: bool = False + skip_validation_rules: List[ValidationRuleSkips] = field(default_factory=lambda: [ValidationRuleSkips.UniqueFragmentNames,]) files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) diff --git a/tests/test_schema.py b/tests/test_schema.py index 42873867..367164da 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,3 +1,4 @@ +from ariadne_codegen.settings import get_validation_rule import httpx import pytest from graphql import GraphQLSchema, OperationDefinitionNode, build_schema @@ -15,6 +16,7 @@ read_graphql_file, walk_graphql_files, ) +from ariadne_codegen.settings import ValidationRuleSkips @pytest.fixture @@ -63,6 +65,49 @@ def test_query_2_str(): } """ +@pytest.fixture +def test_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + +@pytest.fixture +def test_duplicate_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + +@pytest.fixture +def test_unused_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + } + } + """ @pytest.fixture def single_file_schema(tmp_path_factory, schema_str): @@ -132,6 +177,24 @@ def single_file_query(tmp_path_factory, test_query_str): file_.write_text(test_query_str, encoding="utf-8") return file_ +@pytest.fixture +def single_file_query_with_fragment(tmp_path_factory, test_query_str, test_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql") + file_.write_text(test_query_str + test_fragment_str, encoding="utf-8") + return file_ + +@pytest.fixture +def single_file_query_with_duplicate_fragment(tmp_path_factory, test_query_str, test_duplicate_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_duplicate_fragment.graphql") + file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8") + return file_ + +@pytest.fixture +def single_file_query_with_unused_fragment(tmp_path_factory, test_query_str, test_unused_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_unused_fragment.graphql") + file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8") + return file_ + @pytest.fixture def invalid_syntax_query_file(tmp_path_factory): @@ -434,3 +497,36 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat get_graphql_queries( invalid_query_for_schema_file.as_posix(), build_schema(schema_str) ) + + +def test_get_graphql_queries_with_fragment_returns_schema_definitions( + single_file_query_with_fragment, schema_str +): + queries = get_graphql_queries( + single_file_query_with_fragment.as_posix(), build_schema(schema_str) + ) + + assert len(queries) == 3 + +def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation( + single_file_query_with_duplicate_fragment, schema_str +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str) + ) + +def test_get_graphql_queries_with_unused_fragment_and_no_skip_rules_raises_invalid_operation( + single_file_query_with_unused_fragment, schema_str +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_unused_fragment.as_posix(), build_schema(schema_str), [] + ) + +def test_get_graphql_queries_with_skip_unique_fragment_names_and_duplicate_fragment_returns_schema_definition( + single_file_query_with_duplicate_fragment, schema_str +): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str), [get_validation_rule(ValidationRuleSkips.NoUnusedFragments),get_validation_rule(ValidationRuleSkips.UniqueFragmentNames)] + )