|
21 | 21 | from __future__ import annotations |
22 | 22 |
|
23 | 23 | import asyncio |
| 24 | +from typing import List |
24 | 25 |
|
25 | 26 | import neo4j |
26 | 27 | from neo4j_genai.embeddings.openai import OpenAIEmbeddings |
27 | 28 | from neo4j_genai.experimental.pipeline import Component, Pipeline |
28 | 29 | from neo4j_genai.experimental.pipeline.component import DataModel |
| 30 | +from neo4j_genai.experimental.pipeline.pipeline import PipelineResult |
29 | 31 | from neo4j_genai.experimental.pipeline.types import ( |
30 | 32 | ComponentConfig, |
31 | 33 | ConnectionConfig, |
|
37 | 39 | from neo4j_genai.retrievers.base import Retriever |
38 | 40 |
|
39 | 41 |
|
40 | | -class StringDataModel(DataModel): |
41 | | - result: str |
| 42 | +class ComponentResultDataModel(DataModel): |
| 43 | + """A simple DataModel with a single text field""" |
| 44 | + |
| 45 | + text: str |
42 | 46 |
|
43 | 47 |
|
44 | 48 | class RetrieverComponent(Component): |
45 | 49 | def __init__(self, retriever: Retriever) -> None: |
46 | 50 | self.retriever = retriever |
47 | 51 |
|
48 | | - async def run(self, query: str) -> StringDataModel: |
| 52 | + async def run(self, query: str) -> ComponentResultDataModel: |
49 | 53 | res = self.retriever.search(query_text=query) |
50 | | - return StringDataModel(result="\n".join(c.content for c in res.items)) |
| 54 | + return ComponentResultDataModel(text="\n".join(c.content for c in res.items)) |
51 | 55 |
|
52 | 56 |
|
53 | 57 | class PromptTemplateComponent(Component): |
54 | 58 | def __init__(self, prompt: PromptTemplate) -> None: |
55 | 59 | self.prompt = prompt |
56 | 60 |
|
57 | | - async def run(self, query: str, context: list[str]) -> StringDataModel: |
| 61 | + async def run(self, query: str, context: List[str]) -> ComponentResultDataModel: |
58 | 62 | prompt = self.prompt.format(query, context, examples="") |
59 | | - return StringDataModel(result=prompt) |
| 63 | + return ComponentResultDataModel(text=prompt) |
60 | 64 |
|
61 | 65 |
|
62 | 66 | class LLMComponent(Component): |
63 | 67 | def __init__(self, llm: LLMInterface) -> None: |
64 | 68 | self.llm = llm |
65 | 69 |
|
66 | | - async def run(self, prompt: str) -> StringDataModel: |
| 70 | + async def run(self, prompt: str) -> ComponentResultDataModel: |
67 | 71 | llm_response = self.llm.invoke(prompt) |
68 | | - return StringDataModel(result=llm_response.content) |
| 72 | + return ComponentResultDataModel(text=llm_response.content) |
69 | 73 |
|
70 | 74 |
|
71 | 75 | if __name__ == "__main__": |
@@ -96,21 +100,21 @@ async def run(self, prompt: str) -> StringDataModel: |
96 | 100 | ConnectionConfig( |
97 | 101 | start="retrieve", |
98 | 102 | end="augment", |
99 | | - input_config={"context": "retrieve.result"}, |
| 103 | + input_config={"context": "retrieve.text"}, |
100 | 104 | ), |
101 | 105 | ConnectionConfig( |
102 | 106 | start="augment", |
103 | 107 | end="generate", |
104 | | - input_config={"prompt": "augment.result"}, |
| 108 | + input_config={"prompt": "augment.text"}, |
105 | 109 | ), |
106 | 110 | ], |
107 | 111 | ) |
108 | 112 | ) |
109 | 113 |
|
110 | 114 | query = "A movie about the US presidency" |
111 | | - result = asyncio.run( |
| 115 | + pipe_output: PipelineResult = asyncio.run( |
112 | 116 | pipe.run({"retrieve": {"query": query}, "augment": {"query": query}}) |
113 | 117 | ) |
114 | | - print(result["generate"]["result"]) |
| 118 | + print(pipe_output.result["generate"]["text"]) |
115 | 119 |
|
116 | 120 | driver.close() |
0 commit comments