Skip to content

Commit b830019

Browse files
Pijukateljanbuchar
andauthored
fix: Make context helpers work in FailedRequestHandler and ErrorHandler (#1570)
### Description - Make context helpers `push_data`,`get_key_value_store`, and `add_requests` work in `FailedRequestHandler` and `ErrorHandler`. Previously, they had no effect as the result of a failed request was never committed. ### Issues - Closes: #1532 ### Testing - Added unit tests ### Checklist - [ ] CI passed --------- Co-authored-by: Jan Buchar <jan.buchar@apify.com>
1 parent af1527b commit b830019

File tree

3 files changed

+120
-45
lines changed

3 files changed

+120
-45
lines changed

src/crawlee/_types.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import re
1616
from collections.abc import Callable, Coroutine, Sequence
1717

18-
from typing_extensions import NotRequired, Required, Unpack
18+
from typing_extensions import NotRequired, Required, Self, Unpack
1919

2020
from crawlee import Glob, Request
2121
from crawlee._request import RequestOptions
@@ -643,6 +643,25 @@ def __hash__(self) -> int:
643643
"""Return hash of the context. Each context is considered unique."""
644644
return id(self)
645645

646+
def create_modified_copy(
647+
self,
648+
push_data: PushDataFunction | None = None,
649+
add_requests: AddRequestsFunction | None = None,
650+
get_key_value_store: GetKeyValueStoreFromRequestHandlerFunction | None = None,
651+
) -> Self:
652+
"""Create a modified copy of the crawling context with specified changes."""
653+
original_fields = {field.name: getattr(self, field.name) for field in dataclasses.fields(self)}
654+
modified_fields = {
655+
key: value
656+
for key, value in {
657+
'push_data': push_data,
658+
'add_requests': add_requests,
659+
'get_key_value_store': get_key_value_store,
660+
}.items()
661+
if value
662+
}
663+
return self.__class__(**{**original_fields, **modified_fields})
664+
646665

647666
class GetDataKwargs(TypedDict):
648667
"""Keyword arguments for dataset's `get_data` method."""

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import asyncio
5+
import functools
56
import logging
67
import signal
78
import sys
@@ -14,7 +15,7 @@
1415
from datetime import timedelta
1516
from functools import partial
1617
from pathlib import Path
17-
from typing import TYPE_CHECKING, Any, Generic, Literal, cast
18+
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast
1819
from urllib.parse import ParseResult, urlparse
1920
from weakref import WeakKeyDictionary
2021

@@ -96,6 +97,9 @@
9697
TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext)
9798
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
9899
TRequestIterator = TypeVar('TRequestIterator', str, Request)
100+
TParams = ParamSpec('TParams')
101+
T = TypeVar('T')
102+
99103
ErrorHandler = Callable[[TCrawlingContext, Exception], Awaitable[Request | None]]
100104
FailedRequestHandler = Callable[[TCrawlingContext, Exception], Awaitable[None]]
101105
SkippedRequestCallback = Callable[[str, SkippedReason], Awaitable[None]]
@@ -520,6 +524,24 @@ def stop(self, reason: str = 'Stop was called externally.') -> None:
520524
self._logger.info(f'Crawler.stop() was called with following reason: {reason}.')
521525
self._unexpected_stop = True
522526

527+
def _wrap_handler_with_error_context(
528+
self, handler: Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]]
529+
) -> Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]]:
530+
"""Decorate error handlers to make their context helpers usable."""
531+
532+
@functools.wraps(handler)
533+
async def wrapped_handler(context: TCrawlingContext | BasicCrawlingContext, exception: Exception) -> T:
534+
# Original context helpers that are from `RequestHandlerRunResult` will not be commited as the request
535+
# failed. Modified context provides context helpers with direct access to the storages.
536+
error_context = context.create_modified_copy(
537+
push_data=self._push_data,
538+
get_key_value_store=self.get_key_value_store,
539+
add_requests=functools.partial(self._add_requests, context),
540+
)
541+
return await handler(error_context, exception)
542+
543+
return wrapped_handler
544+
523545
def _stop_if_max_requests_count_exceeded(self) -> None:
524546
"""Call `stop` when the maximum number of requests to crawl has been reached."""
525547
if self._max_requests_per_crawl is None:
@@ -618,7 +640,7 @@ def error_handler(
618640
619641
The error handler is invoked after a request handler error occurs and before a retry attempt.
620642
"""
621-
self._error_handler = handler
643+
self._error_handler = self._wrap_handler_with_error_context(handler)
622644
return handler
623645

624646
def failed_request_handler(
@@ -628,7 +650,7 @@ def failed_request_handler(
628650
629651
The failed request handler is invoked when a request has failed all retry attempts.
630652
"""
631-
self._failed_request_handler = handler
653+
self._failed_request_handler = self._wrap_handler_with_error_context(handler)
632654
return handler
633655

634656
def on_skipped_request(self, callback: SkippedRequestCallback) -> SkippedRequestCallback:
@@ -1256,52 +1278,46 @@ def _convert_url_to_request_iterator(self, urls: Sequence[str | Request], base_u
12561278
else:
12571279
yield Request.from_url(url)
12581280

1259-
async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None:
1260-
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
1261-
result = self._context_result_map[context]
1262-
1263-
base_request_manager = await self.get_request_manager()
1264-
1265-
origin = context.request.loaded_url or context.request.url
1266-
1267-
for add_requests_call in result.add_requests_calls:
1268-
rq_id = add_requests_call.get('rq_id')
1269-
rq_name = add_requests_call.get('rq_name')
1270-
rq_alias = add_requests_call.get('rq_alias')
1271-
specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None)
1272-
if specified_params > 1:
1273-
raise ValueError('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.')
1274-
if rq_id or rq_name or rq_alias:
1275-
request_manager: RequestManager | RequestQueue = await RequestQueue.open(
1276-
id=rq_id,
1277-
name=rq_name,
1278-
alias=rq_alias,
1279-
storage_client=self._service_locator.get_storage_client(),
1280-
configuration=self._service_locator.get_configuration(),
1281-
)
1282-
else:
1283-
request_manager = base_request_manager
1284-
1285-
requests = list[Request]()
1286-
1287-
base_url = url if (url := add_requests_call.get('base_url')) else origin
1288-
1289-
requests_iterator = self._convert_url_to_request_iterator(add_requests_call['requests'], base_url)
1281+
async def _add_requests(
1282+
self,
1283+
context: BasicCrawlingContext,
1284+
requests: Sequence[str | Request],
1285+
rq_id: str | None = None,
1286+
rq_name: str | None = None,
1287+
rq_alias: str | None = None,
1288+
**kwargs: Unpack[EnqueueLinksKwargs],
1289+
) -> None:
1290+
"""Add requests method aware of the crawling context."""
1291+
if rq_id or rq_name or rq_alias:
1292+
request_manager: RequestManager = await RequestQueue.open(
1293+
id=rq_id,
1294+
name=rq_name,
1295+
alias=rq_alias,
1296+
storage_client=self._service_locator.get_storage_client(),
1297+
configuration=self._service_locator.get_configuration(),
1298+
)
1299+
else:
1300+
request_manager = await self.get_request_manager()
12901301

1291-
enqueue_links_kwargs: EnqueueLinksKwargs = {k: v for k, v in add_requests_call.items() if k != 'requests'} # type: ignore[assignment]
1302+
context_aware_requests = list[Request]()
1303+
base_url = kwargs.get('base_url') or context.request.loaded_url or context.request.url
1304+
requests_iterator = self._convert_url_to_request_iterator(requests, base_url)
1305+
filter_requests_iterator = self._enqueue_links_filter_iterator(requests_iterator, context.request.url, **kwargs)
1306+
for dst_request in filter_requests_iterator:
1307+
# Update the crawl depth of the request.
1308+
dst_request.crawl_depth = context.request.crawl_depth + 1
12921309

1293-
filter_requests_iterator = self._enqueue_links_filter_iterator(
1294-
requests_iterator, context.request.url, **enqueue_links_kwargs
1295-
)
1310+
if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth:
1311+
context_aware_requests.append(dst_request)
12961312

1297-
for dst_request in filter_requests_iterator:
1298-
# Update the crawl depth of the request.
1299-
dst_request.crawl_depth = context.request.crawl_depth + 1
1313+
return await request_manager.add_requests(context_aware_requests)
13001314

1301-
if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth:
1302-
requests.append(dst_request)
1315+
async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None:
1316+
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
1317+
result = self._context_result_map[context]
13031318

1304-
await request_manager.add_requests(requests)
1319+
for add_requests_call in result.add_requests_calls:
1320+
await self._add_requests(context, **add_requests_call)
13051321

13061322
for push_data_call in result.push_data_calls:
13071323
await self._push_data(**push_data_call)

tests/unit/crawlers/_basic/test_basic_crawler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,46 @@ async def failed_request_handler(context: BasicCrawlingContext, error: Exception
284284
assert isinstance(calls[0][1], RuntimeError)
285285

286286

287+
@pytest.mark.parametrize('handler', ['failed_request_handler', 'error_handler'])
288+
async def test_handlers_use_context_helpers(tmp_path: Path, handler: str) -> None:
289+
"""Test that context helpers used in `failed_request_handler` and in `error_handler` have effect."""
290+
# Prepare crawler
291+
storage_client = FileSystemStorageClient()
292+
crawler = BasicCrawler(
293+
max_request_retries=1, storage_client=storage_client, configuration=Configuration(storage_dir=str(tmp_path))
294+
)
295+
# Test data
296+
rq_alias = 'other'
297+
test_data = {'some': 'data'}
298+
test_key = 'key'
299+
test_value = 'value'
300+
test_request = Request.from_url('https://d.placeholder.com')
301+
302+
# Request handler with injected error
303+
@crawler.router.default_handler
304+
async def request_handler(context: BasicCrawlingContext) -> None:
305+
raise RuntimeError('Arbitrary crash for testing purposes')
306+
307+
# Apply one of the handlers
308+
@getattr(crawler, handler) # type:ignore[misc] # Untyped decorator is ok to make the test concise
309+
async def handler_implementation(context: BasicCrawlingContext, error: Exception) -> None:
310+
await context.push_data(test_data)
311+
await context.add_requests(requests=[test_request], rq_alias=rq_alias)
312+
kvs = await context.get_key_value_store()
313+
await kvs.set_value(test_key, test_value)
314+
315+
await crawler.run(['https://b.placeholder.com'])
316+
317+
# Verify that the context helpers used in handlers had effect on used storages
318+
dataset = await Dataset.open(storage_client=storage_client)
319+
kvs = await KeyValueStore.open(storage_client=storage_client)
320+
rq = await RequestQueue.open(alias=rq_alias, storage_client=storage_client)
321+
322+
assert test_value == await kvs.get_value(test_key)
323+
assert [test_data] == (await dataset.get_data()).items
324+
assert test_request == await rq.fetch_next_request()
325+
326+
287327
async def test_handles_error_in_failed_request_handler() -> None:
288328
crawler = BasicCrawler(max_request_retries=3)
289329

0 commit comments

Comments
 (0)