@@ -81,14 +81,14 @@ def _load_pretrained_weights(self):
8181
8282 # Mapping (`hf: ours`) of decoder layers
8383 for i in range (12 ):
84- mapping [f'transformer.h.{ i } .ln_1.weight' ] = f'blocks.{ i } .pre_norm .weight'
85- mapping [f'transformer.h.{ i } .ln_1.bias' ] = f'blocks.{ i } .pre_norm .bias'
84+ mapping [f'transformer.h.{ i } .ln_1.weight' ] = f'blocks.{ i } .attn_norm .weight'
85+ mapping [f'transformer.h.{ i } .ln_1.bias' ] = f'blocks.{ i } .attn_norm .bias'
8686 mapping [f'transformer.h.{ i } .attn.c_attn.weight' ] = f'blocks.{ i } .attn.qkv_projection.weight'
8787 mapping [f'transformer.h.{ i } .attn.c_attn.bias' ] = f'blocks.{ i } .attn.qkv_projection.bias'
8888 mapping [f'transformer.h.{ i } .attn.c_proj.weight' ] = f'blocks.{ i } .attn.output_projection.weight'
8989 mapping [f'transformer.h.{ i } .attn.c_proj.bias' ] = f'blocks.{ i } .attn.output_projection.bias'
90- mapping [f'transformer.h.{ i } .ln_2.weight' ] = f'blocks.{ i } .post_norm .weight'
91- mapping [f'transformer.h.{ i } .ln_2.bias' ] = f'blocks.{ i } .post_norm .bias'
90+ mapping [f'transformer.h.{ i } .ln_2.weight' ] = f'blocks.{ i } .ffn_norm .weight'
91+ mapping [f'transformer.h.{ i } .ln_2.bias' ] = f'blocks.{ i } .ffn_norm .bias'
9292 mapping [f'transformer.h.{ i } .mlp.c_fc.weight' ] = f'blocks.{ i } .ffn.linear_in.weight'
9393 mapping [f'transformer.h.{ i } .mlp.c_fc.bias' ] = f'blocks.{ i } .ffn.linear_in.bias'
9494 mapping [f'transformer.h.{ i } .mlp.c_proj.weight' ] = f'blocks.{ i } .ffn.linear_out.weight'
@@ -110,7 +110,11 @@ def _load_pretrained_weights(self):
110110 new_state_dict [layer ] = torch .transpose (new_state_dict [layer ], 0 , 1 )
111111
112112 # Load out model. We use `strict = False` because the state does not have LoRA weights
113- self .model .load_state_dict (new_state_dict , strict = False )
113+ missing_keys , unexpected_keys = self .model .load_state_dict (new_state_dict , strict = False )
114+
115+ # make sure that only lora weights are not loaded
116+ assert all ('lora' in key for key in missing_keys )
117+ assert not unexpected_keys
114118
115119 def initialize (self ):
116120 """
0 commit comments