Skip to content

Commit f66183a

Browse files
authored
[fix] Fixes non-async public API access (#10857)
It looks like the synchronous version of the public API broke due to an addition of `from __future__ import annotations`. This change updates the async-to-sync adapter to work with both types of type annotations.
1 parent cbd68e3 commit f66183a

File tree

2 files changed

+184
-16
lines changed

2 files changed

+184
-16
lines changed

comfy_api/internal/async_to_sync.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import textwrap
99
import threading
1010
from enum import Enum
11-
from typing import Optional, Type, get_origin, get_args
11+
from typing import Optional, Type, get_origin, get_args, get_type_hints
1212

1313

1414
class TypeTracker:
@@ -220,11 +220,18 @@ def __init__(self, *args, **kwargs):
220220
self._async_instance = async_class(*args, **kwargs)
221221

222222
# Handle annotated class attributes (like execution: Execution)
223-
# Get all annotations from the class hierarchy
224-
all_annotations = {}
225-
for base_class in reversed(inspect.getmro(async_class)):
226-
if hasattr(base_class, "__annotations__"):
227-
all_annotations.update(base_class.__annotations__)
223+
# Get all annotations from the class hierarchy and resolve string annotations
224+
try:
225+
# get_type_hints resolves string annotations to actual type objects
226+
# This handles classes using 'from __future__ import annotations'
227+
all_annotations = get_type_hints(async_class)
228+
except Exception:
229+
# Fallback to raw annotations if get_type_hints fails
230+
# (e.g., for undefined forward references)
231+
all_annotations = {}
232+
for base_class in reversed(inspect.getmro(async_class)):
233+
if hasattr(base_class, "__annotations__"):
234+
all_annotations.update(base_class.__annotations__)
228235

229236
# For each annotated attribute, check if it needs to be created or wrapped
230237
for attr_name, attr_type in all_annotations.items():
@@ -625,15 +632,19 @@ def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
625632
"""Extract class attributes that are classes themselves."""
626633
class_attributes = []
627634

635+
# Get resolved type hints to handle string annotations
636+
try:
637+
type_hints = get_type_hints(async_class)
638+
except Exception:
639+
type_hints = {}
640+
628641
# Look for class attributes that are classes
629642
for name, attr in sorted(inspect.getmembers(async_class)):
630643
if isinstance(attr, type) and not name.startswith("_"):
631644
class_attributes.append((name, attr))
632-
elif (
633-
hasattr(async_class, "__annotations__")
634-
and name in async_class.__annotations__
635-
):
636-
annotation = async_class.__annotations__[name]
645+
elif name in type_hints:
646+
# Use resolved type hint instead of raw annotation
647+
annotation = type_hints[name]
637648
if isinstance(annotation, type):
638649
class_attributes.append((name, annotation))
639650

@@ -908,11 +919,15 @@ def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
908919
attribute_mappings = {}
909920

910921
# First check annotations for typed attributes (including from parent classes)
911-
# Collect all annotations from the class hierarchy
912-
all_annotations = {}
913-
for base_class in reversed(inspect.getmro(async_class)):
914-
if hasattr(base_class, "__annotations__"):
915-
all_annotations.update(base_class.__annotations__)
922+
# Resolve string annotations to actual types
923+
try:
924+
all_annotations = get_type_hints(async_class)
925+
except Exception:
926+
# Fallback to raw annotations
927+
all_annotations = {}
928+
for base_class in reversed(inspect.getmro(async_class)):
929+
if hasattr(base_class, "__annotations__"):
930+
all_annotations.update(base_class.__annotations__)
916931

917932
for attr_name, attr_type in sorted(all_annotations.items()):
918933
for class_name, class_type in class_attributes:

tests/execution/test_public_api.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
Tests for public ComfyAPI and ComfyAPISync functions.
3+
4+
These tests verify that the public API methods work correctly in both sync and async contexts,
5+
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
6+
handles string annotations from 'from __future__ import annotations'.
7+
"""
8+
9+
import pytest
10+
import time
11+
import subprocess
12+
import torch
13+
from pytest import fixture
14+
from comfy_execution.graph_utils import GraphBuilder
15+
from tests.execution.test_execution import ComfyClient
16+
17+
18+
@pytest.mark.execution
19+
class TestPublicAPI:
20+
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
21+
22+
@fixture(scope="class", autouse=True)
23+
def _server(self, args_pytest):
24+
"""Start ComfyUI server for testing."""
25+
pargs = [
26+
'python', 'main.py',
27+
'--output-directory', args_pytest["output_dir"],
28+
'--listen', args_pytest["listen"],
29+
'--port', str(args_pytest["port"]),
30+
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
31+
'--cpu',
32+
]
33+
p = subprocess.Popen(pargs)
34+
yield
35+
p.kill()
36+
torch.cuda.empty_cache()
37+
38+
@fixture(scope="class", autouse=True)
39+
def shared_client(self, args_pytest, _server):
40+
"""Create shared client with connection retry."""
41+
client = ComfyClient()
42+
n_tries = 5
43+
for i in range(n_tries):
44+
time.sleep(4)
45+
try:
46+
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
47+
break
48+
except ConnectionRefusedError:
49+
if i == n_tries - 1:
50+
raise
51+
yield client
52+
del client
53+
torch.cuda.empty_cache()
54+
55+
@fixture
56+
def client(self, shared_client, request):
57+
"""Set test name for each test."""
58+
shared_client.set_test_name(f"public_api[{request.node.name}]")
59+
yield shared_client
60+
61+
@fixture
62+
def builder(self, request):
63+
"""Create GraphBuilder for each test."""
64+
yield GraphBuilder(prefix=request.node.name)
65+
66+
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
67+
"""Test that TestSyncProgressUpdate executes without errors.
68+
69+
This test validates that api_sync.execution.set_progress() works correctly,
70+
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
71+
"""
72+
g = builder
73+
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
74+
75+
# Use TestSyncProgressUpdate with short sleep
76+
progress_node = g.node("TestSyncProgressUpdate",
77+
value=image.out(0),
78+
sleep_seconds=0.5)
79+
output = g.node("SaveImage", images=progress_node.out(0))
80+
81+
# Execute workflow
82+
result = client.run(g)
83+
84+
# Verify execution
85+
assert result.did_run(progress_node), "Progress node should have executed"
86+
assert result.did_run(output), "Output node should have executed"
87+
88+
# Verify output
89+
images = result.get_images(output)
90+
assert len(images) == 1, "Should have produced 1 image"
91+
92+
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
93+
"""Test that TestAsyncProgressUpdate executes without errors.
94+
95+
This test validates that await api.execution.set_progress() works correctly
96+
in async contexts.
97+
"""
98+
g = builder
99+
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
100+
101+
# Use TestAsyncProgressUpdate with short sleep
102+
progress_node = g.node("TestAsyncProgressUpdate",
103+
value=image.out(0),
104+
sleep_seconds=0.5)
105+
output = g.node("SaveImage", images=progress_node.out(0))
106+
107+
# Execute workflow
108+
result = client.run(g)
109+
110+
# Verify execution
111+
assert result.did_run(progress_node), "Async progress node should have executed"
112+
assert result.did_run(output), "Output node should have executed"
113+
114+
# Verify output
115+
images = result.get_images(output)
116+
assert len(images) == 1, "Should have produced 1 image"
117+
118+
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
119+
"""Test both sync and async progress updates in same workflow.
120+
121+
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
122+
correctly in the same workflow execution.
123+
"""
124+
g = builder
125+
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
126+
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
127+
128+
# Use both types of progress nodes
129+
sync_progress = g.node("TestSyncProgressUpdate",
130+
value=image1.out(0),
131+
sleep_seconds=0.3)
132+
async_progress = g.node("TestAsyncProgressUpdate",
133+
value=image2.out(0),
134+
sleep_seconds=0.3)
135+
136+
# Create outputs
137+
output1 = g.node("SaveImage", images=sync_progress.out(0))
138+
output2 = g.node("SaveImage", images=async_progress.out(0))
139+
140+
# Execute workflow
141+
result = client.run(g)
142+
143+
# Both should execute successfully
144+
assert result.did_run(sync_progress), "Sync progress node should have executed"
145+
assert result.did_run(async_progress), "Async progress node should have executed"
146+
assert result.did_run(output1), "First output node should have executed"
147+
assert result.did_run(output2), "Second output node should have executed"
148+
149+
# Verify outputs
150+
images1 = result.get_images(output1)
151+
images2 = result.get_images(output2)
152+
assert len(images1) == 1, "Should have produced 1 image from sync node"
153+
assert len(images2) == 1, "Should have produced 1 image from async node"

0 commit comments

Comments
 (0)