66import graphql .introspection .IntrospectionQuery ;
77import graphql .schema .GraphQLFieldDefinition ;
88import graphql .servlet .internal .GraphQLRequest ;
9+ import graphql .servlet .internal .VariableMapper ;
910import org .slf4j .Logger ;
1011import org .slf4j .LoggerFactory ;
1112
1617import javax .servlet .http .HttpServletRequest ;
1718import javax .servlet .http .HttpServletResponse ;
1819import javax .servlet .http .Part ;
19- import java .io .*;
20- import java .util .*;
20+ import java .io .BufferedInputStream ;
21+ import java .io .ByteArrayOutputStream ;
22+ import java .io .IOException ;
23+ import java .io .InputStream ;
24+ import java .io .Writer ;
25+ import java .util .ArrayList ;
26+ import java .util .Arrays ;
27+ import java .util .Collections ;
28+ import java .util .HashMap ;
29+ import java .util .List ;
30+ import java .util .Map ;
31+ import java .util .Objects ;
32+ import java .util .Optional ;
2133import java .util .function .BiConsumer ;
2234import java .util .function .Consumer ;
2335import java .util .function .Function ;
2436import java .util .stream .Collectors ;
25- import java .util .stream .Stream ;
2637
2738/**
2839 * @author Andrew Potter
@@ -37,6 +48,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements
3748 public static final int STATUS_BAD_REQUEST = 400 ;
3849
3950 private static final GraphQLRequest INTROSPECTION_REQUEST = new GraphQLRequest (IntrospectionQuery .INTROSPECTION_QUERY , new HashMap <>(), null );
51+ private static final String [] MULTIPART_KEYS = new String []{"operations" , "graphql" , "query" };
4052
4153 protected abstract GraphQLQueryInvoker getQueryInvoker ();
4254
@@ -103,79 +115,58 @@ public AbstractGraphQLHttpServlet(List<GraphQLServletListener> listeners, boolea
103115 String query = CharStreams .toString (request .getReader ());
104116 query (queryInvoker , graphQLObjectMapper , invocationInputFactory .create (new GraphQLRequest (query , null , null )), response );
105117 } else if (request .getContentType () != null && request .getContentType ().startsWith ("multipart/form-data" ) && !request .getParts ().isEmpty ()) {
106- final Map <String , List <Part >> fileItems = request .getParts ().stream ()
107- .collect (Collectors .toMap (
108- Part ::getName ,
109- Collections ::singletonList ,
110- (l1 , l2 ) -> Stream .concat (l1 .stream (), l2 .stream ()).collect (Collectors .toList ())));
111-
112- if (fileItems .containsKey ("graphql" )) {
113- final Optional <Part > graphqlItem = getFileItem (fileItems , "graphql" );
114- if (graphqlItem .isPresent ()) {
115- InputStream inputStream = graphqlItem .get ().getInputStream ();
116-
117- if (!inputStream .markSupported ()) {
118- inputStream = new BufferedInputStream (inputStream );
119- }
120-
121- if (isBatchedQuery (inputStream )) {
122- GraphQLBatchedInvocationInput invocationInput = invocationInputFactory .create (graphQLObjectMapper .readBatchedGraphQLRequest (inputStream ), request );
123- invocationInput .getContext ().setFiles (fileItems );
124- queryBatched (queryInvoker , graphQLObjectMapper , invocationInput , response );
125- return ;
126- } else {
127- GraphQLSingleInvocationInput invocationInput = invocationInputFactory .create (graphQLObjectMapper .readGraphQLRequest (inputStream ), request );
128- invocationInput .getContext ().setFiles (fileItems );
129- query (queryInvoker , graphQLObjectMapper , invocationInput , response );
130- return ;
131- }
118+ final Map <String , List <Part >> fileItems = request .getParts ()
119+ .stream ()
120+ .collect (Collectors .groupingBy (Part ::getName ));
121+
122+ for (String key : MULTIPART_KEYS ) {
123+ // Check to see if there is a part under the key we seek
124+ if (!fileItems .containsKey (key )) {
125+ continue ;
132126 }
133- } else if (fileItems .containsKey ("query" )) {
134- final Optional <Part > queryItem = getFileItem (fileItems , "query" );
135- if (queryItem .isPresent ()) {
136- InputStream inputStream = queryItem .get ().getInputStream ();
137127
138- if (!inputStream .markSupported ()) {
139- inputStream = new BufferedInputStream (inputStream );
140- }
128+ final Optional <Part > queryItem = getFileItem (fileItems , key );
129+ if (!queryItem .isPresent ()) {
130+ // If there is a part, but we don't see an item, then break and return BAD_REQUEST
131+ break ;
132+ }
141133
142- if (isBatchedQuery (inputStream )) {
143- GraphQLBatchedInvocationInput invocationInput = invocationInputFactory .create (graphQLObjectMapper .readBatchedGraphQLRequest (inputStream ), request );
144- invocationInput .getContext ().setFiles (fileItems );
145- queryBatched (queryInvoker , graphQLObjectMapper , invocationInput , response );
146- return ;
134+ InputStream inputStream = asMarkableInputStream (queryItem .get ().getInputStream ());
135+
136+ final Optional <Map <String , List <String >>> variablesMap =
137+ getFileItem (fileItems , "map" ).map (graphQLObjectMapper ::deserializeMultipartMap );
138+
139+ if (isBatchedQuery (inputStream )) {
140+ List <GraphQLRequest > graphQLRequests =
141+ graphQLObjectMapper .readBatchedGraphQLRequest (inputStream );
142+ variablesMap .ifPresent (map -> graphQLRequests .forEach (r -> mapMultipartVariables (r , map , fileItems )));
143+ GraphQLBatchedInvocationInput invocationInput =
144+ invocationInputFactory .create (graphQLRequests , request );
145+ invocationInput .getContext ().setFiles (fileItems );
146+ queryBatched (queryInvoker , graphQLObjectMapper , invocationInput , response );
147+ return ;
148+ } else {
149+ GraphQLRequest graphQLRequest ;
150+ if ("query" .equals (key )) {
151+ graphQLRequest = buildRequestFromQuery (inputStream , graphQLObjectMapper , fileItems );
147152 } else {
148- String query = new String (ByteStreams .toByteArray (inputStream ));
149-
150- Map <String , Object > variables = null ;
151- final Optional <Part > variablesItem = getFileItem (fileItems , "variables" );
152- if (variablesItem .isPresent ()) {
153- variables = graphQLObjectMapper .deserializeVariables (new String (ByteStreams .toByteArray (variablesItem .get ().getInputStream ())));
154- }
155-
156- String operationName = null ;
157- final Optional <Part > operationNameItem = getFileItem (fileItems , "operationName" );
158- if (operationNameItem .isPresent ()) {
159- operationName = new String (ByteStreams .toByteArray (operationNameItem .get ().getInputStream ())).trim ();
160- }
161-
162- GraphQLSingleInvocationInput invocationInput = invocationInputFactory .create (new GraphQLRequest (query , variables , operationName ), request );
163- invocationInput .getContext ().setFiles (fileItems );
164- query (queryInvoker , graphQLObjectMapper , invocationInput , response );
165- return ;
153+ graphQLRequest = graphQLObjectMapper .readGraphQLRequest (inputStream );
166154 }
155+
156+ variablesMap .ifPresent (m -> mapMultipartVariables (graphQLRequest , m , fileItems ));
157+ GraphQLSingleInvocationInput invocationInput =
158+ invocationInputFactory .create (graphQLRequest , request );
159+ invocationInput .getContext ().setFiles (fileItems );
160+ query (queryInvoker , graphQLObjectMapper , invocationInput , response );
161+ return ;
167162 }
168163 }
169164
170165 response .setStatus (STATUS_BAD_REQUEST );
171- log .info ("Bad POST multipart request: no part named \" graphql \" or \" query \" " );
166+ log .info ("Bad POST multipart request: no part named " + Arrays . toString ( MULTIPART_KEYS ) );
172167 } else {
173168 // this is not a multipart request
174- InputStream inputStream = request .getInputStream ();
175-
176- if (!inputStream .markSupported ()) {
177- inputStream = new BufferedInputStream (inputStream );
178- }
169+ InputStream inputStream = asMarkableInputStream (request .getInputStream ());
179170
180171 if (isBatchedQuery (inputStream )) {
181172 queryBatched (queryInvoker , graphQLObjectMapper , invocationInputFactory .create (graphQLObjectMapper .readBatchedGraphQLRequest (inputStream ), request ), response );
@@ -190,6 +181,52 @@ public AbstractGraphQLHttpServlet(List<GraphQLServletListener> listeners, boolea
190181 };
191182 }
192183
184+ private static InputStream asMarkableInputStream (InputStream inputStream ) {
185+ if (!inputStream .markSupported ()) {
186+ inputStream = new BufferedInputStream (inputStream );
187+ }
188+ return inputStream ;
189+ }
190+
191+ private GraphQLRequest buildRequestFromQuery (InputStream inputStream ,
192+ GraphQLObjectMapper graphQLObjectMapper ,
193+ Map <String , List <Part >> fileItems ) throws IOException
194+ {
195+ GraphQLRequest graphQLRequest ;
196+ String query = new String (ByteStreams .toByteArray (inputStream ));
197+
198+ Map <String , Object > variables = null ;
199+ final Optional <Part > variablesItem = getFileItem (fileItems , "variables" );
200+ if (variablesItem .isPresent ()) {
201+ variables = graphQLObjectMapper .deserializeVariables (new String (ByteStreams .toByteArray (variablesItem .get ().getInputStream ())));
202+ }
203+
204+ String operationName = null ;
205+ final Optional <Part > operationNameItem = getFileItem (fileItems , "operationName" );
206+ if (operationNameItem .isPresent ()) {
207+ operationName = new String (ByteStreams .toByteArray (operationNameItem .get ().getInputStream ())).trim ();
208+ }
209+
210+ graphQLRequest = new GraphQLRequest (query , variables , operationName );
211+ return graphQLRequest ;
212+ }
213+
214+ private void mapMultipartVariables (GraphQLRequest request ,
215+ Map <String , List <String >> variablesMap ,
216+ Map <String , List <Part >> fileItems )
217+ {
218+ Map <String , Object > variables = request .getVariables ();
219+
220+ variablesMap .forEach ((partName , objectPaths ) -> {
221+ Part part = getFileItem (fileItems , partName )
222+ .orElseThrow (() -> new RuntimeException ("unable to find part name " +
223+ partName +
224+ " as referenced in the variables map" ));
225+
226+ objectPaths .forEach (objectPath -> VariableMapper .mapVariable (objectPath , variables , part ));
227+ });
228+ }
229+
193230 public void addListener (GraphQLServletListener servletListener ) {
194231 listeners .add (servletListener );
195232 }
0 commit comments