|
10 | 10 | from neo4j.exceptions import ClientError |
11 | 11 | from pandas import DataFrame |
12 | 12 | from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight |
| 13 | +from pyarrow import __version__ as arrow_version |
13 | 14 | from pyarrow._flight import FlightStreamReader, FlightStreamWriter |
14 | 15 | from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory |
15 | 16 | from pyarrow.types import is_dictionary |
16 | 17 |
|
17 | 18 | from ..server_version.server_version import ServerVersion |
| 19 | +from ..version import __version__ |
18 | 20 | from .arrow_endpoint_version import ArrowEndpointVersion |
19 | 21 | from .arrow_info import ArrowInfo |
20 | 22 | from .query_runner import QueryRunner |
@@ -75,7 +77,8 @@ def __init__( |
75 | 77 | client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} |
76 | 78 | if auth: |
77 | 79 | self._auth_middleware = AuthMiddleware(auth) |
78 | | - client_options["middleware"] = [AuthFactory(self._auth_middleware)] |
| 80 | + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" |
| 81 | + client_options["middleware"] = [AuthFactory(self._auth_middleware), UserAgentFactory(useragent=user_agent)] |
79 | 82 | if tls_root_certs: |
80 | 83 | client_options["tls_root_certs"] = tls_root_certs |
81 | 84 |
|
@@ -219,12 +222,33 @@ def handle_flight_error(e: Exception): |
219 | 222 | raise e |
220 | 223 |
|
221 | 224 |
|
222 | | -class AuthFactory(ClientMiddlewareFactory): # type: ignore |
223 | | - def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None: |
| 225 | +class UserAgentFactory(ClientMiddlewareFactory): |
| 226 | + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: |
| 227 | + super().__init__(*args, **kwargs) |
| 228 | + self._middleware = UserAgentMiddleware(useragent) |
| 229 | + |
| 230 | + def start_call(self, info: Any) -> ClientMiddleware: |
| 231 | + return self._middleware |
| 232 | + |
| 233 | + |
| 234 | +class UserAgentMiddleware(ClientMiddleware): |
| 235 | + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: |
| 236 | + super().__init__(*args, **kwargs) |
| 237 | + self._useragent = useragent |
| 238 | + |
| 239 | + def sending_headers(self) -> Dict[str, str]: |
| 240 | + return {"User-Agent": self._useragent} |
| 241 | + |
| 242 | + def received_headers(self, headers: Dict[str, Any]) -> None: |
| 243 | + pass |
| 244 | + |
| 245 | + |
| 246 | +class AuthFactory(ClientMiddlewareFactory): |
| 247 | + def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: |
224 | 248 | super().__init__(*args, **kwargs) |
225 | 249 | self._middleware = middleware |
226 | 250 |
|
227 | | - def start_call(self, info: Any) -> "AuthMiddleware": |
| 251 | + def start_call(self, info: Any) -> AuthMiddleware: |
228 | 252 | return self._middleware |
229 | 253 |
|
230 | 254 |
|
|
0 commit comments