1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515from __future__ import annotations
16- from abc import ABC , abstractmethod
17- from typing import Optional , Callable , Any
16+ import types
17+ import inspect
18+ from abc import ABC , abstractmethod , ABCMeta
19+ from typing import Optional , Callable , Any , TypeVar
20+ from typing_extensions import ParamSpec
21+
1822import neo4j
1923
2024from neo4j_genai .types import RawSearchResult , RetrieverResult , RetrieverResultItem
2125from neo4j_genai .exceptions import Neo4jVersionError
2226
27+ T = ParamSpec ("T" )
28+ P = TypeVar ("P" )
29+
30+
31+ def copy_function (f : Callable [T , P ]) -> Callable [T , P ]:
32+ """Based on https://stackoverflow.com/a/30714299"""
33+ g = types .FunctionType (
34+ f .__code__ ,
35+ f .__globals__ ,
36+ name = f .__name__ ,
37+ argdefs = f .__defaults__ ,
38+ closure = f .__closure__ ,
39+ )
40+ # in case f was given attrs (note this dict is a shallow copy):
41+ g .__dict__ .update (f .__dict__ )
42+ return g
43+
44+
45+ class RetrieverMetaclass (ABCMeta ):
46+ """This metaclass is used to copy the docstring from the
47+ `get_search_results` method, instantiated in all subclasses,
48+ to the `search` method in the base class.
49+ """
50+
51+ def __new__ (
52+ meta , name : str , bases : tuple [type , ...], attrs : dict [str , Any ]
53+ ) -> type :
54+ if "search" in attrs :
55+ # search method was explicitly overridden, do nothing
56+ return type .__new__ (meta , name , bases , attrs )
57+ # otherwise, we copy the signature and doc of the get_search_results
58+ # method to a copy of the search method
59+ get_search_results_method = attrs .get ("get_search_results" )
60+ search_method = None
61+ for b in bases :
62+ search_method = getattr (b , "search" , None )
63+ if search_method is not None :
64+ break
65+ if search_method and get_search_results_method :
66+ new_search_method = copy_function (search_method )
67+ new_search_method .__doc__ = get_search_results_method .__doc__
68+ new_search_method .__signature__ = inspect .signature ( # type: ignore
69+ get_search_results_method
70+ )
71+ attrs ["search" ] = new_search_method
72+ return type .__new__ (meta , name , bases , attrs )
73+
2374
24- class Retriever (ABC ):
75+ class Retriever (ABC , metaclass = RetrieverMetaclass ):
2576 """
2677 Abstract class for Neo4j retrievers
2778 """
@@ -78,11 +129,11 @@ def _fetch_index_infos(self) -> None:
78129 raise Exception (f"No index with name { self .index_name } found" ) from e
79130
80131 def search (self , * args : Any , ** kwargs : Any ) -> RetrieverResult :
132+ """Search method. Call the `get_search_results` method that returns
133+ a list of `neo4j.Record`, and format them using the function returned by
134+ `get_result_formatter` to return `RetrieverResult`.
81135 """
82- Search method. Call the get_search_result method that returns
83- a list of neo4j.Record, and format them to return RetrieverResult.
84- """
85- raw_result = self ._get_search_results (* args , ** kwargs )
136+ raw_result = self .get_search_results (* args , ** kwargs )
86137 formatter = self .get_result_formatter ()
87138 search_items = [formatter (record ) for record in raw_result .records ]
88139 metadata = raw_result .metadata or {}
@@ -93,7 +144,20 @@ def search(self, *args: Any, **kwargs: Any) -> RetrieverResult:
93144 )
94145
95146 @abstractmethod
96- def _get_search_results (self , * args : Any , ** kwargs : Any ) -> RawSearchResult :
147+ def get_search_results (self , * args : Any , ** kwargs : Any ) -> RawSearchResult :
148+ """This method must be implemented in each child class. It will
149+ receive the same parameters provided to the public interface via
150+ the `search` method, after validation. It returns a `RawSearchResult`
151+ object which comprises a list of `neo4j.Record` objects and an optional
152+ `metadata` dictionary that can contain retriever-level information.
153+
154+ Note that, even though this method is not intended to be called from
155+ outside the class, we make it public to make it clearer for the developers
156+ that it should be implemented in child classes.
157+
158+ Returns:
159+ RawSearchResult: List of Neo4j Records and optional metadata dict
160+ """
97161 pass
98162
99163 def get_result_formatter (self ) -> Callable [[neo4j .Record ], RetrieverResultItem ]:
@@ -127,7 +191,7 @@ def __init__(
127191 self .id_property_neo4j = id_property_neo4j
128192
129193 @abstractmethod
130- def _get_search_results (
194+ def get_search_results (
131195 self ,
132196 query_vector : Optional [list [float ]] = None ,
133197 query_text : Optional [str ] = None ,
@@ -137,7 +201,7 @@ def _get_search_results(
137201 """
138202
139203 Returns:
140- list[neo4j.Record] : List of Neo4j Records
204+ RawSearchResult : List of Neo4j Records and optional metadata dict
141205
142206 """
143207 pass
0 commit comments