|
9 | 9 | Any, |
10 | 10 | Awaitable, |
11 | 11 | Callable, |
| 12 | + Iterator, |
12 | 13 | Protocol, |
13 | 14 | dataclass_transform, |
14 | 15 | Annotated, |
| 16 | + TypeVar, |
| 17 | + Generic, |
| 18 | + Literal, |
15 | 19 | get_args, |
16 | 20 | ) |
| 21 | +from collections.abc import AsyncIterator |
17 | 22 |
|
18 | 23 | from . import _engine # type: ignore |
19 | 24 | from .subprocess_exec import executor_stub |
20 | 25 | from .engine_object import dump_engine_object, load_engine_object |
21 | 26 | from .engine_value import ( |
| 27 | + make_engine_key_encoder, |
22 | 28 | make_engine_value_encoder, |
23 | 29 | make_engine_value_decoder, |
24 | 30 | make_engine_key_decoder, |
25 | 31 | make_engine_struct_decoder, |
26 | 32 | ) |
27 | 33 | from .typing import ( |
| 34 | + KEY_FIELD_NAME, |
| 35 | + AnalyzedTypeInfo, |
| 36 | + StructSchema, |
| 37 | + StructType, |
| 38 | + TableType, |
28 | 39 | TypeAttr, |
29 | 40 | encode_enriched_type_info, |
30 | 41 | resolve_forward_ref, |
@@ -96,12 +107,12 @@ class Executor(Protocol): |
96 | 107 | op_category: OpCategory |
97 | 108 |
|
98 | 109 |
|
99 | | -def _get_required_method(cls: type, name: str) -> Callable[..., Any]: |
100 | | - method = getattr(cls, name, None) |
| 110 | +def _get_required_method(obj: type, name: str) -> Callable[..., Any]: |
| 111 | + method = getattr(obj, name, None) |
101 | 112 | if method is None: |
102 | | - raise ValueError(f"Method {name}() is required for {cls.__name__}") |
103 | | - if not inspect.isfunction(method): |
104 | | - raise ValueError(f"Method {cls.__name__}.{name}() is not a function") |
| 113 | + raise ValueError(f"Method {name}() is required for {obj}") |
| 114 | + if not inspect.isfunction(method) and not inspect.ismethod(method): |
| 115 | + raise ValueError(f"{obj}.{name}() is not a function; {method}") |
105 | 116 | return method |
106 | 117 |
|
107 | 118 |
|
@@ -421,6 +432,233 @@ def _inner(fn: Callable[..., Any]) -> Callable[..., Any]: |
421 | 432 | return _inner |
422 | 433 |
|
423 | 434 |
|
| 435 | +######################################################## |
| 436 | +# Custom source connector |
| 437 | +######################################################## |
| 438 | + |
| 439 | + |
| 440 | +@dataclasses.dataclass |
| 441 | +class SourceReadOptions: |
| 442 | + include_ordinal: bool = False |
| 443 | + include_content_version_fp: bool = False |
| 444 | + include_value: bool = False |
| 445 | + |
| 446 | + |
| 447 | +K = TypeVar("K") |
| 448 | +V = TypeVar("V") |
| 449 | + |
| 450 | +NON_EXISTENCE: Literal["NON_EXISTENCE"] = "NON_EXISTENCE" |
| 451 | +NO_ORDINAL: Literal["NO_ORDINAL"] = "NO_ORDINAL" |
| 452 | + |
| 453 | + |
| 454 | +@dataclasses.dataclass |
| 455 | +class PartialSourceRowData(Generic[V]): |
| 456 | + """ |
| 457 | + The data of a source row. |
| 458 | +
|
| 459 | + - value: The value of the source row. NON_EXISTENCE means the row does not exist. |
| 460 | + - ordinal: The ordinal of the source row. NO_ORDINAL means ordinal is not available for the source. |
| 461 | + - content_version_fp: The content version fingerprint of the source row. |
| 462 | + """ |
| 463 | + |
| 464 | + value: V | Literal["NON_EXISTENCE"] | None = None |
| 465 | + ordinal: int | Literal["NO_ORDINAL"] | None = None |
| 466 | + content_version_fp: bytes | None = None |
| 467 | + |
| 468 | + |
| 469 | +@dataclasses.dataclass |
| 470 | +class PartialSourceRow(Generic[K, V]): |
| 471 | + key: K |
| 472 | + data: PartialSourceRowData[V] |
| 473 | + |
| 474 | + |
| 475 | +class _SourceExecutorContext: |
| 476 | + _executor: Any |
| 477 | + |
| 478 | + _key_encoder: Callable[[Any], Any] |
| 479 | + _key_decoder: Callable[[Any], Any] |
| 480 | + |
| 481 | + _value_encoder: Callable[[Any], Any] |
| 482 | + |
| 483 | + _list_fn: Callable[ |
| 484 | + [SourceReadOptions], |
| 485 | + AsyncIterator[PartialSourceRow[Any, Any]] |
| 486 | + | Iterator[PartialSourceRow[Any, Any]], |
| 487 | + ] |
| 488 | + _get_value_fn: Callable[ |
| 489 | + [Any, SourceReadOptions], Awaitable[PartialSourceRowData[Any]] |
| 490 | + ] |
| 491 | + _provides_ordinal_fn: Callable[[], bool] | None |
| 492 | + |
| 493 | + def __init__( |
| 494 | + self, |
| 495 | + executor: Any, |
| 496 | + key_type_info: AnalyzedTypeInfo, |
| 497 | + key_decoder: Callable[[Any], Any], |
| 498 | + value_type_info: AnalyzedTypeInfo, |
| 499 | + ): |
| 500 | + self._executor = executor |
| 501 | + |
| 502 | + self._key_encoder = make_engine_key_encoder(key_type_info) |
| 503 | + self._key_decoder = key_decoder |
| 504 | + self._value_encoder = make_engine_value_encoder(value_type_info) |
| 505 | + |
| 506 | + self._list_fn = _get_required_method(executor, "list") |
| 507 | + self._get_value_fn = to_async_call(_get_required_method(executor, "get_value")) |
| 508 | + self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None) |
| 509 | + |
| 510 | + def provides_ordinal(self) -> bool: |
| 511 | + if self._provides_ordinal_fn is not None: |
| 512 | + result = self._provides_ordinal_fn() |
| 513 | + return bool(result) |
| 514 | + else: |
| 515 | + return False |
| 516 | + |
| 517 | + async def list_async( |
| 518 | + self, options: dict[str, Any] |
| 519 | + ) -> AsyncIterator[tuple[Any, dict[str, Any]]]: |
| 520 | + """ |
| 521 | + Return an async iterator that yields individual rows one by one. |
| 522 | + Each yielded item is a tuple of (key, data). |
| 523 | + """ |
| 524 | + # Convert the options dict to SourceReadOptions |
| 525 | + read_options = load_engine_object(SourceReadOptions, options) |
| 526 | + |
| 527 | + # Call the user's list method |
| 528 | + list_result = self._list_fn(read_options) |
| 529 | + |
| 530 | + # Handle both sync and async iterators |
| 531 | + if hasattr(list_result, "__aiter__"): |
| 532 | + async for partial_row in list_result: |
| 533 | + yield ( |
| 534 | + self._key_encoder(partial_row.key), |
| 535 | + self._encode_source_row_data(partial_row.data), |
| 536 | + ) |
| 537 | + else: |
| 538 | + for partial_row in list_result: |
| 539 | + yield ( |
| 540 | + self._key_encoder(partial_row.key), |
| 541 | + self._encode_source_row_data(partial_row.data), |
| 542 | + ) |
| 543 | + |
| 544 | + async def get_value_async( |
| 545 | + self, |
| 546 | + raw_key: Any, |
| 547 | + options: dict[str, Any], |
| 548 | + ) -> dict[str, Any]: |
| 549 | + key = self._key_decoder(raw_key) |
| 550 | + read_options = load_engine_object(SourceReadOptions, options) |
| 551 | + |
| 552 | + row_data = await self._get_value_fn(key, read_options) |
| 553 | + return self._encode_source_row_data(row_data) |
| 554 | + |
| 555 | + def _encode_source_row_data( |
| 556 | + self, row_data: PartialSourceRowData[Any] |
| 557 | + ) -> dict[str, Any]: |
| 558 | + """Convert Python PartialSourceRowData to the format expected by Rust.""" |
| 559 | + return { |
| 560 | + "ordinal": row_data.ordinal, |
| 561 | + "content_version_fp": row_data.content_version_fp, |
| 562 | + "value": ( |
| 563 | + NON_EXISTENCE |
| 564 | + if row_data.value == NON_EXISTENCE |
| 565 | + else self._value_encoder(row_data.value) |
| 566 | + ), |
| 567 | + } |
| 568 | + |
| 569 | + |
| 570 | +class _SourceConnector: |
| 571 | + """ |
| 572 | + The connector class passed to the engine. |
| 573 | + """ |
| 574 | + |
| 575 | + _spec_cls: type[Any] |
| 576 | + _key_type_info: AnalyzedTypeInfo |
| 577 | + _key_decoder: Callable[[Any], Any] |
| 578 | + _value_type_info: AnalyzedTypeInfo |
| 579 | + _table_type: EnrichedValueType |
| 580 | + _connector_cls: type[Any] |
| 581 | + |
| 582 | + _create_fn: Callable[[Any], Awaitable[Any]] |
| 583 | + |
| 584 | + def __init__( |
| 585 | + self, |
| 586 | + spec_cls: type[Any], |
| 587 | + key_type: Any, |
| 588 | + value_type: Any, |
| 589 | + connector_cls: type[Any], |
| 590 | + ): |
| 591 | + self._spec_cls = spec_cls |
| 592 | + self._key_type_info = analyze_type_info(key_type) |
| 593 | + self._value_type_info = analyze_type_info(value_type) |
| 594 | + self._connector_cls = connector_cls |
| 595 | + |
| 596 | + # TODO: We can save the intermediate step after #1083 is fixed. |
| 597 | + encoded_engine_key_type = encode_enriched_type_info(self._key_type_info) |
| 598 | + engine_key_type = EnrichedValueType.decode(encoded_engine_key_type) |
| 599 | + |
| 600 | + # TODO: We can save the intermediate step after #1083 is fixed. |
| 601 | + encoded_engine_value_type = encode_enriched_type_info(self._value_type_info) |
| 602 | + engine_value_type = EnrichedValueType.decode(encoded_engine_value_type) |
| 603 | + |
| 604 | + if not isinstance(engine_value_type.type, StructType): |
| 605 | + raise ValueError(f"Expected a StructType, got {engine_value_type.type}") |
| 606 | + |
| 607 | + if isinstance(engine_key_type.type, StructType): |
| 608 | + key_fields_schema = engine_key_type.type.fields |
| 609 | + else: |
| 610 | + key_fields_schema = [ |
| 611 | + FieldSchema(name=KEY_FIELD_NAME, value_type=engine_key_type) |
| 612 | + ] |
| 613 | + self._key_decoder = make_engine_key_decoder( |
| 614 | + [], key_fields_schema, self._key_type_info |
| 615 | + ) |
| 616 | + self._table_type = EnrichedValueType( |
| 617 | + type=TableType( |
| 618 | + kind="KTable", |
| 619 | + row=StructSchema( |
| 620 | + fields=key_fields_schema + engine_value_type.type.fields |
| 621 | + ), |
| 622 | + num_key_parts=len(key_fields_schema), |
| 623 | + ), |
| 624 | + ) |
| 625 | + |
| 626 | + self._create_fn = to_async_call(_get_required_method(connector_cls, "create")) |
| 627 | + |
| 628 | + async def create_executor(self, raw_spec: dict[str, Any]) -> _SourceExecutorContext: |
| 629 | + spec = load_engine_object(self._spec_cls, raw_spec) |
| 630 | + executor = await self._create_fn(spec) |
| 631 | + return _SourceExecutorContext( |
| 632 | + executor, self._key_type_info, self._key_decoder, self._value_type_info |
| 633 | + ) |
| 634 | + |
| 635 | + def get_table_type(self) -> Any: |
| 636 | + return dump_engine_object(self._table_type) |
| 637 | + |
| 638 | + |
| 639 | +def source_connector( |
| 640 | + *, |
| 641 | + spec_cls: type[Any], |
| 642 | + key_type: Any = Any, |
| 643 | + value_type: Any = Any, |
| 644 | +) -> Callable[[type], type]: |
| 645 | + """ |
| 646 | + Decorate a class to provide a source connector for an op. |
| 647 | + """ |
| 648 | + |
| 649 | + # Validate the spec_cls is a SourceSpec. |
| 650 | + if not issubclass(spec_cls, SourceSpec): |
| 651 | + raise ValueError(f"Expect a SourceSpec, got {spec_cls}") |
| 652 | + |
| 653 | + # Register the source connector. |
| 654 | + def _inner(connector_cls: type) -> type: |
| 655 | + connector = _SourceConnector(spec_cls, key_type, value_type, connector_cls) |
| 656 | + _engine.register_source_connector(spec_cls.__name__, connector) |
| 657 | + return connector_cls |
| 658 | + |
| 659 | + return _inner |
| 660 | + |
| 661 | + |
424 | 662 | ######################################################## |
425 | 663 | # Custom target connector |
426 | 664 | ######################################################## |
|
0 commit comments