Skip to content

Commit 22a1614

Browse files
authored
Merge pull request #739 from DarthMax/improve_arrow_error_messages
Try to extract better error messages from arrow errors
2 parents a182ad5 + 4b08a9b commit 22a1614

File tree

4 files changed

+83
-18
lines changed

4 files changed

+83
-18
lines changed

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> Non
8888

8989
writer, _ = self._client.start_put(flight_descriptor, table.schema)
9090

91-
with writer:
92-
# Write table in chunks
93-
for partition in batches:
94-
writer.write_batch(partition)
95-
pbar.update(partition.num_rows)
96-
# Force a refresh to avoid the progress bar getting stuck at 0%
97-
pbar.refresh()
91+
try:
92+
with writer:
93+
# Write table in chunks
94+
for partition in batches:
95+
writer.write_batch(partition)
96+
pbar.update(partition.num_rows)
97+
# Force a refresh to avoid the progress bar getting stuck at 0%
98+
pbar.refresh()
99+
except Exception as e:
100+
GdsArrowClient.handle_flight_error(e)
98101

99102
def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
100103
desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships"

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import base64
22
import json
3+
import re
34
import time
45
import warnings
56
from typing import Any, Dict, Optional, Tuple
67

8+
from neo4j.exceptions import ClientError
79
from pandas import DataFrame
810
from pyarrow import ChunkedArray, Schema, Table, chunked_array, flight
911
from pyarrow._flight import FlightStreamReader, FlightStreamWriter
@@ -128,8 +130,12 @@ def get_property(
128130
}
129131

130132
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
131-
get = self._flight_client.do_get(ticket)
132-
arrow_table = get.read_all()
133+
134+
try:
135+
get = self._flight_client.do_get(ticket)
136+
arrow_table = get.read_all()
137+
except Exception as e:
138+
self.handle_flight_error(e)
133139

134140
if configuration.get("list_node_labels", False):
135141
# GDS 2.5 had an inconsistent naming of the node labels column
@@ -147,13 +153,17 @@ def get_property(
147153

148154
def send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
149155
action_type = self._versioned_action_type(action_type)
150-
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
151156

152-
# Consume result fully to sanity check and avoid cancelled streams
153-
collected_result = list(result)
154-
assert len(collected_result) == 1
157+
try:
158+
result = self._flight_client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
159+
160+
# Consume result fully to sanity check and avoid cancelled streams
161+
collected_result = list(result)
162+
assert len(collected_result) == 1
155163

156-
json.loads(collected_result[0].body.to_pybytes().decode())
164+
json.loads(collected_result[0].body.to_pybytes().decode())
165+
except Exception as e:
166+
self.handle_flight_error(e)
157167

158168
def start_put(self, payload: Dict[str, Any], schema: Schema) -> Tuple[FlightStreamWriter, FlightStreamReader]:
159169
flight_descriptor = self._versioned_flight_descriptor(payload)
@@ -199,6 +209,30 @@ def _sanitize_arrow_table(arrow_table: Table) -> Table:
199209
arrow_table = arrow_table.set_column(idx, field.name, decoded_col)
200210
return arrow_table
201211

212+
@staticmethod
213+
def handle_flight_error(e: Exception):
214+
if (
215+
isinstance(e, flight.FlightServerError)
216+
or isinstance(e, flight.FlightInternalError)
217+
or isinstance(e, ClientError)
218+
):
219+
original_message = e.args[0]
220+
improved_message = original_message.replace(
221+
"Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
222+
)
223+
improved_message = improved_message.replace(
224+
"Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
225+
)
226+
improved_message = improved_message.replace(
227+
"Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
228+
"",
229+
)
230+
improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)
231+
232+
raise flight.FlightServerError(improved_message)
233+
else:
234+
raise e
235+
202236

203237
class AuthFactory(ClientMiddlewareFactory): # type: ignore
204238
def __init__(self, middleware: "AuthMiddleware", *args: Any, **kwargs: Any) -> None:

graphdatascience/query_runner/session_query_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ def _remote_projection(
143143

144144
versioned_endpoint = self._resolved_protocol_version.versioned_procedure_name(endpoint)
145145

146-
return self._db_query_runner.call_procedure(
147-
versioned_endpoint, remote_project_proc_params, yields, database, logging, False
148-
)
146+
try:
147+
return self._db_query_runner.call_procedure(
148+
versioned_endpoint, remote_project_proc_params, yields, database, logging, False
149+
)
150+
except Exception as e:
151+
GdsArrowClient.handle_flight_error(e)
149152

150153
@staticmethod
151154
def _project_params_v2(

graphdatascience/tests/unit/test_gds_arrow_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import re
2+
13
import pytest
4+
from pyarrow import flight
25

3-
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware
6+
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient
47

58

69
def test_auth_middleware() -> None:
@@ -27,3 +30,25 @@ def test_auth_middleware_bad_headers() -> None:
2730

2831
with pytest.raises(ValueError, match="Incompatible header value received from server: `12342`"):
2932
middleware.received_headers({"authorization": [12342]})
33+
34+
35+
def test_handle_flight_error():
36+
with pytest.raises(
37+
flight.FlightServerError,
38+
match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.",
39+
):
40+
GdsArrowClient.handle_flight_error(
41+
flight.FlightServerError(
42+
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
43+
)
44+
)
45+
46+
with pytest.raises(
47+
flight.FlightServerError,
48+
match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"),
49+
):
50+
GdsArrowClient.handle_flight_error(
51+
flight.FlightServerError(
52+
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
53+
)
54+
)

0 commit comments

Comments
 (0)