1+ import json
2+ import logging
13from ssl import SSLContext
24from typing import Any , AsyncGenerator , Dict , Optional , Union
35
810from aiohttp .typedefs import LooseCookies , LooseHeaders
911from graphql import DocumentNode , ExecutionResult , print_ast
1012
13+ from ..utils import extract_files
1114from .async_transport import AsyncTransport
1215from .exceptions import (
1316 TransportAlreadyConnected ,
1619 TransportServerError ,
1720)
1821
22+ log = logging .getLogger (__name__ )
23+
1924
2025class AIOHTTPTransport (AsyncTransport ):
2126 """:ref:`Async Transport <async_transports>` to execute GraphQL queries
@@ -32,7 +37,7 @@ def __init__(
3237 auth : Optional [BasicAuth ] = None ,
3338 ssl : Union [SSLContext , bool , Fingerprint ] = False ,
3439 timeout : Optional [int ] = None ,
35- client_session_args : Dict [str , Any ] = {} ,
40+ client_session_args : Optional [ Dict [str , Any ]] = None ,
3641 ) -> None :
3742 """Initialize the transport with the given aiohttp parameters.
3843
@@ -54,7 +59,6 @@ def __init__(
5459 self .ssl : Union [SSLContext , bool , Fingerprint ] = ssl
5560 self .timeout : Optional [int ] = timeout
5661 self .client_session_args = client_session_args
57-
5862 self .session : Optional [aiohttp .ClientSession ] = None
5963
6064 async def connect (self ) -> None :
@@ -81,7 +85,8 @@ async def connect(self) -> None:
8185 )
8286
8387 # Adding custom parameters passed from init
84- client_session_args .update (self .client_session_args )
88+ if self .client_session_args :
89+ client_session_args .update (self .client_session_args ) # type: ignore
8590
8691 self .session = aiohttp .ClientSession (** client_session_args )
8792
@@ -104,7 +109,8 @@ async def execute(
104109 document : DocumentNode ,
105110 variable_values : Optional [Dict [str , str ]] = None ,
106111 operation_name : Optional [str ] = None ,
107- extra_args : Dict [str , Any ] = {},
112+ extra_args : Dict [str , Any ] = None ,
113+ upload_files : bool = False ,
108114 ) -> ExecutionResult :
109115 """Execute the provided document AST against the configured remote server
110116 using the current session.
@@ -118,25 +124,70 @@ async def execute(
118124 :param variables_values: An optional Dict of variable values
119125 :param operation_name: An optional Operation name for the request
120126 :param extra_args: additional arguments to send to the aiohttp post method
127+ :param upload_files: Set to True if you want to put files in the variable values
121128 :returns: an ExecutionResult object.
122129 """
123130
124131 query_str = print_ast (document )
132+
125133 payload : Dict [str , Any ] = {
126134 "query" : query_str ,
127135 }
128136
129- if variable_values :
130- payload ["variables" ] = variable_values
131137 if operation_name :
132138 payload ["operationName" ] = operation_name
133139
134- post_args = {
135- "json" : payload ,
136- }
140+ if upload_files :
141+
142+ # If the upload_files flag is set, then we need variable_values
143+ assert variable_values is not None
144+
145+ # If we upload files, we will extract the files present in the
146+ # variable_values dict and replace them by null values
147+ nulled_variable_values , files = extract_files (variable_values )
148+
149+ # Save the nulled variable values in the payload
150+ payload ["variables" ] = nulled_variable_values
151+
152+ # Prepare aiohttp to send multipart-encoded data
153+ data = aiohttp .FormData ()
154+
155+ # Generate the file map
156+ # path is nested in a list because the spec allows multiple pointers
157+ # to the same file. But we don't support that.
158+ # Will generate something like {"0": ["variables.file"]}
159+ file_map = {str (i ): [path ] for i , path in enumerate (files )}
160+
161+ # Enumerate the file streams
162+ # Will generate something like {'0': <_io.BufferedReader ...>}
163+ file_streams = {str (i ): files [path ] for i , path in enumerate (files )}
164+
165+ # Add the payload to the operations field
166+ operations_str = json .dumps (payload )
167+ log .debug ("operations %s" , operations_str )
168+ data .add_field (
169+ "operations" , operations_str , content_type = "application/json"
170+ )
171+
172+ # Add the file map field
173+ file_map_str = json .dumps (file_map )
174+ log .debug ("file_map %s" , file_map_str )
175+ data .add_field ("map" , file_map_str , content_type = "application/json" )
176+
177+ # Add the extracted files as remaining fields
178+ data .add_fields (* file_streams .items ())
179+
180+ post_args : Dict [str , Any ] = {"data" : data }
181+
182+ else :
183+ if variable_values :
184+ payload ["variables" ] = variable_values
185+
186+ post_args = {"json" : payload }
137187
138188 # Pass post_args to aiohttp post method
139- post_args .update (extra_args )
189+ if extra_args :
190+ post_args .update (extra_args )
140191
141192 if self .session is None :
142193 raise TransportClosed ("Transport is not connected" )
0 commit comments