Skip to content

Commit 599cba2

Browse files
siwikmMichał Siwik
andauthored
Remove message name conflict and adds some base typing + fix bug with {} as default arg for fuction #105 (issue) (#107)
* Removed name conflict for Message and added some basic typing * version bump * Revert "Removed name conflict for Message and added some basic typing" This reverts commit 696bece. * removed formatting and add more typing * removed formatting and add more typing * Removed bug with empty dict as default argument * Removed unused import of Dict and Any --------- Co-authored-by: Michał Siwik <michalsiwik@MacBook-Pro-epruf-2.local>
1 parent ff0602d commit 599cba2

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

memphis/memphis.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
# Credit for The NATS.IO Authors
23
# Copyright 2021-2022 The Memphis Authors
34
# Licensed under the Apache License, Version 2.0 (the “License”);
@@ -12,6 +13,7 @@
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
1415

16+
from typing import Iterable, Callable, Union
1517
import random
1618
import json
1719
import ssl
@@ -27,7 +29,6 @@
2729
from jsonschema import validate
2830
from google.protobuf import descriptor_pb2, descriptor_pool
2931
from google.protobuf.message_factory import MessageFactory
30-
from google.protobuf.message import Message
3132

3233
import memphis.retention_types as retention_types
3334
import memphis.storage_types as storage_types
@@ -36,7 +37,7 @@
3637

3738

3839
class set_interval():
39-
def __init__(self, func, sec):
40+
def __init__(self, func: Callable, sec: int):
4041
def func_wrapper():
4142
self.t = Timer(sec, func_wrapper)
4243
self.t.start()
@@ -63,7 +64,7 @@ def __init__(self):
6364
self.update_configurations_sub = {}
6465
self.configuration_tasks = {}
6566

66-
async def get_msgs_update_configurations(self, iterable):
67+
async def get_msgs_update_configurations(self, iterable: Iterable):
6768
try:
6869
async for msg in iterable:
6970
message = msg.data.decode("utf-8")
@@ -87,7 +88,7 @@ async def configurations_listener(self):
8788
except Exception as err:
8889
raise MemphisError(err)
8990

90-
async def connect(self, host, username, connection_token, port=6666, reconnect=True, max_reconnect=10, reconnect_interval_ms=1500, timeout_ms=15000, cert_file='', key_file='', ca_file=''):
91+
async def connect(self, host: str, username: str, connection_token: str, port: int = 6666, reconnect: bool = True, max_reconnect: int = 10, reconnect_interval_ms: int = 1500, timeout_ms: int = 15000, cert_file: str = "", key_file: str = "", ca_file: str = ""):
9192
"""Creates connection with Memphis.
9293
Args:
9394
host (str): memphis host.
@@ -157,7 +158,8 @@ async def send_notification(self, title, msg, failedMsg, type):
157158
msgToSend = json.dumps(msg).encode('utf-8')
158159
await self.broker_manager.publish("$memphis_notifications", msgToSend)
159160

160-
async def station(self, name, retention_type=retention_types.MAX_MESSAGE_AGE_SECONDS, retention_value=604800, storage_type=storage_types.DISK, replicas=1, idempotency_window_ms=120000, schema_name="", send_poison_msg_to_dls=True, send_schema_failed_msg_to_dls=True):
161+
async def station(self, name: str,
162+
retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS, retention_value: int = 604800, storage_type: str = storage_types.DISK, replicas: int = 1, idempotency_window_ms: int = 120000, schema_name: str = "", send_poison_msg_to_dls: bool = True, send_schema_failed_msg_to_dls: bool = True,):
161163
"""Creates a station.
162164
Args:
163165
name (str): station name.
@@ -296,7 +298,7 @@ def __normalize_host(self, host):
296298
else:
297299
return host
298300

299-
async def producer(self, station_name, producer_name, generate_random_suffix=False):
301+
async def producer(self, station_name: str, producer_name: str, generate_random_suffix: bool =False):
300302
"""Creates a producer.
301303
Args:
302304
station_name (str): station name to produce messages into.
@@ -405,7 +407,7 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data
405407
station_name, self.schema_updates_subs[station_name].messages))
406408
self.schema_tasks[station_name] = task
407409

408-
async def consumer(self, station_name, consumer_name, consumer_group="", pull_interval_ms=1000, batch_size=10, batch_max_time_to_wait_ms=5000, max_ack_time_ms=30000, max_msg_deliveries=10, generate_random_suffix=False, start_consume_from_sequence=1, last_messages=-1):
410+
async def consumer(self, station_name: str, consumer_name: str, consumer_group: str ="", pull_interval_ms: int = 1000, batch_size: int = 10, batch_max_time_to_wait_ms: int =5000, max_ack_time_ms: int=30000, max_msg_deliveries: int=10, generate_random_suffix: bool=False, start_consume_from_sequence: int=1, last_messages: int=-1):
409411
"""Creates a consumer.
410412
Args:.
411413
station_name (str): station name to consume messages from.
@@ -486,7 +488,7 @@ def add(self, key, value):
486488

487489

488490
class Station:
489-
def __init__(self, connection, name):
491+
def __init__(self, connection, name: str):
490492
self.connection = connection
491493
self.name = name.lower()
492494

@@ -531,7 +533,7 @@ def get_internal_name(name: str) -> str:
531533

532534

533535
class Producer:
534-
def __init__(self, connection, producer_name, station_name):
536+
def __init__(self, connection, producer_name: str, station_name: str):
535537
self.connection = connection
536538
self.producer_name = producer_name.lower()
537539
self.station_name = station_name
@@ -628,10 +630,10 @@ def validate_graphql(self, message):
628630
e = "Invalid message format, expected GraphQL"
629631
raise Exception("Schema validation has failed: " + str(e))
630632

631-
def get_dls_msg_id(self, station_name, producer_name, unix_time):
633+
def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
632634
return station_name + '~' + producer_name + '~0~' + unix_time
633635

634-
async def produce(self, message, ack_wait_sec=15, headers={}, async_produce=False, msg_id=None):
636+
async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, None] = None, async_produce: bool=False, msg_id: Union[str, None]= None):
635637
"""Produces a message into a station.
636638
Args:
637639
message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
@@ -650,10 +652,10 @@ async def produce(self, message, ack_wait_sec=15, headers={}, async_produce=Fals
650652
"$memphis_producedBy": self.producer_name,
651653
"$memphis_connectionId": self.connection.connection_id}
652654

653-
if msg_id != None:
655+
if msg_id is not None:
654656
memphis_headers["msg-id"] = msg_id
655657

656-
if headers != {}:
658+
if headers is not None:
657659
headers = headers.headers
658660
headers.update(memphis_headers)
659661
else:
@@ -760,7 +762,7 @@ async def default_error_handler(e):
760762

761763

762764
class Consumer:
763-
def __init__(self, connection, station_name, consumer_name, consumer_group, pull_interval_ms, batch_size, batch_max_time_to_wait_ms, max_ack_time_ms, max_msg_deliveries=10, error_callback=None, start_consume_from_sequence=1, last_messages=-1):
765+
def __init__(self, connection, station_name: str, consumer_name, consumer_group, pull_interval_ms: int, batch_size: int, batch_max_time_to_wait_ms: int, max_ack_time_ms: int, max_msg_deliveries: int=10, error_callback=None, start_consume_from_sequence: int=1, last_messages: int=-1):
764766
self.connection = connection
765767
self.station_name = station_name.lower()
766768
self.consumer_name = consumer_name.lower()
@@ -775,6 +777,7 @@ def __init__(self, connection, station_name, consumer_name, consumer_group, pull
775777
error_callback = default_error_handler
776778
self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
777779
self.start_consume_from_sequence = start_consume_from_sequence
780+
778781
self.last_messages= last_messages
779782
self.context = {}
780783

@@ -804,6 +807,7 @@ async def __consume(self, callback):
804807
Message(msg, self.connection, self.consumer_group))
805808
await callback(memphis_messages, None, self.context)
806809
await asyncio.sleep(self.pull_interval_ms/1000)
810+
807811
except asyncio.TimeoutError:
808812
await callback([], MemphisError("Memphis: TimeoutError"), self.context)
809813
continue

0 commit comments

Comments
 (0)