1+ import io
12import json
23import logging
3- from typing import Any , Dict , Optional , Union
4+ from typing import Any , Dict , Optional , Tuple , Type , Union
45
56import requests
67from graphql import DocumentNode , ExecutionResult , print_ast
78from requests .adapters import HTTPAdapter , Retry
89from requests .auth import AuthBase
910from requests .cookies import RequestsCookieJar
11+ from requests_toolbelt .multipart .encoder import MultipartEncoder
1012
1113from gql .transport import Transport
1214
15+ from ..utils import extract_files
1316from .exceptions import (
1417 TransportAlreadyConnected ,
1518 TransportClosed ,
@@ -27,6 +30,8 @@ class RequestsHTTPTransport(Transport):
2730 The transport uses the requests library to send HTTP POST requests.
2831 """
2932
33+ file_classes : Tuple [Type [Any ], ...] = (io .IOBase ,)
34+
3035 def __init__ (
3136 self ,
3237 url : str ,
@@ -104,6 +109,7 @@ def execute( # type: ignore
104109 operation_name : Optional [str ] = None ,
105110 timeout : Optional [int ] = None ,
106111 extra_args : Dict [str , Any ] = None ,
112+ upload_files : bool = False ,
107113 ) -> ExecutionResult :
108114 """Execute GraphQL query.
109115
@@ -116,6 +122,7 @@ def execute( # type: ignore
116122 Only required in multi-operation documents (Default: None).
117123 :param timeout: Specifies a default timeout for requests (Default: None).
118124 :param extra_args: additional arguments to send to the requests post method
125+ :param upload_files: Set to True if you want to put files in the variable values
119126 :return: The result of execution.
120127 `data` is the result of executing the query, `errors` is null
121128 if no errors occurred, and is a non-empty array if an error occurred.
@@ -126,21 +133,77 @@ def execute( # type: ignore
126133
127134 query_str = print_ast (document )
128135 payload : Dict [str , Any ] = {"query" : query_str }
129- if variable_values :
130- payload ["variables" ] = variable_values
136+
131137 if operation_name :
132138 payload ["operationName" ] = operation_name
133139
134- data_key = "json" if self .use_json else "data"
135140 post_args = {
136141 "headers" : self .headers ,
137142 "auth" : self .auth ,
138143 "cookies" : self .cookies ,
139144 "timeout" : timeout or self .default_timeout ,
140145 "verify" : self .verify ,
141- data_key : payload ,
142146 }
143147
148+ if upload_files :
149+ # If the upload_files flag is set, then we need variable_values
150+ assert variable_values is not None
151+
152+ # If we upload files, we will extract the files present in the
153+ # variable_values dict and replace them by null values
154+ nulled_variable_values , files = extract_files (
155+ variables = variable_values , file_classes = self .file_classes ,
156+ )
157+
158+ # Save the nulled variable values in the payload
159+ payload ["variables" ] = nulled_variable_values
160+
161+ # Add the payload to the operations field
162+ operations_str = json .dumps (payload )
163+ log .debug ("operations %s" , operations_str )
164+
165+ # Generate the file map
166+ # path is nested in a list because the spec allows multiple pointers
167+ # to the same file. But we don't support that.
168+ # Will generate something like {"0": ["variables.file"]}
169+ file_map = {str (i ): [path ] for i , path in enumerate (files )}
170+
171+ # Enumerate the file streams
172+ # Will generate something like {'0': <_io.BufferedReader ...>}
173+ file_streams = {str (i ): files [path ] for i , path in enumerate (files )}
174+
175+ # Add the file map field
176+ file_map_str = json .dumps (file_map )
177+ log .debug ("file_map %s" , file_map_str )
178+
179+ fields = {"operations" : operations_str , "map" : file_map_str }
180+
181+ # Add the extracted files as remaining fields
182+ for k , v in file_streams .items ():
183+ fields [k ] = (getattr (v , "name" , k ), v )
184+
185+ # Prepare requests http to send multipart-encoded data
186+ data = MultipartEncoder (fields = fields )
187+
188+ post_args ["data" ] = data
189+
190+ if post_args ["headers" ] is None :
191+ post_args ["headers" ] = {}
192+ else :
193+ post_args ["headers" ] = {** post_args ["headers" ]}
194+
195+ post_args ["headers" ]["Content-Type" ] = data .content_type
196+
197+ else :
198+ if variable_values :
199+ payload ["variables" ] = variable_values
200+
201+ if log .isEnabledFor (logging .INFO ):
202+ log .info (">>> %s" , json .dumps (payload ))
203+
204+ data_key = "json" if self .use_json else "data"
205+ post_args [data_key ] = payload
206+
144207 # Log the payload
145208 if log .isEnabledFor (logging .INFO ):
146209 log .info (">>> %s" , json .dumps (payload ))
0 commit comments