22from __future__ import annotations
33
44import asyncio
5+ import functools
56import logging
67import signal
78import sys
1415from datetime import timedelta
1516from functools import partial
1617from pathlib import Path
17- from typing import TYPE_CHECKING , Any , Generic , Literal , cast
18+ from typing import TYPE_CHECKING , Any , Generic , Literal , ParamSpec , cast
1819from urllib .parse import ParseResult , urlparse
1920from weakref import WeakKeyDictionary
2021
9697TCrawlingContext = TypeVar ('TCrawlingContext' , bound = BasicCrawlingContext , default = BasicCrawlingContext )
9798TStatisticsState = TypeVar ('TStatisticsState' , bound = StatisticsState , default = StatisticsState )
9899TRequestIterator = TypeVar ('TRequestIterator' , str , Request )
100+ TParams = ParamSpec ('TParams' )
101+ T = TypeVar ('T' )
102+
99103ErrorHandler = Callable [[TCrawlingContext , Exception ], Awaitable [Request | None ]]
100104FailedRequestHandler = Callable [[TCrawlingContext , Exception ], Awaitable [None ]]
101105SkippedRequestCallback = Callable [[str , SkippedReason ], Awaitable [None ]]
@@ -520,6 +524,24 @@ def stop(self, reason: str = 'Stop was called externally.') -> None:
520524 self ._logger .info (f'Crawler.stop() was called with following reason: { reason } .' )
521525 self ._unexpected_stop = True
522526
527+ def _wrap_handler_with_error_context (
528+ self , handler : Callable [[TCrawlingContext | BasicCrawlingContext , Exception ], Awaitable [T ]]
529+ ) -> Callable [[TCrawlingContext | BasicCrawlingContext , Exception ], Awaitable [T ]]:
530+ """Decorate error handlers to make their context helpers usable."""
531+
532+ @functools .wraps (handler )
533+ async def wrapped_handler (context : TCrawlingContext | BasicCrawlingContext , exception : Exception ) -> T :
534+ # Original context helpers that are from `RequestHandlerRunResult` will not be commited as the request
535+ # failed. Modified context provides context helpers with direct access to the storages.
536+ error_context = context .create_modified_copy (
537+ push_data = self ._push_data ,
538+ get_key_value_store = self .get_key_value_store ,
539+ add_requests = functools .partial (self ._add_requests , context ),
540+ )
541+ return await handler (error_context , exception )
542+
543+ return wrapped_handler
544+
523545 def _stop_if_max_requests_count_exceeded (self ) -> None :
524546 """Call `stop` when the maximum number of requests to crawl has been reached."""
525547 if self ._max_requests_per_crawl is None :
@@ -618,7 +640,7 @@ def error_handler(
618640
619641 The error handler is invoked after a request handler error occurs and before a retry attempt.
620642 """
621- self ._error_handler = handler
643+ self ._error_handler = self . _wrap_handler_with_error_context ( handler )
622644 return handler
623645
624646 def failed_request_handler (
@@ -628,7 +650,7 @@ def failed_request_handler(
628650
629651 The failed request handler is invoked when a request has failed all retry attempts.
630652 """
631- self ._failed_request_handler = handler
653+ self ._failed_request_handler = self . _wrap_handler_with_error_context ( handler )
632654 return handler
633655
634656 def on_skipped_request (self , callback : SkippedRequestCallback ) -> SkippedRequestCallback :
@@ -1256,52 +1278,46 @@ def _convert_url_to_request_iterator(self, urls: Sequence[str | Request], base_u
12561278 else :
12571279 yield Request .from_url (url )
12581280
1259- async def _commit_request_handler_result (self , context : BasicCrawlingContext ) -> None :
1260- """Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
1261- result = self ._context_result_map [context ]
1262-
1263- base_request_manager = await self .get_request_manager ()
1264-
1265- origin = context .request .loaded_url or context .request .url
1266-
1267- for add_requests_call in result .add_requests_calls :
1268- rq_id = add_requests_call .get ('rq_id' )
1269- rq_name = add_requests_call .get ('rq_name' )
1270- rq_alias = add_requests_call .get ('rq_alias' )
1271- specified_params = sum (1 for param in [rq_id , rq_name , rq_alias ] if param is not None )
1272- if specified_params > 1 :
1273- raise ValueError ('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.' )
1274- if rq_id or rq_name or rq_alias :
1275- request_manager : RequestManager | RequestQueue = await RequestQueue .open (
1276- id = rq_id ,
1277- name = rq_name ,
1278- alias = rq_alias ,
1279- storage_client = self ._service_locator .get_storage_client (),
1280- configuration = self ._service_locator .get_configuration (),
1281- )
1282- else :
1283- request_manager = base_request_manager
1284-
1285- requests = list [Request ]()
1286-
1287- base_url = url if (url := add_requests_call .get ('base_url' )) else origin
1288-
1289- requests_iterator = self ._convert_url_to_request_iterator (add_requests_call ['requests' ], base_url )
1281+ async def _add_requests (
1282+ self ,
1283+ context : BasicCrawlingContext ,
1284+ requests : Sequence [str | Request ],
1285+ rq_id : str | None = None ,
1286+ rq_name : str | None = None ,
1287+ rq_alias : str | None = None ,
1288+ ** kwargs : Unpack [EnqueueLinksKwargs ],
1289+ ) -> None :
1290+ """Add requests method aware of the crawling context."""
1291+ if rq_id or rq_name or rq_alias :
1292+ request_manager : RequestManager = await RequestQueue .open (
1293+ id = rq_id ,
1294+ name = rq_name ,
1295+ alias = rq_alias ,
1296+ storage_client = self ._service_locator .get_storage_client (),
1297+ configuration = self ._service_locator .get_configuration (),
1298+ )
1299+ else :
1300+ request_manager = await self .get_request_manager ()
12901301
1291- enqueue_links_kwargs : EnqueueLinksKwargs = {k : v for k , v in add_requests_call .items () if k != 'requests' } # type: ignore[assignment]
1302+ context_aware_requests = list [Request ]()
1303+ base_url = kwargs .get ('base_url' ) or context .request .loaded_url or context .request .url
1304+ requests_iterator = self ._convert_url_to_request_iterator (requests , base_url )
1305+ filter_requests_iterator = self ._enqueue_links_filter_iterator (requests_iterator , context .request .url , ** kwargs )
1306+ for dst_request in filter_requests_iterator :
1307+ # Update the crawl depth of the request.
1308+ dst_request .crawl_depth = context .request .crawl_depth + 1
12921309
1293- filter_requests_iterator = self ._enqueue_links_filter_iterator (
1294- requests_iterator , context .request .url , ** enqueue_links_kwargs
1295- )
1310+ if self ._max_crawl_depth is None or dst_request .crawl_depth <= self ._max_crawl_depth :
1311+ context_aware_requests .append (dst_request )
12961312
1297- for dst_request in filter_requests_iterator :
1298- # Update the crawl depth of the request.
1299- dst_request .crawl_depth = context .request .crawl_depth + 1
1313+ return await request_manager .add_requests (context_aware_requests )
13001314
1301- if self ._max_crawl_depth is None or dst_request .crawl_depth <= self ._max_crawl_depth :
1302- requests .append (dst_request )
1315+ async def _commit_request_handler_result (self , context : BasicCrawlingContext ) -> None :
1316+ """Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
1317+ result = self ._context_result_map [context ]
13031318
1304- await request_manager .add_requests (requests )
1319+ for add_requests_call in result .add_requests_calls :
1320+ await self ._add_requests (context , ** add_requests_call )
13051321
13061322 for push_data_call in result .push_data_calls :
13071323 await self ._push_data (** push_data_call )
0 commit comments