File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -40,6 +40,7 @@ def __init__(
4040 "sequential" ,
4141 ] = "auto" ,
4242 token : Optional [str ] = None ,
43+ model_kwargs : Optional [dict ] = {},
4344 ):
4445 """wrapper for Hugging Face models
4546
@@ -49,6 +50,7 @@ def __init__(
4950 is_client (bool, optional): whether to access models via client. Defaults to False.
5051 hf_device_map (str, optional): HF config strategy to load the model.
5152 Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto".
53+ model_kwargs (dict, optional): additional kwargs to pass to the model constructor. Defaults to empty dict.
5254 """
5355
5456 super ().__init__ (model )
@@ -59,6 +61,11 @@ def __init__(
5961 hf_autoconfig_kwargs = dict (token = token or os .environ .get ("HF_TOKEN" ))
6062 hf_autotokenizer_kwargs = hf_autoconfig_kwargs .copy ()
6163 hf_automodel_kwargs = hf_autoconfig_kwargs .copy ()
64+
65+ # silently remove device_map from model_kwargs if it is present, as the option is provided in the constructor
66+ if "device_map" in model_kwargs :
67+ model_kwargs .pop ("device_map" )
68+ hf_automodel_kwargs .update (model_kwargs )
6269 if not self .is_client :
6370 try :
6471 import torch
You can’t perform that action at this time.
0 commit comments