Skip to content

Commit 83054ab

Browse files
committed
Add more generic option arrow_client_options to allow user specific arrow client configuration
1 parent 7f84c3e commit 83054ab

File tree

10 files changed

+99
-70
lines changed

10 files changed

+99
-70
lines changed

changelog.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
## New features
66

7-
- `sessions.get_or_create()` now supports passing in a manual selection of root certificates for verifying server certificate
8-
- `sessions.get_or_create()` now supports disabling server certificate verification
7+
- `sessions.get_or_create()` now supports passing additional configuration options for the Arrow Flight Client
98

109

1110
## Bug fixes

doc/modules/ROOT/pages/graph-analytics-serverless.adoc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ This is a hard limit and cannot be changed.
7777

7878
==== Syntax
7979

80+
self,
81+
session_name: str,
82+
memory: SessionMemory,
83+
db_connection: Optional[DbmsConnectionInfo] = None,
84+
ttl: Optional[timedelta] = None,
85+
cloud_location: Optional[CloudLocation] = None,
86+
timeout: Optional[int] = None,
87+
neo4j_driver_options: Optional[dict[str, Any]] = None,
88+
arrow_client_options: Optional[dict[str, Any]] = None,
89+
8090
[source, role=no-test]
8191
----
8292
sessions.get_or_create(
@@ -86,8 +96,8 @@ sessions.get_or_create(
8696
ttl: Optional[timedelta] = None,
8797
cloud_location: Optional[CloudLocation] = None,
8898
timeout: Optional[int] = None,
89-
tls_root_certs: Optional[bytes] = None,
90-
disable_server_verification: bool = False,
99+
neo4j_driver_options: Optional[dict[str, Any]] = None,
100+
arrow_client_options: Optional[dict[str, Any]] = None,
91101
): AuraGraphDataScience
92102
----
93103

@@ -101,8 +111,8 @@ sessions.get_or_create(
101111
| ttl | datetime.timedelta | yes | 1h | Time-to-live for the session.
102112
| cloud_location | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/cloud_location[CloudLocation] | yes | None | Aura-supported cloud provider and region where the GDS Session will run. Required for the Self-managed and Standalone types.
103113
| timeout | int | yes | None | Seconds to wait for the session to enter Ready state. If the time is exceeded, an error will be returned.
104-
| tls_root_certs | bytes | yes | None | PEM-encoded root certificates used for verifying server certificate. If not specified, platform-specific default root certificates will be used.
105-
| disable_server_verification | bool | yes | False | Set to True to disable server TLS certificate verification. Use with caution.
114+
| neo4j_driver_options | dict[str, any] | yes | None | Additional options passed to the Neo4j driver
115+
| arrow_client_options | dict[str, any] | yes | None | Additional options passed to the Arrow Flight Client
106116
|===
107117

108118

graphdatascience/graph/graph_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def database(self) -> str:
6868
"""
6969
return self._graph_info(["database"]) # type: ignore
7070

71-
def configuration(self) -> "Series[Any]":
71+
def configuration(self) -> Series[Any]:
7272
"""
7373
Returns:
7474
the configuration of the graph

graphdatascience/graph_data_science.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
arrow_tls_root_certs: Optional[bytes] = None,
4343
bookmarks: Optional[Any] = None,
4444
show_progress: bool = True,
45+
arrow_client_options: Optional[dict[str, Any]] = None,
4546
):
4647
"""
4748
Construct a new GraphDataScience object.
@@ -65,14 +66,20 @@ def __init__(
6566
- True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint.
6667
- False will make the client use Bolt for all operations.
6768
arrow_disable_server_verification : bool, default True
69+
.. deprecated:: 1.16
70+
Use arrow_client_options instead
6871
A flag that overrides other TLS settings and disables server verification for TLS connections.
6972
arrow_tls_root_certs : Optional[bytes], default None
73+
.. deprecated:: 1.16
74+
Use arrow_client_options instead
7075
PEM-encoded certificates that are used for the connection to the
7176
GDS Arrow Flight server.
7277
bookmarks : Optional[Any], default None
7378
The Neo4j bookmarks to require a certain state before the next query gets executed.
7479
show_progress : bool, default True
7580
A flag to indicate whether to show progress bars for running procedures.
81+
arrow_client_options : Optional[dict[str, Any]], default None
82+
Additional options to be passed to the Arrow Flight client.
7683
"""
7784
if aura_ds:
7885
GraphDataScience._validate_endpoint(endpoint)
@@ -105,14 +112,19 @@ def __init__(
105112
username, password = auth
106113
arrow_auth = UsernamePasswordAuthentication(username, password)
107114

115+
if arrow_client_options is None:
116+
arrow_client_options = {}
117+
if arrow_disable_server_verification:
118+
arrow_client_options["disable_server_verification"] = True
119+
if arrow_tls_root_certs is not None:
120+
arrow_client_options["tls_root_certs"] = arrow_tls_root_certs
108121
self._query_runner = ArrowQueryRunner.create(
109122
self._query_runner,
110-
arrow_info,
111-
arrow_auth,
112-
self._query_runner.encrypted(),
113-
arrow_disable_server_verification,
114-
arrow_tls_root_certs,
115-
None if arrow is True else arrow,
123+
arrow_info=arrow_info,
124+
arrow_authentication=arrow_auth,
125+
encrypted=self._query_runner.encrypted(),
126+
arrow_client_options=arrow_client_options,
127+
connection_string_override=None if arrow is True else arrow,
116128
)
117129

118130
self._query_runner.set_show_progress(show_progress)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,20 @@ def create(
2525
arrow_info: ArrowInfo,
2626
arrow_authentication: Optional[ArrowAuthentication] = None,
2727
encrypted: bool = False,
28-
disable_server_verification: bool = False,
29-
tls_root_certs: Optional[bytes] = None,
28+
arrow_client_options: Optional[dict[str, Any]] = None,
3029
connection_string_override: Optional[str] = None,
3130
retry_config: Optional[RetryConfig] = None,
3231
) -> ArrowQueryRunner:
3332
if not arrow_info.enabled:
3433
raise ValueError("Arrow is not enabled on the server")
3534

3635
gds_arrow_client = GdsArrowClient.create(
37-
arrow_info,
38-
arrow_authentication,
39-
encrypted,
40-
disable_server_verification,
41-
tls_root_certs,
42-
connection_string_override,
36+
arrow_info=arrow_info,
37+
auth=arrow_authentication,
38+
encrypted=encrypted,
39+
connection_string_override=connection_string_override,
4340
retry_config=retry_config,
41+
arrow_client_options=arrow_client_options,
4442
)
4543

4644
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def create(
5555
tls_root_certs: Optional[bytes] = None,
5656
connection_string_override: Optional[str] = None,
5757
retry_config: Optional[RetryConfig] = None,
58+
arrow_client_options: Optional[dict[str, Any]] = None,
5859
) -> GdsArrowClient:
5960
connection_string: str
6061
if connection_string_override is not None:
@@ -78,14 +79,15 @@ def create(
7879
)
7980

8081
return GdsArrowClient(
81-
host,
82-
retry_config,
83-
int(port),
84-
auth,
85-
encrypted,
86-
disable_server_verification,
87-
tls_root_certs,
88-
arrow_endpoint_version,
82+
host=host,
83+
retry_config=retry_config,
84+
port=int(port),
85+
auth=auth,
86+
encrypted=encrypted,
87+
disable_server_verification=disable_server_verification,
88+
tls_root_certs=tls_root_certs,
89+
arrow_endpoint_version=arrow_endpoint_version,
90+
arrow_client_options=arrow_client_options,
8991
)
9092

9193
def __init__(
@@ -99,6 +101,7 @@ def __init__(
99101
tls_root_certs: Optional[bytes] = None,
100102
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.V1,
101103
user_agent: Optional[str] = None,
104+
arrow_client_options: Optional[dict[str, Any]] = None,
102105
):
103106
"""Creates a new GdsArrowClient instance.
104107
@@ -113,27 +116,41 @@ def __init__(
113116
encrypted: bool
114117
A flag that indicates whether the connection should be encrypted (default is False)
115118
disable_server_verification: bool
119+
.. deprecated:: 1.16
120+
Use arrow_client_options instead
116121
A flag that disables server verification for TLS connections (default is False)
117122
tls_root_certs: Optional[bytes]
123+
.. deprecated:: 1.16
124+
Use arrow_client_options instead
118125
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
119126
arrow_endpoint_version:
120127
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
121128
user_agent: Optional[str]
122129
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
123130
retry_config: Optional[RetryConfig]
124131
The retry configuration to use for the Arrow requests send by the client.
132+
arrow_client_options: Optional[dict[str, Any]]
133+
Additional configuration for the Arrow flight client.
134+
125135
"""
126136
self._arrow_endpoint_version = arrow_endpoint_version
127137
self._host = host
128138
self._port = port
129139
self._auth = None
130140
self._encrypted = encrypted
131-
self._disable_server_verification = disable_server_verification
132-
self._tls_root_certs = tls_root_certs
133141
self._user_agent = user_agent
134142
self._retry_config = retry_config
135143
self._logger = logging.getLogger("gds_arrow_client")
136144

145+
self._arrow_client_options: dict[str, Any] = arrow_client_options # type: ignore
146+
if self._arrow_client_options is None:
147+
self._arrow_client_options = {}
148+
149+
if disable_server_verification:
150+
self._arrow_client_options["disable_server_verification"] = True
151+
if tls_root_certs is not None:
152+
self._arrow_client_options["tls_root_certs"] = tls_root_certs
153+
137154
if auth:
138155
if not isinstance(auth, ArrowAuthentication):
139156
username, password = auth
@@ -149,18 +166,27 @@ def _instantiate_flight_client(self) -> flight.FlightClient:
149166
if self._encrypted
150167
else flight.Location.for_grpc_tcp(self._host, self._port)
151168
)
152-
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
169+
170+
client_options = self._arrow_client_options.copy()
171+
153172
if self._auth:
154173
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
155174
if self._user_agent:
156175
user_agent = self._user_agent
157176

158-
client_options["middleware"] = [
159-
AuthFactory(self._auth_middleware),
160-
UserAgentFactory(useragent=user_agent),
161-
]
162-
if self._tls_root_certs:
163-
client_options["tls_root_certs"] = self._tls_root_certs
177+
if "middleware" in client_options:
178+
if not isinstance(client_options["middleware"], list):
179+
raise TypeError("client_options['middleware'] must be a list")
180+
else:
181+
client_options["middleware"] = []
182+
183+
client_options["middleware"].extend(
184+
[
185+
AuthFactory(self._auth_middleware),
186+
UserAgentFactory(useragent=user_agent),
187+
]
188+
) # type: ignore
189+
164190
return flight.FlightClient(location, **client_options)
165191

166192
def connection_info(self) -> tuple[str, int]:

graphdatascience/session/aura_graph_data_science.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def create(
3838
arrow_authentication: Optional[ArrowAuthentication],
3939
db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]],
4040
delete_fn: Callable[[], bool],
41-
arrow_disable_server_verification: bool = False,
42-
arrow_tls_root_certs: Optional[bytes] = None,
41+
arrow_client_options: Optional[dict[str, Any]] = None,
4342
bookmarks: Optional[Any] = None,
4443
show_progress: bool = True,
4544
) -> AuraGraphDataScience:
@@ -55,17 +54,15 @@ def create(
5554
arrow_info=arrow_info,
5655
arrow_authentication=arrow_authentication,
5756
encrypted=session_bolt_query_runner.encrypted(),
58-
disable_server_verification=arrow_disable_server_verification,
59-
tls_root_certs=arrow_tls_root_certs,
57+
arrow_client_options=arrow_client_options,
6058
)
6159

6260
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
6361
session_arrow_client = GdsArrowClient.create(
6462
arrow_info,
6563
arrow_authentication,
6664
session_bolt_query_runner.encrypted(),
67-
arrow_disable_server_verification,
68-
arrow_tls_root_certs,
65+
arrow_client_options=arrow_client_options,
6966
)
7067

7168
gds_version = session_bolt_query_runner.server_version()

graphdatascience/session/dedicated_sessions.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def get_or_create(
6363
cloud_location: Optional[CloudLocation] = None,
6464
timeout: Optional[int] = None,
6565
neo4j_driver_options: Optional[dict[str, Any]] = None,
66-
tls_root_certs: Optional[bytes] = None,
67-
disable_server_verification: bool = False,
66+
arrow_client_options: Optional[dict[str, Any]] = None,
6867
) -> AuraGraphDataScience:
6968
if db_connection is None:
7069
if not cloud_location:
@@ -106,8 +105,7 @@ def get_or_create(
106105
session_bolt_connection_info=session_bolt_connection_info,
107106
arrow_authentication=arrow_authentication,
108107
db_runner=db_runner,
109-
tls_root_certs=tls_root_certs,
110-
disable_server_verification=disable_server_verification,
108+
arrow_client_options=arrow_client_options,
111109
)
112110

113111
def _create_db_runner(
@@ -213,14 +211,12 @@ def _construct_client(
213211
session_bolt_connection_info: DbmsConnectionInfo,
214212
arrow_authentication: ArrowAuthentication,
215213
db_runner: Optional[Neo4jQueryRunner],
216-
tls_root_certs: Optional[bytes],
217-
disable_server_verification: bool,
214+
arrow_client_options: Optional[dict[str, Any]] = None,
218215
) -> AuraGraphDataScience:
219216
return AuraGraphDataScience.create(
220217
session_bolt_connection_info=session_bolt_connection_info,
221218
arrow_authentication=arrow_authentication,
222219
db_endpoint=db_runner,
223220
delete_fn=lambda: self._aura_api.delete_session(session_id=session_id),
224-
arrow_tls_root_certs=tls_root_certs,
225-
arrow_disable_server_verification=disable_server_verification,
221+
arrow_client_options=arrow_client_options,
226222
)

graphdatascience/session/gds_sessions.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def get_or_create(
106106
cloud_location: Optional[CloudLocation] = None,
107107
timeout: Optional[int] = None,
108108
neo4j_driver_config: Optional[dict[str, Any]] = None,
109-
tls_root_certs: Optional[bytes] = None,
110-
disable_server_verification: bool = False,
109+
arrow_client_options: Optional[dict[str, Any]] = None,
111110
) -> AuraGraphDataScience:
112111
"""
113112
Retrieves an existing session with the given session name and database connection,
@@ -124,8 +123,7 @@ def get_or_create(
124123
cloud_location (Optional[CloudLocation]): The cloud location. Required if the GDS session is for a self-managed database.
125124
timeout (Optional[int]): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation.
126125
neo4j_driver_config (Optional[dict[str, Any]]): Optional configuration for the Neo4j driver.
127-
tls_root_certs (Optional[bytes]): Manually specify PEM-encoded root certificates used for verifying server certificate. If not specified, platform-specific default root certificates will be used.
128-
disable_server_verification (bool): Set to True to disable server certificate verification. Use with caution.
126+
arrow_client_options (Optional[dict[str, Any]]): Optional configuration for the Arrow Flight client.
129127
Returns:
130128
AuraGraphDataScience: The session.
131129
"""
@@ -137,8 +135,7 @@ def get_or_create(
137135
cloud_location=cloud_location,
138136
timeout=timeout,
139137
neo4j_driver_options=neo4j_driver_config,
140-
tls_root_certs=tls_root_certs,
141-
disable_server_verification=disable_server_verification,
138+
arrow_client_options=arrow_client_options,
142139
)
143140

144141
def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:

0 commit comments

Comments
 (0)