Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 6a7077c

Browse files
committed
hardcode tensor squeeze
1 parent d65cf47 commit 6a7077c

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,24 @@ def convert_unet_state_dict(unet_state_dict):
191191
for sd_part, hf_part in unet_conversion_map_layer:
192192
v = v.replace(hf_part, sd_part)
193193
mapping[k] = v
194-
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
194+
195+
196+
# there must be a pattern, but I don't want to bother atm
197+
do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]]
198+
print (do_not_unsqueeze)
199+
200+
new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()}
201+
# HACK: idk why the hell it does not work with list comprehension
202+
for k, v in new_state_dict.items():
203+
has_k = False
204+
for n in do_not_unsqueeze:
205+
if k == n:
206+
has_k = True
207+
208+
if has_k:
209+
v = v.squeeze(-1)
210+
new_state_dict[k] = v
211+
195212
return new_state_dict
196213

197214
# TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha

0 commit comments

Comments
 (0)