Skip to content

Commit 30ac269

Browse files
author
Peyton
committed
Renamed folder.
1 parent dc884df commit 30ac269

File tree

1 file changed

+334
-0
lines changed

1 file changed

+334
-0
lines changed
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
import glob
2+
import inspect
3+
import json
4+
import logging
5+
import os
6+
from typing import Dict, List, Any, Callable
7+
import sys
8+
9+
sys.path.append('../')
10+
11+
from nodes import NODE_CLASS_MAPPINGS
12+
13+
logging.basicConfig(level=logging.INFO)
14+
15+
16+
def read_json_file(file_path: str) -> dict:
17+
"""
18+
Reads a JSON file and returns its contents as a dictionary.
19+
20+
Args:
21+
file_path (str): The path to the JSON file.
22+
23+
Returns:
24+
dict: The contents of the JSON file as a dictionary.
25+
26+
Raises:
27+
FileNotFoundError: If the file is not found, it lists all JSON files in the directory of the file path.
28+
ValueError: If the file is not a valid JSON.
29+
"""
30+
31+
try:
32+
with open(file_path, 'r') as file:
33+
data = json.load(file)
34+
return data
35+
36+
except FileNotFoundError:
37+
# Get the directory from the file_path
38+
directory = os.path.dirname(file_path)
39+
40+
# If the directory is an empty string (which means file is in the current directory),
41+
# get the current working directory
42+
if not directory:
43+
directory = os.getcwd()
44+
45+
# Find all JSON files in the directory
46+
json_files = glob.glob(f"{directory}/*.json")
47+
48+
# Format the list of JSON files as a string
49+
json_files_str = "\n".join(json_files)
50+
51+
raise FileNotFoundError(f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}")
52+
53+
except json.JSONDecodeError:
54+
raise ValueError(f"Invalid JSON format in file: {file_path}")
55+
56+
57+
def determine_load_order(data: Dict) -> List:
58+
"""
59+
Determine the load order of each key in the provided dictionary. This code will place the
60+
nodes without node dependencies first, then ensure that any node whose result is used
61+
in another node will be added to the list in the order it should be executed.
62+
63+
Args:
64+
data (Dict):
65+
The dictionary for which to determine the load order.
66+
67+
Returns:
68+
List:
69+
A list of tuples where each tuple contains a key, its corresponding dictionary,
70+
and a boolean indicating whether or not the function is dependent on the output of
71+
a previous function, ordered by load order.
72+
"""
73+
74+
# Create a dictionary to keep track of visited nodes.
75+
visited = {}
76+
# Create a list to store the load order for functions
77+
load_order = []
78+
# Boolean to indicate whether or not the class is a loader class that should not be
79+
# reloaded during every loop
80+
is_loader = False
81+
82+
def dfs(key: str) -> None:
83+
"""
84+
Depth-First Search function.
85+
86+
Args:
87+
key (str): The key from which to start the DFS.
88+
89+
Returns:
90+
None
91+
"""
92+
# Mark the node as visited.
93+
visited[key] = True
94+
inputs = data[key]['inputs']
95+
96+
# Loop over each input key.
97+
for input_key, val in inputs.items():
98+
# If the value is a list and the first item in the list (which should be a key)
99+
# has not been visited yet, then recursively apply dfs on the dependency.
100+
if isinstance(val, list) and val[0] not in visited:
101+
dfs(val[0])
102+
103+
# Add the key and its corresponding data to the load order list.
104+
load_order.append((key, data[key], is_loader))
105+
106+
# Load Loader keys first
107+
for key in data:
108+
class_def = NODE_CLASS_MAPPINGS[data[key]['class_type']]()
109+
if class_def.CATEGORY == 'loaders' or class_def.FUNCTION in ['encode'] or not any(isinstance(val, list) for val in data[key]['inputs'].values()):
110+
is_loader = True
111+
# If the key has not been visited, perform a DFS from that key.
112+
if key not in visited:
113+
dfs(key)
114+
115+
# Reset is_loader bool
116+
is_loader = False
117+
# Loop over each key in the data.
118+
for key in data:
119+
# If the key has not been visited, perform a DFS from that key.
120+
if key not in visited:
121+
dfs(key)
122+
123+
return load_order
124+
125+
126+
def create_function_call_code(obj_name: str, func: str, variable_name: str, is_loader: bool, **kwargs) -> str:
127+
"""
128+
This function generates Python code for a function call.
129+
130+
Args:
131+
obj_name (str): The name of the initialized object.
132+
func (str): The function to be called.
133+
variable_name (str): The name of the variable that the function result should be assigned to.
134+
is_loader (bool): Determines the code indentation.
135+
**kwargs: The keyword arguments for the function.
136+
137+
Returns:
138+
str: The generated Python code.
139+
"""
140+
141+
def format_arg(key: str, value: any) -> str:
142+
"""Formats arguments based on key and value."""
143+
if key == 'noise_seed':
144+
return f'{key}=random.randint(1, 2**64)'
145+
elif isinstance(value, str):
146+
value = value.replace("\n", "\\n")
147+
return f'{key}="{value}"'
148+
elif key == 'images' and "saveimage" in obj_name and isinstance(value, dict) and 'variable_name' in value:
149+
return f'{key}={value["variable_name"]}.detach()'
150+
elif isinstance(value, dict) and 'variable_name' in value:
151+
return f'{key}={value["variable_name"]}'
152+
return f'{key}={value}'
153+
154+
args = ', '.join(format_arg(key, value) for key, value in kwargs.items())
155+
156+
# Generate the Python code
157+
code = f'{variable_name} = {obj_name}.{func}({args})'
158+
159+
# If the code contains dependencies, indent the code because it will be placed inside
160+
# of a for loop
161+
if not is_loader:
162+
code = f'\t{code}'
163+
164+
return code
165+
166+
167+
def update_inputs(inputs: Dict, executed_variables: Dict) -> Dict:
168+
"""
169+
Update inputs based on the executed variables.
170+
171+
Args:
172+
inputs (Dict): Inputs dictionary to update.
173+
executed_variables (Dict): Dictionary storing executed variable names.
174+
175+
Returns:
176+
Dict: Updated inputs dictionary.
177+
"""
178+
for key in inputs.keys():
179+
if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys():
180+
inputs[key] = {'variable_name': f"{executed_variables[inputs[key][0]]}[{inputs[key][1]}]"}
181+
return inputs
182+
183+
184+
def get_class_info(class_type: str) -> (str, str, str):
185+
"""
186+
Generates and returns necessary information about class type.
187+
188+
Args:
189+
class_type (str): Class type
190+
191+
Returns:
192+
class_type (str): Updated class type
193+
import_statement (str): Import statement string
194+
class_code (str): Class initialization code
195+
"""
196+
# If the class is 'VAEDecode', adjust the class name
197+
if class_type == 'VAEDecode':
198+
class_type = 'VAEDecodeTiled'
199+
200+
import_statement = class_type
201+
class_code = f'{class_type.lower()} = {class_type}()'
202+
203+
return class_type, import_statement, class_code
204+
205+
206+
def assemble_python_code(import_statements: set, loader_code: List[str], code: List[str], queue_size: int) -> str:
207+
"""
208+
Generates final code string.
209+
210+
Args:
211+
import_statements (set): A set of unique import statements
212+
code (List[str]): A list of code strings
213+
queue_size (int): Number of photos that will be generated by the script.
214+
215+
Returns:
216+
final_code (str): Generated final code as a string
217+
"""
218+
static_imports = ['import random']
219+
imports_code = [f"from nodes import {class_name}" for class_name in import_statements]
220+
main_function_code = f"def main():\n\t" + '\n\t'.join(loader_code) + f'\n\tfor q in {range(1, queue_size)}:\n\t' + '\n\t'.join(code)
221+
final_code = '\n'.join(static_imports + ['import sys\nsys.path.append("../")'] + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
222+
223+
return final_code
224+
225+
226+
def write_code_to_file(filename: str, code: str) -> None:
227+
"""
228+
Writes given code to a .py file. If the directory does not exist, it creates it.
229+
230+
Args:
231+
filename (str): The name of the Python file to save the code to.
232+
code (str): The code to save.
233+
"""
234+
235+
# Extract directory from the filename
236+
directory = os.path.dirname(filename)
237+
238+
# If the directory does not exist, create it
239+
if directory and not os.path.exists(directory):
240+
os.makedirs(directory)
241+
242+
# Save the code to a .py file
243+
with open(filename, 'w') as file:
244+
file.write(code)
245+
246+
247+
def get_function_parameters(func: Callable) -> List:
248+
"""Get the names of a function's parameters.
249+
250+
Args:
251+
func (Callable): The function whose parameters we want to inspect.
252+
253+
Returns:
254+
List: A list containing the names of the function's parameters.
255+
"""
256+
signature = inspect.signature(func)
257+
parameters = {name: param.default if param.default != param.empty else None
258+
for name, param in signature.parameters.items()}
259+
return list(parameters.keys())
260+
261+
262+
def generate_workflow(load_order: List, filename: str = 'generated_code_workflow.py', queue_size: int = 10) -> str:
263+
"""
264+
Generate the execution code based on the load order.
265+
266+
Args:
267+
load_order (List): A list of tuples representing the load order.
268+
filename (str): The name of the Python file to which the code should be saved.
269+
Defaults to 'generated_code_workflow.py'.
270+
queue_size (int): The number of photos that will be created by the script.
271+
272+
Returns:
273+
str: Generated execution code as a string.
274+
"""
275+
276+
# Create the necessary data structures to hold imports and generated code
277+
import_statements, executed_variables, loader_code, code = set(), {}, [], []
278+
# This dictionary will store the names of the objects that we have already initialized
279+
initialized_objects = {}
280+
281+
# Loop over each dictionary in the load order list
282+
for idx, data, is_loader in load_order:
283+
284+
# Generate class definition and inputs from the data
285+
inputs, class_type = data['inputs'], data['class_type']
286+
class_def = NODE_CLASS_MAPPINGS[class_type]()
287+
288+
# If the class hasn't been initialized yet, initialize it and generate the import statements
289+
if class_type not in initialized_objects:
290+
class_type, import_statement, class_code = get_class_info(class_type)
291+
initialized_objects[class_type] = class_type.lower()
292+
import_statements.add(import_statement)
293+
loader_code.append(class_code)
294+
295+
# Get all possible parameters for class_def
296+
class_def_params = get_function_parameters(getattr(class_def, class_def.FUNCTION))
297+
298+
# Remove any keyword arguments from **inputs if they are not in class_def_params
299+
inputs = {key: value for key, value in inputs.items() if key in class_def_params}
300+
301+
# Create executed variable and generate code
302+
executed_variables[idx] = f'{class_type.lower()}_{idx}'
303+
inputs = update_inputs(inputs, executed_variables)
304+
305+
if is_loader:
306+
loader_code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs))
307+
else:
308+
code.append(create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_loader, **inputs))
309+
310+
# Generate final code by combining imports and code, and wrap them in a main function
311+
final_code = assemble_python_code(import_statements, loader_code, code, queue_size)
312+
313+
# Save the code to a .py file
314+
write_code_to_file(filename, final_code)
315+
316+
return final_code
317+
318+
319+
def main(input, queue_size=10):
320+
"""
321+
Main function to be executed.
322+
"""
323+
# Load JSON data from the input file
324+
prompt = read_json_file(input)
325+
load_order = determine_load_order(prompt)
326+
output_file = input.replace('.json', '.py')
327+
code = generate_workflow(load_order, filename=output_file, queue_size=queue_size)
328+
logging.info(code)
329+
330+
331+
if __name__ == '__main__':
332+
input = 'workflow_api.json'
333+
queue_size = 10
334+
main(input, queue_size)

0 commit comments

Comments
 (0)