Skip to content

Commit 3148b27

Browse files
authored
Fix generating unions (#43)
1 parent 3e89cc6 commit 3148b27

File tree

7 files changed

+252
-116
lines changed

7 files changed

+252
-116
lines changed

apollo-execution-processor/src/main/kotlin/com/apollographql/execution/processor/definitions.kt

Lines changed: 165 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,14 @@ private class TypeDefinitionContext(
7272

7373
val declarationsToVisit = mutableListOf<DeclarationToVisit>()
7474

75+
val usedTypeNames = mutableSetOf<String>()
76+
val unions = mutableMapOf<String, Set<String>>()
7577

78+
/**
79+
* Walk the Kotlin type graph. It goes:
80+
* - recursively (depth first) for supertypes so we can filter out union markers
81+
* - breadth first for subtypes/fields so we don't loop on circular field/interfaces references
82+
*/
7683
fun walk(
7784
query: KSClassDeclaration,
7885
mutation: KSClassDeclaration?,
@@ -86,84 +93,9 @@ private class TypeDefinitionContext(
8693
declarationsToVisit.add(DeclarationToVisit(subscription, VisitContext.OUTPUT, "subscription"))
8794
}
8895

89-
val usedNames = mutableSetOf<String>()
9096
while (declarationsToVisit.isNotEmpty()) {
9197
val declarationToVisit = declarationsToVisit.removeFirst()
92-
val declaration = declarationToVisit.declaration
93-
val context = declarationToVisit.context
94-
95-
val qualifiedName = declaration.asClassName().asString()
96-
if (typeDefinitions.containsKey(qualifiedName)) {
97-
// Already visited
98-
continue
99-
}
100-
101-
if (builtinTypes.contains(qualifiedName)) {
102-
typeDefinitions.put(qualifiedName, builtinScalarDefinition(qualifiedName))
103-
continue
104-
}
105-
106-
val name = declaration.graphqlName()
107-
if (usedNames.contains(name)) {
108-
logger.error("Duplicate type '$name'. Either rename the declaration or use @GraphQLName.", declaration)
109-
typeDefinitions.put(qualifiedName, null)
110-
continue
111-
}
112-
usedNames.add(name)
113-
114-
if (declaration.typeParameters.isNotEmpty()) {
115-
logger.error("Generic classes are not supported")
116-
typeDefinitions.put(qualifiedName, null)
117-
continue
118-
}
119-
120-
if (unsupportedTypes.contains(qualifiedName)) {
121-
logger.error(
122-
"'$qualifiedName' is not a supported built-in type. Either use one of the built-in types (Boolean, String, Int, Double) or use a custom scalar.",
123-
declaration
124-
)
125-
typeDefinitions.put(qualifiedName, null)
126-
continue
127-
}
128-
129-
if (declaration.isExternal()) {
130-
logger.error(
131-
"'$qualifiedName' doesn't have a containing file and probably comes from a dependency.",
132-
declaration
133-
)
134-
typeDefinitions.put(qualifiedName, null)
135-
continue
136-
}
137-
138-
/**
139-
* Track the files
140-
*/
141-
ksFiles.add(declaration.containingFile)
142-
143-
if (declaration is KSTypeAlias) {
144-
typeDefinitions.put(qualifiedName, declaration.toSirScalarDefinition(qualifiedName))
145-
continue
146-
}
147-
if (declaration !is KSClassDeclaration) {
148-
logger.error("Unsupported type", declaration)
149-
continue
150-
}
151-
if (declaration.classKind == ClassKind.ENUM_CLASS) {
152-
typeDefinitions.put(qualifiedName, declaration.toSirEnumDefinition())
153-
continue
154-
}
155-
if (declaration.findAnnotation("GraphQLScalar") != null) {
156-
typeDefinitions.put(qualifiedName, declaration.toSirScalarDefinition(qualifiedName))
157-
continue
158-
}
159-
if (context == VisitContext.INPUT) {
160-
typeDefinitions.put(qualifiedName, declaration.toSirInputObject())
161-
continue
162-
}
163-
if (context == VisitContext.OUTPUT) {
164-
typeDefinitions.put(qualifiedName, declaration.toSirComposite(declarationToVisit.isoperationType))
165-
continue
166-
}
98+
getOrResolve(declarationToVisit)
16799
}
168100

169101
val finalizedDirectiveDefinitions = directiveDefinitions.mapNotNull {
@@ -180,11 +112,91 @@ private class TypeDefinitionContext(
180112
}
181113

182114
return TraversalResults(
183-
definitions = finalizedDirectiveDefinitions + typeDefinitions.values.filterNotNull().toList(),
115+
/**
116+
* Not 100% sure what order to use for the types.
117+
* Fields in source order make sense but for classes that may be defined in different files, it's a lot less clear
118+
*/
119+
definitions = typeDefinitions.patchUnions(unions).sortedBy { it.type() + it.name } + finalizedDirectiveDefinitions.sortedBy { it.name },
184120
analyzedFiles = ksFiles.filterNotNull()
185121
)
186122
}
187123

124+
private fun getOrResolve(declarationToVisit: DeclarationToVisit): SirTypeDefinition? {
125+
val qualifiedName = declarationToVisit. declaration.asClassName().asString()
126+
if (typeDefinitions.containsKey(qualifiedName)) {
127+
// Already visited (maybe error)
128+
return typeDefinitions.get(qualifiedName)
129+
}
130+
131+
val typeDefinition = resolveType(qualifiedName, declarationToVisit)
132+
typeDefinitions.put(qualifiedName, typeDefinition)
133+
return typeDefinition
134+
}
135+
136+
/**
137+
* If returning null, this function also logs an error to fail the processor.
138+
*
139+
* @return the definition or null if there was an error
140+
*/
141+
private fun resolveType(qualifiedName: String, declarationToVisit: DeclarationToVisit): SirTypeDefinition? {
142+
val declaration = declarationToVisit.declaration
143+
val context = declarationToVisit.context
144+
145+
if (builtinTypes.contains(qualifiedName)) {
146+
return builtinScalarDefinition(qualifiedName)
147+
}
148+
if (unsupportedTypes.contains(qualifiedName)) {
149+
logger.error(
150+
"'$qualifiedName' is not a supported built-in type. Either use one of the built-in types (Boolean, String, Int, Double) or use a custom scalar.",
151+
declaration
152+
)
153+
return null
154+
}
155+
if (declaration.containingFile == null) {
156+
logger.error(
157+
"'$qualifiedName' doesn't have a containing file and probably comes from a dependency.",
158+
declaration
159+
)
160+
return null
161+
}
162+
163+
/**
164+
* Track the files
165+
*/
166+
ksFiles.add(declaration.containingFile)
167+
168+
val name = declaration.graphqlName()
169+
if (usedTypeNames.contains(name)) {
170+
logger.error("Duplicate type '$name'. Either rename the declaration or use @GraphQLName.", declaration)
171+
return null
172+
}
173+
usedTypeNames.add(name)
174+
175+
if (declaration.typeParameters.isNotEmpty()) {
176+
logger.error("Generic classes are not supported")
177+
return null
178+
}
179+
180+
if (declaration is KSTypeAlias) {
181+
return declaration.toSirScalarDefinition(qualifiedName)
182+
}
183+
if (declaration !is KSClassDeclaration) {
184+
logger.error("Unsupported type", declaration)
185+
return null
186+
}
187+
if (declaration.classKind == ClassKind.ENUM_CLASS) {
188+
return declaration.toSirEnumDefinition()
189+
}
190+
if (declaration.findAnnotation("GraphQLScalar") != null) {
191+
return declaration.toSirScalarDefinition(qualifiedName)
192+
}
193+
194+
return when(context) {
195+
VisitContext.OUTPUT -> declaration.toSirComposite(declarationToVisit.operationType)
196+
VisitContext.INPUT -> declaration.toSirInputObject()
197+
}
198+
}
199+
188200
/**
189201
* Same code for both type aliases and classes
190202
*/
@@ -353,6 +365,7 @@ private class TypeDefinitionContext(
353365
GQLEnumValue(null, simpleName.asString())
354366
}
355367
}
368+
356369
else -> {
357370
logger.error("Cannot convert $this to a GQLValue", argument)
358371
GQLNullValue(null) // not correct but compilation should fail anyway
@@ -410,7 +423,7 @@ private class TypeDefinitionContext(
410423
name = name,
411424
description = description,
412425
qualifiedName = qualifiedName,
413-
interfaces = interfaces(),
426+
interfaces = interfaces(name),
414427
targetClassName = asClassName(),
415428
instantiation = instantiation(),
416429
operationType = operationType,
@@ -426,18 +439,31 @@ private class TypeDefinitionContext(
426439
return null
427440
}
428441

429-
val subclasses = getSealedSubclasses().map {
430-
// Look into subclasses
442+
getSealedSubclasses().forEach {
443+
/**
444+
* We go depth first on the superclasses but need to escape the callstack and
445+
* remember to also go the other direction to not miss anything from the graph.
446+
*
447+
* If we were to go depth first only, we would miss all the concrete animal types
448+
* below:
449+
*
450+
* ```graphql
451+
* type Query {
452+
* animal: Animal
453+
* }
454+
*
455+
* union Animal = Cat | Dog | Lion ...
456+
* ```
457+
*/
431458
declarationsToVisit.add(DeclarationToVisit(it, VisitContext.OUTPUT, null))
432-
it.graphqlName()
433-
}.toList()
459+
}
434460

435461
if (allFields.isEmpty()) {
436462
SirUnionDefinition(
437463
name = name,
438464
description = description,
439465
qualifiedName = qualifiedName,
440-
memberTypes = subclasses,
466+
memberTypes = emptyList(), // we'll patch that later
441467
directives = directives(GQLDirectiveLocation.UNION),
442468
)
443469
} else {
@@ -448,7 +474,7 @@ private class TypeDefinitionContext(
448474
name = name,
449475
description = description,
450476
qualifiedName = qualifiedName,
451-
interfaces = interfaces(),
477+
interfaces = interfaces(null),
452478
fields = allFields,
453479
directives = directives(GQLDirectiveLocation.INTERFACE),
454480
)
@@ -474,24 +500,33 @@ private class TypeDefinitionContext(
474500
)
475501
}
476502

477-
private fun KSClassDeclaration.interfaces(): List<String> {
503+
private fun KSClassDeclaration.interfaces(objectName: String?): List<String> {
478504
return getAllSuperTypes().mapNotNull {
479505
val declaration = it.declaration
480506
if (it.arguments.isNotEmpty()) {
481507
logger.error("Generic interfaces are not supported", this)
482508
null
483509
} else if (declaration is KSClassDeclaration) {
484510
if (declaration.asClassName().asString() == "kotlin.Any") {
485-
null
486-
} else if (declaration.containingFile == null) {
487-
logger.error(
488-
"Class '${simpleName.asString()}' has a super class without a containing file that probably comes from a dependency.",
489-
this
490-
)
511+
// kotlin.Any is a super type of everything, just ignore it
491512
null
492513
} else {
493-
declarationsToVisit.add(DeclarationToVisit(declaration, VisitContext.OUTPUT, null))
494-
declaration.graphqlName()
514+
val supertype = getOrResolve(DeclarationToVisit(declaration, VisitContext.OUTPUT, null))
515+
if (supertype is SirInterfaceDefinition) {
516+
supertype.name
517+
} else if (supertype is SirUnionDefinition) {
518+
if (objectName == null) {
519+
logger.error("Interfaces are not allowed to extend union markers. Only classes can")
520+
} else {
521+
unions.compute(supertype.name) { _, oldValue ->
522+
oldValue.orEmpty() + objectName
523+
}
524+
}
525+
null
526+
} else {
527+
// error
528+
null
529+
}
495530
}
496531
} else {
497532
logger.error("Unrecognized super class", this)
@@ -658,6 +693,11 @@ private class TypeDefinitionContext(
658693
}
659694

660695
if (!argumentType.isMarkedNullable) {
696+
/*
697+
* Note: it's still possible to have a missing variable at runtime in a non-null position.
698+
* Those cases trigger request error before reaching the resolver and the argument cannot
699+
* be of Optional type.
700+
*/
661701
logger.error("Input value is not nullable and cannot be optional", debugContext.node)
662702
return SirErrorType
663703
}
@@ -666,13 +706,7 @@ private class TypeDefinitionContext(
666706
} else {
667707
if (!hasDefaultValue && isMarkedNullable) {
668708
logger.error(
669-
"""
670-
Input value is nullable and doesn't have a default value: it must also be optional.
671-
672-
If the type is nullable with a default value and no value is provided by the user, the default value is passed to the resolver, the resolver code does not need to handle the `Absent` case.
673-
If the type is non-nullable and there is no default value, variable values may still be absent at runtime. These cases are caught during coercion before it reaches the resolver code.
674-
675-
""".trimIndent(),
709+
"Input value is nullable and doesn't have a default value: it must also be optional.",
676710
debugContext.node
677711
)
678712
return SirErrorType
@@ -780,6 +814,36 @@ private class TypeDefinitionContext(
780814
}
781815
}
782816

817+
/**
818+
* Sorting helper function. Not 100% sure of the order here
819+
*/
820+
private fun SirTypeDefinition.type(): String {
821+
return when (this) {
822+
is SirScalarDefinition -> "0"
823+
is SirEnumDefinition -> "1"
824+
is SirObjectDefinition -> "2"
825+
is SirInterfaceDefinition -> "3"
826+
is SirUnionDefinition -> "4"
827+
is SirInputObjectDefinition -> "5"
828+
}
829+
}
830+
831+
private fun Map<String, SirTypeDefinition?>.patchUnions(unions: Map<String, Set<String>>): List<SirTypeDefinition> {
832+
return values.filterNotNull().map {
833+
if (it is SirUnionDefinition) {
834+
SirUnionDefinition(
835+
it.name,
836+
it.description,
837+
it.qualifiedName,
838+
unions.get(it.name)!!.toList(),
839+
it.directives
840+
)
841+
} else {
842+
it
843+
}
844+
}
845+
}
846+
783847

784848
private fun KSDeclaration.isApolloOptional(): Boolean {
785849
return asClassName().asString() == "com.apollographql.apollo.api.Optional"
@@ -807,7 +871,7 @@ private val builtinTypes = listOf("Double", "String", "Boolean", "Int").map {
807871
private class DeclarationToVisit(
808872
val declaration: KSDeclaration,
809873
val context: VisitContext,
810-
val isoperationType: String? = null
874+
val operationType: String? = null
811875
)
812876

813877
private enum class VisitContext {

0 commit comments

Comments
 (0)