Skip to content

Commit 97fc8ea

Browse files
ezyangpytorchmergebot
authored andcommitted
Run the benchmark suite with dynamic batch only (pytorch#97912)
Symbolic shapes compile time on full CI with inductor is horribly long (even though our aot_eager local runs seemed to suggest that the added latency was only 10s per model.) To patch over the problem for now, run the benchmark suite with dynamic batch only. This should absolve a lot of sins. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#97912 Approved by: https://github.com/janeyx99, https://github.com/desertfire
1 parent 4cce607 commit 97fc8ea

File tree

6 files changed

+42
-6
lines changed

6 files changed

+42
-6
lines changed

.ci/pytorch/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ test_perf_for_dashboard() {
302302
--accuracy --"$dtype" --backend "$backend" "$@" \
303303
--output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_${suite}_${dtype}_training_cuda_accuracy.csv"
304304
python "benchmarks/dynamo/$suite.py" \
305-
--accuracy --"$dtype" --backend "$backend" --dynamic-shapes --disable-cudagraphs "$@" \
305+
--accuracy --"$dtype" --backend "$backend" --dynamic-shapes --dynamic-batch-only --disable-cudagraphs "$@" \
306306
--output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_training_cuda_accuracy.csv"
307307

308308
# Run performance test
@@ -316,7 +316,7 @@ test_perf_for_dashboard() {
316316
--performance --cold-start-latency --"$dtype" --backend "$backend" "$@" \
317317
--output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_${suite}_${dtype}_training_cuda_performance.csv"
318318
python "benchmarks/dynamo/$suite.py" \
319-
--performance --cold-start-latency --"$dtype" --backend "$backend" --dynamic-shapes --disable-cudagraphs "$@" \
319+
--performance --cold-start-latency --"$dtype" --backend "$backend" --dynamic-shapes --dynamic-batch-only --disable-cudagraphs "$@" \
320320
--output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_training_cuda_performance.csv"
321321
done
322322
}

benchmarks/dynamo/common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,11 @@ def get_example_inputs(self):
16931693
action="store_true",
16941694
help="Runs a dynamic shapes version of the benchmark, if available.",
16951695
)
1696+
parser.add_argument(
1697+
"--dynamic-batch-only",
1698+
action="store_true",
1699+
help="Only assume batch dimension is dynamic. Implies --dynamic-shapes",
1700+
)
16961701
parser.add_argument(
16971702
"--specialize-int", action="store_true", help="Run with specialize_int=True."
16981703
)
@@ -1956,6 +1961,10 @@ def run(runner, args, original_dir=None):
19561961
if args.dynamic_ci_skips_only:
19571962
args.dynamic_shapes = True
19581963
args.ci = True
1964+
if args.dynamic_batch_only:
1965+
args.dynamic_shapes = True
1966+
torch._dynamo.config.assume_static_by_default = True
1967+
torch._dynamo.config.allow_ignore_mark_dynamic = True
19591968
if args.dynamic_shapes:
19601969
torch._dynamo.config.dynamic_shapes = True
19611970
if args.specialize_int:
@@ -2329,6 +2338,21 @@ def run(runner, args, original_dir=None):
23292338
elif args.bfloat16:
23302339
model, example_inputs = cast_to_bf16(model, example_inputs)
23312340

2341+
# Look for stuff that looks like batch size, and mark it dynamic.
2342+
# Better integration would integrate directly with benchmark suite
2343+
# but cannot conveniently do this
2344+
# NB: This must be done late enough so that we don't do more
2345+
# conversions on the inputs
2346+
# NB: Assumes only the first batch-y like dimension is the batch
2347+
def detect_and_mark_batch(t):
2348+
for i, s in enumerate(t.size()):
2349+
if s == batch_size:
2350+
torch._dynamo.mark_dynamic(t, i)
2351+
break
2352+
2353+
if args.dynamic_batch_only:
2354+
tree_map(detect_and_mark_batch, example_inputs)
2355+
23322356
if args.log_operator_inputs:
23332357
log_operator_inputs(
23342358
model, example_inputs, runner.model_iter_fn, name, args

torch/_dynamo/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@
6868
# see [Note - on the state of mark_dynamic]
6969
assume_static_by_default = False
7070

71+
# Typically, if you mark_dynamic a dimension, we will error if the dimension
72+
# actually ended up getting specialized. This knob changes the behavior so
73+
# that we don't error at all. This is helpful for our CI where I'm using a
74+
# heuristic to mark batch dimensions as dynamic and the heuristic may get it
75+
# wrong.
76+
allow_ignore_mark_dynamic = False
77+
7178
# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
7279
guard_nn_modules = False
7380

torch/_dynamo/guards.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,10 @@ def TENSOR_MATCH(self, guard: Guard):
536536
f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
537537
)
538538
else:
539-
assert not hasattr(
540-
value, "_dynamo_dynamic_indices"
541-
), f"Illegal Unreachable state, guard accumulation for dynamic tensor that should have been static. Initial static message: {tensor_static_reason_to_message(reason)}" # noqa: B950
539+
if not config.allow_ignore_mark_dynamic:
540+
assert not hasattr(
541+
value, "_dynamo_dynamic_indices"
542+
), f"Illegal Unreachable state, guard accumulation for dynamic tensor that should have been static. Initial static message: {tensor_static_reason_to_message(reason)}" # noqa: B950
542543

543544
if len(code) > 0:
544545
self._produce_guard_code(guard, code)

torch/_dynamo/variables/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,10 @@ def wrap_to_fake_tensor_and_record(
11651165
# Precedence: export constraints > eager constraints
11661166
constraint = dim2constraint.get(i)
11671167
if constraint is None:
1168-
if i in getattr(e, "_dynamo_dynamic_indices", set()):
1168+
if (
1169+
i in getattr(e, "_dynamo_dynamic_indices", set())
1170+
and not config.allow_ignore_mark_dynamic
1171+
):
11691172
constraint = RelaxedUnspecConstraint()
11701173
constraint_dims.append(constraint)
11711174

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,7 @@ def create_symbol(
15791579
# Even if we're duck shaping, if we haven't seen this particular
15801580
# value before, we also create a new symbol
15811581
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
1582+
log.info("create_symbol %s = %s", sympy_expr, val)
15821583
# We always associate vars to vals
15831584
self.var_to_val[sympy_expr] = sympy.Integer(val)
15841585
# Do the appending later, because we always want to populate this

0 commit comments

Comments
 (0)