1717package com .introproventures .graphql .jpa .query .schema .impl ;
1818
1919import static graphql .Scalars .GraphQLBoolean ;
20- import static graphql .introspection .Introspection .DirectiveLocation .FIELD ;
2120import static graphql .schema .GraphQLArgument .newArgument ;
2221import static graphql .schema .GraphQLInputObjectField .newInputObjectField ;
2322import static graphql .schema .GraphQLInputObjectType .newInputObject ;
24- import static graphql .schema .GraphQLNonNull .nonNull ;
2523
2624import java .beans .Introspector ;
2725import java .lang .reflect .AnnotatedElement ;
4947import javax .persistence .metamodel .SingularAttribute ;
5048import javax .persistence .metamodel .Type ;
5149
52- import graphql .schema .GraphQLDirective ;
5350import org .dataloader .MappedBatchLoaderWithContext ;
5451import org .slf4j .Logger ;
5552import org .slf4j .LoggerFactory ;
6461import com .introproventures .graphql .jpa .query .schema .impl .EntityIntrospector .EntityIntrospectionResult .AttributePropertyDescriptor ;
6562import com .introproventures .graphql .jpa .query .schema .impl .PredicateFilter .Criteria ;
6663import com .introproventures .graphql .jpa .query .schema .relay .GraphQLJpaRelayDataFetcher ;
67-
6864import graphql .Assert ;
69- import graphql .Directives ;
7065import graphql .Scalars ;
7166import graphql .relay .Relay ;
7267import graphql .schema .Coercing ;
8681import graphql .schema .GraphQLTypeReference ;
8782import graphql .schema .PropertyDataFetcher ;
8883
84+
8985/**
9086 * JPA specific schema builder implementation of {code #GraphQLSchemaBuilder} interface
9187 *
@@ -136,7 +132,7 @@ public class GraphQLJpaSchemaBuilder implements GraphQLSchemaBuilder {
136132 private int defaultFetchSize = 100 ;
137133 private int defaultPageLimitSize = 100 ;
138134 private boolean enableDefaultMaxResults = true ;
139-
135+
140136 private RestrictedKeysProvider restrictedKeysProvider = (entityDescriptor ) -> Optional .of (Collections .emptyList ());
141137
142138 private final Relay relay = new Relay ();
@@ -505,7 +501,7 @@ private String resolveTypeName(ManagedType<?> managedType) {
505501 String typeName ="" ;
506502
507503 if (managedType instanceof EmbeddableType ){
508- typeName = managedType .getJavaType ().getSimpleName ()+ "EmbeddableType" ;
504+ typeName = managedType .getJavaType ().getSimpleName ();
509505 } else if (managedType instanceof EntityType ) {
510506 typeName = ((EntityType <?>)managedType ).getName ();
511507 }
@@ -514,7 +510,14 @@ private String resolveTypeName(ManagedType<?> managedType) {
514510 }
515511
516512 private GraphQLInputObjectType getWhereInputType (ManagedType <?> managedType ) {
517- return inputObjectCache .computeIfAbsent (managedType , this ::computeWhereInputType );
513+ GraphQLInputObjectType type = inputObjectCache .get (managedType );
514+ if (type == null ) {
515+ type = computeWhereInputType (managedType );
516+ inputObjectCache .put (managedType , type );
517+ return type ;
518+ }
519+ return type ;
520+
518521 }
519522
520523 private String resolveWhereInputTypeName (ManagedType <?> managedType ) {
@@ -610,6 +613,11 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
610613 if (whereAttributesMap .containsKey (type ))
611614 return whereAttributesMap .get (type );
612615
616+ if (isEmbeddable (attribute )) {
617+ EmbeddableType <?> embeddableType = (EmbeddableType <?>) ((SingularAttribute <?, ?>) attribute ).getType ();
618+ return getWhereInputType (embeddableType );
619+ }
620+
613621 GraphQLInputObjectType .Builder builder = GraphQLInputObjectType .newInputObject ()
614622 .name (type )
615623 .description ("Criteria expression specification of " +namingStrategy .singularize (attribute .getName ())+" attribute in entity " + attribute .getDeclaringType ().getJavaType ())
@@ -786,7 +794,7 @@ else if (attribute.getJavaMember().getClass().isAssignableFrom(Field.class)
786794 }
787795
788796 private GraphQLArgument getArgument (Attribute <?,?> attribute ) {
789- GraphQLInputType type = getAttributeInputType (attribute );
797+ GraphQLInputType type = getAttributeInputTypeForSearchByIdArg (attribute );
790798 String description = getSchemaDescription (attribute );
791799
792800 return GraphQLArgument .newArgument ()
@@ -796,25 +804,33 @@ private GraphQLArgument getArgument(Attribute<?,?> attribute) {
796804 .build ();
797805 }
798806
799- private GraphQLType getEmbeddableType (EmbeddableType <?> embeddableType , boolean input ) {
800- if (input && embeddableInputCache .containsKey (embeddableType .getJavaType ()))
801- return embeddableInputCache .get (embeddableType .getJavaType ());
802-
803- if (!input && embeddableOutputCache .containsKey (embeddableType .getJavaType ()))
804- return embeddableOutputCache .get (embeddableType .getJavaType ());
805- String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+ (input ? "Input" : "" ) +"EmbeddableType" ;
806- GraphQLType graphQLType =null ;
807+ private GraphQLType getEmbeddableType (EmbeddableType <?> embeddableType , boolean input , boolean searchByIdArg ) {
808+ GraphQLType graphQLType ;
807809 if (input ) {
808- graphQLType = GraphQLInputObjectType .newInputObject ()
809- .name (embeddableTypeName )
810+
811+ if (searchByIdArg ) {
812+ if (embeddableInputCache .containsKey (embeddableType .getJavaType ())) {
813+ return embeddableInputCache .get (embeddableType .getJavaType ());
814+ }
815+ graphQLType = GraphQLInputObjectType .newInputObject ()
816+ .name (namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+ "InputEmbeddableIdType" )
810817 .description (getSchemaDescription (embeddableType ))
811818 .fields (embeddableType .getAttributes ().stream ()
812- .filter (this ::isNotIgnored )
813- .map (this ::getInputObjectField )
814- .collect (Collectors .toList ())
819+ .filter (this ::isNotIgnored )
820+ .map (this ::getInputObjectField )
821+ .collect (Collectors .toList ())
815822 )
816823 .build ();
824+ embeddableInputCache .put (embeddableType .getJavaType (), (GraphQLInputObjectType ) graphQLType );
825+ return graphQLType ;
826+ }
827+
828+ graphQLType = getWhereInputType (embeddableType );
817829 } else {
830+ if (embeddableOutputCache .containsKey (embeddableType .getJavaType ())) {
831+ return embeddableOutputCache .get (embeddableType .getJavaType ());
832+ }
833+ String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ()) + "EmbeddableType" ;
818834 graphQLType = GraphQLObjectType .newObject ()
819835 .name (embeddableTypeName )
820836 .description (getSchemaDescription (embeddableType ))
@@ -824,13 +840,8 @@ private GraphQLType getEmbeddableType(EmbeddableType<?> embeddableType, boolean
824840 .collect (Collectors .toList ())
825841 )
826842 .build ();
827- }
828- if (input ) {
829- embeddableInputCache .putIfAbsent (embeddableType .getJavaType (), (GraphQLInputObjectType ) graphQLType );
830- } else {
831843 embeddableOutputCache .putIfAbsent (embeddableType .getJavaType (), (GraphQLObjectType ) graphQLType );
832844 }
833-
834845 return graphQLType ;
835846 }
836847
@@ -1020,31 +1031,39 @@ private Stream<Attribute<?,?>> findBasicAttributes(Collection<Attribute<?,?>> at
10201031 }
10211032
10221033 private GraphQLInputType getAttributeInputType (Attribute <?,?> attribute ) {
1034+ return getAttributeInputType (attribute , false );
1035+ }
1036+
1037+ private GraphQLInputType getAttributeInputTypeForSearchByIdArg (Attribute <?,?> attribute ) {
1038+ return getAttributeInputType (attribute , true );
1039+ }
1040+
1041+ private GraphQLInputType getAttributeInputType (Attribute <?,?> attribute , boolean searchByIdArgType ) {
10231042
10241043 try {
1025- return (GraphQLInputType ) getAttributeType (attribute , true );
1044+ return (GraphQLInputType ) getAttributeType (attribute , true , searchByIdArgType );
10261045 } catch (ClassCastException e ){
10271046 throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Input Argument" );
10281047 }
10291048 }
10301049
10311050 private GraphQLOutputType getAttributeOutputType (Attribute <?,?> attribute ) {
10321051 try {
1033- return (GraphQLOutputType ) getAttributeType (attribute , false );
1052+ return (GraphQLOutputType ) getAttributeType (attribute , false , false );
10341053 } catch (ClassCastException e ) {
10351054 throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Output Argument" );
10361055 }
10371056 }
10381057
10391058 @ SuppressWarnings ( "rawtypes" )
1040- protected GraphQLType getAttributeType (Attribute <?,?> attribute , boolean input ) {
1059+ protected GraphQLType getAttributeType (Attribute <?,?> attribute , boolean input , boolean searchByIdArgType ) {
10411060
10421061 if (isBasic (attribute )) {
10431062 return getGraphQLTypeFromJavaType (attribute .getJavaType ());
10441063 }
10451064 else if (isEmbeddable (attribute )) {
10461065 EmbeddableType embeddableType = (EmbeddableType ) ((SingularAttribute ) attribute ).getType ();
1047- return getEmbeddableType (embeddableType , input );
1066+ return getEmbeddableType (embeddableType , input , searchByIdArgType );
10481067 }
10491068 else if (isToMany (attribute )) {
10501069 EntityType foreignType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
@@ -1066,8 +1085,7 @@ else if (isElementCollection(attribute)) {
10661085 }
10671086 else if (foreignType .getPersistenceType () == Type .PersistenceType .EMBEDDABLE ) {
10681087 EmbeddableType embeddableType = EmbeddableType .class .cast (foreignType );
1069- GraphQLType graphQLType = getEmbeddableType (embeddableType ,
1070- input );
1088+ GraphQLType graphQLType = getEmbeddableType (embeddableType , input , searchByIdArgType );
10711089
10721090 return input ? graphQLType : new GraphQLList (graphQLType );
10731091 }
0 commit comments