Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7979acf

Browse files
committed
StableCascade vqgan
1 parent 8c76a77 commit 7979acf

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
{
2+
"input_model": {
3+
"type": "PyTorchModel",
4+
"config": {
5+
"model_path": "stabilityai/stable-cascade",
6+
"model_loader": "vqgan_load",
7+
"model_script": "models.py",
8+
"io_config": {
9+
"input_names": [ "sample", "return_dict" ],
10+
"output_names": [ "latent_sample" ],
11+
"dynamic_axes": { "sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } }
12+
},
13+
"dummy_inputs_func": "vqgan_conversion_inputs"
14+
}
15+
},
16+
"systems": {
17+
"local_system": {
18+
"type": "LocalSystem",
19+
"config": {
20+
"accelerators": [
21+
{
22+
"device": "gpu",
23+
"execution_providers": [
24+
"DmlExecutionProvider"
25+
]
26+
}
27+
]
28+
}
29+
}
30+
},
31+
"evaluators": {
32+
"common_evaluator": {
33+
"metrics": [
34+
{
35+
"name": "latency",
36+
"type": "latency",
37+
"sub_types": [{"name": "avg"}],
38+
"user_config": {
39+
"user_script": "models.py",
40+
"dataloader_func": "vqgan_data_loader",
41+
"batch_size": 1
42+
}
43+
}
44+
]
45+
}
46+
},
47+
"passes": {
48+
"convert": {
49+
"type": "OnnxConversion",
50+
"config": {
51+
"target_opset": 16
52+
}
53+
},
54+
"optimize": {
55+
"type": "OrtTransformersOptimization",
56+
"config": {
57+
"model_type": "vae",
58+
"opt_level": 0,
59+
"float16": true,
60+
"use_gpu": true,
61+
"keep_io_types": false,
62+
"optimization_options": {
63+
"enable_gelu": true,
64+
"enable_layer_norm": true,
65+
"enable_attention": true,
66+
"use_multi_head_attention": true,
67+
"enable_skip_layer_norm": false,
68+
"enable_embed_layer_norm": true,
69+
"enable_bias_skip_layer_norm": false,
70+
"enable_bias_gelu": true,
71+
"enable_gelu_approximation": false,
72+
"enable_qordered_matmul": false,
73+
"enable_shape_inference": true,
74+
"enable_gemm_fast_gelu": false,
75+
"enable_nhwc_conv": false,
76+
"enable_group_norm": true,
77+
"enable_bias_splitgelu": false,
78+
"enable_packed_qkv": true,
79+
"enable_packed_kv": true,
80+
"enable_bias_add": false,
81+
"group_norm_channels_last": false
82+
},
83+
"force_fp32_ops": ["RandomNormalLike"],
84+
"force_fp16_inputs": {
85+
"GroupNorm": [0, 1, 2]
86+
}
87+
}
88+
}
89+
},
90+
"pass_flows": [
91+
["convert", "optimize"]
92+
],
93+
"engine": {
94+
"log_severity_level": 0,
95+
"evaluator": "common_evaluator",
96+
"evaluate_input_model": false,
97+
"host": "local_system",
98+
"target": "local_system",
99+
"cache_dir": "cache",
100+
"output_name": "vqgan",
101+
"output_dir": "footprints"
102+
}
103+
}

OnnxStack.Converter/stable_cascade/models.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from typing import Union, Optional, Tuple
88
from diffusers import AutoencoderKL, StableCascadeUNet
9+
from diffusers.pipelines.wuerstchen import PaellaVQModel
910
from transformers.models.clip.modeling_clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjection
1011
from dataclasses import dataclass
1112

@@ -129,7 +130,7 @@ def image_encoder_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
129130

130131

131132
def image_encoder_load(model_name):
132-
model = CLIPVisionModelWithProjection.from_pretrained(model_name, subfolder="image_encoder")
133+
model = CLIPVisionModelWithProjection.from_pretrained(model_name, subfolder="image_encoder", use_safetensors=True)
133134
return model
134135

135136

@@ -138,4 +139,28 @@ def image_encoder_conversion_inputs(model=None):
138139

139140

140141
def image_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
141-
return RandomDataLoader(image_encoder_inputs, batchsize, torch.float16)
142+
return RandomDataLoader(image_encoder_inputs, batchsize, torch.float16)
143+
144+
145+
# -----------------------------------------------------------------------------
146+
# vqgan
147+
# -----------------------------------------------------------------------------
148+
149+
def vqgan_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
150+
inputs = {
151+
"sample": torch.rand((batchsize, 3, 256, 256), dtype=torch_dtype)
152+
}
153+
return inputs
154+
155+
156+
def vqgan_load(model_name):
157+
model = PaellaVQModel.from_pretrained(model_name, subfolder="vqgan", use_safetensors=True)
158+
return model
159+
160+
161+
def vqgan_conversion_inputs(model=None):
162+
return tuple(vqgan_inputs(1, torch.float32, True).values())
163+
164+
165+
def vqgan_data_loader(data_dir, batchsize, *args, **kwargs):
166+
return RandomDataLoader(vqgan_inputs, batchsize, torch.float16)

0 commit comments

Comments
 (0)