Skip to content

Commit 0455f75

Browse files
authored
Merge pull request #752 from FlorentinD/arrow-client-set-useragent
Set user agent for Arrow client
2 parents 4117305 + 955c04b commit 0455f75

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from neo4j.exceptions import ClientError
1111
from pandas import DataFrame
1212
from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight
13+
from pyarrow import __version__ as arrow_version
1314
from pyarrow._flight import FlightStreamReader, FlightStreamWriter
1415
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
1516
from pyarrow.types import is_dictionary
1617

1718
from ..server_version.server_version import ServerVersion
19+
from ..version import __version__
1820
from .arrow_endpoint_version import ArrowEndpointVersion
1921
from .arrow_info import ArrowInfo
2022
from .query_runner import QueryRunner
@@ -75,7 +77,8 @@ def __init__(
7577
client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
7678
if auth:
7779
self._auth_middleware = AuthMiddleware(auth)
78-
client_options["middleware"] = [AuthFactory(self._auth_middleware)]
80+
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
81+
client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)]
7982
if tls_root_certs:
8083
client_options["tls_root_certs"] = tls_root_certs
8184

@@ -219,12 +222,33 @@ def handle_flight_error(e: Exception):
219222
raise e
220223

221224

222-
class AuthFactory(ClientMiddlewareFactory): # type: ignore
223-
def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None:
225+
class UserAgentFactory(ClientMiddlewareFactory):
226+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
227+
super().__init__(*args, **kwargs)
228+
self._middleware = UserAgentMiddleware(useragent)
229+
230+
def start_call(self, info: Any) -> ClientMiddleware:
231+
return self._middleware
232+
233+
234+
class UserAgentMiddleware(ClientMiddleware):
235+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
236+
super().__init__(*args, **kwargs)
237+
self._useragent = useragent
238+
239+
def sending_headers(self) -> Dict[str, str]:
240+
return {"x-gds-user-agent": self._useragent}
241+
242+
def received_headers(self, headers: Dict[str, Any]) -> None:
243+
pass
244+
245+
246+
class AuthFactory(ClientMiddlewareFactory):
247+
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
224248
super().__init__(*args, **kwargs)
225249
self._middleware = middleware
226250

227-
def start_call(self, info: Any) -> "AuthMiddleware":
251+
def start_call(self, info: Any) -> AuthMiddleware:
228252
return self._middleware
229253

230254

0 commit comments

Comments
 (0)