Skip to content

Commit d9cc213

Browse files
xmfanfacebook-github-bot
authored andcommitted
Stop benchmarking compile time of dead code (#145590)
Summary: FIXES pytorch/pytorch#144775 frfr See details on the problem: pytorch/pytorch#144775 (comment) We fixed some silent incorrectness, but it results in less nodes DCE'd. The benchmark iteration loop had some dead code which could contain side effect ops that aren't safe to DCE. The regression is expected. This PR removes the compile time benchmarking of the dead code, which should reduce the noise of the benchmark and aligns with the benchmarking used by performance tests New benchmark results: ```python dev,name,batch_size,accuracy,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks,autograd_captures,autograd_compiles,cudagraph_skips,compilation_latency cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,39.322364 # after pytorch/pytorch#144319 cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,38.972257 # before pytorch/pytorch#144319 ``` X-link: pytorch/pytorch#145590 Approved by: https://github.com/jansel ghstack dependencies: #145447 Reviewed By: ZainRizvi Differential Revision: D68860252 fbshipit-source-id: 60371bdf3ba6e6f38766d6589690a221f8cebda4
1 parent 0e370a0 commit d9cc213

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2781,11 +2781,11 @@ def batch_size_finder(self, device, model_name, initial_batch_size=1024):
27812781
batch_size = self.decay_batch_exp(batch_size)
27822782
return 1
27832783

2784-
def run_n_iterations(self, mod, inputs):
2784+
def run_n_iterations(self, mod, inputs, model_iter_fn):
27852785
n = self.args.iterations
27862786
for _ in range(n - 1):
2787-
self.model_iter_fn(mod, inputs, collect_outputs=False)
2788-
return self.model_iter_fn(mod, inputs, collect_outputs=True)
2787+
model_iter_fn(mod, inputs, collect_outputs=False)
2788+
return model_iter_fn(mod, inputs, collect_outputs=True)
27892789

27902790
@torch._disable_dynamo(recursive=True)
27912791
def optimizer_zero_grad(self, mod):
@@ -2953,7 +2953,9 @@ def record_status(accuracy_status, dynamo_start_stats):
29532953
clone_inputs(example_inputs),
29542954
)
29552955
self.init_optimizer(name, current_device, model_fp64.parameters())
2956-
fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
2956+
fp64_outputs = self.run_n_iterations(
2957+
model_fp64, inputs_fp64, self.model_iter_fn
2958+
)
29572959
fp64_outputs = tree_map(
29582960
lambda x: x.to(torch.float64)
29592961
if isinstance(x, torch.Tensor) and x.is_floating_point()
@@ -2986,7 +2988,7 @@ def record_status(accuracy_status, dynamo_start_stats):
29862988
model_copy = self.deepcopy_and_maybe_parallelize(model)
29872989
self.init_optimizer(name, current_device, model_copy.parameters())
29882990
correct_result = self.run_n_iterations(
2989-
model_copy, clone_inputs(example_inputs)
2991+
model_copy, clone_inputs(example_inputs), self.model_iter_fn
29902992
)
29912993
except Exception as e:
29922994
accuracy_status = (
@@ -3007,7 +3009,7 @@ def record_status(accuracy_status, dynamo_start_stats):
30073009
model_copy = self.deepcopy_and_maybe_parallelize(model)
30083010
self.init_optimizer(name, current_device, model_copy.parameters())
30093011
correct_rerun_result = self.run_n_iterations(
3010-
model_copy, clone_inputs(example_inputs)
3012+
model_copy, clone_inputs(example_inputs), self.model_iter_fn
30113013
)
30123014
except Exception as e:
30133015
accuracy_status = (
@@ -3066,13 +3068,15 @@ def record_status(accuracy_status, dynamo_start_stats):
30663068
)
30673069
new_result = optimized_model_iter_fn(model_copy, example_inputs)
30683070
else:
3069-
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
3071+
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
30703072
with maybe_enable_compiled_autograd(
30713073
self.args.compiled_autograd,
30723074
fullgraph=self.args.nopython,
30733075
dynamic=self.args.dynamic_shapes,
30743076
):
3075-
new_result = optimized_model_iter_fn(model_copy, example_inputs)
3077+
new_result = self.run_n_iterations(
3078+
model_copy, example_inputs, optimized_model_iter_fn
3079+
)
30763080
except Exception as e:
30773081
log.exception("")
30783082
print(
@@ -3167,7 +3171,9 @@ def check_tolerance(
31673171
lambda x: x.to(base_device), example_inputs_copy
31683172
)
31693173
self.init_optimizer(name, base_device, model_copy.parameters())
3170-
correct_result = self.run_n_iterations(model_copy, example_inputs_copy)
3174+
correct_result = self.run_n_iterations(
3175+
model_copy, example_inputs_copy, self.model_iter_fn
3176+
)
31713177

31723178
# Run with Dynamo
31733179
# Sometime CI fails with random triton compilation failure which will be skipped for now
@@ -3176,8 +3182,10 @@ def check_tolerance(
31763182
torch._dynamo.reset()
31773183
try:
31783184
self.init_optimizer(name, current_device, model.parameters())
3179-
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
3180-
new_result = optimized_model_iter_fn(model, example_inputs)
3185+
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
3186+
new_result = self.run_n_iterations(
3187+
model_copy, example_inputs, optimized_model_iter_fn
3188+
)
31813189
except Exception:
31823190
log.exception("")
31833191
print(
@@ -4460,6 +4468,16 @@ def run(runner, args, original_dir=None):
44604468
# Stricter check to disable fallbacks
44614469
args.suppress_errors = False
44624470

4471+
if not args.disable_cudagraphs:
4472+
runner.skip_models.update(
4473+
{
4474+
# xfail: https://github.com/pytorch/pytorch/issues/145773
4475+
"convit_base",
4476+
"llama",
4477+
"cm3leon_generate",
4478+
}
4479+
)
4480+
44634481
if args.device_index is not None:
44644482
if args.multiprocess:
44654483
print("Cannot specify both --device_index and --multiprocess")

userbenchmark/dynamo/dynamobench/timm_models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111

1212
try:
13-
from .common import BenchmarkRunner, download_retry_decorator, main
13+
from .common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
1414
except ImportError:
15-
from common import BenchmarkRunner, download_retry_decorator, main
15+
from common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
1616

1717
import torch
1818
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
@@ -218,6 +218,18 @@ def __init__(self):
218218
super().__init__()
219219
self.suite_name = "timm_models"
220220

221+
@property
222+
def _config(self):
223+
return load_yaml_file("timm_models.yaml")
224+
225+
@property
226+
def _skip(self):
227+
return self._config["skip"]
228+
229+
@property
230+
def skip_models(self):
231+
return self._skip["all"]
232+
221233
@property
222234
def force_amp_for_fp16_bf16_models(self):
223235
return FORCE_AMP_FOR_FP16_BF16_MODELS

0 commit comments

Comments
 (0)