Skip to content

Commit cfddeb5

Browse files
author
Peyton
committed
Updated logic for importing static libraries to generated script.
1 parent 54c26fb commit cfddeb5

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import glob
2+
import inspect
3+
import json
4+
import logging
5+
import os
6+
from typing import Dict, List, Any, Callable
7+
import sys
8+
9+
sys.path.append('../')
10+
11+
from nodes import NODE_CLASS_MAPPINGS
12+
13+
logging.basicConfig(level=logging.INFO)
14+
15+
16+
def read_json_file(file_path: str) -> dict:
17+
"""
18+
Reads a JSON file and returns its contents as a dictionary.
19+
20+
Args:
21+
file_path (str): The path to the JSON file.
22+
23+
Returns:
24+
dict: The contents of the JSON file as a dictionary.
25+
26+
Raises:
27+
FileNotFoundError: If the file is not found, it lists all JSON files in the directory of the file path.
28+
ValueError: If the file is not a valid JSON.
29+
"""
30+
31+
try:
32+
with open(file_path, 'r') as file:
33+
data = json.load(file)
34+
return data
35+
36+
except FileNotFoundError:
37+
# Get the directory from the file_path
38+
directory = os.path.dirname(file_path)
39+
40+
# If the directory is an empty string (which means file is in the current directory),
41+
# get the current working directory
42+
if not directory:
43+
directory = os.getcwd()
44+
45+
# Find all JSON files in the directory
46+
json_files = glob.glob(f"{directory}/*.json")
47+
48+
# Format the list of JSON files as a string
49+
json_files_str = "\n".join(json_files)
50+
51+
raise FileNotFoundError(f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}")
52+
53+
except json.JSONDecodeError:
54+
raise ValueError(f"Invalid JSON format in file: {file_path}")
55+
56+
57+
def determine_load_order(data: Dict) -> List:
58+
"""
59+
Determine the load order of each key in the provided dictionary. This code will place the
60+
nodes without node dependencies first, then ensure that any node whose result is used
61+
in another node will be added to the list in the order it should be executed.
62+
63+
Args:
64+
data (Dict):
65+
The dictionary for which to determine the load order.
66+
67+
Returns:
68+
List:
69+
A list of tuples where each tuple contains a key, its corresponding dictionary,
70+
and a boolean indicating whether or not the function is dependent on the output of
71+
a previous function, ordered by load order.
72+
"""
73+
74+
# Create a dictionary to keep track of visited nodes.
75+
visited = {}
76+
# Create a list to store the load order for functions
77+
load_order = []
78+
# Boolean to indicate whether or not the class is a loader class that should not be
79+
# reloaded during every loop
80+
is_loader = False
81+
82+
def dfs(key: str) -> None:
83+
"""
84+
Depth-First Search function.
85+
86+
Args:
87+
key (str): The key from which to start the DFS.
88+
89+
Returns:
90+
None
91+
"""
92+
# Mark the node as visited.
93+
visited[key] = True
94+
inputs = data[key]['inputs']
95+
96+
# Loop over each input key.
97+
for input_key, val in inputs.items():
98+
# If the value is a list and the first item in the list (which should be a key)
99+
# has not been visited yet, then recursively apply dfs on the dependency.
100+
if isinstance(val, list) and val[0] not in visited:
101+
dfs(val[0])
102+
103+
# Add the key and its corresponding data to the load order list.
104+
load_order.append((key, data[key], is_loader))
105+
106+
# Load Loader keys first
107+
for key in data:
108+
class_def = NODE_CLASS_MAPPINGS[data[key]['class_type']]()
109+
if class_def.CATEGORY == 'loaders' or class_def.FUNCTION in ['encode'] or not any(isinstance(val, list) for val in data[key]['inputs'].values()):
110+
is_loader = True
111+
# If the key has not been visited, perform a DFS from that key.
112+
if key not in visited:
113+
dfs(key)
114+
115+
# Reset is_loader bool
116+
is_loader = False
117+
# Loop over each key in the data.
118+
for key in data:
119+
# If the key has not been visited, perform a DFS from that key.
120+
if key not in visited:
121+
dfs(key)
122+
123+
return load_order
124+
125+
126+
def create_function_call_code(obj_name: str, func: str, variable_name: str, is_loader: bool, **kwargs) -> str:
127+
"""
128+
This function generates Python code for a function call.
129+
130+
Args:
131+
obj_name (str): The name of the initialized object.
132+
func (str): The function to be called.
133+
variable_name (str): The name of the variable that the function result should be assigned to.
134+
is_loader (bool): Determines the code indentation.
135+
**kwargs: The keyword arguments for the function.
136+
137+
Returns:
138+
str: The generated Python code.
139+
"""
140+
141+
def format_arg(key: str, value: any) -> str:
142+
"""Formats arguments based on key and value."""
143+
if key == 'noise_seed':
144+
return f'{key}=random.randint(1, 2**64)'
145+
elif isinstance(value, str):
146+
value = value.replace("\n", "\\n")
147+
return f'{key}="{value}"'
148+
elif key == 'images' and "saveimage" in obj_name and isinstance(value, dict) and 'variable_name' in value:
149+
return f'{key}={value["variable_name"]}.detach()'
150+
elif isinstance(value, dict) and 'variable_name' in value:
151+
return f'{key}={value["variable_name"]}'
152+
return f'{key}={value}'
153+
154+
args = ', '.join(format_arg(key, value) for key, value in kwargs.items())
155+
156+
# Generate the Python code
157+
code = f'{variable_name} = {obj_name}.{func}({args})'
158+
159+
# If the code contains dependencies, indent the code because it will be placed inside
160+
# of a for loop
161+
if not is_loader:
162+
code = f'\t{code}'
163+
164+
return code
165+
166+
167+
def update_inputs(inputs: Dict, executed_variables: Dict) -> Dict:
168+
"""
169+
Update inputs based on the executed variables.
170+
171+
Args:
172+
inputs (Dict): Inputs dictionary to update.
173+
executed_variables (Dict): Dictionary storing executed variable names.
174+
175+
Returns:
176+
Dict: Updated inputs dictionary.
177+
"""
178+
for key in inputs.keys():
179+
if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys():
180+
inputs[key] = {'variable_name': f"{executed_variables[inputs[key][0]]}[{inputs[key][1]}]"}
181+
return inputs
182+
183+
184+
def get_class_info(class_type: str) -> (str, str, str):
185+
"""
186+
Generates and returns necessary information about class type.
187+
188+
Args:
189+
class_type (str): Class type
190+
191+
Returns:
192+
class_type (str): Updated class type
193+
import_statement (str): Import statement string
194+
class_code (str): Class initialization code
195+
"""
196+
# If the class is 'VAEDecode', adjust the class name
197+
if class_type == 'VAEDecode':
198+
class_type = 'VAEDecodeTiled'
199+
200+
import_statement = class_type
201+
class_code = f'{class_type.lower()} = {class_type}()'
202+
203+
return class_type, import_statement, class_code
204+
205+
206+
def assemble_python_code(import_statements: set, loader_code: List[str], code: List[str], query_size: int) -> str:
207+
"""
208+
Generates final code string.
209+
210+
Args:
211+
import_statements (set): A set of unique import statements
212+
code (List[str]): A list of code strings
213+
214+
Returns:
215+
final_code (str): Generated final code as a string
216+
"""
217+
static_imports = ['import random']
218+
imports_code = [f"from nodes import {class_name}" for class_name in import_statements]
219+
main_function_code = f"def main():\n\t" + '\n\t'.join(loader_code) + f'\n\tfor q in {range(1, query_size)}:\n\t' + '\n\t'.join(code)
220+
final_code = '\n'.join(static_imports + ['import sys\nsys.path.append("../")'] + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
221+
222+
return final_code
223+
224+
225+
def write_code_to_file(filename: str, code: str) -> None:
226+
"""
227+
Writes given code to a .py file. If the directory does not exist, it creates it.
228+
229+
Args:
230+
filename (str): The name of the Python file to save the code to.
231+
code (str): The code to save.
232+
"""
233+
234+
# Extract directory from the filename
235+
directory = os.path.dirname(filename)
236+
237+
# If the directory does not exist, create it
238+
if directory and not os.path.exists(directory):
239+
os.makedirs(directory)
240+
241+
# Save the code to a .py file
242+
with open(filename, 'w') as file:
243+
file.write(code)
244+
245+
246+
def get_function_parameters(func: Callable[..., Any]) -> List:
247+
"""Get the names of a function's parameters.
248+
249+
Args:
250+
func (Callable[..., Any]): The function whose parameters we want to inspect.
251+
252+
Returns:
253+
List: A dictionary containing the names of the function's parameters.
254+
"""
255+
signature = inspect.signature(func)
256+
parameters = {name: param.default if param.default != param.empty else None
257+
for name, param in signature.parameters.items()}
258+
return list(parameters.keys())
259+
260+
261+
def generate_workflow(load_order: List, filename: str = 'generated_code_workflow.py', query_size: int = 10) -> str:
262+
"""
263+
Generate the execution code based on the load order.
264+
265+
Args:
266+
load_order (List): A list of tuples representing the load order.
267+
filename (str): The name of the Python file to which the code should be saved.
268+
Defaults to 'generated_code_workflow.py'.
269+
270+
Returns:
271+
str: Generated execution code as a string.
272+
"""
273+
274+
# Create the necessary data structures to hold imports and generated code
275+
import_statements, executed_variables, loader_code, code = set(), {}, [], []
276+
# This dictionary will store the names of the objects that we have already initialized
277+
initialized_objects = {}
278+
279+
# Loop over each dictionary in the load order list
280+
for idx, data, is_loader in load_order:
281+
282+
# Generate class definition and inputs from the data
283+
inputs, class_type = data['inputs'], data['class_type']
284+
class_def = NODE_CLASS_MAPPINGS[class_type]()
285+
286+
# If the class hasn't been initialized yet, initialize it and generate the import statements
287+
if class_type not in initialized_objects:
288+
class_type, import_statement, class_code = get_class_info(class_type)
289+
initialized_objects[class_type] = class_type.lower()
290+
import_statements.add(import_statement)
291+
loader_code.append(class_code)
292+
293+
# Get all possible parameters for class_def
294+
class_def_params = get_function_parameters(getattr(class_def, class_def.FUNCTION))
295+
296+
# Remove any keyword arguments from **inputs if they are not in class_def_params
297+
inputs = {key: value for key, value in inputs.items() if key in class_def_params}
298+
299+
# Create executed variable and generate code
300+
executed_variables[idx] = f'{class_type.lower()}_{idx}'
301+
inputs = update_inputs(inputs, executed_variables)
302+
303+
if is_loader:
304+
loader_code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs))
305+
else:
306+
code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs))
307+
308+
# Generate final code by combining imports and code, and wrap them in a main function
309+
final_code = assemble_python_code(import_statements, loader_code, code, query_size)
310+
311+
# Save the code to a .py file
312+
write_code_to_file(filename, final_code)
313+
314+
return final_code
315+
316+
317+
def main(input, query_size=10):
318+
"""
319+
Main function to be executed.
320+
"""
321+
# Load JSON data from the input file
322+
prompt = read_json_file(input)
323+
load_order = determine_load_order(prompt)
324+
output_file = input.replace('.json', '.py')
325+
code = generate_workflow(load_order, filename=output_file, query_size=query_size)
326+
logging.info(code)
327+
328+
329+
if __name__ == '__main__':
330+
input = 'workflow_api.json'
331+
query_size = 10
332+
main(input, query_size)

0 commit comments

Comments
 (0)