@@ -49,26 +49,33 @@ def text_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
4949# -----------------------------------------------------------------------------
5050# DECODER UNET
5151# -----------------------------------------------------------------------------
52+ class DecoderStableCascadeUNet (StableCascadeUNet ):
53+ def forward (
54+ self ,
55+ sample : torch .FloatTensor ,
56+ timestep_ratio : torch .Tensor ,
57+ clip_text_pooled : torch .Tensor ,
58+ effnet : torch .Tensor ,
59+ ) -> Union [StableCascadeUNet , Tuple ]:
60+ return super ().forward (
61+ sample = sample ,
62+ timestep_ratio = timestep_ratio ,
63+ clip_text_pooled = clip_text_pooled ,
64+ effnet = effnet
65+ )
5266
5367def decoder_inputs (batchsize , torch_dtype , is_conversion_inputs = False ):
54- # TODO(jstoecker): Rename onnx::Concat_4 to text_embeds and onnx::Shape_5 to time_ids
5568 inputs = {
5669 "sample" : torch .rand ((batchsize , 4 , 256 , 256 ), dtype = torch_dtype ),
5770 "timestep_ratio" : torch .rand ((batchsize ,), dtype = torch_dtype ),
5871 "clip_text_pooled" : torch .rand ((batchsize , 1 , 1280 ), dtype = torch_dtype ),
5972 "effnet" : torch .rand ((batchsize , 16 , 24 , 24 ), dtype = torch_dtype )
6073 }
61-
62- # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
63- kwargs = {
64- "return_dict" : False ,
65- }
66-
6774 return inputs
6875
6976
7077def decoder_load (model_name ):
71- model = StableCascadeUNet .from_pretrained (model_name , subfolder = "decoder" )
78+ model = DecoderStableCascadeUNet .from_pretrained (model_name , subfolder = "decoder" )
7279 return model
7380
7481
@@ -92,14 +99,8 @@ def prior_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
9299 "timestep_ratio" : torch .rand ((batchsize ,), dtype = torch_dtype ),
93100 "clip_text_pooled" : torch .rand ((batchsize , 1 , 1280 ), dtype = torch_dtype ),
94101 "clip_text" : torch .rand ((batchsize , 77 , 1280 ), dtype = torch_dtype ),
95- "clip_img" : torch .rand ((batchsize , 1 , 768 ), dtype = torch_dtype )
96- }
97-
98- # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
99- kwargs = {
100- "return_dict" : False ,
102+ "clip_img" : torch .rand ((batchsize , 1 , 768 ), dtype = torch_dtype ),
101103 }
102-
103104 return inputs
104105
105106
@@ -147,16 +148,26 @@ def image_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
147148# -----------------------------------------------------------------------------
148149# VQGAN
149150# -----------------------------------------------------------------------------
151+ class DecodePaellaVQModel (PaellaVQModel ):
152+ def forward (
153+ self ,
154+ sample : torch .FloatTensor ,
155+ ) -> Union [PaellaVQModel , Tuple ]:
156+ return super ().decode (
157+ h = sample ,
158+ force_not_quantize = True ,
159+ return_dict = True ,
160+ )
150161
151162def vqgan_inputs (batchsize , torch_dtype , is_conversion_inputs = False ):
152163 inputs = {
153- "sample" : torch .rand ((batchsize , 3 , 256 , 256 ), dtype = torch_dtype )
164+ "sample" : torch .rand ((batchsize , 4 , 256 , 256 ), dtype = torch_dtype )
154165 }
155166 return inputs
156167
157168
158169def vqgan_load (model_name ):
159- model = PaellaVQModel .from_pretrained (model_name , subfolder = "vqgan" , use_safetensors = True )
170+ model = DecodePaellaVQModel .from_pretrained (model_name , subfolder = "vqgan" , use_safetensors = True )
160171 return model
161172
162173
0 commit comments