diff --git a/cf_pipelines/base/pipeline.py b/cf_pipelines/base/pipeline.py index 36b220c..7d77b6e 100644 --- a/cf_pipelines/base/pipeline.py +++ b/cf_pipelines/base/pipeline.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union +import networkx as nx from ploomber import DAG from ploomber.executors import Serial from ploomber.io import serializer_pickle, unserializer_pickle @@ -15,6 +16,7 @@ from cf_pipelines.base.helper_classes import FunctionDetails, ProductLineage from cf_pipelines.base.utils import get_return_keys_from_function, remove_extension, wrap_preserving_signature +from cf_pipelines.exceptions import CycledPipelineError class Pipeline: @@ -287,6 +289,16 @@ def make_dag(self) -> DAG: for function_name, dependencies in solved_dependencies.items(): for dependency in dependencies: callables[dependency] >> callables[function_name] + + try: + # TODO: investigate how reliable is accessing `_G` + # TODO: ask for a way to find cycles eagerly without building (executing) the pipeline? + cycles = nx.find_cycle(dag._G) + raise CycledPipelineError(cycles) + except nx.NetworkXNoCycle: + # Ironically, this means everything is good! so we just ignore this exception + pass + return dag def generate_run_id(self) -> str: diff --git a/cf_pipelines/exceptions.py b/cf_pipelines/exceptions.py new file mode 100644 index 0000000..3b1a95a --- /dev/null +++ b/cf_pipelines/exceptions.py @@ -0,0 +1,10 @@ +from typing import List, Tuple + + +class CycledPipelineError(Exception): + def __init__(self, cycles: List[Tuple[str, str]]): + self.cycles = cycles + + message = "\n".join([f'f "{f1}" depends on "{f2}".' for f1, f2 in cycles]) + + super().__init__("Your pipeline contains cycles: " + message) diff --git a/tests/base/test_pipelines_loops.py b/tests/base/test_pipelines_loops.py new file mode 100644 index 0000000..2edc4ff --- /dev/null +++ b/tests/base/test_pipelines_loops.py @@ -0,0 +1,42 @@ +import pytest + +from cf_pipelines import Pipeline +from cf_pipelines.exceptions import CycledPipelineError + + +@pytest.fixture +def simple_looped_pipeline(parse_indented): + simple = Pipeline("Simple Pipeline") + + @simple.step("step_1") + def one(*, one): + return {"one.txt": None} + + return simple + + +@pytest.fixture +def looped_pipeline(parse_indented): + simple = Pipeline("Simple Pipeline") + + @simple.step("step") + def f_one(): + return {"two.txt": None} + + @simple.step("step") + def f_two(*, two, four): + return {"three.txt": None} + + @simple.step("step") + def f_three(*, three): + return {"four.txt": None} + + return simple + + +@pytest.mark.parametrize("pipeline", ["simple_looped_pipeline", "looped_pipeline"]) +def test_make_dag_fail_when_looped(pipeline, request): + actual_pipeline = request.getfixturevalue(pipeline) + + with pytest.raises(CycledPipelineError): + actual_pipeline.make_dag()