|
| 1 | +import glob |
| 2 | +import inspect |
| 3 | +import json |
| 4 | +import logging |
| 5 | +import os |
| 6 | +from typing import Dict, List, Any, Callable |
| 7 | +import sys |
| 8 | + |
| 9 | +sys.path.append('../') |
| 10 | + |
| 11 | +from nodes import NODE_CLASS_MAPPINGS |
| 12 | + |
| 13 | +logging.basicConfig(level=logging.INFO) |
| 14 | + |
| 15 | + |
| 16 | +def read_json_file(file_path: str) -> dict: |
| 17 | + """ |
| 18 | + Reads a JSON file and returns its contents as a dictionary. |
| 19 | +
|
| 20 | + Args: |
| 21 | + file_path (str): The path to the JSON file. |
| 22 | +
|
| 23 | + Returns: |
| 24 | + dict: The contents of the JSON file as a dictionary. |
| 25 | +
|
| 26 | + Raises: |
| 27 | + FileNotFoundError: If the file is not found, it lists all JSON files in the directory of the file path. |
| 28 | + ValueError: If the file is not a valid JSON. |
| 29 | + """ |
| 30 | + |
| 31 | + try: |
| 32 | + with open(file_path, 'r') as file: |
| 33 | + data = json.load(file) |
| 34 | + return data |
| 35 | + |
| 36 | + except FileNotFoundError: |
| 37 | + # Get the directory from the file_path |
| 38 | + directory = os.path.dirname(file_path) |
| 39 | + |
| 40 | + # If the directory is an empty string (which means file is in the current directory), |
| 41 | + # get the current working directory |
| 42 | + if not directory: |
| 43 | + directory = os.getcwd() |
| 44 | + |
| 45 | + # Find all JSON files in the directory |
| 46 | + json_files = glob.glob(f"{directory}/*.json") |
| 47 | + |
| 48 | + # Format the list of JSON files as a string |
| 49 | + json_files_str = "\n".join(json_files) |
| 50 | + |
| 51 | + raise FileNotFoundError(f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}") |
| 52 | + |
| 53 | + except json.JSONDecodeError: |
| 54 | + raise ValueError(f"Invalid JSON format in file: {file_path}") |
| 55 | + |
| 56 | + |
| 57 | +def determine_load_order(data: Dict) -> List: |
| 58 | + """ |
| 59 | + Determine the load order of each key in the provided dictionary. This code will place the |
| 60 | + nodes without node dependencies first, then ensure that any node whose result is used |
| 61 | + in another node will be added to the list in the order it should be executed. |
| 62 | +
|
| 63 | + Args: |
| 64 | + data (Dict): |
| 65 | + The dictionary for which to determine the load order. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + List: |
| 69 | + A list of tuples where each tuple contains a key, its corresponding dictionary, |
| 70 | + and a boolean indicating whether or not the function is dependent on the output of |
| 71 | + a previous function, ordered by load order. |
| 72 | + """ |
| 73 | + |
| 74 | + # Create a dictionary to keep track of visited nodes. |
| 75 | + visited = {} |
| 76 | + # Create a list to store the load order for functions |
| 77 | + load_order = [] |
| 78 | + # Boolean to indicate whether or not the class is a loader class that should not be |
| 79 | + # reloaded during every loop |
| 80 | + is_loader = False |
| 81 | + |
| 82 | + def dfs(key: str) -> None: |
| 83 | + """ |
| 84 | + Depth-First Search function. |
| 85 | +
|
| 86 | + Args: |
| 87 | + key (str): The key from which to start the DFS. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + None |
| 91 | + """ |
| 92 | + # Mark the node as visited. |
| 93 | + visited[key] = True |
| 94 | + inputs = data[key]['inputs'] |
| 95 | + |
| 96 | + # Loop over each input key. |
| 97 | + for input_key, val in inputs.items(): |
| 98 | + # If the value is a list and the first item in the list (which should be a key) |
| 99 | + # has not been visited yet, then recursively apply dfs on the dependency. |
| 100 | + if isinstance(val, list) and val[0] not in visited: |
| 101 | + dfs(val[0]) |
| 102 | + |
| 103 | + # Add the key and its corresponding data to the load order list. |
| 104 | + load_order.append((key, data[key], is_loader)) |
| 105 | + |
| 106 | + # Load Loader keys first |
| 107 | + for key in data: |
| 108 | + class_def = NODE_CLASS_MAPPINGS[data[key]['class_type']]() |
| 109 | + if class_def.CATEGORY == 'loaders' or class_def.FUNCTION in ['encode'] or not any(isinstance(val, list) for val in data[key]['inputs'].values()): |
| 110 | + is_loader = True |
| 111 | + # If the key has not been visited, perform a DFS from that key. |
| 112 | + if key not in visited: |
| 113 | + dfs(key) |
| 114 | + |
| 115 | + # Reset is_loader bool |
| 116 | + is_loader = False |
| 117 | + # Loop over each key in the data. |
| 118 | + for key in data: |
| 119 | + # If the key has not been visited, perform a DFS from that key. |
| 120 | + if key not in visited: |
| 121 | + dfs(key) |
| 122 | + |
| 123 | + return load_order |
| 124 | + |
| 125 | + |
| 126 | +def create_function_call_code(obj_name: str, func: str, variable_name: str, is_loader: bool, **kwargs) -> str: |
| 127 | + """ |
| 128 | + This function generates Python code for a function call. |
| 129 | +
|
| 130 | + Args: |
| 131 | + obj_name (str): The name of the initialized object. |
| 132 | + func (str): The function to be called. |
| 133 | + variable_name (str): The name of the variable that the function result should be assigned to. |
| 134 | + is_loader (bool): Determines the code indentation. |
| 135 | + **kwargs: The keyword arguments for the function. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + str: The generated Python code. |
| 139 | + """ |
| 140 | + |
| 141 | + def format_arg(key: str, value: any) -> str: |
| 142 | + """Formats arguments based on key and value.""" |
| 143 | + if key == 'noise_seed': |
| 144 | + return f'{key}=random.randint(1, 2**64)' |
| 145 | + elif isinstance(value, str): |
| 146 | + value = value.replace("\n", "\\n") |
| 147 | + return f'{key}="{value}"' |
| 148 | + elif key == 'images' and "saveimage" in obj_name and isinstance(value, dict) and 'variable_name' in value: |
| 149 | + return f'{key}={value["variable_name"]}.detach()' |
| 150 | + elif isinstance(value, dict) and 'variable_name' in value: |
| 151 | + return f'{key}={value["variable_name"]}' |
| 152 | + return f'{key}={value}' |
| 153 | + |
| 154 | + args = ', '.join(format_arg(key, value) for key, value in kwargs.items()) |
| 155 | + |
| 156 | + # Generate the Python code |
| 157 | + code = f'{variable_name} = {obj_name}.{func}({args})' |
| 158 | + |
| 159 | + # If the code contains dependencies, indent the code because it will be placed inside |
| 160 | + # of a for loop |
| 161 | + if not is_loader: |
| 162 | + code = f'\t{code}' |
| 163 | + |
| 164 | + return code |
| 165 | + |
| 166 | + |
| 167 | +def update_inputs(inputs: Dict, executed_variables: Dict) -> Dict: |
| 168 | + """ |
| 169 | + Update inputs based on the executed variables. |
| 170 | +
|
| 171 | + Args: |
| 172 | + inputs (Dict): Inputs dictionary to update. |
| 173 | + executed_variables (Dict): Dictionary storing executed variable names. |
| 174 | +
|
| 175 | + Returns: |
| 176 | + Dict: Updated inputs dictionary. |
| 177 | + """ |
| 178 | + for key in inputs.keys(): |
| 179 | + if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys(): |
| 180 | + inputs[key] = {'variable_name': f"{executed_variables[inputs[key][0]]}[{inputs[key][1]}]"} |
| 181 | + return inputs |
| 182 | + |
| 183 | + |
| 184 | +def get_class_info(class_type: str) -> (str, str, str): |
| 185 | + """ |
| 186 | + Generates and returns necessary information about class type. |
| 187 | +
|
| 188 | + Args: |
| 189 | + class_type (str): Class type |
| 190 | +
|
| 191 | + Returns: |
| 192 | + class_type (str): Updated class type |
| 193 | + import_statement (str): Import statement string |
| 194 | + class_code (str): Class initialization code |
| 195 | + """ |
| 196 | + # If the class is 'VAEDecode', adjust the class name |
| 197 | + if class_type == 'VAEDecode': |
| 198 | + class_type = 'VAEDecodeTiled' |
| 199 | + |
| 200 | + import_statement = class_type |
| 201 | + class_code = f'{class_type.lower()} = {class_type}()' |
| 202 | + |
| 203 | + return class_type, import_statement, class_code |
| 204 | + |
| 205 | + |
| 206 | +def assemble_python_code(import_statements: set, loader_code: List[str], code: List[str], queue_size: int) -> str: |
| 207 | + """ |
| 208 | + Generates final code string. |
| 209 | +
|
| 210 | + Args: |
| 211 | + import_statements (set): A set of unique import statements |
| 212 | + code (List[str]): A list of code strings |
| 213 | + queue_size (int): Number of photos that will be generated by the script. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + final_code (str): Generated final code as a string |
| 217 | + """ |
| 218 | + static_imports = ['import random'] |
| 219 | + imports_code = [f"from nodes import {class_name}" for class_name in import_statements] |
| 220 | + main_function_code = f"def main():\n\t" + '\n\t'.join(loader_code) + f'\n\tfor q in {range(1, queue_size)}:\n\t' + '\n\t'.join(code) |
| 221 | + final_code = '\n'.join(static_imports + ['import sys\nsys.path.append("../")'] + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()']) |
| 222 | + |
| 223 | + return final_code |
| 224 | + |
| 225 | + |
| 226 | +def write_code_to_file(filename: str, code: str) -> None: |
| 227 | + """ |
| 228 | + Writes given code to a .py file. If the directory does not exist, it creates it. |
| 229 | +
|
| 230 | + Args: |
| 231 | + filename (str): The name of the Python file to save the code to. |
| 232 | + code (str): The code to save. |
| 233 | + """ |
| 234 | + |
| 235 | + # Extract directory from the filename |
| 236 | + directory = os.path.dirname(filename) |
| 237 | + |
| 238 | + # If the directory does not exist, create it |
| 239 | + if directory and not os.path.exists(directory): |
| 240 | + os.makedirs(directory) |
| 241 | + |
| 242 | + # Save the code to a .py file |
| 243 | + with open(filename, 'w') as file: |
| 244 | + file.write(code) |
| 245 | + |
| 246 | + |
| 247 | +def get_function_parameters(func: Callable) -> List: |
| 248 | + """Get the names of a function's parameters. |
| 249 | +
|
| 250 | + Args: |
| 251 | + func (Callable): The function whose parameters we want to inspect. |
| 252 | +
|
| 253 | + Returns: |
| 254 | + List: A list containing the names of the function's parameters. |
| 255 | + """ |
| 256 | + signature = inspect.signature(func) |
| 257 | + parameters = {name: param.default if param.default != param.empty else None |
| 258 | + for name, param in signature.parameters.items()} |
| 259 | + return list(parameters.keys()) |
| 260 | + |
| 261 | + |
| 262 | +def generate_workflow(load_order: List, filename: str = 'generated_code_workflow.py', queue_size: int = 10) -> str: |
| 263 | + """ |
| 264 | + Generate the execution code based on the load order. |
| 265 | +
|
| 266 | + Args: |
| 267 | + load_order (List): A list of tuples representing the load order. |
| 268 | + filename (str): The name of the Python file to which the code should be saved. |
| 269 | + Defaults to 'generated_code_workflow.py'. |
| 270 | + queue_size (int): The number of photos that will be created by the script. |
| 271 | +
|
| 272 | + Returns: |
| 273 | + str: Generated execution code as a string. |
| 274 | + """ |
| 275 | + |
| 276 | + # Create the necessary data structures to hold imports and generated code |
| 277 | + import_statements, executed_variables, loader_code, code = set(), {}, [], [] |
| 278 | + # This dictionary will store the names of the objects that we have already initialized |
| 279 | + initialized_objects = {} |
| 280 | + |
| 281 | + # Loop over each dictionary in the load order list |
| 282 | + for idx, data, is_loader in load_order: |
| 283 | + |
| 284 | + # Generate class definition and inputs from the data |
| 285 | + inputs, class_type = data['inputs'], data['class_type'] |
| 286 | + class_def = NODE_CLASS_MAPPINGS[class_type]() |
| 287 | + |
| 288 | + # If the class hasn't been initialized yet, initialize it and generate the import statements |
| 289 | + if class_type not in initialized_objects: |
| 290 | + class_type, import_statement, class_code = get_class_info(class_type) |
| 291 | + initialized_objects[class_type] = class_type.lower() |
| 292 | + import_statements.add(import_statement) |
| 293 | + loader_code.append(class_code) |
| 294 | + |
| 295 | + # Get all possible parameters for class_def |
| 296 | + class_def_params = get_function_parameters(getattr(class_def, class_def.FUNCTION)) |
| 297 | + |
| 298 | + # Remove any keyword arguments from **inputs if they are not in class_def_params |
| 299 | + inputs = {key: value for key, value in inputs.items() if key in class_def_params} |
| 300 | + |
| 301 | + # Create executed variable and generate code |
| 302 | + executed_variables[idx] = f'{class_type.lower()}_{idx}' |
| 303 | + inputs = update_inputs(inputs, executed_variables) |
| 304 | + |
| 305 | + if is_loader: |
| 306 | + loader_code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs)) |
| 307 | + else: |
| 308 | + code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs)) |
| 309 | + |
| 310 | + # Generate final code by combining imports and code, and wrap them in a main function |
| 311 | + final_code = assemble_python_code(import_statements, loader_code, code, queue_size) |
| 312 | + |
| 313 | + # Save the code to a .py file |
| 314 | + write_code_to_file(filename, final_code) |
| 315 | + |
| 316 | + return final_code |
| 317 | + |
| 318 | + |
| 319 | +def main(input, queue_size=10): |
| 320 | + """ |
| 321 | + Main function to be executed. |
| 322 | + """ |
| 323 | + # Load JSON data from the input file |
| 324 | + prompt = read_json_file(input) |
| 325 | + load_order = determine_load_order(prompt) |
| 326 | + output_file = input.replace('.json', '.py') |
| 327 | + code = generate_workflow(load_order, filename=output_file, queue_size=queue_size) |
| 328 | + logging.info(code) |
| 329 | + |
| 330 | + |
| 331 | +if __name__ == '__main__': |
| 332 | + input = 'workflow_api.json' |
| 333 | + queue_size = 10 |
| 334 | + main(input, queue_size) |
0 commit comments