Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit d6fd9b0

Browse files
committed
convert temp_convs in right way
1 parent 2e641f1 commit d6fd9b0

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
("skip_connection", "conv_shortcut"),
5050

5151
# MS
52-
("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
52+
#("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
5353
]
5454

5555
unet_conversion_map_layer = []
@@ -99,13 +99,13 @@
9999

100100
# Temporal MS stuff
101101
hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}."
102-
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
102+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv."
103103
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
104104

105105
if i < 3:
106106
# no attention layers in down_blocks.3
107107
hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}."
108-
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
108+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2."
109109
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
110110

111111
for j in range(3):
@@ -124,13 +124,13 @@
124124

125125
# loop over resnets/attentions for upblocks
126126
hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}."
127-
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
127+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv."
128128
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
129129

130130
if i > 0:
131131
# no attention layers in up_blocks.0
132132
hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}."
133-
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
133+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.2."
134134
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
135135

136136
# Up/Downsamplers are 2D, so don't need to touch them
@@ -155,17 +155,17 @@
155155

156156
for j in range(2):
157157
hf_mid_res_prefix = f"mid_block.resnets.{j}."
158-
sd_mid_res_prefix = f"middle_block.{2*j}."
158+
sd_mid_res_prefix = f"middle_block.{2*j+1}."
159159
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
160160

161161
# Temporal
162162
hf_mid_atn_prefix = "mid_block.temp_attentions.0."
163-
sd_mid_atn_prefix = "middle_block.1."
163+
sd_mid_atn_prefix = "middle_block.2."
164164
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
165165

166166
for j in range(2):
167-
hf_mid_res_prefix = f"mid_block.temp_convs.{j}."
168-
sd_mid_res_prefix = f"middle_block.{2*j}."
167+
hf_mid_res_prefix = f"mid_block.temp_convs.{j+1}."
168+
sd_mid_res_prefix = f"middle_block.{2*j+1}.temopral_conv."
169169
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
170170

171171
# The pipeline
@@ -183,10 +183,10 @@ def convert_unet_state_dict(unet_state_dict):
183183
for sd_part, hf_part in unet_conversion_map_resnet:
184184
v = v.replace(hf_part, sd_part)
185185
mapping[k] = v
186-
elif "temp_convs" in k:
187-
for sd_part, hf_part in unet_conversion_map_resnet:
188-
v = v.replace(hf_part, sd_part)
189-
mapping[k] = v
186+
# elif "temp_convs" in k:
187+
# for sd_part, hf_part in unet_conversion_map_resnet:
188+
# v = v.replace(hf_part, sd_part)
189+
# mapping[k] = v
190190
for k, v in mapping.items():
191191
for sd_part, hf_part in unet_conversion_map_layer:
192192
v = v.replace(hf_part, sd_part)

0 commit comments

Comments
 (0)