Skip to content

Commit 2e2b725

Browse files
committed
update test function source
1 parent 240767c commit 2e2b725

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

tests/test_modeling_common.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,13 @@
105105
CONFIG_NAME,
106106
GENERATION_CONFIG_NAME,
107107
SAFE_WEIGHTS_NAME,
108-
is_accelerate_available,
109108
is_torch_bf16_available_on_device,
110109
is_torch_fp16_available_on_device,
111110
)
112111

113112
from .generation.test_utils import GenerationTesterMixin
114113

115114

116-
if is_accelerate_available():
117-
from accelerate.utils import compute_module_sizes
118-
119-
120115
if is_torch_available():
121116
import torch
122117
from safetensors import safe_open
@@ -125,6 +120,7 @@
125120
from torch import nn
126121

127122
from transformers import MODEL_MAPPING
123+
from transformers.integrations.accelerate import compute_module_sizes
128124
from transformers.integrations.tensor_parallel import _get_parameter_tp_plan
129125
from transformers.modeling_utils import load_state_dict
130126
from transformers.pytorch_utils import id_tensor_storage
@@ -2370,7 +2366,7 @@ def test_disk_offload_bin(self):
23702366
torch.manual_seed(0)
23712367
base_output = model(**inputs_dict_class)
23722368

2373-
model_size = compute_module_sizes(model)[""]
2369+
model_size = compute_module_sizes(model)[0][""]
23742370
with tempfile.TemporaryDirectory() as tmp_dir:
23752371
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
23762372

@@ -2416,7 +2412,7 @@ def test_disk_offload_safetensors(self):
24162412
torch.manual_seed(0)
24172413
base_output = model(**inputs_dict_class)
24182414

2419-
model_size = compute_module_sizes(model)[""]
2415+
model_size = compute_module_sizes(model)[0][""]
24202416
with tempfile.TemporaryDirectory() as tmp_dir:
24212417
model.cpu().save_pretrained(tmp_dir)
24222418

@@ -2455,7 +2451,7 @@ def test_cpu_offload(self):
24552451
torch.manual_seed(0)
24562452
base_output = model(**inputs_dict_class)
24572453

2458-
model_size = compute_module_sizes(model)[""]
2454+
model_size = compute_module_sizes(model)[0][""]
24592455
# We test several splits of sizes to make sure it works.
24602456
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
24612457
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -2498,7 +2494,7 @@ def test_model_parallelism(self):
24982494
torch.manual_seed(0)
24992495
base_output = model(**inputs_dict_class)
25002496

2501-
model_size = compute_module_sizes(model)[""]
2497+
model_size = compute_module_sizes(model)[0][""]
25022498
# We test several splits of sizes to make sure it works.
25032499
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
25042500
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)