1919 Component ,
2020 Pipeline ,
2121)
22- from neo4j_graphrag .experimental .pipeline .exceptions import PipelineDefinitionError
22+ from neo4j_graphrag .experimental .pipeline .exceptions import (
23+ PipelineDefinitionError ,
24+ PipelineMissingDependencyError ,
25+ PipelineStatusUpdateError ,
26+ )
2327from neo4j_graphrag .experimental .pipeline .orchestrator import Orchestrator
2428from neo4j_graphrag .experimental .pipeline .types import RunStatus
2529
@@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None
3438 pipe .add_component (ComponentPassThrough (), "a" )
3539 pipe .add_component (ComponentPassThrough (), "b" )
3640 orchestrator = Orchestrator (pipe )
37- with pytest .raises (PipelineDefinitionError ):
41+ with pytest .raises (PipelineDefinitionError ) as exc :
3842 orchestrator .get_input_config_for_task (pipe .get_node_by_name ("a" ))
43+ assert "You must validate the pipeline input config first" in str (exc .value )
3944
4045
4146@pytest .mark .asyncio
@@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None:
5964 "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component"
6065)
6166@pytest .mark .asyncio
62- async def test_pipeline_get_component_inputs_from_parent_specific (
67+ async def test_orchestrator_get_component_inputs_from_parent_specific (
6368 mock_result : Mock ,
6469) -> None :
65- """Propagate one specific output field from 'a' to the next component."""
70+ """Propagate one specific output field from parent to a child component."""
6671 pipe = Pipeline ()
6772 pipe .add_component (ComponentPassThrough (), "a" )
6873 pipe .add_component (ComponentPassThrough (), "b" )
@@ -164,6 +169,56 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_
164169 )
165170
166171
172+ @pytest .mark .asyncio
173+ @patch (
174+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
175+ )
176+ @pytest .mark .parametrize (
177+ "old_status, new_status, result" ,
178+ [
179+ # Normal path: from UNKNOWN to RUNNING to DONE
180+ (RunStatus .UNKNOWN , RunStatus .RUNNING , "ok" ),
181+ (RunStatus .RUNNING , RunStatus .DONE , "ok" ),
182+ # Error: status is already set to this value
183+ (RunStatus .RUNNING , RunStatus .RUNNING , "Status is already RunStatus.RUNNING" ),
184+ (RunStatus .DONE , RunStatus .DONE , "Status is already RunStatus.DONE" ),
185+ # Error: can't go back in time
186+ (
187+ RunStatus .DONE ,
188+ RunStatus .RUNNING ,
189+ "Can't go from RunStatus.DONE to RunStatus.RUNNING" ,
190+ ),
191+ (
192+ RunStatus .RUNNING ,
193+ RunStatus .UNKNOWN ,
194+ "Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN" ,
195+ ),
196+ (
197+ RunStatus .DONE ,
198+ RunStatus .UNKNOWN ,
199+ "Can't go from RunStatus.DONE to RunStatus.UNKNOWN" ,
200+ ),
201+ ],
202+ )
203+ async def test_orchestrator_set_component_status (
204+ mock_status : Mock ,
205+ old_status : RunStatus ,
206+ new_status : RunStatus ,
207+ result : str ,
208+ ) -> None :
209+ pipe = Pipeline ()
210+ orchestrator = Orchestrator (pipeline = pipe )
211+ mock_status .side_effect = [
212+ old_status ,
213+ ]
214+ if result == "ok" :
215+ await orchestrator .set_task_status ("task_name" , new_status )
216+ else :
217+ with pytest .raises (PipelineStatusUpdateError ) as exc :
218+ await orchestrator .set_task_status ("task_name" , new_status )
219+ assert result in str (exc )
220+
221+
167222@pytest .fixture (scope = "function" )
168223def pipeline_branch () -> Pipeline :
169224 pipe = Pipeline ()
@@ -190,21 +245,45 @@ def pipeline_aggregation() -> Pipeline:
190245@patch (
191246 "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
192247)
193- async def test_orchestrator_branch (
248+ async def test_orchestrator_check_dependency_complete (
194249 mock_status : Mock , pipeline_branch : Pipeline
250+ ) -> None :
251+ """a -> b, c"""
252+ orchestrator = Orchestrator (pipeline = pipeline_branch )
253+ node_a = pipeline_branch .get_node_by_name ("a" )
254+ await orchestrator .check_dependencies_complete (node_a )
255+ node_b = pipeline_branch .get_node_by_name ("b" )
256+ # dependency is DONE:
257+ mock_status .side_effect = [RunStatus .DONE ]
258+ await orchestrator .check_dependencies_complete (node_b )
259+ # dependency is not DONE:
260+ mock_status .side_effect = [RunStatus .RUNNING ]
261+ with pytest .raises (PipelineMissingDependencyError ):
262+ await orchestrator .check_dependencies_complete (node_b )
263+
264+
265+ @pytest .mark .asyncio
266+ @patch (
267+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
268+ )
269+ @patch (
270+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
271+ )
272+ async def test_orchestrator_next_task_branch_no_missing_dependencies (
273+ mock_dep : Mock , mock_status : Mock , pipeline_branch : Pipeline
195274) -> None :
196275 """a -> b, c"""
197276 orchestrator = Orchestrator (pipeline = pipeline_branch )
198277 node_a = pipeline_branch .get_node_by_name ("a" )
199278 mock_status .side_effect = [
200- # next b
279+ # next "b"
201280 RunStatus .UNKNOWN ,
202- # dep of b = a
203- RunStatus .DONE ,
204- # next c
281+ # next "c"
205282 RunStatus .UNKNOWN ,
206- # dep of c = a
207- RunStatus .DONE ,
283+ ]
284+ mock_dep .side_effect = [
285+ None , # "b" has no missing dependencies
286+ None , # "c" has no missing dependencies
208287 ]
209288 next_tasks = [n async for n in orchestrator .next (node_a )]
210289 next_task_names = [n .name for n in next_tasks ]
@@ -215,31 +294,48 @@ async def test_orchestrator_branch(
215294@patch (
216295 "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
217296)
218- async def test_orchestrator_aggregation (
219- mock_status : Mock , pipeline_aggregation : Pipeline
297+ @patch (
298+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
299+ )
300+ async def test_orchestrator_next_task_branch_missing_dependencies (
301+ mock_dep : Mock , mock_status : Mock , pipeline_branch : Pipeline
220302) -> None :
221- """a, b -> c"""
222- orchestrator = Orchestrator (pipeline = pipeline_aggregation )
223- node_a = pipeline_aggregation .get_node_by_name ("a" )
303+ """a -> b, c"""
304+ orchestrator = Orchestrator (pipeline = pipeline_branch )
305+ node_a = pipeline_branch .get_node_by_name ("a" )
224306 mock_status .side_effect = [
225- # next c:
307+ # next "b"
226308 RunStatus .UNKNOWN ,
227- # dep of c = a
228- RunStatus .DONE ,
229- # dep of c = b
309+ # next "c"
230310 RunStatus .UNKNOWN ,
231311 ]
232- next_task_names = [n .name async for n in orchestrator .next (node_a )]
233- # "c" dependencies not ready yet
234- assert next_task_names == []
235- # set "b" to DONE
312+ mock_dep .side_effect = [
313+ PipelineMissingDependencyError , # "b" has missing dependencies
314+ None , # "c" has no missing dependencies
315+ ]
316+ next_tasks = [n async for n in orchestrator .next (node_a )]
317+ next_task_names = [n .name for n in next_tasks ]
318+ assert next_task_names == ["c" ]
319+
320+
321+ @pytest .mark .asyncio
322+ @patch (
323+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
324+ )
325+ @patch (
326+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
327+ )
328+ async def test_orchestrator_next_task_aggregation_no_missing_dependencies (
329+ mock_dep : Mock , mock_status : Mock , pipeline_aggregation : Pipeline
330+ ) -> None :
331+ """a, b -> c"""
332+ orchestrator = Orchestrator (pipeline = pipeline_aggregation )
333+ node_a = pipeline_aggregation .get_node_by_name ("a" )
236334 mock_status .side_effect = [
237- # next c:
238- RunStatus .UNKNOWN ,
239- # dep of c = a
240- RunStatus .DONE ,
241- # dep of c = b
242- RunStatus .DONE ,
335+ RunStatus .UNKNOWN , # status for "c", not started
336+ ]
337+ mock_dep .side_effect = [
338+ None , # no missing deps
243339 ]
244340 # then "c" can start
245341 next_tasks = [n async for n in orchestrator .next (node_a )]
@@ -248,8 +344,41 @@ async def test_orchestrator_aggregation(
248344
249345
250346@pytest .mark .asyncio
251- async def test_orchestrator_aggregation_waiting (pipeline_aggregation : Pipeline ) -> None :
347+ @patch (
348+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
349+ )
350+ @patch (
351+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
352+ )
353+ async def test_orchestrator_next_task_aggregation_missing_dependency (
354+ mock_dep : Mock , mock_status : Mock , pipeline_aggregation : Pipeline
355+ ) -> None :
356+ """a, b -> c"""
252357 orchestrator = Orchestrator (pipeline = pipeline_aggregation )
253358 node_a = pipeline_aggregation .get_node_by_name ("a" )
254- next_tasks = [n async for n in orchestrator .next (node_a )]
255- assert next_tasks == []
359+ mock_status .side_effect = [
360+ RunStatus .UNKNOWN , # status for "c" is unknown, it's a possible next
361+ ]
362+ mock_dep .side_effect = [
363+ PipelineMissingDependencyError , # some dependencies are not done yet
364+ ]
365+ next_task_names = [n .name async for n in orchestrator .next (node_a )]
366+ # "c" dependencies not ready yet
367+ assert next_task_names == []
368+
369+
370+ @pytest .mark .asyncio
371+ @patch (
372+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
373+ )
374+ async def test_orchestrator_next_task_aggregation_next_already_started (
375+ mock_status : Mock , pipeline_aggregation : Pipeline
376+ ) -> None :
377+ """a, b -> c"""
378+ orchestrator = Orchestrator (pipeline = pipeline_aggregation )
379+ node_a = pipeline_aggregation .get_node_by_name ("a" )
380+ mock_status .side_effect = [
381+ RunStatus .RUNNING , # status for "c" is already running, do not start it again
382+ ]
383+ next_task_names = [n .name async for n in orchestrator .next (node_a )]
384+ assert next_task_names == []
0 commit comments