Skip to content

Commit 2a95445

Browse files
authored
Merge pull request #912 from neo4j/tls-session
Allow passing in root certificates and disable server verification
2 parents 59252c6 + f27e55e commit 2a95445

File tree

10 files changed

+153
-46
lines changed

10 files changed

+153
-46
lines changed

changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
## New features
66

7+
- `sessions.get_or_create()` now supports passing additional configuration options for the Arrow Flight Client
8+
9+
710
## Bug fixes
811

912
- Fix reporting error based on http responses from the Aura-API with an invalid JSON body. Earlier the client would report JSONDecodeError instead of showing the actual issue.

doc/modules/ROOT/pages/graph-analytics-serverless.adoc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ This is a hard limit and cannot be changed.
7777

7878
==== Syntax
7979

80+
self,
81+
session_name: str,
82+
memory: SessionMemory,
83+
db_connection: Optional[DbmsConnectionInfo] = None,
84+
ttl: Optional[timedelta] = None,
85+
cloud_location: Optional[CloudLocation] = None,
86+
timeout: Optional[int] = None,
87+
neo4j_driver_options: Optional[dict[str, Any]] = None,
88+
arrow_client_options: Optional[dict[str, Any]] = None,
89+
8090
[source, role=no-test]
8191
----
8292
sessions.get_or_create(
@@ -86,19 +96,23 @@ sessions.get_or_create(
8696
ttl: Optional[timedelta] = None,
8797
cloud_location: Optional[CloudLocation] = None,
8898
timeout: Optional[int] = None,
99+
neo4j_driver_options: Optional[dict[str, Any]] = None,
100+
arrow_client_options: Optional[dict[str, Any]] = None,
89101
): AuraGraphDataScience
90102
----
91103

92104
.Parameters:
93105
[opts="header",cols="3m,1m,1,1m,6", role="no-break"]
94106
|===
95-
| Name | Type | Optional | Default | Description
96-
| session_name | str | no | - | Name of the session. Must be unique within the project.
97-
| memory | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/session_memory[SessionMemory] | no | - | Amount of memory available to the session.
98-
| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/DbmsConnectionInfo[DbmsConnectionInfo] | yes | None | Bolt server URL, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object].
99-
| ttl | datetime.timedelta | yes | 1h | Time-to-live for the session.
100-
| cloud_location| https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/cloud_location[CloudLocation] | yes | None | Aura-supported cloud provider and region where the GDS Session will run. Required for the Self-managed and Standalone types.
101-
| timeout | int | yes | None | Seconds to wait for the session to enter Ready state. If the time is exceeded, an error will be returned.
107+
| Name | Type | Optional | Default | Description
108+
| session_name | str | no | - | Name of the session. Must be unique within the project.
109+
| memory | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/session_memory[SessionMemory] | no | - | Amount of memory available to the session.
110+
| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/DbmsConnectionInfo[DbmsConnectionInfo] | yes | None | Bolt server URL, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object].
111+
| ttl | datetime.timedelta | yes | 1h | Time-to-live for the session.
112+
| cloud_location | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/cloud_location[CloudLocation] | yes | None | Aura-supported cloud provider and region where the GDS Session will run. Required for the Self-managed and Standalone types.
113+
| timeout | int | yes | None | Seconds to wait for the session to enter Ready state. If the time is exceeded, an error will be returned.
114+
| neo4j_driver_options | dict[str, any] | yes | None | Additional options passed to the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified.
115+
| arrow_client_options | dict[str, any] | yes | None | Additional options passed to the Arrow Flight Client used to connect to the Session.
102116
|===
103117

104118

graphdatascience/graph/graph_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def database(self) -> str:
6868
"""
6969
return self._graph_info(["database"]) # type: ignore
7070

71-
def configuration(self) -> "Series[Any]":
71+
def configuration(self) -> Series[Any]:
7272
"""
7373
Returns:
7474
the configuration of the graph

graphdatascience/graph_data_science.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
arrow_tls_root_certs: Optional[bytes] = None,
4343
bookmarks: Optional[Any] = None,
4444
show_progress: bool = True,
45+
arrow_client_options: Optional[dict[str, Any]] = None,
4546
):
4647
"""
4748
Construct a new GraphDataScience object.
@@ -65,14 +66,20 @@ def __init__(
6566
- True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint.
6667
- False will make the client use Bolt for all operations.
6768
arrow_disable_server_verification : bool, default True
69+
.. deprecated:: 1.16
70+
Use arrow_client_options instead
6871
A flag that overrides other TLS settings and disables server verification for TLS connections.
6972
arrow_tls_root_certs : Optional[bytes], default None
73+
.. deprecated:: 1.16
74+
Use arrow_client_options instead
7075
PEM-encoded certificates that are used for the connection to the
7176
GDS Arrow Flight server.
7277
bookmarks : Optional[Any], default None
7378
The Neo4j bookmarks to require a certain state before the next query gets executed.
7479
show_progress : bool, default True
7580
A flag to indicate whether to show progress bars for running procedures.
81+
arrow_client_options : Optional[dict[str, Any]], default None
82+
Additional options to be passed to the Arrow Flight client.
7683
"""
7784
if aura_ds:
7885
GraphDataScience._validate_endpoint(endpoint)
@@ -105,14 +112,19 @@ def __init__(
105112
username, password = auth
106113
arrow_auth = UsernamePasswordAuthentication(username, password)
107114

115+
if arrow_client_options is None:
116+
arrow_client_options = {}
117+
if arrow_disable_server_verification:
118+
arrow_client_options["disable_server_verification"] = True
119+
if arrow_tls_root_certs is not None:
120+
arrow_client_options["tls_root_certs"] = arrow_tls_root_certs
108121
self._query_runner = ArrowQueryRunner.create(
109122
self._query_runner,
110-
arrow_info,
111-
arrow_auth,
112-
self._query_runner.encrypted(),
113-
arrow_disable_server_verification,
114-
arrow_tls_root_certs,
115-
None if arrow is True else arrow,
123+
arrow_info=arrow_info,
124+
arrow_authentication=arrow_auth,
125+
encrypted=self._query_runner.encrypted(),
126+
arrow_client_options=arrow_client_options,
127+
connection_string_override=None if arrow is True else arrow,
116128
)
117129

118130
self._query_runner.set_show_progress(show_progress)

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,20 @@ def create(
2525
arrow_info: ArrowInfo,
2626
arrow_authentication: Optional[ArrowAuthentication] = None,
2727
encrypted: bool = False,
28-
disable_server_verification: bool = False,
29-
tls_root_certs: Optional[bytes] = None,
28+
arrow_client_options: Optional[dict[str, Any]] = None,
3029
connection_string_override: Optional[str] = None,
3130
retry_config: Optional[RetryConfig] = None,
3231
) -> ArrowQueryRunner:
3332
if not arrow_info.enabled:
3433
raise ValueError("Arrow is not enabled on the server")
3534

3635
gds_arrow_client = GdsArrowClient.create(
37-
arrow_info,
38-
arrow_authentication,
39-
encrypted,
40-
disable_server_verification,
41-
tls_root_certs,
42-
connection_string_override,
36+
arrow_info=arrow_info,
37+
auth=arrow_authentication,
38+
encrypted=encrypted,
39+
connection_string_override=connection_string_override,
4340
retry_config=retry_config,
41+
arrow_client_options=arrow_client_options,
4442
)
4543

4644
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def create(
5555
tls_root_certs: Optional[bytes] = None,
5656
connection_string_override: Optional[str] = None,
5757
retry_config: Optional[RetryConfig] = None,
58+
arrow_client_options: Optional[dict[str, Any]] = None,
5859
) -> GdsArrowClient:
5960
connection_string: str
6061
if connection_string_override is not None:
@@ -78,14 +79,15 @@ def create(
7879
)
7980

8081
return GdsArrowClient(
81-
host,
82-
retry_config,
83-
int(port),
84-
auth,
85-
encrypted,
86-
disable_server_verification,
87-
tls_root_certs,
88-
arrow_endpoint_version,
82+
host=host,
83+
retry_config=retry_config,
84+
port=int(port),
85+
auth=auth,
86+
encrypted=encrypted,
87+
disable_server_verification=disable_server_verification,
88+
tls_root_certs=tls_root_certs,
89+
arrow_endpoint_version=arrow_endpoint_version,
90+
arrow_client_options=arrow_client_options,
8991
)
9092

9193
def __init__(
@@ -99,6 +101,7 @@ def __init__(
99101
tls_root_certs: Optional[bytes] = None,
100102
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.V1,
101103
user_agent: Optional[str] = None,
104+
arrow_client_options: Optional[dict[str, Any]] = None,
102105
):
103106
"""Creates a new GdsArrowClient instance.
104107
@@ -113,27 +116,39 @@ def __init__(
113116
encrypted: bool
114117
A flag that indicates whether the connection should be encrypted (default is False)
115118
disable_server_verification: bool
119+
.. deprecated:: 1.16
120+
Use arrow_client_options instead
116121
A flag that disables server verification for TLS connections (default is False)
117122
tls_root_certs: Optional[bytes]
123+
.. deprecated:: 1.16
124+
Use arrow_client_options instead
118125
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
119126
arrow_endpoint_version:
120127
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
121128
user_agent: Optional[str]
122129
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
123130
retry_config: Optional[RetryConfig]
124131
The retry configuration to use for the Arrow requests send by the client.
132+
arrow_client_options: Optional[dict[str, Any]]
133+
Additional configuration for the Arrow flight client.
134+
125135
"""
126136
self._arrow_endpoint_version = arrow_endpoint_version
127137
self._host = host
128138
self._port = port
129139
self._auth = None
130140
self._encrypted = encrypted
131-
self._disable_server_verification = disable_server_verification
132-
self._tls_root_certs = tls_root_certs
133141
self._user_agent = user_agent
134142
self._retry_config = retry_config
135143
self._logger = logging.getLogger("gds_arrow_client")
136144

145+
self._arrow_client_options = arrow_client_options if arrow_client_options is not None else {}
146+
147+
if disable_server_verification:
148+
self._arrow_client_options["disable_server_verification"] = True
149+
if tls_root_certs is not None:
150+
self._arrow_client_options["tls_root_certs"] = tls_root_certs
151+
137152
if auth:
138153
if not isinstance(auth, ArrowAuthentication):
139154
username, password = auth
@@ -149,18 +164,27 @@ def _instantiate_flight_client(self) -> flight.FlightClient:
149164
if self._encrypted
150165
else flight.Location.for_grpc_tcp(self._host, self._port)
151166
)
152-
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
167+
168+
client_options = self._arrow_client_options.copy()
169+
153170
if self._auth:
154171
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
155172
if self._user_agent:
156173
user_agent = self._user_agent
157174

158-
client_options["middleware"] = [
159-
AuthFactory(self._auth_middleware),
160-
UserAgentFactory(useragent=user_agent),
161-
]
162-
if self._tls_root_certs:
163-
client_options["tls_root_certs"] = self._tls_root_certs
175+
if "middleware" in client_options:
176+
if not isinstance(client_options["middleware"], list):
177+
raise TypeError("client_options['middleware'] must be a list")
178+
else:
179+
client_options["middleware"] = []
180+
181+
client_options["middleware"].extend(
182+
[
183+
AuthFactory(self._auth_middleware),
184+
UserAgentFactory(useragent=user_agent),
185+
]
186+
)
187+
164188
return flight.FlightClient(location, **client_options)
165189

166190
def connection_info(self) -> tuple[str, int]:

graphdatascience/session/aura_graph_data_science.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def create(
3838
arrow_authentication: Optional[ArrowAuthentication],
3939
db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]],
4040
delete_fn: Callable[[], bool],
41-
arrow_disable_server_verification: bool = False,
42-
arrow_tls_root_certs: Optional[bytes] = None,
41+
arrow_client_options: Optional[dict[str, Any]] = None,
4342
bookmarks: Optional[Any] = None,
4443
show_progress: bool = True,
4544
) -> AuraGraphDataScience:
@@ -55,17 +54,15 @@ def create(
5554
arrow_info=arrow_info,
5655
arrow_authentication=arrow_authentication,
5756
encrypted=session_bolt_query_runner.encrypted(),
58-
disable_server_verification=arrow_disable_server_verification,
59-
tls_root_certs=arrow_tls_root_certs,
57+
arrow_client_options=arrow_client_options,
6058
)
6159

6260
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
6361
session_arrow_client = GdsArrowClient.create(
6462
arrow_info,
6563
arrow_authentication,
6664
session_bolt_query_runner.encrypted(),
67-
arrow_disable_server_verification,
68-
arrow_tls_root_certs,
65+
arrow_client_options=arrow_client_options,
6966
)
7067

7168
gds_version = session_bolt_query_runner.server_version()

graphdatascience/session/dedicated_sessions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def get_or_create(
6363
cloud_location: Optional[CloudLocation] = None,
6464
timeout: Optional[int] = None,
6565
neo4j_driver_options: Optional[dict[str, Any]] = None,
66+
arrow_client_options: Optional[dict[str, Any]] = None,
6667
) -> AuraGraphDataScience:
6768
if db_connection is None:
6869
if not cloud_location:
@@ -104,6 +105,7 @@ def get_or_create(
104105
session_bolt_connection_info=session_bolt_connection_info,
105106
arrow_authentication=arrow_authentication,
106107
db_runner=db_runner,
108+
arrow_client_options=arrow_client_options,
107109
)
108110

109111
def _create_db_runner(
@@ -209,10 +211,12 @@ def _construct_client(
209211
session_bolt_connection_info: DbmsConnectionInfo,
210212
arrow_authentication: ArrowAuthentication,
211213
db_runner: Optional[Neo4jQueryRunner],
214+
arrow_client_options: Optional[dict[str, Any]] = None,
212215
) -> AuraGraphDataScience:
213216
return AuraGraphDataScience.create(
214217
session_bolt_connection_info=session_bolt_connection_info,
215218
arrow_authentication=arrow_authentication,
216219
db_endpoint=db_runner,
217220
delete_fn=lambda: self._aura_api.delete_session(session_id=session_id),
221+
arrow_client_options=arrow_client_options,
218222
)

graphdatascience/session/gds_sessions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def get_or_create(
106106
cloud_location: Optional[CloudLocation] = None,
107107
timeout: Optional[int] = None,
108108
neo4j_driver_config: Optional[dict[str, Any]] = None,
109+
arrow_client_options: Optional[dict[str, Any]] = None,
109110
) -> AuraGraphDataScience:
110111
"""
111112
Retrieves an existing session with the given session name and database connection,
@@ -121,7 +122,8 @@ def get_or_create(
121122
ttl: (Optional[timedelta]): The sessions time to live after inactivity in seconds.
122123
cloud_location (Optional[CloudLocation]): The cloud location. Required if the GDS session is for a self-managed database.
123124
timeout (Optional[int]): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation.
124-
neo4j_driver_config (Optional[dict[str, Any]]): Optional configuration for the Neo4j driver.
125+
neo4j_driver_config (Optional[dict[str, Any]]): Optional configuration for the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified..
126+
arrow_client_options (Optional[dict[str, Any]]): Optional configuration for the Arrow Flight client.
125127
Returns:
126128
AuraGraphDataScience: The session.
127129
"""
@@ -133,6 +135,7 @@ def get_or_create(
133135
cloud_location=cloud_location,
134136
timeout=timeout,
135137
neo4j_driver_options=neo4j_driver_config,
138+
arrow_client_options=arrow_client_options,
136139
)
137140

138141
def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:

0 commit comments

Comments
 (0)