Skip to content

Commit f5935df

Browse files
committed
Collect arrow do_action result for eager response check
1 parent f946973 commit f5935df

File tree

4 files changed

+18
-21
lines changed

4 files changed

+18
-21
lines changed

graphdatascience/arrow_client/authenticated_flight_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,18 @@ def do_action(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[
180180

181181
return self._flight_client.do_action(Action(endpoint, payload_bytes)) # type: ignore
182182

183-
def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> Iterator[Result]:
183+
def do_action_with_retry(self, endpoint: str, payload: bytes | dict[str, Any]) -> list[Result]:
184184
@retry(
185185
reraise=True,
186186
before=before_log("Send action", self._logger, logging.DEBUG),
187187
retry=self._retry_config.retry,
188188
stop=self._retry_config.stop,
189189
wait=self._retry_config.wait,
190190
)
191-
def run_with_retry() -> Iterator[Result]:
192-
return self.do_action(endpoint, payload)
191+
def run_with_retry() -> list[Result]:
192+
# the Flight response error code is only checked on iterator consumption
193+
# we eagerly collect iterator here to trigger retry in case of an error
194+
return list(self.do_action(endpoint, payload))
193195

194196
return run_with_retry()
195197

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import json
2-
from typing import Any, Iterator
2+
from typing import Any
33

44
from pyarrow._flight import Result
55

66

7-
def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]:
7+
def deserialize_single(input_stream: list[Result]) -> dict[str, Any]:
88
rows = deserialize(input_stream)
99
if len(rows) != 1:
1010
raise ValueError(f"Expected exactly one result, got {len(rows)}")
1111

1212
return rows[0]
1313

1414

15-
def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]:
15+
def deserialize(input_stream: list[Result]) -> list[dict[str, Any]]:
1616
def deserialize_row(row: Result): # type:ignore
1717
return json.loads(row.body.to_pybytes().decode())
1818

19-
return [deserialize_row(row) for row in list(input_stream)]
19+
return [deserialize_row(row) for row in input_stream]

graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def create_graph(
1515
arrow_client: AuthenticatedArrowClient, graph_name: str, gdl: str, undirected: tuple[str, str] | None = None
1616
) -> Generator[GraphV2, Any, None]:
1717
try:
18-
raw_res = arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl})
19-
deserialize_single(raw_res)
18+
raw_res = list(arrow_client.do_action("v2/graph.fromGDL", {"graphName": graph_name, "gdlGraph": gdl}))
2019

2120
if undirected is not None:
2221
JobClient.run_job_and_wait(
@@ -26,9 +25,11 @@ def create_graph(
2625
show_progress=False,
2726
)
2827

29-
raw_res = arrow_client.do_action(
30-
"v2/graph.relationships.drop",
31-
{"graphName": graph_name, "relationshipType": undirected[0]},
28+
raw_res = list(
29+
arrow_client.do_action(
30+
"v2/graph.relationships.drop",
31+
{"graphName": graph_name, "relationshipType": undirected[0]},
32+
)
3233
)
3334
deserialize_single(raw_res)
3435

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
from typing import Iterator
2-
31
import pytest
4-
from pyarrow._flight import Result
52

63
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
74
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
85

96

107
def test_deserialize_single_success() -> None:
11-
input_stream = iter([ArrowTestResult({"key": "value"})])
128
expected = {"key": "value"}
13-
actual = deserialize_single(input_stream)
9+
actual = deserialize_single([ArrowTestResult({"key": "value"})])
1410
assert expected == actual
1511

1612

1713
def test_deserialize_single_raises_on_empty_stream() -> None:
18-
input_stream: Iterator[Result] = iter([])
1914
with pytest.raises(ValueError, match="Expected exactly one result, got 0"):
20-
deserialize_single(input_stream)
15+
deserialize_single([])
2116

2217

2318
def test_deserialize_single_raises_on_multiple_results() -> None:
24-
input_stream = iter([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])
2519
with pytest.raises(ValueError, match="Expected exactly one result, got 2"):
26-
deserialize_single(input_stream)
20+
deserialize_single([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])

0 commit comments

Comments
 (0)