@@ -166,7 +166,7 @@ class CodeGenerator:
166166 base_node_class_mappings (Dict): Base mappings of node classes.
167167 """
168168
169- def __init__ (self , node_class_mappings : Dict , base_node_class_mappings : Dict ):
169+ def __init__ (self , node_class_mappings : Dict , base_node_class_mappings : Dict , prompt : Dict ):
170170 """Initialize the CodeGenerator with given node class mappings.
171171
172172 Args:
@@ -175,6 +175,7 @@ def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict):
175175 """
176176 self .node_class_mappings = node_class_mappings
177177 self .base_node_class_mappings = base_node_class_mappings
178+ self .prompt = prompt
178179
179180 def can_be_imported (self , import_name : str ):
180181 if import_name in self .base_node_class_mappings .keys ():
@@ -195,6 +196,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
195196 Returns:
196197 str: Generated execution code as a string.
197198 """
199+ include_prompt_data = False
198200 # Create the necessary data structures to hold imports and generated code
199201 import_statements , executed_variables , arg_inputs , special_functions_code , code = set (['NODE_CLASS_MAPPINGS' ]), {}, [], [], []
200202 # This dictionary will store the names of the objects that we have already initialized
@@ -206,8 +208,9 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
206208 # Generate class definition and inputs from the data
207209 inputs , class_type = data ['inputs' ], data ['class_type' ]
208210
211+ input_types = self .node_class_mappings [class_type ].INPUT_TYPES ()
209212 missing = []
210- for i , input in enumerate (self . node_class_mappings [ class_type ]. INPUT_TYPES () .get ("required" , {}).keys ()):
213+ for i , input in enumerate (input_types .get ("required" , {}).keys ()):
211214 if input not in inputs :
212215 input_var = f"{ input } { len (arg_inputs )+ 1 } "
213216 arg_inputs .append ((input_var , f"Argument { i } , input `{ input } ` for node \\ \" { data ['_meta' ].get ('title' , class_type )} \\ \" id { idx } " ))
@@ -233,14 +236,18 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
233236
234237 # Get all possible parameters for class_def
235238 class_def_params = self .get_function_parameters (getattr (class_def , class_def .FUNCTION ))
239+ no_params = class_def_params is None
236240
237241 # Remove any keyword arguments from **inputs if they are not in class_def_params
238- inputs = {key : value for key , value in inputs .items () if key in class_def_params }
242+ inputs = {key : value for key , value in inputs .items () if no_params or key in class_def_params }
239243 for input , input_var , arg in missing :
240244 inputs [input ] = {"variable_name" : f"parse_arg(args." + input_var + ")" }
241245 # Deal with hidden variables
242- if 'unique_id' in class_def_params :
246+ if no_params or 'unique_id' in class_def_params :
243247 inputs ['unique_id' ] = random .randint (1 , 2 ** 64 )
248+ if no_params or 'prompt' in class_def_params :
249+ inputs ["prompt" ] = {"variable_name" : "PROMPT_DATA" }
250+ include_prompt_data = True
244251
245252 # Create executed variable and generate code
246253 executed_variables [idx ] = f'{ self .clean_variable_name (class_type )} _{ idx } '
@@ -261,7 +268,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
261268 code .append (self .create_function_call_code (initialized_objects [class_type ], class_def .FUNCTION , executed_variables [idx ], is_special_function , ** inputs ))
262269
263270 # Generate final code by combining imports and code, and wrap them in a main function
264- final_code = self .assemble_python_code (import_statements , special_functions_code , arg_inputs , code , queue_size , custom_nodes )
271+ final_code = self .assemble_python_code (import_statements , special_functions_code , arg_inputs , code , queue_size , custom_nodes , include_prompt_data )
265272
266273 return final_code
267274
@@ -304,7 +311,7 @@ def format_arg(self, key: str, value: any) -> str:
304311 return f'{ key } ={ value ["variable_name" ]} '
305312 return f'{ key } ={ value } '
306313
307- def assemble_python_code (self , import_statements : set , special_functions_code : List [str ], arg_inputs : List [Tuple [str , str ]], code : List [str ], queue_size : int , custom_nodes = False ) -> str :
314+ def assemble_python_code (self , import_statements : set , special_functions_code : List [str ], arg_inputs : List [Tuple [str , str ]], code : List [str ], queue_size : int , custom_nodes = False , include_prompt_data = True ) -> str :
308315 """Generates the final code string.
309316
310317 Args:
@@ -349,6 +356,8 @@ def assemble_python_code(self, import_statements: set, special_functions_code: L
349356 # Define static import statements required for the script
350357 static_imports = ['import os' , 'import random' , 'import sys' , 'import json' , 'import argparse' , 'import contextlib' , 'from typing import Sequence, Mapping, Any, Union' ,
351358 'import torch' ] + func_strings + argparse_code
359+ if include_prompt_data :
360+ static_imports .append (f'PROMPT_DATA = json.loads({ repr (json .dumps (self .prompt ))} )' )
352361 # Check if custom nodes should be included
353362 if custom_nodes :
354363 static_imports .append (f'\n { inspect .getsource (import_custom_nodes )} \n ' )
@@ -459,7 +468,8 @@ def get_function_parameters(self, func: Callable) -> List:
459468 signature = inspect .signature (func )
460469 parameters = {name : param .default if param .default != param .empty else None
461470 for name , param in signature .parameters .items ()}
462- return list (parameters .keys ())
471+ catch_all = any (param .kind == inspect .Parameter .VAR_KEYWORD for param in signature .parameters .values ())
472+ return list (parameters .keys ()) if not catch_all else None
463473
464474 def update_inputs (self , inputs : Dict , executed_variables : Dict ) -> Dict :
465475 """Update inputs based on the executed variables.
@@ -542,7 +552,7 @@ def execute(self):
542552 load_order = load_order_determiner .determine_load_order ()
543553
544554 # Step 4: Generate the workflow code
545- code_generator = CodeGenerator (self .node_class_mappings , self .base_node_class_mappings )
555+ code_generator = CodeGenerator (self .node_class_mappings , self .base_node_class_mappings , data )
546556 generated_code = code_generator .generate_workflow (load_order , queue_size = self .queue_size )
547557
548558 # Step 5: Write the generated code to a file
0 commit comments