Skip to content

Commit 8e5eb2e

Browse files
authored
feat(custom-source): make custom source API more robust and expose more error information (#1198)
feat(custom-source): make custom source API more robust
1 parent 3d878b8 commit 8e5eb2e

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

python/cocoindex/op.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,8 @@ class _SourceExecutorContext:
485485
AsyncIterator[PartialSourceRow[Any, Any]]
486486
| Iterator[PartialSourceRow[Any, Any]],
487487
]
488-
_get_value_fn: Callable[
489-
[Any, SourceReadOptions], Awaitable[PartialSourceRowData[Any]]
490-
]
488+
_orig_get_value_fn: Callable[..., Any]
489+
_get_value_fn: Callable[..., Awaitable[PartialSourceRowData[Any]]]
491490
_provides_ordinal_fn: Callable[[], bool] | None
492491

493492
def __init__(
@@ -504,7 +503,8 @@ def __init__(
504503
self._value_encoder = make_engine_value_encoder(value_type_info)
505504

506505
self._list_fn = _get_required_method(executor, "list")
507-
self._get_value_fn = to_async_call(_get_required_method(executor, "get_value"))
506+
self._orig_get_value_fn = _get_required_method(executor, "get_value")
507+
self._get_value_fn = to_async_call(self._orig_get_value_fn)
508508
self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None)
509509

510510
def provides_ordinal(self) -> bool:
@@ -521,11 +521,9 @@ async def list_async(
521521
Return an async iterator that yields individual rows one by one.
522522
Each yielded item is a tuple of (key, data).
523523
"""
524-
# Convert the options dict to SourceReadOptions
525524
read_options = load_engine_object(SourceReadOptions, options)
526-
527-
# Call the user's list method
528-
list_result = self._list_fn(read_options)
525+
args = _build_args(self._list_fn, 0, options=read_options)
526+
list_result = self._list_fn(*args)
529527

530528
# Handle both sync and async iterators
531529
if hasattr(list_result, "__aiter__"):
@@ -548,8 +546,8 @@ async def get_value_async(
548546
) -> dict[str, Any]:
549547
key = self._key_decoder(raw_key)
550548
read_options = load_engine_object(SourceReadOptions, options)
551-
552-
row_data = await self._get_value_fn(key, read_options)
549+
args = _build_args(self._orig_get_value_fn, 1, key=key, options=read_options)
550+
row_data = await self._get_value_fn(*args)
553551
return self._encode_source_row_data(row_data)
554552

555553
def _encode_source_row_data(

src/ops/py_factory.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ impl PySourceExecutor {
355355
if py_err.is_instance_of::<pyo3::exceptions::PyStopAsyncIteration>(py) {
356356
Ok(None)
357357
} else {
358-
Err(anyhow!("Error from async iterator: {}", py_err))
358+
Err(py_err).to_result_with_py_trace(py)
359359
}
360360
}
361361
}

0 commit comments

Comments
 (0)