Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit c6b097b

Browse files
committed
StableCascade: Fix missing decoder input, use VQGAN decode
1 parent 74834fc commit c6b097b

File tree

3 files changed

+31
-20
lines changed

3 files changed

+31
-20
lines changed

OnnxStack.Converter/stable_cascade/config_decoder.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
"model_loader": "decoder_load",
77
"model_script": "models.py",
88
"io_config": {
9-
"input_names": [ "sample", "timestep_ratio", "clip_text_pooled", "effnet", "return_dict" ],
9+
"input_names": [ "sample", "timestep_ratio", "clip_text_pooled", "effnet" ],
1010
"output_names": [ "out_sample" ],
1111
"dynamic_axes": {
1212
"sample": {"0": "unet_sample_batch", "1": "unet_sample_channels", "2": "unet_sample_height", "3": "unet_sample_width"},
1313
"timestep_ratio": {"0": "unet_timestep_ratio"},
1414
"clip_text_pooled": {"0": "unet_clip_text_pooled_batch", "1": "unet_clip_text_pooled_size"},
15-
"effnet": {"0": "unet_hidden_batch", "1": "unet_hidden_size"}
15+
"effnet": {"0": "effnet_sample_batch", "1": "effnet_sample_channels", "2": "effnet_sample_height", "3": "effnet_sample_width"}
1616
}
1717
},
1818
"dummy_inputs_func": "decoder_conversion_inputs"

OnnxStack.Converter/stable_cascade/config_prior.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"model_loader": "prior_load",
77
"model_script": "models.py",
88
"io_config": {
9-
"input_names": [ "sample", "timestep_ratio", "clip_text_pooled", "clip_text", "clip_img", "return_dict" ],
9+
"input_names": [ "sample", "timestep_ratio", "clip_text_pooled", "clip_text", "clip_img" ],
1010
"output_names": [ "out_sample" ],
1111
"dynamic_axes": {
1212
"sample": {"0": "unet_sample_batch", "1": "unet_sample_channels", "2": "unet_sample_height", "3": "unet_sample_width"},

OnnxStack.Converter/stable_cascade/models.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,33 @@ def text_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
4949
# -----------------------------------------------------------------------------
5050
# DECODER UNET
5151
# -----------------------------------------------------------------------------
52+
class DecoderStableCascadeUNet(StableCascadeUNet):
53+
def forward(
54+
self,
55+
sample: torch.FloatTensor,
56+
timestep_ratio: torch.Tensor,
57+
clip_text_pooled: torch.Tensor,
58+
effnet: torch.Tensor,
59+
) -> Union[StableCascadeUNet, Tuple]:
60+
return super().forward(
61+
sample = sample,
62+
timestep_ratio = timestep_ratio,
63+
clip_text_pooled = clip_text_pooled,
64+
effnet = effnet
65+
)
5266

5367
def decoder_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
54-
# TODO(jstoecker): Rename onnx::Concat_4 to text_embeds and onnx::Shape_5 to time_ids
5568
inputs = {
5669
"sample": torch.rand((batchsize, 4, 256, 256), dtype=torch_dtype),
5770
"timestep_ratio": torch.rand((batchsize,), dtype=torch_dtype),
5871
"clip_text_pooled": torch.rand((batchsize , 1, 1280), dtype=torch_dtype),
5972
"effnet": torch.rand((batchsize, 16, 24, 24), dtype=torch_dtype)
6073
}
61-
62-
# use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
63-
kwargs = {
64-
"return_dict": False,
65-
}
66-
6774
return inputs
6875

6976

7077
def decoder_load(model_name):
71-
model = StableCascadeUNet.from_pretrained(model_name, subfolder="decoder")
78+
model = DecoderStableCascadeUNet.from_pretrained(model_name, subfolder="decoder")
7279
return model
7380

7481

@@ -92,14 +99,8 @@ def prior_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
9299
"timestep_ratio": torch.rand((batchsize,), dtype=torch_dtype),
93100
"clip_text_pooled": torch.rand((batchsize , 1, 1280), dtype=torch_dtype),
94101
"clip_text": torch.rand((batchsize , 77, 1280), dtype=torch_dtype),
95-
"clip_img": torch.rand((batchsize , 1, 768), dtype=torch_dtype)
96-
}
97-
98-
# use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
99-
kwargs = {
100-
"return_dict": False,
102+
"clip_img": torch.rand((batchsize , 1, 768), dtype=torch_dtype),
101103
}
102-
103104
return inputs
104105

105106

@@ -147,16 +148,26 @@ def image_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
147148
# -----------------------------------------------------------------------------
148149
# VQGAN
149150
# -----------------------------------------------------------------------------
151+
class DecodePaellaVQModel(PaellaVQModel):
152+
def forward(
153+
self,
154+
sample: torch.FloatTensor,
155+
) -> Union[PaellaVQModel, Tuple]:
156+
return super().decode(
157+
h = sample,
158+
force_not_quantize = True,
159+
return_dict = True,
160+
)
150161

151162
def vqgan_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
152163
inputs = {
153-
"sample": torch.rand((batchsize, 3, 256, 256), dtype=torch_dtype)
164+
"sample": torch.rand((batchsize, 4, 256, 256), dtype=torch_dtype)
154165
}
155166
return inputs
156167

157168

158169
def vqgan_load(model_name):
159-
model = PaellaVQModel.from_pretrained(model_name, subfolder="vqgan", use_safetensors=True)
170+
model = DecodePaellaVQModel.from_pretrained(model_name, subfolder="vqgan", use_safetensors=True)
160171
return model
161172

162173

0 commit comments

Comments
 (0)