Skip to content

Commit 793854d

Browse files
committed
Fixed Windows paths, missing inputs used as args
1 parent f0bc65e commit 793854d

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

comfyui_to_python.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def can_be_imported(self, import_name: str):
172172

173173
return False
174174

175-
def generate_workflow(self, load_order: List, filename: str = 'generated_code_workflow.py', queue_size: int = 10) -> str:
175+
def generate_workflow(self, load_order: List, queue_size: int = 10) -> str:
176176
"""Generate the execution code based on the load order.
177177
178178
Args:
@@ -185,16 +185,25 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
185185
str: Generated execution code as a string.
186186
"""
187187
# Create the necessary data structures to hold imports and generated code
188-
import_statements, executed_variables, special_functions_code, code = set(['NODE_CLASS_MAPPINGS']), {}, [], []
188+
import_statements, executed_variables, arg_inputs, special_functions_code, code = set(['NODE_CLASS_MAPPINGS']), {}, [], [], []
189189
# This dictionary will store the names of the objects that we have already initialized
190190
initialized_objects = {}
191191

192192
custom_nodes = False
193193
# Loop over each dictionary in the load order list
194194
for idx, data, is_special_function in load_order:
195-
196195
# Generate class definition and inputs from the data
197196
inputs, class_type = data['inputs'], data['class_type']
197+
198+
missing = []
199+
for i, input in enumerate(self.node_class_mappings[class_type].INPUT_TYPES().get("required", {}).keys()):
200+
if input not in inputs:
201+
input_var = f"{input}{len(arg_inputs)+1}"
202+
arg_inputs.append((input_var, f"Argument {i}, input `{input}` for node \\\"{data['_meta'].get('title', class_type)}\\\" id {idx}"))
203+
print("WARNING: Missing required input", input, "for", class_type)
204+
print("That will be CLI arg " + str(len(arg_inputs)))
205+
missing.append((input_var, len(arg_inputs)))
206+
198207
class_def = self.node_class_mappings[class_type]()
199208

200209
# If the class hasn't been initialized yet, initialize it and generate the import statements
@@ -216,6 +225,8 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
216225

217226
# Remove any keyword arguments from **inputs if they are not in class_def_params
218227
inputs = {key: value for key, value in inputs.items() if key in class_def_params}
228+
for input, arg in missing:
229+
inputs[input] = {"variable_name": f"argv." + input}
219230
# Deal with hidden variables
220231
if 'unique_id' in class_def_params:
221232
inputs['unique_id'] = random.randint(1, 2**64)
@@ -230,7 +241,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
230241
code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
231242

232243
# Generate final code by combining imports and code, and wrap them in a main function
233-
final_code = self.assemble_python_code(import_statements, special_functions_code, code, queue_size, custom_nodes)
244+
final_code = self.assemble_python_code(import_statements, special_functions_code, arg_inputs, code, queue_size, custom_nodes)
234245

235246
return final_code
236247

@@ -272,13 +283,12 @@ def format_arg(self, key: str, value: any) -> str:
272283
if key == 'noise_seed' or key == 'seed':
273284
return f'{key}=random.randint(1, 2**64)'
274285
elif isinstance(value, str):
275-
value = value.replace("\n", "\\n").replace('"', "'")
276-
return f'{key}="{value}"'
286+
return f'{key}={repr(value)}'
277287
elif isinstance(value, dict) and 'variable_name' in value:
278288
return f'{key}={value["variable_name"]}'
279289
return f'{key}={value}'
280290

281-
def assemble_python_code(self, import_statements: set, speical_functions_code: List[str], code: List[str], queue_size: int, custom_nodes=False) -> str:
291+
def assemble_python_code(self, import_statements: set, speical_functions_code: List[str], arg_inputs: List[Tuple[str, str]], code: List[str], queue_size: int, custom_nodes=False) -> str:
282292
"""Generates the final code string.
283293
284294
Args:
@@ -295,22 +305,31 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
295305
func_strings = []
296306
for func in [get_value_at_index, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths, import_custom_nodes]:
297307
func_strings.append(f'\n{inspect.getsource(func)}')
308+
309+
argparse_code = ""
310+
if arg_inputs:
311+
argparse_code = f'argv = sys.argv[:]\n\nparser = argparse.ArgumentParser(description="A converted ComfyUI workflow. Required inputs listed below. Values passed should be in JSON")\n'
312+
for i, (input_name, arg_desc) in enumerate(arg_inputs):
313+
argparse_code += f'parser.add_argument("{input_name}", help="{arg_desc} (autogenerated)")\n'
314+
argparse_code += 'args = parser.parse_args()\nsys.argv = [sys.argv[0]]\n'
315+
298316
# Define static import statements required for the script
299-
static_imports = ['import os', 'import random', 'import sys', 'from typing import Sequence, Mapping, Any, Union',
300-
'import torch'] + func_strings + ['\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n']
317+
static_imports = ['import os', 'import random', 'import sys', 'import argparse', 'from typing import Sequence, Mapping, Any, Union',
318+
'import torch'] + func_strings + [argparse_code]
301319
# Check if custom nodes should be included
302320
if custom_nodes:
303321
static_imports.append(f'\n{inspect.getsource(import_custom_nodes)}\n')
304322
custom_nodes = 'import_custom_nodes()\n\t'
305323
else:
306324
custom_nodes = ''
325+
static_imports += ['\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n']
307326
# Create import statements for node classes
308-
imports_code = [f"from nodes import {', '.join([class_name for class_name in import_statements])}" ]
327+
imports_code = [f"from nodes import {', '.join([class_name for class_name in import_statements])}", '']
309328
# Assemble the main function code, including custom nodes if applicable
310329
main_function_code = "def main():\n\t" + f'{custom_nodes}with torch.inference_mode():\n\t\t' + '\n\t\t'.join(speical_functions_code) \
311330
+ f'\n\n\t\tfor q in range({queue_size}):\n\t\t' + '\n\t\t'.join(code)
312331
# Concatenate all parts to form the final code
313-
final_code = '\n'.join(static_imports + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
332+
final_code = '\n'.join(static_imports + imports_code + [main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
314333
# Format the final code according to PEP 8 using the Black library
315334
final_code = black.format_str(final_code, mode=black.Mode())
316335

@@ -453,7 +472,7 @@ def execute(self):
453472

454473
# Step 4: Generate the workflow code
455474
code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings)
456-
generated_code = code_generator.generate_workflow(load_order, filename=self.output_file, queue_size=self.queue_size)
475+
generated_code = code_generator.generate_workflow(load_order, queue_size=self.queue_size)
457476

458477
# Step 5: Write the generated code to a file
459478
FileHandler.write_code_to_file(self.output_file, generated_code)

0 commit comments

Comments
 (0)