1+ from __future__ import annotations
12# Credit for The NATS.IO Authors
23# Copyright 2021-2022 The Memphis Authors
34# Licensed under the Apache License, Version 2.0 (the “License”);
1213# See the License for the specific language governing permissions and
1314# limitations under the License.
1415
16+ from typing import Iterable , Callable , Union
1517import random
1618import json
1719import ssl
2729from jsonschema import validate
2830from google .protobuf import descriptor_pb2 , descriptor_pool
2931from google .protobuf .message_factory import MessageFactory
30- from google .protobuf .message import Message
3132
3233import memphis .retention_types as retention_types
3334import memphis .storage_types as storage_types
3637
3738
3839class set_interval ():
39- def __init__ (self , func , sec ):
40+ def __init__ (self , func : Callable , sec : int ):
4041 def func_wrapper ():
4142 self .t = Timer (sec , func_wrapper )
4243 self .t .start ()
@@ -63,7 +64,7 @@ def __init__(self):
6364 self .update_configurations_sub = {}
6465 self .configuration_tasks = {}
6566
66- async def get_msgs_update_configurations (self , iterable ):
67+ async def get_msgs_update_configurations (self , iterable : Iterable ):
6768 try :
6869 async for msg in iterable :
6970 message = msg .data .decode ("utf-8" )
@@ -87,7 +88,7 @@ async def configurations_listener(self):
8788 except Exception as err :
8889 raise MemphisError (err )
8990
90- async def connect (self , host , username , connection_token , port = 6666 , reconnect = True , max_reconnect = 10 , reconnect_interval_ms = 1500 , timeout_ms = 15000 , cert_file = '' , key_file = '' , ca_file = '' ):
91+ async def connect (self , host : str , username : str , connection_token : str , port : int = 6666 , reconnect : bool = True , max_reconnect : int = 10 , reconnect_interval_ms : int = 1500 , timeout_ms : int = 15000 , cert_file : str = "" , key_file : str = "" , ca_file : str = "" ):
9192 """Creates connection with Memphis.
9293 Args:
9394 host (str): memphis host.
@@ -157,7 +158,8 @@ async def send_notification(self, title, msg, failedMsg, type):
157158 msgToSend = json .dumps (msg ).encode ('utf-8' )
158159 await self .broker_manager .publish ("$memphis_notifications" , msgToSend )
159160
160- async def station (self , name , retention_type = retention_types .MAX_MESSAGE_AGE_SECONDS , retention_value = 604800 , storage_type = storage_types .DISK , replicas = 1 , idempotency_window_ms = 120000 , schema_name = "" , send_poison_msg_to_dls = True , send_schema_failed_msg_to_dls = True ):
161+ async def station (self , name : str ,
162+ retention_type : str = retention_types .MAX_MESSAGE_AGE_SECONDS , retention_value : int = 604800 , storage_type : str = storage_types .DISK , replicas : int = 1 , idempotency_window_ms : int = 120000 , schema_name : str = "" , send_poison_msg_to_dls : bool = True , send_schema_failed_msg_to_dls : bool = True ,):
161163 """Creates a station.
162164 Args:
163165 name (str): station name.
@@ -296,7 +298,7 @@ def __normalize_host(self, host):
296298 else :
297299 return host
298300
299- async def producer (self , station_name , producer_name , generate_random_suffix = False ):
301+ async def producer (self , station_name : str , producer_name : str , generate_random_suffix : bool = False ):
300302 """Creates a producer.
301303 Args:
302304 station_name (str): station name to produce messages into.
@@ -405,7 +407,7 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data
405407 station_name , self .schema_updates_subs [station_name ].messages ))
406408 self .schema_tasks [station_name ] = task
407409
408- async def consumer (self , station_name , consumer_name , consumer_group = "" , pull_interval_ms = 1000 , batch_size = 10 , batch_max_time_to_wait_ms = 5000 , max_ack_time_ms = 30000 , max_msg_deliveries = 10 , generate_random_suffix = False , start_consume_from_sequence = 1 , last_messages = - 1 ):
410+ async def consumer (self , station_name : str , consumer_name : str , consumer_group : str = "" , pull_interval_ms : int = 1000 , batch_size : int = 10 , batch_max_time_to_wait_ms : int = 5000 , max_ack_time_ms : int = 30000 , max_msg_deliveries : int = 10 , generate_random_suffix : bool = False , start_consume_from_sequence : int = 1 , last_messages : int = - 1 ):
409411 """Creates a consumer.
410412 Args:.
411413 station_name (str): station name to consume messages from.
@@ -486,7 +488,7 @@ def add(self, key, value):
486488
487489
488490class Station :
489- def __init__ (self , connection , name ):
491+ def __init__ (self , connection , name : str ):
490492 self .connection = connection
491493 self .name = name .lower ()
492494
@@ -531,7 +533,7 @@ def get_internal_name(name: str) -> str:
531533
532534
533535class Producer :
534- def __init__ (self , connection , producer_name , station_name ):
536+ def __init__ (self , connection , producer_name : str , station_name : str ):
535537 self .connection = connection
536538 self .producer_name = producer_name .lower ()
537539 self .station_name = station_name
@@ -628,10 +630,10 @@ def validate_graphql(self, message):
628630 e = "Invalid message format, expected GraphQL"
629631 raise Exception ("Schema validation has failed: " + str (e ))
630632
631- def get_dls_msg_id (self , station_name , producer_name , unix_time ):
633+ def get_dls_msg_id (self , station_name : str , producer_name : str , unix_time : str ):
632634 return station_name + '~' + producer_name + '~0~' + unix_time
633635
634- async def produce (self , message , ack_wait_sec = 15 , headers = {} , async_produce = False , msg_id = None ):
636+ async def produce (self , message , ack_wait_sec : int = 15 , headers : Union [ Headers , None ] = None , async_produce : bool = False , msg_id : Union [ str , None ] = None ):
635637 """Produces a message into a station.
636638 Args:
637639 message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
@@ -650,10 +652,10 @@ async def produce(self, message, ack_wait_sec=15, headers={}, async_produce=Fals
650652 "$memphis_producedBy" : self .producer_name ,
651653 "$memphis_connectionId" : self .connection .connection_id }
652654
653- if msg_id != None :
655+ if msg_id is not None :
654656 memphis_headers ["msg-id" ] = msg_id
655657
656- if headers != {} :
658+ if headers is not None :
657659 headers = headers .headers
658660 headers .update (memphis_headers )
659661 else :
@@ -760,7 +762,7 @@ async def default_error_handler(e):
760762
761763
762764class Consumer :
763- def __init__ (self , connection , station_name , consumer_name , consumer_group , pull_interval_ms , batch_size , batch_max_time_to_wait_ms , max_ack_time_ms , max_msg_deliveries = 10 , error_callback = None , start_consume_from_sequence = 1 , last_messages = - 1 ):
765+ def __init__ (self , connection , station_name : str , consumer_name , consumer_group , pull_interval_ms : int , batch_size : int , batch_max_time_to_wait_ms : int , max_ack_time_ms : int , max_msg_deliveries : int = 10 , error_callback = None , start_consume_from_sequence : int = 1 , last_messages : int = - 1 ):
764766 self .connection = connection
765767 self .station_name = station_name .lower ()
766768 self .consumer_name = consumer_name .lower ()
@@ -775,6 +777,7 @@ def __init__(self, connection, station_name, consumer_name, consumer_group, pull
775777 error_callback = default_error_handler
776778 self .t_ping = asyncio .create_task (self .__ping_consumer (error_callback ))
777779 self .start_consume_from_sequence = start_consume_from_sequence
780+
778781 self .last_messages = last_messages
779782 self .context = {}
780783
@@ -804,6 +807,7 @@ async def __consume(self, callback):
804807 Message (msg , self .connection , self .consumer_group ))
805808 await callback (memphis_messages , None , self .context )
806809 await asyncio .sleep (self .pull_interval_ms / 1000 )
810+
807811 except asyncio .TimeoutError :
808812 await callback ([], MemphisError ("Memphis: TimeoutError" ), self .context )
809813 continue
0 commit comments