Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 3375f80

Browse files
Store error on transaction stream (#248)
## What is the goal of this PR? We store any exceptions received against the transaction stream (as well as query streams) in order to propagate the error to any future transaction operations, not just open query streams. ## What are the changes implemented in this PR? * store errors against query streams and transaction streams, so we can error on transaction operations against a closed transaction
1 parent 768edfb commit 3375f80

File tree

4 files changed

+20
-26
lines changed

4 files changed

+20
-26
lines changed

typedb/connection/transaction.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323

2424
import typedb_protocol.common.transaction_pb2 as transaction_proto
2525
from grpc import RpcError
26-
2726
from typedb.api.connection.options import TypeDBOptions
28-
from typedb.api.query.future import QueryFuture
2927
from typedb.api.connection.transaction import _TypeDBTransactionExtended, TransactionType
28+
from typedb.api.query.future import QueryFuture
3029
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, TRANSACTION_CLOSED_WITH_ERRORS
3130
from typedb.common.rpc.request_builder import transaction_commit_req, transaction_rollback_req, transaction_open_req
3231
from typedb.concept.concept_manager import _ConceptManager
@@ -113,8 +112,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
113112
return False
114113

115114
def _raise_transaction_closed(self):
116-
errors = self._bidirectional_stream.get_errors()
117-
if len(errors) == 0:
115+
error = self._bidirectional_stream.get_error()
116+
if error is None:
118117
raise TypeDBClientException.of(TRANSACTION_CLOSED)
119118
else:
120-
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, errors)
119+
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, error)

typedb/stream/bidirectional_stream.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, stub: TypeDBStub, transmitter: RequestTransmitter):
4444
self._response_iterator = stub.transaction(self._request_iterator)
4545
self._dispatcher = transmitter.dispatcher(self._request_iterator)
4646
self._is_open = AtomicBoolean(True)
47+
self._error: TypeDBClientException = None
4748

4849
def single(self, req: transaction_proto.Transaction.Req, batch: bool) -> "BidirectionalStream.Single[transaction_proto.Transaction.Res]":
4950
request_id = uuid4()
@@ -60,7 +61,7 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction
6061
req.req_id = request_id.bytes
6162
self._response_collector.new_queue(request_id)
6263
self._dispatcher.dispatch(req)
63-
return ResponsePartIterator(request_id, self, self._dispatcher)
64+
return ResponsePartIterator(request_id, self)
6465

6566
def done(self, request_id: UUID):
6667
self._response_collector.remove(request_id)
@@ -104,11 +105,15 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio
104105
else:
105106
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id)
106107

107-
def get_errors(self) -> List[TypeDBClientException]:
108-
return self._response_collector.get_errors()
108+
def dispatcher(self):
109+
return self._dispatcher
110+
111+
def get_error(self) -> TypeDBClientException:
112+
return self._error
109113

110114
def close(self, error: TypeDBClientException = None):
111115
if self._is_open.compare_and_set(True, False):
116+
self._error = error
112117
self._response_collector.close(error)
113118
try:
114119
self._dispatcher.close()

typedb/stream/response_collector.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,6 @@ def close(self, error: Optional[TypeDBClientException]):
5454
for collector in self._response_queues.values():
5555
collector.close(error)
5656

57-
def get_errors(self) -> [TypeDBClientException]:
58-
errors = []
59-
with self._collectors_lock:
60-
for collector in self._response_queues.values():
61-
error = collector.get_error()
62-
if error is not None:
63-
errors.append(error)
64-
return errors
65-
6657
class Queue(Generic[R]):
6758

6859
def __init__(self):
@@ -87,8 +78,6 @@ def close(self, error: Optional[TypeDBClientException]):
8778
self._error = error
8879
self._response_queue.put(DoneResponse())
8980

90-
def get_error(self) -> TypeDBClientException:
91-
return self._error
9281

9382

9483
class Response:

typedb/stream/response_part_iterator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,15 @@
2525
from enum import Enum
2626
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE
2727
from typedb.common.rpc.request_builder import transaction_stream_req
28-
from typedb.stream.request_transmitter import RequestTransmitter
2928

3029
if TYPE_CHECKING:
3130
from typedb.stream.bidirectional_stream import BidirectionalStream
3231

3332

3433
class ResponsePartIterator(Iterator[transaction_proto.Transaction.ResPart]):
3534

36-
def __init__(self, request_id: UUID, bidirectional_stream: "BidirectionalStream", request_dispatcher: RequestTransmitter.Dispatcher):
35+
def __init__(self, request_id: UUID, bidirectional_stream: "BidirectionalStream"):
3736
self._request_id = request_id
38-
self._dispatcher = request_dispatcher
3937
self._bidirectional_stream = bidirectional_stream
4038
self._state = ResponsePartIterator.State.EMPTY
4139
self._next: transaction_proto.Transaction.ResPart = None
@@ -54,7 +52,7 @@ def _fetch_and_check(self) -> bool:
5452
self._state = ResponsePartIterator.State.DONE
5553
return False
5654
elif state == transaction_proto.Transaction.Stream.State.Value("CONTINUE"):
57-
self._dispatcher.dispatch(transaction_stream_req(self._request_id))
55+
self._bidirectional_stream.dispatcher().dispatch(transaction_stream_req(self._request_id))
5856
return self._fetch_and_check()
5957
else:
6058
raise TypeDBClientException.of(ILLEGAL_ARGUMENT)
@@ -76,8 +74,11 @@ def _has_next(self) -> bool:
7674
raise TypeDBClientException.of(ILLEGAL_STATE)
7775

7876
def __next__(self) -> transaction_proto.Transaction.ResPart:
79-
if not self._has_next():
77+
if self._bidirectional_stream.get_error() is not None:
78+
raise self._bidirectional_stream.get_error()
79+
elif not self._has_next():
8080
self._bidirectional_stream.done(self._request_id)
8181
raise StopIteration
82-
self._state = ResponsePartIterator.State.EMPTY
83-
return self._next
82+
else:
83+
self._state = ResponsePartIterator.State.EMPTY
84+
return self._next

0 commit comments

Comments
 (0)