From 4ff838c8f5858de1fc8552bb51b3b1f91d3e8015 Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Mon, 24 Jan 2022 14:04:02 +0000 Subject: [PATCH] Add ability to use `wal2json` `format-version` 2 This option removes the need to use `write-in-chunks`. --- README.md | 1 + tap_postgres/__init__.py | 1 + .../sync_strategies/logical_replication.py | 163 +++++++++++++++--- tests/test_full_table_interruption.py | 10 +- tests/test_logical_replication.py | 124 ++++++++++++- tests/utils.py | 5 +- 6 files changed, 263 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 9a787307..d523cab7 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Full list of options in `config.json`: | use_secondary | Boolean | No | Use a database replica for `INCREMENTAL` and `FULL_TABLE` replication (Default : False) | | secondary_host | String | No | PostgreSQL Replica host (required if `use_secondary` is `True`) | | secondary_port | Integer | No | PostgreSQL Replica port (required if `use_secondary` is `True`) | +| wal2json_message_format | Integer | No | Which `wal2json` message format to use (Default: 1) | ### Run the tap in Discovery Mode diff --git a/tap_postgres/__init__.py b/tap_postgres/__init__.py index 5c7c2ac1..173314b1 100644 --- a/tap_postgres/__init__.py +++ b/tap_postgres/__init__.py @@ -407,6 +407,7 @@ def main_impl(): 'break_at_end_lsn': args.config.get('break_at_end_lsn', True), 'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)), 'use_secondary': args.config.get('use_secondary', False), + 'wal2json_message_format': args.config.get('wal2json_message_format', 1) } if conn_config['use_secondary']: diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index eb690956..cc7fb3f2 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -377,18 +377,24 @@ def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, time_extracted=time_extracted) -# pylint: disable=unused-argument,too-many-locals -def consume_message(streams, state, msg, time_extracted, conn_info): - # Strip leading comma generated by write-in-chunks and parse valid JSON - try: - payload = json.loads(msg.payload.lstrip(',')) - except Exception: - return state +def check_for_new_columns(columns, target_stream, conn_info): + diff = set(columns).difference(target_stream['schema']['properties'].keys()) - lsn = msg.data_start + if diff: + LOGGER.info('Detected new columns "%s", refreshing schema of stream %s', diff, target_stream['stream']) + # encountered a column that is not in the schema + # refresh the stream schema and metadata by running discovery + refresh_streams_schema(conn_info, [target_stream]) - streams_lookup = {s['tap_stream_id']: s for s in streams} + # add the automatic properties back to the stream + add_automatic_properties(target_stream, conn_info.get('debug_lsn', False)) + # publish new schema + sync_common.send_schema_message(target_stream, ['lsn']) + + +# pylint: disable=too-many-locals +def consume_message_format_1(payload, conn_info, streams_lookup, state, time_extracted, lsn): tap_stream_id = post_db.compute_tap_stream_id(payload['schema'], payload['table']) if streams_lookup.get(tap_stream_id) is None: return state @@ -400,22 +406,8 @@ def consume_message(streams, state, msg, time_extracted, conn_info): # Get the additional fields in payload that are not in schema properties: # only inserts and updates have the list of columns that can be used to detect any different in columns - diff = set() if payload['kind'] in {'insert', 'update'}: - diff = set(payload['columnnames']).difference(target_stream['schema']['properties'].keys()) - - # if there is new columns in the payload that are not in the schema properties then refresh the stream schema - if diff: - LOGGER.info('Detected new columns "%s", refreshing schema of stream %s', diff, target_stream['stream']) - # encountered a column that is not in the schema - # refresh the stream schema and metadata by running discovery - refresh_streams_schema(conn_info, [target_stream]) - - # add the automatic properties back to the stream - add_automatic_properties(target_stream, conn_info.get('debug_lsn', False)) - - # publish new schema - sync_common.send_schema_message(target_stream, ['lsn']) + check_for_new_columns(payload['columnnames'], target_stream, conn_info) stream_version = get_stream_version(target_stream['tap_stream_id'], state) stream_md_map = metadata.to_map(target_stream['metadata']) @@ -476,6 +468,109 @@ def consume_message(streams, state, msg, time_extracted, conn_info): return state +def consume_message_format_2(payload, conn_info, streams_lookup, state, time_extracted, lsn): + ## Action Types: + # I = Insert + # U = Update + # D = Delete + # B = Begin Transaction + # C = Commit Transaction + # M = Message + # T = Truncate + action = payload['action'] + if action not in {'U', 'I', 'D'}: + raise UnsupportedPayloadKindError(f"unrecognized replication operation: {action}") + + tap_stream_id = post_db.compute_tap_stream_id(payload['schema'], payload['table']) + if streams_lookup.get(tap_stream_id) is not None: + target_stream = streams_lookup[tap_stream_id] + + # Get the additional fields in payload that are not in schema properties: + # only inserts and updates have the list of columns that can be used to detect any different in columns + if payload['action'] in {'I', 'U'}: + check_for_new_columns({column['name'] for column in payload['columns']}, target_stream, conn_info) + + stream_version = get_stream_version(target_stream['tap_stream_id'], state) + stream_md_map = metadata.to_map(target_stream['metadata']) + + desired_columns = {c for c in target_stream['schema']['properties'].keys() if sync_common.should_sync_column( + stream_md_map, c)} + + stream_version = get_stream_version(target_stream['tap_stream_id'], state) + stream_md_map = metadata.to_map(target_stream['metadata']) + + desired_columns = [ + col for col in target_stream['schema']['properties'].keys() + if sync_common.should_sync_column(stream_md_map, col) + ] + + col_names = [] + col_vals = [] + if payload['action'] in ['I', 'U']: + for column in payload['columns']: + if column['name'] in set(desired_columns): + col_names.append(column['name']) + col_vals.append(column['value']) + + col_names = col_names + ['_sdc_deleted_at'] + col_vals = col_vals + [None] + + if conn_info.get('debug_lsn'): + col_names = col_names + ['_sdc_lsn'] + col_vals = col_vals + [str(lsn)] + + elif payload['action'] == 'D': + for column in payload['identity']: + if column['name'] in set(desired_columns): + col_names.append(column['name']) + col_vals.append(column['value']) + + col_names = col_names + ['_sdc_deleted_at'] + col_vals = col_vals + [singer.utils.strftime(singer.utils.strptime_to_utc(payload['timestamp']))] + + if conn_info.get('debug_lsn'): + col_vals = col_vals + [str(lsn)] + col_names = col_names + ['_sdc_lsn'] + + # Write 1 record to match the API of V1 + record_message = row_to_singer_message( + target_stream, + col_vals, + stream_version, + col_names, + time_extracted, + stream_md_map, + conn_info, + ) + + singer.write_message(record_message) + state = singer.write_bookmark(state, target_stream['tap_stream_id'], 'lsn', lsn) + + return state + + +def consume_message(streams, state, msg, time_extracted, conn_info): + # Strip leading comma generated by write-in-chunks and parse valid JSON + try: + payload = json.loads(msg.payload.lstrip(',')) + except Exception: + return state + + lsn = msg.data_start + + streams_lookup = {s['tap_stream_id']: s for s in streams} + + message_format = conn_info['wal2json_message_format'] + if message_format == 1: + state = consume_message_format_1(payload, conn_info, streams_lookup, state, time_extracted, lsn) + elif message_format == 2: + state = consume_message_format_2(payload, conn_info, streams_lookup, state, time_extracted, lsn) + else: + raise Exception(f"Unknown wal2json message format version: {message_format}") + + return state + + def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'): """Generate replication slot name with @@ -591,14 +686,26 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): int_to_lsn(end_lsn), slot) # psycopg2 2.8.4 will send a keep-alive message to postgres every status_interval + options = { + 'add-tables': streams_to_wal2json_tables(logical_streams), + 'include-timestamp': True, + 'include-types': False, + } + if conn_info['wal2json_message_format'] == 1: + options.update({'write-in-chunks': 1}) + else: + options.update( + { + 'format-version': conn_info['wal2json_message_format'], + 'include-transaction': False, + 'actions': 'insert,update,delete', + } + ) cur.start_replication(slot_name=slot, decode=True, start_lsn=start_lsn, status_interval=poll_interval, - options={ - 'write-in-chunks': 1, - 'add-tables': streams_to_wal2json_tables(logical_streams) - }) + options=options) except psycopg2.ProgrammingError as ex: raise Exception(f"Unable to start replication with logical replication (slot {ex})") from ex diff --git a/tests/test_full_table_interruption.py b/tests/test_full_table_interruption.py index c1417468..b99e9eea 100644 --- a/tests/test_full_table_interruption.py +++ b/tests/test_full_table_interruption.py @@ -48,7 +48,7 @@ def do_not_dump_catalog(catalog): tap_postgres.dump_catalog = do_not_dump_catalog full_table.UPDATE_BOOKMARK_PERIOD = 1 -@pytest.mark.parametrize('use_secondary', [False, True]) +@pytest.mark.parametrize('use_secondary,message_format', [(False, 1), (True, 2)]) @unittest.mock.patch('psycopg2.connect', wraps=psycopg2.connect) class TestLogicalInterruption: maxDiff = None @@ -67,11 +67,11 @@ def setup_method(self): global CAUGHT_MESSAGES CAUGHT_MESSAGES.clear() - def test_catalog(self, mock_connect, use_secondary): + def test_catalog(self, mock_connect, use_secondary, message_format): singer.write_message = singer_write_message_no_cow pg_common.write_schema_message = singer_write_message_ok - conn_config = get_test_connection_config(use_secondary=use_secondary) + conn_config = get_test_connection_config(use_secondary=use_secondary, message_format=message_format) streams = tap_postgres.do_discovery(conn_config) # Assert that we connected to the correct database @@ -115,7 +115,7 @@ def test_catalog(self, mock_connect, use_secondary): #the initial phase of cows logical replication will be a full table. #it will sync the first record and then blow up on the 2nd record try: - tap_postgres.do_sync(get_test_connection_config(use_secondary=use_secondary), {'streams' : streams}, None, state) + tap_postgres.do_sync(conn_config, {'streams' : streams}, None, state) except Exception: blew_up_on_cow = True @@ -171,7 +171,7 @@ def test_catalog(self, mock_connect, use_secondary): global COW_RECORD_COUNT COW_RECORD_COUNT = 0 CAUGHT_MESSAGES.clear() - tap_postgres.do_sync(get_test_connection_config(use_secondary=use_secondary), {'streams' : streams}, None, old_state) + tap_postgres.do_sync(conn_config, {'streams' : streams}, None, old_state) mock_connect.assert_called_with(**expected_connection) mock_connect.reset_mock() diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py index ae0f5832..21d9e804 100644 --- a/tests/test_logical_replication.py +++ b/tests/test_logical_replication.py @@ -137,7 +137,7 @@ def test_consume_with_message_payload_is_not_json_expect_same_state(self): self.WalMessage(payload='this is an invalid json message', data_start=None), None, - {} + {'wal2json_message_format': 1}, ) self.assertDictEqual({}, output) @@ -148,7 +148,19 @@ def test_consume_with_message_stream_in_payload_is_not_selected_expect_same_stat self.WalMessage(payload='{"schema": "myschema", "table": "notmytable"}', data_start='some lsn'), None, - {} + {'wal2json_message_format': 1}, + ) + + self.assertDictEqual({}, output) + + def test_consume_with_message_stream_in_payload_is_not_selected_expect_same_state_format_2(self): + output = logical_replication.consume_message( + [{'tap_stream_id': 'myschema-mytable'}], + {}, + self.WalMessage(payload='{"action": "U", "schema": "myschema", "table": "notmytable"}', + data_start='some lsn'), + None, + {'wal2json_message_format': 2}, ) self.assertDictEqual({}, output) @@ -161,7 +173,18 @@ def test_consume_with_payload_kind_is_not_supported_expect_exception(self): self.WalMessage(payload='{"kind":"truncate", "schema": "myschema", "table": "mytable"}', data_start='some lsn'), None, - {} + {'wal2json_message_format': 1}, + ) + + def test_consume_with_payload_kind_is_not_supported_expect_exception_format_2(self): + with self.assertRaises(UnsupportedPayloadKindError): + logical_replication.consume_message( + [{'tap_stream_id': 'myschema-mytable'}], + {}, + self.WalMessage(payload='{"action":"T", "schema": "myschema", "table": "mytable"}', + data_start='some lsn'), + None, + {'wal2json_message_format': 2}, ) @patch('tap_postgres.logical_replication.singer.write_message') @@ -231,11 +254,100 @@ def test_consume_message_with_new_column_in_payload_will_refresh_schema(self, '"schema": "myschema", ' '"table": "mytable",' '"columnnames": ["id", "date_created", "new_col"],' - '"columnnames": [1, null, "some random text"]' + '"columnvalues": [1, null, "some random text"]' + '}', + data_start='some lsn'), + None, + {'wal2json_message_format': 1}, + ) + + self.assertDictEqual(return_v, + { + 'bookmarks': { + "myschema-mytable": { + "last_replication_method": "LOG_BASED", + "lsn": "some lsn", + "version": 1000, + "xmin": None + } + } + }) + + refresh_schema_mock.assert_called_once_with({'wal2json_message_format': 1}, [streams[0]]) + send_schema_mock.assert_called_once() + write_message_mock.assert_called_once() + + @patch('tap_postgres.logical_replication.singer.write_message') + @patch('tap_postgres.logical_replication.sync_common.send_schema_message') + @patch('tap_postgres.logical_replication.refresh_streams_schema') + def test_consume_message_with_new_column_in_payload_will_refresh_schema_format_2(self, + refresh_schema_mock, + send_schema_mock, + write_message_mock): + streams = [ + { + 'tap_stream_id': 'myschema-mytable', + 'stream': 'mytable', + 'schema': { + 'properties': { + 'id': {}, + 'date_created': {} + } + }, + 'metadata': [ + { + 'breadcrumb': [], + 'metadata': { + 'is-view': False, + 'table-key-properties': ['id'], + 'schema-name': 'myschema' + } + }, + { + "breadcrumb": [ + "properties", + "id" + ], + "metadata": { + "sql-datatype": "integer", + "inclusion": "automatic", + } + }, + { + "breadcrumb": [ + "properties", + "date_created" + ], + "metadata": { + "sql-datatype": "datetime", + "inclusion": "available", + "selected": True + } + } + ], + } + ] + + return_v = logical_replication.consume_message( + streams, + { + 'bookmarks': { + "myschema-mytable": { + "last_replication_method": "LOG_BASED", + "lsn": None, + "version": 1000, + "xmin": None + } + } + }, + self.WalMessage(payload='{"action": "I", ' + '"schema": "myschema", ' + '"table": "mytable",' + '"columns": [{"name": "id", "value": 1}, {"name": "date_created", "value": null}, {"name": "new_col", "value": "some random text"}]' '}', data_start='some lsn'), None, - {} + {'wal2json_message_format': 2}, ) self.assertDictEqual(return_v, @@ -250,7 +362,7 @@ def test_consume_message_with_new_column_in_payload_will_refresh_schema(self, } }) - refresh_schema_mock.assert_called_once_with({}, [streams[0]]) + refresh_schema_mock.assert_called_once_with({'wal2json_message_format': 2}, [streams[0]]) send_schema_mock.assert_called_once() write_message_mock.assert_called_once() diff --git a/tests/utils.py b/tests/utils.py index a3c2e2cb..ccfbcf9b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,14 +10,15 @@ LOGGER = get_logger() -def get_test_connection_config(target_db='postgres', use_secondary=False): +def get_test_connection_config(target_db='postgres', use_secondary=False, message_format=1): try: conn_config = {'host': os.environ['TAP_POSTGRES_HOST'], 'user': os.environ['TAP_POSTGRES_USER'], 'password': os.environ['TAP_POSTGRES_PASSWORD'], 'port': os.environ['TAP_POSTGRES_PORT'], 'dbname': target_db, - 'use_secondary': use_secondary,} + 'use_secondary': use_secondary, + 'wal2json_message_format': message_format} except KeyError as exc: raise Exception( "set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT"