Skip to content

Commit 9e8c71c

Browse files
Prompt user to set_env_variables.sh + remove global scope imports from runner files (#243)
1 parent 14c1197 commit 9e8c71c

File tree

69 files changed

+1584
-333
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1584
-333
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ jobs:
2828
run:
2929
python3 -m flake8
3030

31+
- name: Ensure runner files don't do imports in global scope and check if env checking codeblock prepended
32+
run:
33+
python3 -m unittest tests.test_imports
34+
3135
- name: Git checkout w/ submodules
3236
uses: actions/checkout@v4
3337
with:

computer_vision/classification/alexnet/run.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
import torch
5-
import torchvision
6-
from utils.benchmark import run_model
7-
from utils.cv.imagenet import ImageNet
8-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
3+
try:
4+
from utils import misc # noqa
5+
except ModuleNotFoundError:
6+
import os
7+
import sys
8+
filename = "set_env_variables.sh"
9+
directory = os.path.realpath(__file__).split("/")[:-1]
10+
for idx in range(1, len(directory) - 1):
11+
subdir = "/".join(directory[:-idx])
12+
if filename in os.listdir(subdir):
13+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
14+
break
15+
else:
16+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
17+
f"\033[0m")
18+
sys.exit(1)
919

1020

1121
def parse_args():
22+
import argparse
1223
parser = argparse.ArgumentParser(description="Run AlexNet model.")
1324
parser.add_argument("-p", "--precision",
1425
type=str, choices=["fp32"], required=True,
@@ -38,6 +49,10 @@ def parse_args():
3849

3950

4051
def run_pytorch_fp(model_name, batch_size, num_runs, timeout, images_path, labels_path, disable_jit_freeze=False):
52+
import torch
53+
import torchvision
54+
from utils.benchmark import run_model
55+
from utils.cv.imagenet import ImageNet
4156
from utils.pytorch import PyTorchRunner
4257

4358
def run_single_pass(pytorch_runner, imagenet):
@@ -64,6 +79,7 @@ def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, images_path, lab
6479

6580

6681
def main():
82+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
6783
args = parse_args()
6884
download_ampere_imagenet()
6985

computer_vision/classification/densenet_121/run.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
import torch
5-
import torchvision
6-
from utils.benchmark import run_model
7-
from utils.cv.imagenet import ImageNet
8-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
3+
4+
try:
5+
from utils import misc # noqa
6+
except ModuleNotFoundError:
7+
import os
8+
import sys
9+
filename = "set_env_variables.sh"
10+
directory = os.path.realpath(__file__).split("/")[:-1]
11+
for idx in range(1, len(directory) - 1):
12+
subdir = "/".join(directory[:-idx])
13+
if filename in os.listdir(subdir):
14+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
15+
break
16+
else:
17+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
18+
f"\033[0m")
19+
sys.exit(1)
920

1021

1122
def parse_args():
23+
import argparse
1224
parser = argparse.ArgumentParser(description="Run DenseNet 121 model.")
1325
parser.add_argument("-m", "--model_path",
1426
type=str,
@@ -41,6 +53,10 @@ def parse_args():
4153

4254

4355
def run_pytorch_fp(model_name, batch_size, num_runs, timeout, images_path, labels_path, disable_jit_freeze=False):
56+
import torch
57+
import torchvision
58+
from utils.benchmark import run_model
59+
from utils.cv.imagenet import ImageNet
4460
from utils.pytorch import PyTorchRunner
4561

4662
def run_single_pass(pytorch_runner, imagenet):
@@ -63,6 +79,8 @@ def run_single_pass(pytorch_runner, imagenet):
6379

6480

6581
def run_ort_fp(model_path, batch_size, num_runs, timeout, images_path, labels_path):
82+
from utils.benchmark import run_model
83+
from utils.cv.imagenet import ImageNet
6684
from utils.ort import OrtRunner
6785

6886
def run_single_pass(ort_runner, imagenet):
@@ -94,6 +112,7 @@ def run_ort_fp32(model_path, batch_size, num_runs, timeout, images_path, labels_
94112

95113

96114
def main():
115+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
97116
args = parse_args()
98117
download_ampere_imagenet()
99118

computer_vision/classification/densenet_169/run.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
from utils.benchmark import run_model
5-
from utils.cv.imagenet import ImageNet
6-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
3+
try:
4+
from utils import misc # noqa
5+
except ModuleNotFoundError:
6+
import os
7+
import sys
8+
filename = "set_env_variables.sh"
9+
directory = os.path.realpath(__file__).split("/")[:-1]
10+
for idx in range(1, len(directory) - 1):
11+
subdir = "/".join(directory[:-idx])
12+
if filename in os.listdir(subdir):
13+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
14+
break
15+
else:
16+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
17+
f"\033[0m")
18+
sys.exit(1)
719

820

921
def parse_args():
22+
import argparse
1023
parser = argparse.ArgumentParser(description="Run ResNet-50 v1.5 model.")
1124
parser.add_argument("-m", "--model_path",
1225
type=str,
@@ -37,6 +50,8 @@ def parse_args():
3750

3851

3952
def run_tf_fp(model_path, batch_size, num_runs, timeout, images_path, labels_path):
53+
from utils.benchmark import run_model
54+
from utils.cv.imagenet import ImageNet
4055
from utils.tf import TFFrozenModelRunner
4156

4257
def run_single_pass(tf_runner, imagenet):
@@ -58,6 +73,8 @@ def run_single_pass(tf_runner, imagenet):
5873

5974

6075
def run_tflite(model_path, batch_size, num_runs, timeout, images_path, labels_path):
76+
from utils.benchmark import run_model
77+
from utils.cv.imagenet import ImageNet
6178
from utils.tflite import TFLiteRunner
6279

6380
def run_single_pass(tflite_runner, imagenet):
@@ -92,6 +109,7 @@ def run_tflite_int8(model_path, batch_size, num_runs, timeout, images_path, labe
92109

93110

94111
def main():
112+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
95113
args = parse_args()
96114
download_ampere_imagenet()
97115

computer_vision/classification/googlenet/run.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
import os
5-
import warnings
6-
import torch
7-
import torchvision
8-
from utils.benchmark import run_model
9-
from utils.cv.imagenet import ImageNet
10-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
11-
warnings.filterwarnings("ignore")
3+
try:
4+
from utils import misc # noqa
5+
except ModuleNotFoundError:
6+
import os
7+
import sys
8+
filename = "set_env_variables.sh"
9+
directory = os.path.realpath(__file__).split("/")[:-1]
10+
for idx in range(1, len(directory) - 1):
11+
subdir = "/".join(directory[:-idx])
12+
if filename in os.listdir(subdir):
13+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
14+
break
15+
else:
16+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
17+
f"\033[0m")
18+
sys.exit(1)
1219

1320

1421
def parse_args():
22+
import argparse
1523
parser = argparse.ArgumentParser(description="Run GoogLeNet model.")
1624
parser.add_argument("-p", "--precision",
1725
type=str, choices=["fp32"], required=True,
@@ -41,6 +49,11 @@ def parse_args():
4149

4250

4351
def run_pytorch_fp(model_name, batch_size, num_runs, timeout, images_path, labels_path, disable_jit_freeze=False):
52+
import os
53+
import torch
54+
import torchvision
55+
from utils.benchmark import run_model
56+
from utils.cv.imagenet import ImageNet
4457
from utils.pytorch import PyTorchRunner
4558

4659
def run_single_pass(pytorch_runner, imagenet):
@@ -69,6 +82,7 @@ def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, images_path, lab
6982

7083

7184
def main():
85+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
7286
args = parse_args()
7387
download_ampere_imagenet()
7488

computer_vision/classification/inception_resnet_v2/run.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
from utils.benchmark import run_model
5-
from utils.cv.imagenet import ImageNet
6-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
3+
try:
4+
from utils import misc # noqa
5+
except ModuleNotFoundError:
6+
import os
7+
import sys
8+
filename = "set_env_variables.sh"
9+
directory = os.path.realpath(__file__).split("/")[:-1]
10+
for idx in range(1, len(directory) - 1):
11+
subdir = "/".join(directory[:-idx])
12+
if filename in os.listdir(subdir):
13+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
14+
break
15+
else:
16+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
17+
f"\033[0m")
18+
sys.exit(1)
719

820

921
def parse_args():
22+
import argparse
1023
parser = argparse.ArgumentParser(description="Run Inception ResNet v2 model.")
1124
parser.add_argument("-m", "--model_path",
1225
type=str,
@@ -37,6 +50,8 @@ def parse_args():
3750

3851

3952
def run_tf_fp(model_path, batch_size, num_runs, timeout, images_path, labels_path):
53+
from utils.benchmark import run_model
54+
from utils.cv.imagenet import ImageNet
4055
from utils.tf import TFFrozenModelRunner
4156

4257
def run_single_pass(tf_runner, imagenet):
@@ -58,6 +73,8 @@ def run_single_pass(tf_runner, imagenet):
5873

5974

6075
def run_tflite(model_path, batch_size, num_runs, timeout, images_path, labels_path):
76+
from utils.benchmark import run_model
77+
from utils.cv.imagenet import ImageNet
6178
from utils.tflite import TFLiteRunner
6279

6380
def run_single_pass(tflite_runner, imagenet):
@@ -92,6 +109,7 @@ def run_tflite_int8(model_path, batch_size, num_runs, timeout, images_path, labe
92109

93110

94111
def main():
112+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
95113
args = parse_args()
96114
download_ampere_imagenet()
97115

computer_vision/classification/inception_v2/run.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright (c) 2024, Ampere Computing LLC
3-
import argparse
4-
from utils.benchmark import run_model
5-
from utils.cv.imagenet import ImageNet
6-
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
3+
try:
4+
from utils import misc # noqa
5+
except ModuleNotFoundError:
6+
import os
7+
import sys
8+
filename = "set_env_variables.sh"
9+
directory = os.path.realpath(__file__).split("/")[:-1]
10+
for idx in range(1, len(directory) - 1):
11+
subdir = "/".join(directory[:-idx])
12+
if filename in os.listdir(subdir):
13+
print(f"\nPlease run \033[91m'source {os.path.join(subdir, filename)}'\033[0m first.")
14+
break
15+
else:
16+
print(f"\n\033[91mFAIL: Couldn't find {filename}, are you running this script as part of Ampere Model Library?"
17+
f"\033[0m")
18+
sys.exit(1)
719

820

921
def parse_args():
22+
import argparse
1023
parser = argparse.ArgumentParser(description="Run Inception v2 model.")
1124
parser.add_argument("-m", "--model_path",
1225
type=str,
@@ -37,6 +50,8 @@ def parse_args():
3750

3851

3952
def run_tf_fp(model_path, batch_size, num_runs, timeout, images_path, labels_path):
53+
from utils.benchmark import run_model
54+
from utils.cv.imagenet import ImageNet
4055
from utils.tf import TFFrozenModelRunner
4156

4257
def run_single_pass(tf_runner, imagenet):
@@ -58,6 +73,8 @@ def run_single_pass(tf_runner, imagenet):
5873

5974

6075
def run_tflite(model_path, batch_size, num_runs, timeout, images_path, labels_path):
76+
from utils.benchmark import run_model
77+
from utils.cv.imagenet import ImageNet
6178
from utils.tflite import TFLiteRunner
6279

6380
def run_single_pass(tflite_runner, imagenet):
@@ -80,6 +97,8 @@ def run_single_pass(tflite_runner, imagenet):
8097

8198

8299
def run_ort_fp(model_path, batch_size, num_runs, timeout, images_path, labels_path):
100+
from utils.benchmark import run_model
101+
from utils.cv.imagenet import ImageNet
83102
from utils.ort import OrtRunner
84103

85104
def run_single_pass(ort_runner, imagenet):
@@ -117,6 +136,7 @@ def run_ort_fp16(model_path, batch_size, num_runs, timeout, images_path, labels_
117136

118137

119138
def main():
139+
from utils.misc import print_goodbye_message_and_die, download_ampere_imagenet
120140
args = parse_args()
121141
download_ampere_imagenet()
122142

0 commit comments

Comments
 (0)