diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 0a37fa299..050819a22 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -11,7 +11,6 @@ ) from huggingface_hub import snapshot_download from transformers.integrations.deepspeed import HfDeepSpeedConfig -from transformers.modeling_utils import no_init_weights from dschat.utils.model.reward_model import RewardModel from dschat.utils.utils import load_state_dict_into_model, print_rank_0 @@ -100,8 +99,7 @@ def create_hf_model(model_class, dschf = None if rlhf_training: # the weight loading is handled by create critic model - with no_init_weights(): - model = model_class.from_config(model_config) + model = model_class.from_config(model_config) else: model = model_class.from_pretrained( model_name_or_path,