4949 ("skip_connection" , "conv_shortcut" ),
5050
5151 # MS
52- ("temopral_conv" , "temp_convs" ), # ROFL, they have a typo here --kabachuha
52+ # ("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
5353]
5454
5555unet_conversion_map_layer = []
9999
100100 # Temporal MS stuff
101101 hf_down_res_prefix = f"down_blocks.{ i } .temp_convs.{ j } ."
102- sd_down_res_prefix = f"input_blocks.{ 3 * i + j + 1 } .0."
102+ sd_down_res_prefix = f"input_blocks.{ 3 * i + j + 1 } .0.temopral_conv. "
103103 unet_conversion_map_layer .append ((sd_down_res_prefix , hf_down_res_prefix ))
104104
105105 if i < 3 :
106106 # no attention layers in down_blocks.3
107107 hf_down_atn_prefix = f"down_blocks.{ i } .temp_attentions.{ j } ."
108- sd_down_atn_prefix = f"input_blocks.{ 3 * i + j + 1 } .1 ."
108+ sd_down_atn_prefix = f"input_blocks.{ 3 * i + j + 1 } .2 ."
109109 unet_conversion_map_layer .append ((sd_down_atn_prefix , hf_down_atn_prefix ))
110110
111111 for j in range (3 ):
124124
125125 # loop over resnets/attentions for upblocks
126126 hf_up_res_prefix = f"up_blocks.{ i } .temp_convs.{ j } ."
127- sd_up_res_prefix = f"output_blocks.{ 3 * i + j } .0."
127+ sd_up_res_prefix = f"output_blocks.{ 3 * i + j } .0.temopral_conv. "
128128 unet_conversion_map_layer .append ((sd_up_res_prefix , hf_up_res_prefix ))
129129
130130 if i > 0 :
131131 # no attention layers in up_blocks.0
132132 hf_up_atn_prefix = f"up_blocks.{ i } .temp_attentions.{ j } ."
133- sd_up_atn_prefix = f"output_blocks.{ 3 * i + j } .1 ."
133+ sd_up_atn_prefix = f"output_blocks.{ 3 * i + j } .2 ."
134134 unet_conversion_map_layer .append ((sd_up_atn_prefix , hf_up_atn_prefix ))
135135
136136 # Up/Downsamplers are 2D, so don't need to touch them
155155
156156for j in range (2 ):
157157 hf_mid_res_prefix = f"mid_block.resnets.{ j } ."
158- sd_mid_res_prefix = f"middle_block.{ 2 * j } ."
158+ sd_mid_res_prefix = f"middle_block.{ 2 * j + 1 } ."
159159 unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
160160
161161# Temporal
162162hf_mid_atn_prefix = "mid_block.temp_attentions.0."
163- sd_mid_atn_prefix = "middle_block.1 ."
163+ sd_mid_atn_prefix = "middle_block.2 ."
164164unet_conversion_map_layer .append ((sd_mid_atn_prefix , hf_mid_atn_prefix ))
165165
166166for j in range (2 ):
167- hf_mid_res_prefix = f"mid_block.temp_convs.{ j } ."
168- sd_mid_res_prefix = f"middle_block.{ 2 * j } ."
167+ hf_mid_res_prefix = f"mid_block.temp_convs.{ j + 1 } ."
168+ sd_mid_res_prefix = f"middle_block.{ 2 * j + 1 } .temopral_conv ."
169169 unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
170170
171171# The pipeline
@@ -183,10 +183,10 @@ def convert_unet_state_dict(unet_state_dict):
183183 for sd_part , hf_part in unet_conversion_map_resnet :
184184 v = v .replace (hf_part , sd_part )
185185 mapping [k ] = v
186- elif "temp_convs" in k :
187- for sd_part , hf_part in unet_conversion_map_resnet :
188- v = v .replace (hf_part , sd_part )
189- mapping [k ] = v
186+ # elif "temp_convs" in k:
187+ # for sd_part, hf_part in unet_conversion_map_resnet:
188+ # v = v.replace(hf_part, sd_part)
189+ # mapping[k] = v
190190 for k , v in mapping .items ():
191191 for sd_part , hf_part in unet_conversion_map_layer :
192192 v = v .replace (hf_part , sd_part )
0 commit comments