diff --git a/invokeai/app/services/shared/README.md b/invokeai/app/services/shared/README.md new file mode 100644 index 00000000000..a65a8ebe495 --- /dev/null +++ b/invokeai/app/services/shared/README.md @@ -0,0 +1,194 @@ +# InvokeAI Graph - Design Overview + +High-level design for the graph module. Focuses on responsibilities, data flow, and how traversal works. + +## 1) Purpose + +Provide a typed, acyclic workflow model (**Graph**) plus a runtime scheduler (**GraphExecutionState**) that expands +iterator patterns, tracks readiness via indegree (the number of incoming edges to a node in the directed graph), and +executes nodes in class-grouped batches. Source graphs remain immutable during a run; runtime expansion happens in a +separate execution graph. + +## 2) Major Data Types + +### EdgeConnection + +* Fields: `node_id: str`, `field: str`. +* Hashable; printed as `node.field` for readable diagnostics. + +### Edge + +* Fields: `source: EdgeConnection`, `destination: EdgeConnection`. +* One directed connection from a specific output port to a specific input port. + +### AnyInvocation / AnyInvocationOutput + +* Pydantic wrappers that carry concrete invocation models and outputs. +* No registry logic in this file; they are permissive containers for heterogeneous nodes. + +### IterateInvocation / CollectInvocation + +* Control nodes used by validation and execution: + + * **IterateInvocation**: input `collection`, outputs include `item` (and index/total). + * **CollectInvocation**: many `item` inputs aggregated to one `collection` output. + +## 3) Graph (author-time model) + +A container for declared nodes and edges. Does **not** perform iteration expansion. + +### 3.1 Data + +* `nodes: dict[str, AnyInvocation]` - key must equal `node.id`. +* `edges: list[Edge]` - zero or more. +* Utility: `_get_input_edges(node_id, field?)`, `_get_output_edges(node_id, field?)` + These scan `self.edges` (no adjacency indices in the current code). + +### 3.2 Validation (`validate_self`) + +Runs a sequence of checks: + +1. **Node ID uniqueness** + No duplicate IDs; map key equals `node.id`. +2. **Endpoint existence** + Source and destination node IDs must exist. +3. **Port existence** + Input ports must exist on the node class; output ports on the node's output model. +4. **Type compatibility** + `get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`. +5. **DAG constraint** + Build a *flat* `DiGraph` (no runtime expansion) and assert acyclicity. +6. **Iterator / collector structure** + Enforce special rules: + + * Iterator's input must be `collection`; its outgoing edges use `item`. + * Collector accepts many `item` inputs; outputs a single `collection`. + * Edge fan-in to a non-collector input is rejected. + +### 3.3 Edge admission (`_validate_edge`) + +Checks a single prospective edge before insertion: + +* Endpoints/ports exist. +* Destination port is not already occupied unless it's a collector `item`. +* Adding the edge to the flat DAG must keep it acyclic. +* Iterator/collector constraints re-checked when the edge creates relevant patterns. + +### 3.4 Topology utilities + +* `nx_graph()` - DiGraph of declared nodes and edges. +* `nx_graph_with_data()` - includes node/edge attributes. +* `nx_graph_flat()` - "flattened" DAG (still author-time; no runtime copies). + Used in validation and in `_prepare()` during execution planning. + +### 3.5 Mutation helpers + +* `add_node`, `update_node` (preserve edges, rewrite endpoints if id changes), `delete_node`. +* `add_edge`, `delete_edge` (with validation). + +## 4) GraphExecutionState (runtime) + +Holds the state for a single run. Keeps the source graph intact; materializes a separate execution graph. + +### 4.1 Data + +* `graph: Graph` - immutable source during a run. +* `execution_graph: Graph` - materialized runtime nodes/edges. +* `executed: set[str]`, `executed_history: list[str]`. +* `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`. +* `prepared_source_mapping: dict[str, str]` - exec id → source id. +* `source_prepared_mapping: dict[str, set[str]]` - source id → exec ids. +* `indegree: dict[str, int]` - unmet inputs per exec node. +* **Ready queues grouped by class** (private attrs): + `_ready_queues: dict[class_name, deque[str]]`, `_active_class: Optional[str]`. Optional `ready_order: list[str]` to + prioritize classes. + +### 4.2 Core methods + +* `next()` + Returns the next ready exec node. If none, calls `_prepare()` to materialize more, then retries. Before returning a + node, `_prepare_inputs()` deep-copies inbound values into the node fields. +* `complete(node_id, output)` + Record result; mark exec node executed; if all exec copies of the same **source** are done, mark the source executed. + For each outgoing exec edge, decrement child indegree and enqueue when it reaches zero. + +### 4.3 Preparation (`_prepare()`) + +* Build a flat DAG from the **source** graph. +* Choose the **next source node** in topological order that: + + 1. has not been prepared, + 2. if it is an iterator, *its inputs are already executed*, + 3. it has *no unexecuted iterator ancestors*. +* If the node is a **CollectInvocation**: collapse all prepared parents into one mapping and create **one** exec node. +* Otherwise: compute all combinations of prepared iterator ancestors. For each combination, pick the matching prepared parent per upstream and create **one** exec node. +* For each new exec node: + + * Deep-copy the source node; assign a fresh ID (and `index` for iterators). + * Wire edges from chosen prepared parents. + * Set `indegree = number of unmet inputs` (i.e., parents not yet executed). + * If `indegree == 0`, enqueue into its class queue. + +### 4.4 Readiness and batching + +* `_enqueue_if_ready(nid)` enqueues by class name only when `indegree == 0` and not executed. +* `_get_next_node()` drains the `_active_class` queue FIFO; when empty, selects the next nonempty class queue (by `ready_order` if set, else alphabetical), and continues. Optional fairness knobs can limit batch size per class; default is drain fully. + +#### 4.4.1 Indegree (what it is and how it's used) + +**Indegree** is the number of incoming edges to a node in the execution graph that are still unmet. In this engine: +* For every materialized exec node, `indegree[node]` equals the count of its prerequisite parents that have **not** finished yet. +* A node is "ready" exactly when `indegree[node] == 0`; only then is it enqueued. +* When a node completes, the scheduler decrements `indegree[child]` for each outgoing edge. Any child that reaches 0 is enqueued. + +Example: edges `A→C`, `B→C`, `C→D`. Start: `A:0, B:0, C:2, D:1`. Run `A` → `C:1`. Run `B` → `C:0` → enqueue `C`. Run `C` +→ `D:0` → enqueue `D`. Run `D` → done. + +### 4.5 Input hydration (`_prepare_inputs()`) + +* For **CollectInvocation**: gather all incoming `item` values into `collection`. +* For all others: deep-copy each incoming edge's value into the destination field. + This prevents cross-node mutation through shared references. + +## 5) Traversal Summary + +1. Author builds a valid **Graph**. +2. Create **GraphExecutionState** with that graph. +3. Loop: + + * `node = state.next()` → may trigger `_prepare()` expansion. + * Execute node externally → `output`. + * `state.complete(node.id, output)` → updates indegrees and queues. +4. Finish when `next()` returns `None`. + +The source graph is never mutated; all expansion occurs in `execution_graph` with traceability back to source nodes. + +## 6) Invariants + +* Source **Graph** remains a DAG and type-consistent. +* `execution_graph` remains a DAG. +* Nodes are enqueued only when `indegree == 0`. +* `results` and `errors` are keyed by **exec node id**. +* Collectors only aggregate `item` inputs; other inputs behave one-to-one. + +## 7) Extensibility + +* **New node types**: implement as Pydantic models with typed fields and outputs. Register per your invocation system; this file accepts them as `AnyInvocation`. +* **Scheduling policy**: adjust `ready_order` to batch by class; add a batch cap for fairness without changing complexity. +* **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` time, as long as the DAG invariant holds. + +## 8) Error Model (selected) + +* `DuplicateNodeIdError`, `NodeAlreadyInGraphError` +* `NodeNotFoundError`, `NodeFieldNotFoundError` +* `InvalidEdgeError`, `CyclicalGraphError` +* `NodeInputError` (raised when preparing inputs for execution) + +Messages favor short, precise diagnostics (node id, field, and failing condition). + +## 9) Rationale + +* **Two-graph approach** isolates authoring from execution expansion and keeps validation simple. +* **Indegree + queues** gives O(1) scheduling decisions with clear batching semantics. +* **Iterator/collector separation** keeps fan-out/fan-in explicit and testable. +* **Deep-copy hydration** avoids incidental aliasing bugs between nodes. diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3cc62e9ecea..2501169edb5 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,7 +2,8 @@ import copy import itertools -from typing import Any, Optional, TypeVar, Union, get_args, get_origin +from collections import deque +from typing import Any, Deque, Iterable, Optional, Type, TypeVar, Union, get_args, get_origin import networkx as nx from pydantic import ( @@ -10,6 +11,7 @@ ConfigDict, GetCoreSchemaHandler, GetJsonSchemaHandler, + PrivateAttr, ValidationError, field_validator, ) @@ -33,6 +35,10 @@ # in 3.10 this would be "from types import NoneType" NoneType = type(None) +# Port name constants +ITEM_FIELD = "item" +COLLECTION_FIELD = "collection" + class EdgeConnection(BaseModel): node_id: str = Field(description="The id of the node for this edge connection") @@ -395,7 +401,7 @@ def delete_edge(self, edge: Edge) -> None: try: self.edges.remove(edge) - except KeyError: + except ValueError: pass def validate_self(self) -> None: @@ -414,7 +420,8 @@ def validate_self(self) -> None: # Validate that all node ids are unique node_ids = [n.id for n in self.nodes.values()] - duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2} + seen = set() + duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)} if duplicate_node_ids: raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") @@ -529,19 +536,19 @@ def _validate_edge(self, edge: Edge): raise InvalidEdgeError(f"Field types are incompatible ({edge})") # Validate if iterator output type matches iterator input type (if this edge results in both being set) - if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": + if isinstance(to_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD: err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source) if err is not None: raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}") # Validate if iterator input type matches output type (if this edge results in both being set) - if isinstance(from_node, IterateInvocation) and edge.source.field == "item": + if isinstance(from_node, IterateInvocation) and edge.source.field == ITEM_FIELD: err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination) if err is not None: raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}") # Validate if collector input type matches output type (if this edge results in both being set) - if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": + if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD: err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source) if err is not None: raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}") @@ -549,7 +556,7 @@ def _validate_edge(self, edge: Edge): # Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any] if ( isinstance(from_node, CollectInvocation) - and edge.source.field == "collection" + and edge.source.field == COLLECTION_FIELD and not self._is_destination_field_list_of_Any(edge) and not self._is_destination_field_Any(edge) ): @@ -639,8 +646,8 @@ def _is_iterator_connection_valid( new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> str | None: - inputs = [e.source for e in self._get_input_edges(node_id, "collection")] - outputs = [e.destination for e in self._get_output_edges(node_id, "item")] + inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)] + outputs = [e.destination for e in self._get_output_edges(node_id, ITEM_FIELD)] if new_input is not None: inputs.append(new_input) @@ -670,7 +677,7 @@ def _is_iterator_connection_valid( if isinstance(input_node, CollectInvocation): # Traverse the graph to find the first collector input edge. Collectors validate that their collection # inputs are all of the same type, so we can use the first input edge to determine the collector's type - first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0] + first_collector_input_edge = self._get_input_edges(input_node.id, ITEM_FIELD)[0] first_collector_input_type = get_output_field_type( self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field ) @@ -690,8 +697,8 @@ def _is_collector_connection_valid( new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> str | None: - inputs = [e.source for e in self._get_input_edges(node_id, "item")] - outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] + inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)] + outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)] if new_input is not None: inputs.append(new_input) @@ -761,7 +768,7 @@ def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph: # TODO: figure out if iteration nodes need to be expanded unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} - g.add_edges_from([(e[0], e[1]) for e in unique_edges]) + g.add_edges_from(unique_edges) return g @@ -802,6 +809,41 @@ class GraphExecutionState(BaseModel): description="The map of original graph nodes to prepared nodes", default_factory=dict, ) + # Ready queues grouped by node class name (internal only) + _ready_queues: dict[str, Deque[str]] = PrivateAttr(default_factory=dict) + # Current class being drained; stays until its queue empties + _active_class: Optional[str] = PrivateAttr(default=None) + # Optional priority; others follow in name order + ready_order: list[str] = Field(default_factory=list) + indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes") + + def _type_key(self, node_obj: BaseInvocation) -> str: + return node_obj.__class__.__name__ + + def _queue_for(self, cls_name: str) -> Deque[str]: + q = self._ready_queues.get(cls_name) + if q is None: + q = deque() + self._ready_queues[cls_name] = q + return q + + def set_ready_order(self, order: Iterable[Type[BaseInvocation] | str]) -> None: + names: list[str] = [] + for x in order: + names.append(x.__name__ if hasattr(x, "__name__") else str(x)) + self.ready_order = names + + def _enqueue_if_ready(self, nid: str) -> None: + """Push nid to its class queue if unmet inputs == 0.""" + # Invariants: exec node exists and has an indegree entry + if nid not in self.execution_graph.nodes: + raise KeyError(f"exec node {nid} missing from execution_graph") + if nid not in self.indegree: + raise KeyError(f"indegree missing for exec node {nid}") + if self.indegree[nid] != 0 or nid in self.executed: + return + node_obj = self.execution_graph.nodes[nid] + self._queue_for(self._type_key(node_obj)).append(nid) model_config = ConfigDict( json_schema_extra={ @@ -834,12 +876,14 @@ def next(self) -> Optional[BaseInvocation]: # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: - prepared_id = self._prepare() + base_g = self.graph.nx_graph_flat() + prepared_id = self._prepare(base_g) # Prepare as many nodes as we can while prepared_id is not None: - prepared_id = self._prepare() - next_node = self._get_next_node() + prepared_id = self._prepare(base_g) + if next_node is None: + next_node = self._get_next_node() # Get values from edges if next_node is not None: @@ -869,6 +913,18 @@ def complete(self, node_id: str, output: BaseInvocationOutput) -> None: self.executed.add(source_node) self.executed_history.append(source_node) + # Decrement children indegree and enqueue when ready + for e in self.execution_graph._get_output_edges(node_id): + child = e.destination.node_id + if child not in self.indegree: + raise KeyError(f"indegree missing for exec node {child}") + # Only decrement if there's something to satisfy + if self.indegree[child] == 0: + raise RuntimeError(f"indegree underflow for {child} from parent {node_id}") + self.indegree[child] -= 1 + if self.indegree[child] == 0: + self._enqueue_if_ready(child) + def set_node_error(self, node_id: str, error: str): """Marks a node as errored""" self.errors[node_id] = error @@ -892,7 +948,7 @@ def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[st # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection"))) + input_collection_edge = next(iter(self.graph._get_input_edges(node_id, COLLECTION_FIELD))) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) @@ -922,7 +978,7 @@ def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[st # Create a new node (or one for each iteration of this iterator) for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]: # Create a new node - new_node = copy.deepcopy(node) + new_node = node.model_copy(deep=True) # Create the node id (use a random uuid) new_node.id = uuid_string() @@ -946,53 +1002,55 @@ def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[st ) self.execution_graph.add_edge(new_edge) + # Initialize indegree as unmet inputs only and enqueue if ready + inputs = self.execution_graph._get_input_edges(new_node.id) + unmet = sum(1 for e in inputs if e.source.node_id not in self.executed) + self.indegree[new_node.id] = unmet + self._enqueue_if_ready(new_node.id) + new_nodes.append(new_node.id) return new_nodes - def _iterator_graph(self) -> nx.DiGraph: + def _iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" - g = self.graph.nx_graph_flat() + g = base.copy() if base is not None else self.graph.nx_graph_flat() collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation)) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) return g - def _get_node_iterators(self, node_id: str) -> list[str]: + def _get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]: """Gets iterators for a node""" - g = self._iterator_graph() - iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] - return iterators + g = it_graph or self._iterator_graph() + return [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] - def _prepare(self) -> Optional[str]: + def _prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]: # Get flattened source graph - g = self.graph.nx_graph_flat() + g = base_g or self.graph.nx_graph_flat() # Find next node that: # - was not already prepared # - is not an iterate node whose inputs have not been executed # - does not have an unexecuted iterate ancestor sorted_nodes = nx.topological_sort(g) + + def unprepared(n: str) -> bool: + return n not in self.source_prepared_mapping + + def iter_inputs_ready(n: str) -> bool: + if not isinstance(self.graph.get_node(n), IterateInvocation): + return True + return all(u in self.executed for u, _ in g.in_edges(n)) + + def no_unexecuted_iter_ancestors(n: str) -> bool: + return not any( + isinstance(self.graph.get_node(a), IterateInvocation) and a not in self.executed + for a in nx.ancestors(g, n) + ) + next_node_id = next( - ( - n - for n in sorted_nodes - # exclude nodes that have already been prepared - if n not in self.source_prepared_mapping - # exclude iterate nodes whose inputs have not been executed - and not ( - isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node... - and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs - ) - # exclude nodes who have unexecuted iterate ancestors - and not any( - ( - isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`... - and a not in self.executed # ...that is not executed - for a in nx.ancestors(g, n) # for all ancestors `a` of node `n` - ) - ) - ), + (n for n in sorted_nodes if unprepared(n) and iter_inputs_ready(n) and no_unexecuted_iter_ancestors(n)), None, ) @@ -1000,7 +1058,7 @@ def _prepare(self) -> Optional[str]: return None # Get all parents of the next node - next_node_parents = [e[0] for e in g.in_edges(next_node_id)] + next_node_parents = [u for u, _ in g.in_edges(next_node_id)] # Create execution nodes next_node = self.graph.get_node(next_node_id) @@ -1018,7 +1076,8 @@ def _prepare(self) -> Optional[str]: else: # Iterators or normal nodes # Get all iterator combinations for this node # Will produce a list of lists of prepared iterator nodes, from which results can be iterated - iterator_nodes = self._get_node_iterators(next_node_id) + it_g = self._iterator_graph(g) + iterator_nodes = self._get_node_iterators(next_node_id, it_g) iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) @@ -1066,45 +1125,41 @@ def _get_iteration_node( ) def _get_next_node(self) -> Optional[BaseInvocation]: - """Gets the deepest node that is ready to be executed""" - g = self.execution_graph.nx_graph() - - # Perform a topological sort using depth-first search - topo_order = list(nx.dfs_postorder_nodes(g)) - - # Get all IterateInvocation nodes - iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)] - - # Sort the IterateInvocation nodes based on their index attribute - iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index) - - # Prioritize IterateInvocation nodes and their children - for iterate_node in iterate_nodes: - if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))): - return self.execution_graph.nodes[iterate_node] - - # Check the children of the IterateInvocation node - for child_node in nx.dfs_postorder_nodes(g, iterate_node): - if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))): - return self.execution_graph.nodes[child_node] - - # If no IterateInvocation node or its children are ready, return the first ready node in the topological order - for node in topo_order: - if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))): - return self.execution_graph.nodes[node] - - # If no node is found, return None + """Gets the next ready node: FIFO within class, drain class before switching.""" + # 1) Continue draining the active class + if self._active_class: + q = self._ready_queues.get(self._active_class) + while q: + nid = q.popleft() + if nid not in self.executed: + return self.execution_graph.nodes[nid] + # emptied: release active class + self._active_class = None + + # 2) Pick next class by priority, then by class name + seen = set(self.ready_order) + for cls_name in self.ready_order: + q = self._ready_queues.get(cls_name) + if q: + self._active_class = cls_name + # recurse to drain newly set active class + return self._get_next_node() + for cls_name in sorted(k for k in self._ready_queues.keys() if k not in seen): + q = self._ready_queues[cls_name] + if q: + self._active_class = cls_name + return self._get_next_node() return None def _prepare_inputs(self, node: BaseInvocation): - input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] + input_edges = self.execution_graph._get_input_edges(node.id) # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input # will see the mutation. if isinstance(node, CollectInvocation): output_collection = [ copydeep(getattr(self.results[edge.source.node_id], edge.source.field)) for edge in input_edges - if edge.destination.field == "item" + if edge.destination.field == ITEM_FIELD ] node.collection = output_collection else: diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 40214ffa554..756e3c10ccb 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -9281,6 +9281,15 @@ export type components = { source_prepared_mapping: { [key: string]: string[]; }; + /** Ready Order */ + ready_order?: string[]; + /** + * Indegree + * @description Remaining unmet input count for exec nodes + */ + indegree?: { + [key: string]: number; + }; }; /** * Grounding DINO (Text Prompt Object Detection) diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 9aee5febc90..381c4c73482 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -152,6 +152,29 @@ def test_graph_state_prepares_eagerly(): def test_graph_executes_depth_first(): """Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch""" + + def assert_topo_order_and_all_executed(state: GraphExecutionState, order: list[str]): + """ + Validates: + 1) Every materialized exec node executed exactly once. + 2) Execution order respects all exec-graph dependencies (u→v ⇒ u before v). + """ + # order must be EXEC node ids in run order + exec_nodes = set(state.execution_graph.nodes.keys()) + + # 1) coverage: all exec nodes ran, and no duplicates + pos = {nid: i for i, nid in enumerate(order)} + assert set(pos.keys()) == exec_nodes, ( + f"Executed {len(pos)} of {len(exec_nodes)} nodes. Missing: {sorted(exec_nodes - set(pos))[:10]}" + ) + assert len(pos) == len(order), "Duplicate execution detected" + + # 2) topo order: parents before children + for e in state.execution_graph.edges: + u = e.source.node_id + v = e.destination.node_id + assert pos[u] < pos[v], f"child {v} ran before parent {u}" + graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] @@ -164,36 +187,17 @@ def test_graph_executes_depth_first(): graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) g = GraphExecutionState(graph=graph) - _ = invoke_next(g) - _ = invoke_next(g) - _ = invoke_next(g) - _ = invoke_next(g) - - # Because ordering is not guaranteed, we cannot compare results directly. - # Instead, we must count the number of results. - def get_completed_count(g: GraphExecutionState, id: str): - ids = list(g.source_prepared_mapping[id]) - completed_ids = [i for i in g.executed if i in ids] - return len(completed_ids) - - # Check at each step that the number of executed nodes matches the expectation for depth-first execution - assert get_completed_count(g, "prompt_iterated") == 1 - assert get_completed_count(g, "prompt_successor") == 0 + order: list[str] = [] - _ = invoke_next(g) - - assert get_completed_count(g, "prompt_iterated") == 1 - assert get_completed_count(g, "prompt_successor") == 1 - - _ = invoke_next(g) - - assert get_completed_count(g, "prompt_iterated") == 2 - assert get_completed_count(g, "prompt_successor") == 1 - - _ = invoke_next(g) + while True: + n = g.next() + if n is None: + break + o = n.invoke(Mock(InvocationContext)) + g.complete(n.id, o) + order.append(n.id) - assert get_completed_count(g, "prompt_iterated") == 2 - assert get_completed_count(g, "prompt_successor") == 2 + assert_topo_order_and_all_executed(g, order) # Because this tests deterministic ordering, we run it multiple times