4040
4141from .audit_logs import AuditLogEntry
4242from .errors import NoMoreItems
43- from .monetization import Entitlement
4443from .object import Object
4544from .utils import maybe_coroutine , snowflake_time , time_snowflake
4645
5251 "MemberIterator" ,
5352 "ScheduledEventSubscribersIterator" ,
5453 "EntitlementIterator" ,
54+ "SubscriptionIterator" ,
5555)
5656
5757if TYPE_CHECKING :
5858 from .abc import Snowflake
5959 from .guild import BanEntry , Guild
6060 from .member import Member
6161 from .message import Message
62+ from .monetization import Entitlement , Subscription
6263 from .scheduled_events import ScheduledEvent
6364 from .threads import Thread
6465 from .types .audit_log import AuditLog as AuditLogPayload
6566 from .types .guild import Guild as GuildPayload
6667 from .types .message import Message as MessagePayload
6768 from .types .monetization import Entitlement as EntitlementPayload
69+ from .types .monetization import Subscription as SubscriptionPayload
6870 from .types .threads import Thread as ThreadPayload
6971 from .types .user import PartialUser as PartialUserPayload
7072 from .user import User
@@ -1031,6 +1033,11 @@ def _get_retrieve(self):
10311033 self .retrieve = r
10321034 return r > 0
10331035
1036+ def create_entitlement (self , data ) -> Entitlement :
1037+ from .monetization import Entitlement
1038+
1039+ return Entitlement (data = data , state = self .state )
1040+
10341041 async def fill_entitlements (self ):
10351042 if not self ._get_retrieve ():
10361043 return
@@ -1044,9 +1051,9 @@ async def fill_entitlements(self):
10441051 self .limit = 0 # terminate loop
10451052
10461053 for element in data :
1047- await self .entitlements .put (Entitlement ( data = element , state = self .state ))
1054+ await self .entitlements .put (self .create_entitlement ( element ))
10481055
1049- async def _retrieve_entitlements (self , retrieve ) -> list [Entitlement ]:
1056+ async def _retrieve_entitlements (self , retrieve ) -> list [EntitlementPayload ]:
10501057 """Retrieve entitlements and update next parameters."""
10511058 raise NotImplementedError
10521059
@@ -1089,3 +1096,105 @@ async def _retrieve_entitlements_after_strategy(
10891096 self .limit -= retrieve
10901097 self .after = Object (id = int (data [- 1 ]["id" ]))
10911098 return data
1099+
1100+
1101+ class SubscriptionIterator (_AsyncIterator ["Subscription" ]):
1102+ def __init__ (
1103+ self ,
1104+ state ,
1105+ sku_id : int ,
1106+ limit : int = None ,
1107+ before : datetime .datetime | None = None ,
1108+ after : datetime .datetime | None = None ,
1109+ user_id : int | None = None ,
1110+ ):
1111+ if isinstance (before , datetime .datetime ):
1112+ before = Object (id = time_snowflake (before , high = False ))
1113+ if isinstance (after , datetime .datetime ):
1114+ after = Object (id = time_snowflake (after , high = True ))
1115+
1116+ self .state = state
1117+ self .sku_id = sku_id
1118+ self .limit = limit
1119+ self .before = before
1120+ self .after = after
1121+ self .user_id = user_id
1122+
1123+ self ._filter = None
1124+
1125+ self .get_subscriptions = state .http .list_sku_subscriptions
1126+ self .subscriptions = asyncio .Queue ()
1127+
1128+ if self .before and self .after :
1129+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_before_strategy
1130+ self ._filter = lambda m : int (m ["id" ]) > self .after .id
1131+ elif self .after :
1132+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_after_strategy
1133+ else :
1134+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_before_strategy
1135+
1136+ async def next (self ) -> Guild :
1137+ if self .subscriptions .empty ():
1138+ await self .fill_subscriptions ()
1139+
1140+ try :
1141+ return self .subscriptions .get_nowait ()
1142+ except asyncio .QueueEmpty :
1143+ raise NoMoreItems ()
1144+
1145+ def _get_retrieve (self ):
1146+ l = self .limit
1147+ if l is None or l > 100 :
1148+ r = 100
1149+ else :
1150+ r = l
1151+ self .retrieve = r
1152+ return r > 0
1153+
1154+ def create_subscription (self , data ) -> Subscription :
1155+ from .monetization import Subscription
1156+
1157+ return Subscription (state = self .state , data = data )
1158+
1159+ async def fill_subscriptions (self ):
1160+ if self ._get_retrieve ():
1161+ data = await self ._retrieve_subscriptions (self .retrieve )
1162+ if self .limit is None or len (data ) < 100 :
1163+ self .limit = 0
1164+
1165+ if self ._filter :
1166+ data = filter (self ._filter , data )
1167+
1168+ for element in data :
1169+ await self .subscriptions .put (self .create_subscription (element ))
1170+
1171+ async def _retrieve_subscriptions (self , retrieve ) -> list [SubscriptionPayload ]:
1172+ raise NotImplementedError
1173+
1174+ async def _retrieve_subscriptions_before_strategy (self , retrieve ):
1175+ before = self .before .id if self .before else None
1176+ data : list [SubscriptionPayload ] = await self .get_subscriptions (
1177+ self .sku_id ,
1178+ limit = retrieve ,
1179+ before = before ,
1180+ user_id = self .user_id ,
1181+ )
1182+ if len (data ):
1183+ if self .limit is not None :
1184+ self .limit -= retrieve
1185+ self .before = Object (id = int (data [- 1 ]["id" ]))
1186+ return data
1187+
1188+ async def _retrieve_subscriptions_after_strategy (self , retrieve ):
1189+ after = self .after .id if self .after else None
1190+ data : list [SubscriptionPayload ] = await self .get_subscriptions (
1191+ self .sku_id ,
1192+ limit = retrieve ,
1193+ after = after ,
1194+ user_id = self .user_id ,
1195+ )
1196+ if len (data ):
1197+ if self .limit is not None :
1198+ self .limit -= retrieve
1199+ self .after = Object (id = int (data [0 ]["id" ]))
1200+ return data
0 commit comments