@@ -81,7 +81,7 @@ class SchemaParser internal constructor(
8181 val inputObjects: MutableList <GraphQLInputObjectType > = mutableListOf ()
8282 inputObjectDefinitions.forEach {
8383 if (inputObjects.none { io -> io.name == it.name }) {
84- inputObjects.add(createInputObject(it, inputObjects))
84+ inputObjects.add(createInputObject(it, inputObjects, mutableSetOf () ))
8585 }
8686 }
8787 val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
@@ -173,7 +173,8 @@ class SchemaParser internal constructor(
173173 return output.toTypedArray()
174174 }
175175
176- private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >): GraphQLInputObjectType {
176+ private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >,
177+ referencingInputObjects : MutableSet <String >): GraphQLInputObjectType {
177178 val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }
178179
179180 val builder = GraphQLInputObjectType .newInputObject()
@@ -184,14 +185,16 @@ class SchemaParser internal constructor(
184185
185186 builder.withDirectives(* buildDirectives(definition.directives, setOf (), Introspection .DirectiveLocation .INPUT_OBJECT ))
186187
188+ referencingInputObjects.add(definition.name)
189+
187190 (extensionDefinitions + definition).forEach {
188191 it.inputValueDefinitions.forEach { inputDefinition ->
189192 val fieldBuilder = GraphQLInputObjectField .newInputObjectField()
190193 .name(inputDefinition.name)
191194 .definition(inputDefinition)
192195 .description(if (inputDefinition.description != null ) inputDefinition.description.content else getDocumentation(inputDefinition))
193196 .defaultValue(buildDefaultValue(inputDefinition.defaultValue))
194- .type(determineInputType(inputDefinition.type, inputObjects))
197+ .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects ))
195198 .withDirectives(* buildDirectives(inputDefinition.directives, setOf (), Introspection .DirectiveLocation .INPUT_FIELD_DEFINITION ))
196199 builder.field(fieldBuilder.build())
197200 }
@@ -297,7 +300,7 @@ class SchemaParser internal constructor(
297300 .definition(argumentDefinition)
298301 .description(if (argumentDefinition.description != null ) argumentDefinition.description.content else getDocumentation(argumentDefinition))
299302 .defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
300- .type(determineInputType(argumentDefinition.type, inputObjects))
303+ .type(determineInputType(argumentDefinition.type, inputObjects, setOf () ))
301304 .withDirectives(* buildDirectives(argumentDefinition.directives, setOf (), Introspection .DirectiveLocation .ARGUMENT_DEFINITION ))
302305 field.argument(argumentBuilder.build())
303306 }
@@ -328,7 +331,7 @@ class SchemaParser internal constructor(
328331 is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
329332 is InputObjectTypeDefinition -> {
330333 log.info(" Create input object" )
331- createInputObject(typeDefinition, inputObjects)
334+ createInputObject(typeDefinition, inputObjects, mutableSetOf () )
332335 }
333336 is TypeName -> {
334337 val scalarType = customScalars[typeDefinition.name]
@@ -346,16 +349,19 @@ class SchemaParser internal constructor(
346349 else -> throw SchemaError (" Unknown type: $typeDefinition " )
347350 }
348351
349- private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >) =
350- determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
352+ private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >, referencingInputObjects : Set < String > ) =
353+ determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects ) as GraphQLInputType
351354
352- private fun <T : Any > determineInputType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
355+ private fun <T : Any > determineInputType (expectedType : KClass <T >,
356+ typeDefinition : Type <* >, allowedTypeReferences : Set <String >,
357+ inputObjects : List <GraphQLInputObjectType >,
358+ referencingInputObjects : Set <String >): GraphQLType =
353359 when (typeDefinition) {
354360 is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
355361 is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
356362 is InputObjectTypeDefinition -> {
357363 log.info(" Create input object" )
358- createInputObject(typeDefinition, inputObjects)
364+ createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet < String > )
359365 }
360366 is TypeName -> {
361367 val scalarType = customScalars[typeDefinition.name]
@@ -373,9 +379,14 @@ class SchemaParser internal constructor(
373379 } else {
374380 val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
375381 if (filteredDefinitions.isNotEmpty()) {
376- val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects)
377- (inputObjects as MutableList ).add(inputObject)
378- inputObject
382+ val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
383+ if (referencingInputObject != null ) {
384+ GraphQLTypeReference (referencingInputObject)
385+ } else {
386+ val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects, referencingInputObjects as MutableSet <String >)
387+ (inputObjects as MutableList ).add(inputObject)
388+ inputObject
389+ }
379390 } else {
380391 // todo: handle enum type
381392 GraphQLTypeReference (typeDefinition.name)
0 commit comments