|
5 | 5 | import config |
6 | 6 | import torch |
7 | 7 | from typing import Union, Optional, Tuple |
8 | | -from diffusers import AutoencoderKL, StableCascadeUNet, ControlNetModel |
9 | | -from diffusers.models.controlnet import ControlNetOutput, BaseOutput as ControlNetBaseOutput |
10 | | -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| 8 | +from diffusers import AutoencoderKL, StableCascadeUNet |
11 | 9 | from transformers.models.clip.modeling_clip import CLIPTextModelWithProjection |
12 | 10 | from dataclasses import dataclass |
13 | 11 |
|
@@ -89,10 +87,10 @@ def decoder_data_loader(data_dir, batchsize, *args, **kwargs): |
89 | 87 | def prior_inputs(batchsize, torch_dtype, is_conversion_inputs=False): |
90 | 88 | inputs = { |
91 | 89 | "sample": torch.rand((batchsize, 16, 24, 24), dtype=torch_dtype), |
92 | | - "timestep_ratio": torch.rand(((batchsize *2),), dtype=torch_dtype), |
93 | | - "clip_text_pooled": torch.rand(((batchsize *2) , 1, 1280), dtype=torch_dtype), |
94 | | - "clip_text": torch.rand(((batchsize *2) , 77, 1280), dtype=torch_dtype), |
95 | | - "clip_img": torch.rand(((batchsize *2) , 1, 768), dtype=torch_dtype) |
| 90 | + "timestep_ratio": torch.rand((batchsize,), dtype=torch_dtype), |
| 91 | + "clip_text_pooled": torch.rand((batchsize , 1, 1280), dtype=torch_dtype), |
| 92 | + "clip_text": torch.rand((batchsize , 77, 1280), dtype=torch_dtype), |
| 93 | + "clip_img": torch.rand((batchsize , 1, 768), dtype=torch_dtype) |
96 | 94 | } |
97 | 95 |
|
98 | 96 | # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs |
|
0 commit comments