Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7381750

Browse files
committed
Add only_unet option
1 parent 78d1593 commit 7381750

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

OnnxStack.Converter/stable_diffusion_xl/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python convert.py --model_input "D:\Models\stable-diffusion-xl-base-1.0" --contr
2222

2323
`--tempDir` - Directory for temp Olive files
2424

25+
`--only_unet` - Only convert UNET model
2526

2627
## Extra Requirements
2728
To successfully optimize SDXL models you will need the patched `vae` from repository below otherwise you may get black image results

OnnxStack.Converter/stable_diffusion_xl/convert.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def optimize(
2727
model_input: str,
2828
model_output: Path,
2929
provider: str,
30-
controlnet: bool
30+
submodel_names: list[str]
3131
):
3232
from google.protobuf import __version__ as protobuf_version
3333

@@ -51,10 +51,6 @@ def optimize(
5151
config.unet_sample_size = pipeline.unet.config.sample_size
5252

5353
model_info = {}
54-
submodel_names = ["tokenizer", "tokenizer_2", "vae_encoder", "vae_decoder", "unet", "text_encoder", "text_encoder_2"]
55-
56-
if controlnet:
57-
submodel_names.append("controlnet")
5854

5955
for submodel_name in submodel_names:
6056
if submodel_name == "tokenizer" or submodel_name == "tokenizer_2":
@@ -81,16 +77,18 @@ def save_onnx_Models(model_dir, model_info, model_output, submodel_names):
8177
conversion_dir = model_output / conversion_type
8278
conversion_dir.mkdir(parents=True, exist_ok=True)
8379

80+
only_unet = "unet" in submodel_names and len(submodel_names) <= 2
8481
# Copy the config and other files required by some applications
85-
model_index_path = model_dir / "model_index.json"
86-
if os.path.exists(model_index_path):
87-
shutil.copy(model_index_path, conversion_dir)
88-
if os.path.exists(model_dir / "tokenizer"):
89-
shutil.copytree(model_dir / "tokenizer", conversion_dir / "tokenizer")
90-
if os.path.exists(model_dir / "tokenizer_2"):
91-
shutil.copytree(model_dir / "tokenizer_2", conversion_dir / "tokenizer_2")
92-
if os.path.exists(model_dir / "scheduler"):
93-
shutil.copytree(model_dir / "scheduler", conversion_dir / "scheduler")
82+
if only_unet is False:
83+
model_index_path = model_dir / "model_index.json"
84+
if os.path.exists(model_index_path):
85+
shutil.copy(model_index_path, conversion_dir)
86+
if os.path.exists(model_dir / "tokenizer"):
87+
shutil.copytree(model_dir / "tokenizer", conversion_dir / "tokenizer")
88+
if os.path.exists(model_dir / "tokenizer_2"):
89+
shutil.copytree(model_dir / "tokenizer_2", conversion_dir / "tokenizer_2")
90+
if os.path.exists(model_dir / "scheduler"):
91+
shutil.copytree(model_dir / "scheduler", conversion_dir / "scheduler")
9492

9593
# Save models files
9694
for submodel_name in submodel_names:
@@ -212,6 +210,8 @@ def parse_common_args(raw_args):
212210
parser.add_argument("--controlnet", action="store_true", help="Create ControlNet Unet Model")
213211
parser.add_argument("--clean", action="store_true", help="Deletes the Olive cache")
214212
parser.add_argument("--tempdir", default=None, type=str, help="Root directory for tempfile directories and files")
213+
parser.add_argument("--only_unet", action="store_true", help="Only convert UNET model")
214+
215215
return parser.parse_known_args(raw_args)
216216

217217

@@ -237,10 +237,17 @@ def main(raw_args=None):
237237

238238
set_tempdir(common_args.tempdir)
239239

240+
submodel_names = ["tokenizer", "tokenizer_2", "vae_encoder", "vae_decoder", "unet", "text_encoder", "text_encoder_2"]
241+
242+
if common_args.only_unet:
243+
submodel_names = ["unet"]
244+
245+
if common_args.controlnet:
246+
submodel_names.append("controlnet")
247+
240248
with warnings.catch_warnings():
241249
warnings.simplefilter("ignore")
242-
optimize(script_dir, common_args.model_input,
243-
model_output, provider, common_args.controlnet)
250+
optimize(script_dir, common_args.model_input, model_output, provider, submodel_names)
244251

245252

246253
if __name__ == "__main__":

0 commit comments

Comments
 (0)