Skip to content

Commit ef09cde

Browse files
authored
fix: avoid using globals for record counting during concurrent declarative stream reads (#732)
1 parent e4b34b6 commit ef09cde

File tree

4 files changed

+90
-19
lines changed

4 files changed

+90
-19
lines changed

airbyte_cdk/manifest_server/command_processor/processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def test_read(
4141
"""
4242
Test the read method of the source.
4343
"""
44-
4544
test_read_handler = TestReader(
4645
max_pages_per_slice=page_limit,
4746
max_slices=slice_limit,

airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,21 @@
1414
from airbyte_cdk.sources.types import Record, StreamSlice
1515
from airbyte_cdk.utils.slice_hasher import SliceHasher
1616

17+
1718
# For Connector Builder test read operations, we track the total number of records
18-
# read for the stream at the global level so that we can stop reading early if we
19-
# exceed the record limit
20-
total_record_counter = 0
19+
# read for the stream so that we can stop reading early if we exceed the record limit.
20+
class RecordCounter:
21+
def __init__(self) -> None:
22+
self.total_record_counter = 0
23+
24+
def increment(self) -> None:
25+
self.total_record_counter += 1
26+
27+
def reset(self) -> None:
28+
self.total_record_counter = 0
29+
30+
def get_total_records(self) -> int:
31+
return self.total_record_counter
2132

2233

2334
class SchemaLoaderCachingDecorator(SchemaLoader):
@@ -51,6 +62,7 @@ def __init__(
5162
self._retriever = retriever
5263
self._message_repository = message_repository
5364
self._max_records_limit = max_records_limit
65+
self._record_counter = RecordCounter()
5466

5567
def create(self, stream_slice: StreamSlice) -> Partition:
5668
return DeclarativePartition(
@@ -60,6 +72,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
6072
message_repository=self._message_repository,
6173
max_records_limit=self._max_records_limit,
6274
stream_slice=stream_slice,
75+
record_counter=self._record_counter,
6376
)
6477

6578

@@ -72,6 +85,7 @@ def __init__(
7285
message_repository: MessageRepository,
7386
max_records_limit: Optional[int],
7487
stream_slice: StreamSlice,
88+
record_counter: RecordCounter,
7589
):
7690
self._stream_name = stream_name
7791
self._schema_loader = schema_loader
@@ -80,17 +94,17 @@ def __init__(
8094
self._max_records_limit = max_records_limit
8195
self._stream_slice = stream_slice
8296
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
97+
self._record_counter = record_counter
8398

8499
def read(self) -> Iterable[Record]:
85100
if self._max_records_limit is not None:
86-
global total_record_counter
87-
if total_record_counter >= self._max_records_limit:
101+
if self._record_counter.get_total_records() >= self._max_records_limit:
88102
return
89103
for stream_data in self._retriever.read_records(
90104
self._schema_loader.get_json_schema(), self._stream_slice
91105
):
92106
if self._max_records_limit is not None:
93-
if total_record_counter >= self._max_records_limit:
107+
if self._record_counter.get_total_records() >= self._max_records_limit:
94108
break
95109

96110
if isinstance(stream_data, Mapping):
@@ -108,7 +122,7 @@ def read(self) -> Iterable[Record]:
108122
self._message_repository.emit_message(stream_data)
109123

110124
if self._max_records_limit is not None:
111-
total_record_counter += 1
125+
self._record_counter.increment()
112126

113127
def to_slice(self) -> Optional[Mapping[str, Any]]:
114128
return self._stream_slice

unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
2424
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
2525
DeclarativePartition,
26+
RecordCounter,
2627
)
2728
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
2829
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
@@ -3624,6 +3625,7 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update
36243625
message_repository=MagicMock(),
36253626
max_records_limit=None,
36263627
stream_slice=slice,
3628+
record_counter=RecordCounter(),
36273629
)
36283630
)
36293631

@@ -3709,6 +3711,7 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update():
37093711
message_repository=MagicMock(),
37103712
max_records_limit=None,
37113713
stream_slice=slice,
3714+
record_counter=RecordCounter(),
37123715
)
37133716
)
37143717
cursor.ensure_at_least_one_state_emitted()
@@ -3804,6 +3807,7 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update
38043807
message_repository=MagicMock(),
38053808
max_records_limit=None,
38063809
stream_slice=slice,
3810+
record_counter=RecordCounter(),
38073811
)
38083812
)
38093813
cursor.ensure_at_least_one_state_emitted()
@@ -3894,6 +3898,7 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi
38943898
message_repository=MagicMock(),
38953899
max_records_limit=None,
38963900
stream_slice=slice,
3901+
record_counter=RecordCounter(),
38973902
)
38983903
)
38993904

@@ -3968,6 +3973,7 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob
39683973
message_repository=MagicMock(),
39693974
max_records_limit=None,
39703975
stream_slice=slice,
3976+
record_counter=RecordCounter(),
39713977
)
39723978
)
39733979
cursor.ensure_at_least_one_state_emitted()
@@ -4053,6 +4059,7 @@ def test_semaphore_cleanup():
40534059
message_repository=MagicMock(),
40544060
max_records_limit=None,
40554061
stream_slice=s,
4062+
record_counter=RecordCounter(),
40564063
)
40574064
)
40584065

@@ -4173,6 +4180,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
41734180
message_repository=MagicMock(),
41744181
max_records_limit=None,
41754182
stream_slice=first_1,
4183+
record_counter=RecordCounter(),
41764184
)
41774185
)
41784186

@@ -4185,6 +4193,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
41854193
message_repository=MagicMock(),
41864194
max_records_limit=None,
41874195
stream_slice=two,
4196+
record_counter=RecordCounter(),
41884197
)
41894198
)
41904199

@@ -4197,6 +4206,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
41974206
message_repository=MagicMock(),
41984207
max_records_limit=None,
41994208
stream_slice=second_1,
4209+
record_counter=RecordCounter(),
42004210
)
42014211
)
42024212

@@ -4258,6 +4268,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
42584268
message_repository=MagicMock(),
42594269
max_records_limit=None,
42604270
stream_slice=first_1,
4271+
record_counter=RecordCounter(),
42614272
)
42624273
)
42634274

@@ -4270,6 +4281,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
42704281
message_repository=MagicMock(),
42714282
max_records_limit=None,
42724283
stream_slice=two,
4284+
record_counter=RecordCounter(),
42734285
)
42744286
)
42754287

@@ -4283,6 +4295,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
42834295
message_repository=MagicMock(),
42844296
max_records_limit=None,
42854297
stream_slice=second_1,
4298+
record_counter=RecordCounter(),
42864299
)
42874300
)
42884301

@@ -4341,6 +4354,7 @@ def test_duplicate_partition_while_processing():
43414354
message_repository=MagicMock(),
43424355
max_records_limit=None,
43434356
stream_slice=generated[1],
4357+
record_counter=RecordCounter(),
43444358
)
43454359
)
43464360
# Now close the initial “1”
@@ -4352,6 +4366,7 @@ def test_duplicate_partition_while_processing():
43524366
message_repository=MagicMock(),
43534367
max_records_limit=None,
43544368
stream_slice=generated[0],
4369+
record_counter=RecordCounter(),
43554370
)
43564371
)
43574372

unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from unittest import TestCase
55
from unittest.mock import Mock
66

7-
# This allows for the global total_record_counter to be reset between tests
8-
import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator
97
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
108
from airbyte_cdk.sources.declarative.retrievers import Retriever
119
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
@@ -35,7 +33,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase):
3533
def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None:
3634
retriever = self._mock_retriever([])
3735
message_repository = Mock(spec=MessageRepository)
38-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
36+
partition_factory = DeclarativePartitionFactory(
3937
_STREAM_NAME,
4038
_SCHEMA_LOADER,
4139
retriever,
@@ -50,7 +48,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self)
5048
def test_given_a_mapping_when_read_then_yield_record(self) -> None:
5149
retriever = self._mock_retriever([_A_RECORD])
5250
message_repository = Mock(spec=MessageRepository)
53-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
51+
partition_factory = DeclarativePartitionFactory(
5452
_STREAM_NAME,
5553
_SCHEMA_LOADER,
5654
retriever,
@@ -68,7 +66,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None:
6866
def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None:
6967
retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE])
7068
message_repository = Mock(spec=MessageRepository)
71-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
69+
partition_factory = DeclarativePartitionFactory(
7270
_STREAM_NAME,
7371
_SCHEMA_LOADER,
7472
retriever,
@@ -80,8 +78,6 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N
8078
message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE)
8179

8280
def test_max_records_reached_stops_reading(self) -> None:
83-
declarative_partition_generator.total_record_counter = 0
84-
8581
expected_records = [
8682
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
8783
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
@@ -97,7 +93,7 @@ def test_max_records_reached_stops_reading(self) -> None:
9793

9894
retriever = self._mock_retriever(mock_records)
9995
message_repository = Mock(spec=MessageRepository)
100-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
96+
partition_factory = DeclarativePartitionFactory(
10197
_STREAM_NAME,
10298
_SCHEMA_LOADER,
10399
retriever,
@@ -113,8 +109,6 @@ def test_max_records_reached_stops_reading(self) -> None:
113109
assert actual_records == expected_records
114110

115111
def test_max_records_reached_on_previous_partition(self) -> None:
116-
declarative_partition_generator.total_record_counter = 0
117-
118112
expected_records = [
119113
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
120114
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
@@ -128,7 +122,7 @@ def test_max_records_reached_on_previous_partition(self) -> None:
128122

129123
retriever = self._mock_retriever(mock_records)
130124
message_repository = Mock(spec=MessageRepository)
131-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
125+
partition_factory = DeclarativePartitionFactory(
132126
_STREAM_NAME,
133127
_SCHEMA_LOADER,
134128
retriever,
@@ -151,6 +145,55 @@ def test_max_records_reached_on_previous_partition(self) -> None:
151145
# called for the first partition read and not the second
152146
retriever.read_records.assert_called_once()
153147

148+
def test_record_counter_isolation_between_different_factories(self) -> None:
149+
"""Test that record counters are isolated between different DeclarativePartitionFactory instances."""
150+
151+
# Create mock records that exceed the limit
152+
records = [
153+
Record(data={"id": 1, "name": "Record1"}, stream_name="stream_name"),
154+
Record(data={"id": 2, "name": "Record2"}, stream_name="stream_name"),
155+
Record(
156+
data={"id": 3, "name": "Record3"}, stream_name="stream_name"
157+
), # Should be blocked by limit
158+
]
159+
160+
# Create first factory with record limit of 2
161+
retriever1 = self._mock_retriever(records)
162+
message_repository1 = Mock(spec=MessageRepository)
163+
factory1 = DeclarativePartitionFactory(
164+
_STREAM_NAME,
165+
_SCHEMA_LOADER,
166+
retriever1,
167+
message_repository1,
168+
max_records_limit=2,
169+
)
170+
171+
# First factory should read up to limit (2 records)
172+
partition1 = factory1.create(_A_STREAM_SLICE)
173+
first_factory_records = list(partition1.read())
174+
assert len(first_factory_records) == 2
175+
176+
# Create second factory with same limit - should be independent
177+
retriever2 = self._mock_retriever(records)
178+
message_repository2 = Mock(spec=MessageRepository)
179+
factory2 = DeclarativePartitionFactory(
180+
_STREAM_NAME,
181+
_SCHEMA_LOADER,
182+
retriever2,
183+
message_repository2,
184+
max_records_limit=2,
185+
)
186+
187+
# Second factory should also be able to read up to limit (2 records)
188+
# This would fail before the fix because record counter was global
189+
partition2 = factory2.create(_A_STREAM_SLICE)
190+
second_factory_records = list(partition2.read())
191+
assert len(second_factory_records) == 2
192+
193+
# Verify both retrievers were called (confirming isolation)
194+
retriever1.read_records.assert_called_once()
195+
retriever2.read_records.assert_called_once()
196+
154197
@staticmethod
155198
def _mock_retriever(read_return_value: List[StreamData]) -> Mock:
156199
retriever = Mock(spec=Retriever)

0 commit comments

Comments
 (0)