Skip to content

Commit 38e7e54

Browse files
Merge pull request #229 from bendavidsteel/model-kwargs-hf
Add the ability to set model kwargs for HF local models
2 parents f663f54 + f9a0362 commit 38e7e54

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

dsp/modules/hf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
"sequential",
4141
] = "auto",
4242
token: Optional[str] = None,
43+
model_kwargs: Optional[dict] = {},
4344
):
4445
"""wrapper for Hugging Face models
4546
@@ -49,6 +50,7 @@ def __init__(
4950
is_client (bool, optional): whether to access models via client. Defaults to False.
5051
hf_device_map (str, optional): HF config strategy to load the model.
5152
Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto".
53+
model_kwargs (dict, optional): additional kwargs to pass to the model constructor. Defaults to empty dict.
5254
"""
5355

5456
super().__init__(model)
@@ -59,6 +61,11 @@ def __init__(
5961
hf_autoconfig_kwargs = dict(token=token or os.environ.get("HF_TOKEN"))
6062
hf_autotokenizer_kwargs = hf_autoconfig_kwargs.copy()
6163
hf_automodel_kwargs = hf_autoconfig_kwargs.copy()
64+
65+
# silently remove device_map from model_kwargs if it is present, as the option is provided in the constructor
66+
if "device_map" in model_kwargs:
67+
model_kwargs.pop("device_map")
68+
hf_automodel_kwargs.update(model_kwargs)
6269
if not self.is_client:
6370
try:
6471
import torch

0 commit comments

Comments
 (0)