|
9 | 9 | import logging |
10 | 10 | import ssl |
11 | 11 | import time |
| 12 | +import warnings |
12 | 13 | from unittest.mock import AsyncMock |
13 | 14 | from hashlib import blake2b |
14 | 15 | from typing import ( |
@@ -531,15 +532,31 @@ def __init__( |
531 | 532 | self._exit_task = None |
532 | 533 | self._open_subscriptions = 0 |
533 | 534 | self._options = options if options else {} |
534 | | - self.last_received = time.time() |
| 535 | + try: |
| 536 | + now = asyncio.get_running_loop().time() |
| 537 | + except RuntimeError: |
| 538 | + warnings.warn( |
| 539 | + "You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. " |
| 540 | + "Verify this is intended." |
| 541 | + ) |
| 542 | + now = asyncio.new_event_loop().time() |
| 543 | + self.last_received = now |
| 544 | + self.last_sent = now |
535 | 545 |
|
536 | 546 | async def __aenter__(self): |
537 | 547 | async with self._lock: |
538 | 548 | self._in_use += 1 |
539 | 549 | await self.connect() |
540 | 550 | return self |
541 | 551 |
|
| 552 | + @staticmethod |
| 553 | + async def loop_time() -> float: |
| 554 | + return asyncio.get_running_loop().time() |
| 555 | + |
542 | 556 | async def connect(self, force=False): |
| 557 | + now = await self.loop_time() |
| 558 | + self.last_received = now |
| 559 | + self.last_sent = now |
543 | 560 | if self._exit_task: |
544 | 561 | self._exit_task.cancel() |
545 | 562 | if not self._initialized or force: |
@@ -595,7 +612,7 @@ async def _recv(self) -> None: |
595 | 612 | try: |
596 | 613 | # TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic |
597 | 614 | response = json.loads(await self.ws.recv(decode=False)) |
598 | | - self.last_received = time.time() |
| 615 | + self.last_received = await self.loop_time() |
599 | 616 | async with self._lock: |
600 | 617 | # note that these 'subscriptions' are all waiting sent messages which have not received |
601 | 618 | # responses, and thus are not the same as RPC 'subscriptions', which are unique |
@@ -631,12 +648,12 @@ async def send(self, payload: dict) -> int: |
631 | 648 | Returns: |
632 | 649 | id: the internal ID of the request (incremented int) |
633 | 650 | """ |
634 | | - # async with self._lock: |
635 | 651 | original_id = get_next_id() |
636 | 652 | # self._open_subscriptions += 1 |
637 | 653 | await self.max_subscriptions.acquire() |
638 | 654 | try: |
639 | 655 | await self.ws.send(json.dumps({**payload, **{"id": original_id}})) |
| 656 | + self.last_sent = await self.loop_time() |
640 | 657 | return original_id |
641 | 658 | except (ConnectionClosed, ssl.SSLError, EOFError): |
642 | 659 | async with self._lock: |
@@ -2126,7 +2143,11 @@ async def _make_rpc_request( |
2126 | 2143 |
|
2127 | 2144 | if request_manager.is_complete: |
2128 | 2145 | break |
2129 | | - if time.time() - self.ws.last_received >= self.retry_timeout: |
| 2146 | + if ( |
| 2147 | + (current_time := await self.ws.loop_time()) - self.ws.last_received |
| 2148 | + >= self.retry_timeout |
| 2149 | + and current_time - self.ws.last_sent >= self.retry_timeout |
| 2150 | + ): |
2130 | 2151 | if attempt >= self.max_retries: |
2131 | 2152 | logger.warning( |
2132 | 2153 | f"Timed out waiting for RPC requests {attempt} times. Exiting." |
|
0 commit comments