Skip to content

Commit 3d878b8

Browse files
authored
feat(custom-source): add custom source support (#1197)
* feat(custom-source): python SDK types * feat(custom-source): add custom source support * cleanup
1 parent 90f6567 commit 3d878b8

File tree

6 files changed

+620
-33
lines changed

6 files changed

+620
-33
lines changed

python/cocoindex/engine_value.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
7070
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
7171

7272

73+
def make_engine_key_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
74+
"""
75+
Create an encoder closure for a key type.
76+
"""
77+
value_encoder = make_engine_value_encoder(type_info)
78+
if isinstance(type_info.variant, AnalyzedBasicType):
79+
return lambda value: [value_encoder(value)]
80+
else:
81+
return value_encoder
82+
83+
7384
def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
7485
"""
7586
Create an encoder closure for a specific type.
@@ -94,6 +105,9 @@ def encode_struct_list(value: Any) -> Any:
94105
# Otherwise it's a vector, falling into basic type in the engine.
95106

96107
if isinstance(variant, AnalyzedDictType):
108+
key_type_info = analyze_type_info(variant.key_type)
109+
key_encoder = make_engine_key_encoder(key_type_info)
110+
97111
value_type_info = analyze_type_info(variant.value_type)
98112
if not isinstance(value_type_info.variant, AnalyzedStructType):
99113
raise ValueError(
@@ -102,22 +116,10 @@ def encode_struct_list(value: Any) -> Any:
102116
)
103117
value_encoder = make_engine_value_encoder(value_type_info)
104118

105-
key_type_info = analyze_type_info(variant.key_type)
106-
key_encoder = make_engine_value_encoder(key_type_info)
107-
if isinstance(key_type_info.variant, AnalyzedBasicType):
108-
109-
def encode_row(k: Any, v: Any) -> Any:
110-
return [key_encoder(k)] + value_encoder(v)
111-
112-
else:
113-
114-
def encode_row(k: Any, v: Any) -> Any:
115-
return key_encoder(k) + value_encoder(v)
116-
117119
def encode_struct_dict(value: Any) -> Any:
118120
if not value:
119121
return []
120-
return [encode_row(k, v) for k, v in value.items()]
122+
return [key_encoder(k) + value_encoder(v) for k, v in value.items()]
121123

122124
return encode_struct_dict
123125

python/cocoindex/op.py

Lines changed: 243 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,33 @@
99
Any,
1010
Awaitable,
1111
Callable,
12+
Iterator,
1213
Protocol,
1314
dataclass_transform,
1415
Annotated,
16+
TypeVar,
17+
Generic,
18+
Literal,
1519
get_args,
1620
)
21+
from collections.abc import AsyncIterator
1722

1823
from . import _engine # type: ignore
1924
from .subprocess_exec import executor_stub
2025
from .engine_object import dump_engine_object, load_engine_object
2126
from .engine_value import (
27+
make_engine_key_encoder,
2228
make_engine_value_encoder,
2329
make_engine_value_decoder,
2430
make_engine_key_decoder,
2531
make_engine_struct_decoder,
2632
)
2733
from .typing import (
34+
KEY_FIELD_NAME,
35+
AnalyzedTypeInfo,
36+
StructSchema,
37+
StructType,
38+
TableType,
2839
TypeAttr,
2940
encode_enriched_type_info,
3041
resolve_forward_ref,
@@ -96,12 +107,12 @@ class Executor(Protocol):
96107
op_category: OpCategory
97108

98109

99-
def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
100-
method = getattr(cls, name, None)
110+
def _get_required_method(obj: type, name: str) -> Callable[..., Any]:
111+
method = getattr(obj, name, None)
101112
if method is None:
102-
raise ValueError(f"Method {name}() is required for {cls.__name__}")
103-
if not inspect.isfunction(method):
104-
raise ValueError(f"Method {cls.__name__}.{name}() is not a function")
113+
raise ValueError(f"Method {name}() is required for {obj}")
114+
if not inspect.isfunction(method) and not inspect.ismethod(method):
115+
raise ValueError(f"{obj}.{name}() is not a function; {method}")
105116
return method
106117

107118

@@ -421,6 +432,233 @@ def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
421432
return _inner
422433

423434

435+
########################################################
436+
# Custom source connector
437+
########################################################
438+
439+
440+
@dataclasses.dataclass
441+
class SourceReadOptions:
442+
include_ordinal: bool = False
443+
include_content_version_fp: bool = False
444+
include_value: bool = False
445+
446+
447+
K = TypeVar("K")
448+
V = TypeVar("V")
449+
450+
NON_EXISTENCE: Literal["NON_EXISTENCE"] = "NON_EXISTENCE"
451+
NO_ORDINAL: Literal["NO_ORDINAL"] = "NO_ORDINAL"
452+
453+
454+
@dataclasses.dataclass
455+
class PartialSourceRowData(Generic[V]):
456+
"""
457+
The data of a source row.
458+
459+
- value: The value of the source row. NON_EXISTENCE means the row does not exist.
460+
- ordinal: The ordinal of the source row. NO_ORDINAL means ordinal is not available for the source.
461+
- content_version_fp: The content version fingerprint of the source row.
462+
"""
463+
464+
value: V | Literal["NON_EXISTENCE"] | None = None
465+
ordinal: int | Literal["NO_ORDINAL"] | None = None
466+
content_version_fp: bytes | None = None
467+
468+
469+
@dataclasses.dataclass
470+
class PartialSourceRow(Generic[K, V]):
471+
key: K
472+
data: PartialSourceRowData[V]
473+
474+
475+
class _SourceExecutorContext:
476+
_executor: Any
477+
478+
_key_encoder: Callable[[Any], Any]
479+
_key_decoder: Callable[[Any], Any]
480+
481+
_value_encoder: Callable[[Any], Any]
482+
483+
_list_fn: Callable[
484+
[SourceReadOptions],
485+
AsyncIterator[PartialSourceRow[Any, Any]]
486+
| Iterator[PartialSourceRow[Any, Any]],
487+
]
488+
_get_value_fn: Callable[
489+
[Any, SourceReadOptions], Awaitable[PartialSourceRowData[Any]]
490+
]
491+
_provides_ordinal_fn: Callable[[], bool] | None
492+
493+
def __init__(
494+
self,
495+
executor: Any,
496+
key_type_info: AnalyzedTypeInfo,
497+
key_decoder: Callable[[Any], Any],
498+
value_type_info: AnalyzedTypeInfo,
499+
):
500+
self._executor = executor
501+
502+
self._key_encoder = make_engine_key_encoder(key_type_info)
503+
self._key_decoder = key_decoder
504+
self._value_encoder = make_engine_value_encoder(value_type_info)
505+
506+
self._list_fn = _get_required_method(executor, "list")
507+
self._get_value_fn = to_async_call(_get_required_method(executor, "get_value"))
508+
self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None)
509+
510+
def provides_ordinal(self) -> bool:
511+
if self._provides_ordinal_fn is not None:
512+
result = self._provides_ordinal_fn()
513+
return bool(result)
514+
else:
515+
return False
516+
517+
async def list_async(
518+
self, options: dict[str, Any]
519+
) -> AsyncIterator[tuple[Any, dict[str, Any]]]:
520+
"""
521+
Return an async iterator that yields individual rows one by one.
522+
Each yielded item is a tuple of (key, data).
523+
"""
524+
# Convert the options dict to SourceReadOptions
525+
read_options = load_engine_object(SourceReadOptions, options)
526+
527+
# Call the user's list method
528+
list_result = self._list_fn(read_options)
529+
530+
# Handle both sync and async iterators
531+
if hasattr(list_result, "__aiter__"):
532+
async for partial_row in list_result:
533+
yield (
534+
self._key_encoder(partial_row.key),
535+
self._encode_source_row_data(partial_row.data),
536+
)
537+
else:
538+
for partial_row in list_result:
539+
yield (
540+
self._key_encoder(partial_row.key),
541+
self._encode_source_row_data(partial_row.data),
542+
)
543+
544+
async def get_value_async(
545+
self,
546+
raw_key: Any,
547+
options: dict[str, Any],
548+
) -> dict[str, Any]:
549+
key = self._key_decoder(raw_key)
550+
read_options = load_engine_object(SourceReadOptions, options)
551+
552+
row_data = await self._get_value_fn(key, read_options)
553+
return self._encode_source_row_data(row_data)
554+
555+
def _encode_source_row_data(
556+
self, row_data: PartialSourceRowData[Any]
557+
) -> dict[str, Any]:
558+
"""Convert Python PartialSourceRowData to the format expected by Rust."""
559+
return {
560+
"ordinal": row_data.ordinal,
561+
"content_version_fp": row_data.content_version_fp,
562+
"value": (
563+
NON_EXISTENCE
564+
if row_data.value == NON_EXISTENCE
565+
else self._value_encoder(row_data.value)
566+
),
567+
}
568+
569+
570+
class _SourceConnector:
571+
"""
572+
The connector class passed to the engine.
573+
"""
574+
575+
_spec_cls: type[Any]
576+
_key_type_info: AnalyzedTypeInfo
577+
_key_decoder: Callable[[Any], Any]
578+
_value_type_info: AnalyzedTypeInfo
579+
_table_type: EnrichedValueType
580+
_connector_cls: type[Any]
581+
582+
_create_fn: Callable[[Any], Awaitable[Any]]
583+
584+
def __init__(
585+
self,
586+
spec_cls: type[Any],
587+
key_type: Any,
588+
value_type: Any,
589+
connector_cls: type[Any],
590+
):
591+
self._spec_cls = spec_cls
592+
self._key_type_info = analyze_type_info(key_type)
593+
self._value_type_info = analyze_type_info(value_type)
594+
self._connector_cls = connector_cls
595+
596+
# TODO: We can save the intermediate step after #1083 is fixed.
597+
encoded_engine_key_type = encode_enriched_type_info(self._key_type_info)
598+
engine_key_type = EnrichedValueType.decode(encoded_engine_key_type)
599+
600+
# TODO: We can save the intermediate step after #1083 is fixed.
601+
encoded_engine_value_type = encode_enriched_type_info(self._value_type_info)
602+
engine_value_type = EnrichedValueType.decode(encoded_engine_value_type)
603+
604+
if not isinstance(engine_value_type.type, StructType):
605+
raise ValueError(f"Expected a StructType, got {engine_value_type.type}")
606+
607+
if isinstance(engine_key_type.type, StructType):
608+
key_fields_schema = engine_key_type.type.fields
609+
else:
610+
key_fields_schema = [
611+
FieldSchema(name=KEY_FIELD_NAME, value_type=engine_key_type)
612+
]
613+
self._key_decoder = make_engine_key_decoder(
614+
[], key_fields_schema, self._key_type_info
615+
)
616+
self._table_type = EnrichedValueType(
617+
type=TableType(
618+
kind="KTable",
619+
row=StructSchema(
620+
fields=key_fields_schema + engine_value_type.type.fields
621+
),
622+
num_key_parts=len(key_fields_schema),
623+
),
624+
)
625+
626+
self._create_fn = to_async_call(_get_required_method(connector_cls, "create"))
627+
628+
async def create_executor(self, raw_spec: dict[str, Any]) -> _SourceExecutorContext:
629+
spec = load_engine_object(self._spec_cls, raw_spec)
630+
executor = await self._create_fn(spec)
631+
return _SourceExecutorContext(
632+
executor, self._key_type_info, self._key_decoder, self._value_type_info
633+
)
634+
635+
def get_table_type(self) -> Any:
636+
return dump_engine_object(self._table_type)
637+
638+
639+
def source_connector(
640+
*,
641+
spec_cls: type[Any],
642+
key_type: Any = Any,
643+
value_type: Any = Any,
644+
) -> Callable[[type], type]:
645+
"""
646+
Decorate a class to provide a source connector for an op.
647+
"""
648+
649+
# Validate the spec_cls is a SourceSpec.
650+
if not issubclass(spec_cls, SourceSpec):
651+
raise ValueError(f"Expect a SourceSpec, got {spec_cls}")
652+
653+
# Register the source connector.
654+
def _inner(connector_cls: type) -> type:
655+
connector = _SourceConnector(spec_cls, key_type, value_type, connector_cls)
656+
_engine.register_source_connector(spec_cls.__name__, connector)
657+
return connector_cls
658+
659+
return _inner
660+
661+
424662
########################################################
425663
# Custom target connector
426664
########################################################

src/ops/interface.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub struct FlowInstanceContext {
1212
pub py_exec_ctx: Option<Arc<crate::py::PythonExecutionContext>>,
1313
}
1414

15-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
15+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
1616
pub struct Ordinal(pub Option<i64>);
1717

1818
impl Ordinal {
@@ -114,7 +114,7 @@ pub struct SourceChangeMessage {
114114
pub ack_fn: Option<Box<dyn FnOnce() -> BoxFuture<'static, Result<()>> + Send + Sync>>,
115115
}
116116

117-
#[derive(Debug, Default)]
117+
#[derive(Debug, Default, Serialize)]
118118
pub struct SourceExecutorReadOptions {
119119
/// When set to true, the implementation must return a non-None `ordinal`.
120120
pub include_ordinal: bool,

0 commit comments

Comments
 (0)