Skip to content

Commit 09de70c

Browse files
FindHaofacebook-github-bot
authored andcommitted
add API coverage test (#1394)
Summary: This PR adds the userbenchmark for PyTorch API coverage test Command to run for a specific model: ``` python3 run_benchmark.py api-coverage --model resnet18 --device cuda --test train,eval ``` If you want to test all models, the `--model` should not be specified. The testing results are stored in `.userbenchmark/api-coverage/logs/`. Each run will generate two log files, logs-${timestamp}.json and logs-${timestamp}.json-api_coverage.csv. The latter one is like the following format. ``` API coverage rate: 25/1332 = 1.88% missed APIs: module_name,func_name ,_TensorBase.acos ... ``` All missed APIs will be listed in the csv file. If a logs-${timestamp}.json-api_need_support.csv is also generated, please open an issue to report the miss supported APIs by this userbenchmark. Pull Request resolved: #1394 Reviewed By: xuzhao9 Differential Revision: D43737578 Pulled By: FindHao fbshipit-source-id: c7570f395572e26dfce52529195e8575e03185ce
1 parent 745644f commit 09de70c

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

userbenchmark/api-coverage/run.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import itertools
2+
import time
3+
from datetime import datetime
4+
from typing import List
5+
import json
6+
import numpy as np
7+
import argparse
8+
import re
9+
import torch
10+
11+
from ..utils import REPO_PATH, add_path, get_output_dir, get_output_json, dump_output
12+
13+
with add_path(REPO_PATH):
14+
from torchbenchmark.util.experiment.instantiator import list_models, load_model, TorchBenchModelConfig
15+
from torchbenchmark.util.experiment.metrics import TorchBenchModelMetrics, get_model_test_metrics
16+
import torchbenchmark.util.experiment.metrics
17+
18+
BM_NAME = "api-coverage"
19+
20+
21+
def parse_func(func):
22+
if hasattr(func, '__module__'):
23+
module_name = func.__module__
24+
func_name = func.__name__
25+
else:
26+
if hasattr(func, '__qualname__'):
27+
func_name = func.__qualname__
28+
module_name = ''
29+
else:
30+
if type(func) == torch._C.Generator:
31+
func_name = 'torch._C.Generator'
32+
module_name = ''
33+
else:
34+
raise RuntimeError("no matched module and func name: ", func, type(func))
35+
return module_name, func_name
36+
37+
38+
def generate_API_list():
39+
tmp_api_list = set()
40+
raw_all_apis = set(torch.overrides.get_testing_overrides().keys())
41+
# collect all items' attribute `module` to a list
42+
for item in raw_all_apis:
43+
module_name, func_name = parse_func(item)
44+
# if (module_name, func_name) in api_list:
45+
# print("duplicated: ", (module_name, func_name))
46+
tmp_api_list.add((module_name, func_name))
47+
ignored_funcs = set([_ for _ in torch.overrides.get_ignored_functions() if _ not in [True, False]])
48+
tmp_ignored_api_list = set()
49+
for item in ignored_funcs:
50+
module_name, func_name = parse_func(item)
51+
tmp_ignored_api_list.add((module_name, func_name))
52+
return tmp_api_list, tmp_ignored_api_list
53+
54+
API_LIST, IGNORED_API_LIST = generate_API_list()
55+
56+
57+
class CoverageMode(torch.overrides.TorchFunctionMode):
58+
59+
def __init__(self, model='', output_file=None):
60+
self.model = model
61+
self.seen = set()
62+
self.api_used = set()
63+
self.output_file = output_file
64+
self.api_need_support = set()
65+
66+
def check_func_in_APIs(self, func):
67+
module_name, func_name = parse_func(func)
68+
if (module_name, func_name) not in API_LIST:
69+
if (module_name, func_name) not in IGNORED_API_LIST and module_name != 'torch._ops.profiler':
70+
new_pair = (module_name, func_name)
71+
if new_pair not in self.api_need_support:
72+
# debugging purpose
73+
# print("not in API_LIST or IGNORED_API_LIST: (%s, %s)" % (module_name, func_name))
74+
self.api_need_support.add((module_name, func_name))
75+
else:
76+
self.api_used.add((module_name, func_name))
77+
# debug
78+
# print("in APIs: ", (module_name, func_name))
79+
80+
def get_api_coverage_rate(self):
81+
return len(self.api_used) / len(API_LIST)
82+
83+
def __torch_function__(self, func, types, args=(), kwargs=None):
84+
self.seen.add(func)
85+
if kwargs is None:
86+
kwargs = {}
87+
self.check_func_in_APIs(func)
88+
return func(*args, **kwargs)
89+
90+
def commit(self):
91+
if self.output_file:
92+
with open(self.output_file, 'a') as f:
93+
for api in self.api_used:
94+
f.write("%s,%s\n" % (api[0], api[1]))
95+
96+
def update_api_used(self, output: set):
97+
for api in self.api_used:
98+
output.add(api)
99+
100+
def update_need_support(self, output: set):
101+
for api in self.api_need_support:
102+
output.add(api)
103+
104+
105+
def generate_model_config(model_name: str) -> List[TorchBenchModelConfig]:
106+
devices = ["cpu", "cuda"]
107+
tests = ["train", "eval"]
108+
cfgs = itertools.product(*[devices, tests])
109+
result = [TorchBenchModelConfig(
110+
name=model_name,
111+
device=device,
112+
test=test,
113+
batch_size=None,
114+
jit=False,
115+
extra_args=[],
116+
extra_env=None,
117+
) for device, test in cfgs]
118+
return result
119+
120+
121+
def parse_args(args: List[str]):
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument("-m", "--models", default="",
124+
help="Specify the models to run, default (empty) runs all models.")
125+
parser.add_argument("-d", "--device", default="cuda", help="Specify the device.")
126+
parser.add_argument("-t", "--test", default="eval,train", help="Specify the test.")
127+
parser.add_argument("-o", "--output", type=str, help="The default output json file.")
128+
args = parser.parse_args(args)
129+
return args
130+
131+
132+
def generate_filter(args: argparse.Namespace):
133+
allowed_models = args.models
134+
if allowed_models:
135+
allowed_models = allowed_models.split(",") if "," in allowed_models else [allowed_models]
136+
allowed_devices = args.device
137+
allowed_devices = allowed_devices.split(",") if "," in allowed_devices else [allowed_devices]
138+
allowed_tests = args.test
139+
allowed_tests = allowed_tests.split(",") if "," in allowed_tests else [allowed_tests]
140+
141+
def cfg_filter(cfg: TorchBenchModelConfig) -> bool:
142+
if cfg.device in allowed_devices and cfg.test in allowed_tests:
143+
if not allowed_models:
144+
return True
145+
else:
146+
return cfg.name in allowed_models
147+
return False
148+
return cfg_filter
149+
150+
151+
def run(args: List[str]):
152+
args = parse_args(args)
153+
output_dir = get_output_dir(BM_NAME)
154+
models = list_models()
155+
cfgs = list(itertools.chain(*map(generate_model_config, models)))
156+
cfg_filter = generate_filter(args)
157+
torchbenchmark.util.experiment.metrics.BENCHMARK_ITERS = 1
158+
torchbenchmark.util.experiment.metrics.WARMUP_ROUNDS = 0
159+
single_round_result = []
160+
api_used = set()
161+
api_need_support = set()
162+
for cfg in filter(cfg_filter, cfgs):
163+
try:
164+
# load the model instance within the same process
165+
model = load_model(cfg)
166+
# get the model test metrics
167+
with CoverageMode('', '') as coverage:
168+
try:
169+
get_model_test_metrics(model, metrics=["latencies"])
170+
finally:
171+
coverage.update_api_used(api_used)
172+
coverage.update_need_support(api_need_support)
173+
except NotImplementedError:
174+
# some models don't implement the test specified
175+
single_round_result.append({
176+
'cfg': cfg.__dict__,
177+
'raw_metrics': "NotImplemented",
178+
})
179+
except RuntimeError as e:
180+
single_round_result.append({
181+
'cfg': cfg.__dict__,
182+
'raw_metrics': f"RuntimeError: {e}",
183+
})
184+
185+
# reduce full results to metrics
186+
# log detailed results in the .userbenchmark/model-stableness/logs/ directory
187+
log_dir = output_dir.joinpath("logs")
188+
log_dir.mkdir(exist_ok=True, parents=True)
189+
fname = "logs-{}.json".format(datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S"))
190+
full_fname = log_dir.joinpath(fname)
191+
with open(full_fname, 'w') as f:
192+
json.dump(single_round_result, f, indent=4)
193+
# log the api coverage
194+
api_coverage_fname = log_dir.joinpath("%s-api_coverage.csv" % fname)
195+
missed_apis = API_LIST - api_used
196+
with open(api_coverage_fname, 'w') as f:
197+
f.write("API coverage rate: %d/%d = %.2f%%\n" %
198+
(len(api_used), len(API_LIST), len(api_used) / len(API_LIST) * 100))
199+
f.write("=====Used APIs=====\n")
200+
f.write("module_name,func_name\n")
201+
for api in api_used:
202+
f.write("%s,%s\n" % (api[0], api[1]))
203+
f.write("=====Missed APIs=====\n")
204+
f.write("module_name,func_name\n")
205+
for api in missed_apis:
206+
f.write("%s,%s\n" % (api[0], api[1]))
207+
if api_need_support:
208+
api_need_support_fname = log_dir.joinpath("%s-api_need_support.csv" % fname)
209+
with open(api_need_support_fname, 'w') as f:
210+
f.write("APIs called but not in API_LIST and IGNORED_API_LIST\n")
211+
f.write("module_name,func_name\n")
212+
for api in api_need_support:
213+
f.write("%s,%s\n" % (api[0], api[1]))
214+
print("The detailed results are saved in %s" % api_coverage_fname)

0 commit comments

Comments
 (0)