diff --git a/README.md b/README.md index 89d008a..0e7de4c 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ See [`Development` on Online Document](https://pgmq-sqlalchemy.readthedocs.io/en ## TODO -- [ ] Add **time-based** partition option and validation to `create_partitioned_queue` method. -- [ ] Read(single/batch) Archive Table ( `read_archive` method ) -- [ ] Detach Archive Table ( `detach_archive` method ) -- [ ] Add `set_vt` utils method. \ No newline at end of file +- [x] Add **time-based** partition option and validation to `create_partitioned_queue` method. +- [x] Read(single/batch) Archive Table ( `read_archive` method ) +- [x] Detach Archive Table ( `detach_archive` method ) +- [x] Add `set_vt` utils method. \ No newline at end of file diff --git a/pgmq_sqlalchemy/queue.py b/pgmq_sqlalchemy/queue.py index 9b393a0..20442f0 100644 --- a/pgmq_sqlalchemy/queue.py +++ b/pgmq_sqlalchemy/queue.py @@ -1,5 +1,6 @@ import asyncio -from typing import List, Optional +import re +from typing import List, Optional, Union from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker @@ -239,11 +240,40 @@ async def _create_partitioned_queue_async( ) await session.commit() + def _validate_partition_interval(self, interval: Union[int, str]) -> str: + """Validate partition interval format. + + Args: + interval: Either an integer for numeric partitioning or a string for time-based partitioning + (e.g., '1 day', '1 hour', '7 days') + + Returns: + The validated interval as a string + + Raises: + ValueError: If the interval format is invalid + """ + if isinstance(interval, int): + if interval <= 0: + raise ValueError("Numeric partition interval must be positive") + return str(interval) + + # Validate time-based interval format + # Valid PostgreSQL interval formats: '1 day', '7 days', '1 hour', '1 month', etc. + time_pattern = r"^\d+\s+(microsecond|millisecond|second|minute|hour|day|week|month|year)s?$" + if not re.match(time_pattern, interval.strip(), re.IGNORECASE): + raise ValueError( + f"Invalid time-based partition interval: '{interval}'. " + "Expected format: ' ' where unit is one of: " + "microsecond, millisecond, second, minute, hour, day, week, month, year" + ) + return interval.strip() + def create_partitioned_queue( self, queue_name: str, - partition_interval: int = 10000, - retention_interval: int = 100000, + partition_interval: Union[int, str] = 10000, + retention_interval: Union[int, str] = 100000, ) -> None: """Create a new **partitioned** queue. @@ -252,16 +282,23 @@ def create_partitioned_queue( .. code-block:: python + # Numeric partitioning (by msg_id) pgmq_client.create_partitioned_queue('my_partitioned_queue', partition_interval=10000, retention_interval=100000) + # Time-based partitioning (by enqueued_at) + pgmq_client.create_partitioned_queue('my_time_queue', partition_interval='1 day', retention_interval='7 days') + Args: queue_name (str): The name of the queue, should be less than 48 characters. - partition_interval (int): Will create a new partition every ``partition_interval`` messages. - retention_interval (int): The interval for retaining partitions. Any messages that have a `msg_id` less than ``max(msg_id)`` - ``retention_interval`` will be dropped. + partition_interval (Union[int, str]): For numeric partitioning, the number of messages per partition. + For time-based partitioning, a PostgreSQL interval string (e.g., '1 day', '1 hour'). + retention_interval (Union[int, str]): For numeric partitioning, messages with msg_id less than max(msg_id) - retention_interval will be dropped. + For time-based partitioning, a PostgreSQL interval string (e.g., '7 days'). .. note:: - | Currently, only support for partitioning by **msg_id**. - | Will add **time-based partitioning** in the future ``pgmq-sqlalchemy`` release. + | Supports both **numeric** (by ``msg_id``) and **time-based** (by ``enqueued_at``) partitioning. + | For time-based partitioning, use interval strings like '1 day', '1 hour', '7 days', etc. + | For numeric partitioning, use integer values. .. important:: | You must make sure that the ``pg_partman`` extension already **installed** in the Postgres. @@ -273,14 +310,24 @@ def create_partitioned_queue( # check if the pg_partman extension exists before creating a partitioned queue at runtime self._check_pg_partman_ext() + # Validate partition intervals + validated_partition_interval = self._validate_partition_interval( + partition_interval + ) + validated_retention_interval = self._validate_partition_interval( + retention_interval + ) + if self.is_async: return self.loop.run_until_complete( self._create_partitioned_queue_async( - queue_name, str(partition_interval), str(retention_interval) + queue_name, + validated_partition_interval, + validated_retention_interval, ) ) return self._create_partitioned_queue_sync( - queue_name, str(partition_interval), str(retention_interval) + queue_name, validated_partition_interval, validated_retention_interval ) def _validate_queue_name_sync(self, queue_name: str) -> None: @@ -1214,3 +1261,238 @@ def metrics_all(self) -> Optional[List[QueueMetrics]]: if self.is_async: return self.loop.run_until_complete(self._metrics_all_async()) return self._metrics_all_sync() + + def _set_vt_sync(self, queue_name: str, msg_id: int, vt: int) -> Optional[Message]: + """Set the visibility timeout of a message synchronously.""" + with self.session_maker() as session: + row = session.execute( + text("select * from pgmq.set_vt(:queue_name, :msg_id, :vt);"), + {"queue_name": queue_name, "msg_id": msg_id, "vt": vt}, + ).fetchone() + session.commit() + if row is None: + return None + return Message( + msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4] + ) + + async def _set_vt_async( + self, queue_name: str, msg_id: int, vt: int + ) -> Optional[Message]: + """Set the visibility timeout of a message asynchronously.""" + async with self.session_maker() as session: + row = ( + await session.execute( + text("select * from pgmq.set_vt(:queue_name, :msg_id, :vt);"), + {"queue_name": queue_name, "msg_id": msg_id, "vt": vt}, + ) + ).fetchone() + await session.commit() + if row is None: + return None + return Message( + msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4] + ) + + def set_vt(self, queue_name: str, msg_id: int, vt: int) -> Optional[Message]: + """ + Set the visibility timeout of a message. + + Args: + queue_name (str): The name of the queue. + msg_id (int): The message ID. + vt (int): The new visibility timeout in seconds. + + Returns: + |schema_message_class|_ or ``None`` if the message does not exist. + + Usage: + + .. code-block:: python + + msg_id = pgmq_client.send('my_queue', {'key': 'value'}) + msg = pgmq_client.read('my_queue', vt=10) + # extend the visibility timeout + msg = pgmq_client.set_vt('my_queue', msg_id, 20) + assert msg is not None + + """ + if self.is_async: + return self.loop.run_until_complete( + self._set_vt_async(queue_name, msg_id, vt) + ) + return self._set_vt_sync(queue_name, msg_id, vt) + + def _detach_archive_sync(self, queue_name: str) -> None: + """Detach the archive table for a queue synchronously.""" + with self.session_maker() as session: + session.execute( + text("select pgmq.detach_archive(:queue_name);"), + {"queue_name": queue_name}, + ) + session.commit() + + async def _detach_archive_async(self, queue_name: str) -> None: + """Detach the archive table for a queue asynchronously.""" + async with self.session_maker() as session: + await session.execute( + text("select pgmq.detach_archive(:queue_name);"), + {"queue_name": queue_name}, + ) + await session.commit() + + def detach_archive(self, queue_name: str) -> None: + """ + Detach the archive table for a queue. + + * The archive table (``pgmq.a_``) will be detached from the queue. + * The archive table will remain in the database but will no longer be associated with the queue. + * This is useful when you want to keep the archived messages but stop archiving new messages. + + .. code-block:: python + + pgmq_client.detach_archive('my_queue') + + """ + if self.is_async: + return self.loop.run_until_complete(self._detach_archive_async(queue_name)) + return self._detach_archive_sync(queue_name) + + def _read_archive_sync(self, queue_name: str) -> Optional[Message]: + """Read a single message from the archive table synchronously.""" + with self.session_maker() as session: + row = session.execute( + text( + f"select msg_id, read_ct, enqueued_at, vt, message from pgmq.a_{queue_name} limit 1;" + ) + ).fetchone() + session.commit() + if row is None: + return None + return Message( + msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4] + ) + + async def _read_archive_async(self, queue_name: str) -> Optional[Message]: + """Read a single message from the archive table asynchronously.""" + async with self.session_maker() as session: + row = ( + await session.execute( + text( + f"select msg_id, read_ct, enqueued_at, vt, message from pgmq.a_{queue_name} limit 1;" + ) + ) + ).fetchone() + await session.commit() + if row is None: + return None + return Message( + msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4] + ) + + def read_archive(self, queue_name: str) -> Optional[Message]: + """ + Read a single message from the archive table. + + Returns: + |schema_message_class|_ or ``None`` if the archive is empty. + + Usage: + + .. code-block:: python + + msg_id = pgmq_client.send('my_queue', {'key': 'value'}) + pgmq_client.archive('my_queue', msg_id) + archived_msg = pgmq_client.read_archive('my_queue') + print(archived_msg.message) + + """ + # Validate queue name first to prevent SQL injection + self.validate_queue_name(queue_name) + if self.is_async: + return self.loop.run_until_complete(self._read_archive_async(queue_name)) + return self._read_archive_sync(queue_name) + + def _read_archive_batch_sync( + self, queue_name: str, batch_size: int = 1 + ) -> Optional[List[Message]]: + """Read multiple messages from the archive table synchronously.""" + with self.session_maker() as session: + rows = session.execute( + text( + f"select msg_id, read_ct, enqueued_at, vt, message from pgmq.a_{queue_name} limit :batch_size;" + ), + {"batch_size": batch_size}, + ).fetchall() + session.commit() + if not rows: + return None + return [ + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=row[4], + ) + for row in rows + ] + + async def _read_archive_batch_async( + self, queue_name: str, batch_size: int = 1 + ) -> Optional[List[Message]]: + """Read multiple messages from the archive table asynchronously.""" + async with self.session_maker() as session: + rows = ( + await session.execute( + text( + f"select msg_id, read_ct, enqueued_at, vt, message from pgmq.a_{queue_name} limit :batch_size;" + ), + {"batch_size": batch_size}, + ) + ).fetchall() + await session.commit() + if not rows: + return None + return [ + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=row[4], + ) + for row in rows + ] + + def read_archive_batch( + self, queue_name: str, batch_size: int = 1 + ) -> Optional[List[Message]]: + """ + Read multiple messages from the archive table. + + Args: + queue_name (str): The name of the queue. + batch_size (int): The number of messages to read. + + Returns: + List of |schema_message_class|_ or ``None`` if the archive is empty. + + Usage: + + .. code-block:: python + + msg_ids = pgmq_client.send_batch('my_queue', [{'key': 'value'}, {'key': 'value'}]) + pgmq_client.archive_batch('my_queue', msg_ids) + archived_msgs = pgmq_client.read_archive_batch('my_queue', batch_size=10) + for msg in archived_msgs: + print(msg.message) + + """ + # Validate queue name first to prevent SQL injection + self.validate_queue_name(queue_name) + if self.is_async: + return self.loop.run_until_complete( + self._read_archive_batch_async(queue_name, batch_size) + ) + return self._read_archive_batch_sync(queue_name, batch_size) diff --git a/tests/test_queue.py b/tests/test_queue.py index fd53e38..781b234 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -2,6 +2,7 @@ import pytest import time +from sqlalchemy import text from sqlalchemy.exc import ProgrammingError from filelock import FileLock from pgmq_sqlalchemy import PGMQueue @@ -383,3 +384,177 @@ def test_metrics_all_queues(pgmq_setup_teardown: PGMQ_WITH_QUEUE): assert queue_2.queue_length == 2 assert queue_1.total_messages == 3 assert queue_2.total_messages == 2 + + +# Tests for set_vt method +def test_set_vt(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + msg = MSG + msg_id = pgmq.send(queue_name, msg) + msg_read = pgmq.read(queue_name, vt=10) + assert msg_read.msg_id == msg_id + # extend the visibility timeout + msg_updated = pgmq.set_vt(queue_name, msg_id, 20) + assert msg_updated is not None + assert msg_updated.msg_id == msg_id + + +def test_set_vt_not_exist(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + msg_updated = pgmq.set_vt(queue_name, 999, 20) + assert msg_updated is None + + +# Tests for detach_archive method +@pgmq_deps +def test_detach_archive(pgmq_fixture, db_session): + """Test detach_archive method - detaches archive table from queue.""" + pgmq: PGMQueue = pgmq_fixture + queue_name = f"test_queue_{uuid.uuid4().hex}" + pgmq.create_queue(queue_name) + msg = MSG + msg_id = pgmq.send(queue_name, msg) + pgmq.archive(queue_name, msg_id) + + # Detach archive should not raise an error + pgmq.detach_archive(queue_name) + + # Read the archive to ensure it still exists after detaching + archived_msg = pgmq.read_archive(queue_name) + assert archived_msg is not None + assert archived_msg.msg_id == msg_id + + # Cleanup: Drop the archive and queue tables + # After detaching, the archive is no longer part of the extension + # We need to drop both tables manually by first removing them from the extension + if pgmq.is_async: + + async def cleanup(): + async with pgmq.session_maker() as session: + # Drop archive table (already detached) + await session.execute( + text(f"DROP TABLE IF EXISTS pgmq.a_{queue_name} CASCADE;") + ) + # Detach and drop queue table + await session.execute( + text(f"ALTER EXTENSION pgmq DROP TABLE pgmq.q_{queue_name};") + ) + await session.execute( + text(f"DROP TABLE IF EXISTS pgmq.q_{queue_name} CASCADE;") + ) + await session.commit() + + pgmq.loop.run_until_complete(cleanup()) + else: + with pgmq.session_maker() as session: + # Drop archive table (already detached) + session.execute(text(f"DROP TABLE IF EXISTS pgmq.a_{queue_name} CASCADE;")) + # Detach and drop queue table + session.execute( + text(f"ALTER EXTENSION pgmq DROP TABLE pgmq.q_{queue_name};") + ) + session.execute(text(f"DROP TABLE IF EXISTS pgmq.q_{queue_name} CASCADE;")) + session.commit() + + +# Tests for read_archive methods +def test_read_archive(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + msg = MSG + msg_ids = pgmq.send_batch(queue_name, [msg, msg, msg]) + pgmq.archive(queue_name, msg_ids[0]) + archived_msg = pgmq.read_archive(queue_name) + assert archived_msg is not None + assert archived_msg.msg_id == msg_ids[0] + assert archived_msg.message == msg + + +def test_read_archive_empty(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + archived_msg = pgmq.read_archive(queue_name) + assert archived_msg is None + + +def test_read_archive_batch(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + msg = MSG + msg_ids = pgmq.send_batch(queue_name, [msg, msg, msg]) + pgmq.archive_batch(queue_name, msg_ids) + archived_msgs = pgmq.read_archive_batch(queue_name, batch_size=10) + assert archived_msgs is not None + assert len(archived_msgs) == 3 + assert [m.msg_id for m in archived_msgs] == msg_ids + for m in archived_msgs: + assert m.message == msg + + +def test_read_archive_batch_empty(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + archived_msgs = pgmq.read_archive_batch(queue_name, batch_size=10) + assert archived_msgs is None + + +def test_read_archive_batch_limit(pgmq_setup_teardown: PGMQ_WITH_QUEUE): + pgmq, queue_name = pgmq_setup_teardown + msg = MSG + msg_ids = pgmq.send_batch(queue_name, [msg, msg, msg, msg, msg]) + pgmq.archive_batch(queue_name, msg_ids) + archived_msgs = pgmq.read_archive_batch(queue_name, batch_size=3) + assert archived_msgs is not None + assert len(archived_msgs) == 3 + + +# Tests for time-based partitioned queues +@pgmq_deps +def test_create_time_based_partitioned_queue(pgmq_fixture, db_session): + pgmq: PGMQueue = pgmq_fixture + queue_name = f"test_queue_{uuid.uuid4().hex}" + pgmq.create_partitioned_queue( + queue_name, partition_interval="1 day", retention_interval="7 days" + ) + assert check_queue_exists(db_session, queue_name) is True + + +@pgmq_deps +def test_create_time_based_partitioned_queue_various_intervals( + pgmq_fixture, db_session +): + pgmq: PGMQueue = pgmq_fixture + + # Test with hour + queue_name_hour = f"test_queue_{uuid.uuid4().hex}" + pgmq.create_partitioned_queue( + queue_name_hour, partition_interval="1 hour", retention_interval="24 hours" + ) + assert check_queue_exists(db_session, queue_name_hour) is True + + # Test with week + queue_name_week = f"test_queue_{uuid.uuid4().hex}" + pgmq.create_partitioned_queue( + queue_name_week, partition_interval="1 week", retention_interval="4 weeks" + ) + assert check_queue_exists(db_session, queue_name_week) is True + + +@pgmq_deps +def test_create_partitioned_queue_invalid_time_interval(pgmq_fixture): + pgmq: PGMQueue = pgmq_fixture + queue_name = f"test_queue_{uuid.uuid4().hex}" + with pytest.raises(ValueError) as e: + pgmq.create_partitioned_queue( + queue_name, + partition_interval="invalid interval", + retention_interval="7 days", + ) + assert "Invalid time-based partition interval" in str(e.value) + + +@pgmq_deps +def test_create_partitioned_queue_invalid_numeric_interval(pgmq_fixture): + pgmq: PGMQueue = pgmq_fixture + queue_name = f"test_queue_{uuid.uuid4().hex}" + with pytest.raises(ValueError) as e: + pgmq.create_partitioned_queue( + queue_name, partition_interval=-100, retention_interval=100000 + ) + assert "Numeric partition interval must be positive" in str(e.value)