Skip to content

Commit a5db4e6

Browse files
arunsureshkumarpatrick91arun-sureshkumaradarshdigievo
authored
fix: Federation V2 Scalar Support (#14)
Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com> Co-authored-by: Arun Suresh Kumar <arun@strollby.com> Co-authored-by: Adarsh Divakaran <adarshdevamritham@gmail.com> Co-authored-by: Arun Suresh Kumar <89654966+arun-sureshkumar@users.noreply.github.com>
1 parent 46846e0 commit a5db4e6

File tree

5 files changed

+91
-15
lines changed

5 files changed

+91
-15
lines changed

examples/entities.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@ def get_file_by_id(id):
66
return File(**{'id': id, 'name': 'test_name'})
77

88

9+
class Author(graphene.ObjectType):
10+
id = graphene.ID(required=True)
11+
name = graphene.String(required=True)
12+
13+
914
@key(fields='id')
15+
@key(fields='id author { name }')
16+
@key(fields='id author { id name }')
1017
class File(graphene.ObjectType):
1118
id = graphene.Int(required=True)
1219
name = graphene.String()
20+
author = graphene.Field(Author, required=True)
1321

1422
def resolve_id(self, info, **kwargs):
1523
return 1
@@ -25,7 +33,7 @@ class Query(graphene.ObjectType):
2533
file = graphene.Field(File)
2634

2735
def resolve_file(self, **kwargs):
28-
return None # no direct access
36+
return None # no direct access
2937

3038

3139
schema = build_schema(Query)
@@ -41,7 +49,7 @@ def resolve_file(self, **kwargs):
4149
print(result.data)
4250
# {'_service': {'sdl': 'type Query {\n file: File\n}\n\ntype File @key(fields: "id") {\n id: Int!\n name: String\n}'}}
4351

44-
query ='''
52+
query = '''
4553
query entities($_representations: [_Any!]!) {
4654
_entities(representations: $_representations) {
4755
... on File {
@@ -55,10 +63,10 @@ def resolve_file(self, **kwargs):
5563

5664
result = schema.execute(query, variables={
5765
"_representations": [
58-
{
59-
"__typename": "File",
60-
"id": 1
61-
}
66+
{
67+
"__typename": "File",
68+
"id": 1
69+
}
6270
]
6371
})
6472
print(result.data)

graphene_federation/entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def decorator(type_):
121121
if "{" not in fields:
122122
# Skip valid fields check if the key is a compound key. The validation for compound keys
123123
# is done on calling get_entities()
124-
fields_set = set(fields.replace(" ", "").split(","))
124+
fields_set = set(fields.split(" "))
125125
assert check_fields_exist_on_type(
126126
fields=fields_set, type_=type_
127127
), f'Field "{fields}" does not exist on type "{type_._meta.name}"'

graphene_federation/extend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def decorator(type_):
4343
if "{" not in fields: # Check for compound keys
4444
# Skip valid fields check if the key is a compound key. The validation for compound keys
4545
# is done on calling get_extended_types()
46-
fields_set = set(fields.replace(" ", "").split(","))
46+
fields_set = set(fields.split(" "))
4747
assert check_fields_exist_on_type(
4848
fields=fields_set, type_=type_
4949
), f'Field "{fields}" does not exist on type "{type_._meta.name}"'
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Any
2+
3+
import graphene
4+
from graphene import Scalar, String, ObjectType
5+
from graphql import graphql_sync
6+
7+
from graphene_federation import build_schema, shareable, inaccessible
8+
9+
10+
def test_custom_scalar():
11+
class AddressScalar(Scalar):
12+
base = String
13+
14+
@staticmethod
15+
def coerce_address(value: Any):
16+
...
17+
18+
serialize = coerce_address
19+
parse_value = coerce_address
20+
21+
@staticmethod
22+
def parse_literal(ast):
23+
...
24+
25+
@shareable
26+
class TestScalar(graphene.ObjectType):
27+
test_shareable_scalar = shareable(String(x=AddressScalar()))
28+
test_inaccessible_scalar = inaccessible(String(x=AddressScalar()))
29+
30+
class Query(ObjectType):
31+
test = String(x=AddressScalar())
32+
test2 = graphene.List(AddressScalar, required=True)
33+
34+
schema = build_schema(query=Query, enable_federation_2=True, types=(TestScalar,))
35+
query = """
36+
query {
37+
_service {
38+
sdl
39+
}
40+
}
41+
"""
42+
result = graphql_sync(schema.graphql_schema, query)
43+
assert (
44+
result.data["_service"]["sdl"].strip()
45+
== """extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: ["@inaccessible", "@shareable"])
46+
type TestScalar @shareable {
47+
testShareableScalar(x: AddressScalar): String @shareable
48+
testInaccessibleScalar(x: AddressScalar): String @inaccessible
49+
}
50+
51+
scalar AddressScalar
52+
53+
type Query {
54+
test(x: AddressScalar): String
55+
test2: [AddressScalar]!
56+
}"""
57+
)

graphene_federation/utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import graphene
44
from graphene import Schema, ObjectType
55
from graphene.types.definitions import GrapheneObjectType
6+
from graphene.types.scalars import ScalarOptions
67
from graphene.types.union import UnionOptions
78
from graphene.utils.str_converters import to_camel_case
8-
from graphql import parse, GraphQLScalarType
9+
from graphql import parse, GraphQLScalarType, GraphQLNonNull
910

1011

1112
def field_name_to_type_attribute(schema: Schema, model: Any) -> Callable[[str], str]:
@@ -47,9 +48,11 @@ def is_valid_compound_key(type_name: str, key: str, schema: Schema):
4748

4849
while key_nodes:
4950
selection_node, parent_object_type = key_nodes[0]
50-
51-
for field in selection_node.selection_set.selections:
51+
if isinstance(parent_object_type, GraphQLNonNull):
52+
parent_type_fields = parent_object_type.of_type.fields
53+
else:
5254
parent_type_fields = parent_object_type.fields
55+
for field in selection_node.selection_set.selections:
5356
if schema.auto_camelcase:
5457
field_name = to_camel_case(field.name.value)
5558
else:
@@ -62,14 +65,20 @@ def is_valid_compound_key(type_name: str, key: str, schema: Schema):
6265
if field.selection_set:
6366
# If the field has sub-selections, add it to node mappings to check for valid subfields
6467

65-
if isinstance(field_type, GraphQLScalarType):
68+
if isinstance(field_type, GraphQLScalarType) or (
69+
isinstance(field_type, GraphQLNonNull)
70+
and isinstance(field_type.of_type, GraphQLScalarType)
71+
):
6672
# sub-selections are added to a scalar type, key is not valid
6773
return False
6874

6975
key_nodes.append((field, field_type))
7076
else:
7177
# If there are no sub-selections for a field, it should be a scalar
72-
if not isinstance(field_type, GraphQLScalarType):
78+
if not isinstance(field_type, GraphQLScalarType) and not (
79+
isinstance(field_type, GraphQLNonNull)
80+
and isinstance(field_type.of_type, GraphQLScalarType)
81+
):
7382
return False
7483

7584
key_nodes.pop(0) # Remove the current node as it is fully processed
@@ -80,8 +89,10 @@ def is_valid_compound_key(type_name: str, key: str, schema: Schema):
8089
def get_attributed_fields(attribute: str, schema: Schema):
8190
fields = {}
8291
for type_name, type_ in schema.graphql_schema.type_map.items():
83-
if not hasattr(type_, "graphene_type") or isinstance(
84-
type_.graphene_type._meta, UnionOptions
92+
if (
93+
not hasattr(type_, "graphene_type")
94+
or isinstance(type_.graphene_type._meta, UnionOptions)
95+
or isinstance(type_.graphene_type._meta, ScalarOptions)
8596
):
8697
continue
8798
for field in list(type_.graphene_type._meta.fields):

0 commit comments

Comments
 (0)