Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit c1b8fee

Browse files
committed
Revert prior model test batch size
1 parent 861a2b0 commit c1b8fee

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

OnnxStack.Converter/stable_cascade/models.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import config
66
import torch
77
from typing import Union, Optional, Tuple
8-
from diffusers import AutoencoderKL, StableCascadeUNet, ControlNetModel
9-
from diffusers.models.controlnet import ControlNetOutput, BaseOutput as ControlNetBaseOutput
10-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
8+
from diffusers import AutoencoderKL, StableCascadeUNet
119
from transformers.models.clip.modeling_clip import CLIPTextModelWithProjection
1210
from dataclasses import dataclass
1311

@@ -89,10 +87,10 @@ def decoder_data_loader(data_dir, batchsize, *args, **kwargs):
8987
def prior_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
9088
inputs = {
9189
"sample": torch.rand((batchsize, 16, 24, 24), dtype=torch_dtype),
92-
"timestep_ratio": torch.rand(((batchsize *2),), dtype=torch_dtype),
93-
"clip_text_pooled": torch.rand(((batchsize *2) , 1, 1280), dtype=torch_dtype),
94-
"clip_text": torch.rand(((batchsize *2) , 77, 1280), dtype=torch_dtype),
95-
"clip_img": torch.rand(((batchsize *2) , 1, 768), dtype=torch_dtype)
90+
"timestep_ratio": torch.rand((batchsize,), dtype=torch_dtype),
91+
"clip_text_pooled": torch.rand((batchsize , 1, 1280), dtype=torch_dtype),
92+
"clip_text": torch.rand((batchsize , 77, 1280), dtype=torch_dtype),
93+
"clip_img": torch.rand((batchsize , 1, 768), dtype=torch_dtype)
9694
}
9795

9896
# use as kwargs since they won't be in the correct position if passed along with the tuple of inputs

0 commit comments

Comments
 (0)