@@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
14081408 def load (cls , model , example_inputs ):
14091409 import torch ._inductor
14101410 import torch .export ._trace
1411- from torch .export .dynamic_shapes import _tree_map_with_path
1411+ from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
14121412
14131413 key = weakref .ref (model )
14141414 if key not in cls .cache :
@@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs):
14281428 else :
14291429 _register_dataclass_output_as_pytree (example_outputs )
14301430
1431- combined_args = tuple ( example_args ) + tuple ( example_kwargs . values () )
1431+ combined_args = _combine_args ( model , example_args , example_kwargs )
14321432 dynamic_shapes = _tree_map_with_path (
14331433 _produce_dynamic_shapes_for_export , combined_args
14341434 )
@@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs):
14491449
14501450
14511451def export (model , example_inputs ):
1452- from torch .export .dynamic_shapes import _tree_map_with_path
1452+ from torch .export .dynamic_shapes import _combine_args , _tree_map_with_path
14531453
14541454 example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
14551455 example_outputs = model (* example_args , ** example_kwargs )
14561456 _register_dataclass_output_as_pytree (example_outputs )
14571457
1458- combined_args = tuple ( example_args ) + tuple ( example_kwargs . values () )
1458+ combined_args = _combine_args ( model , example_args , example_kwargs )
14591459 dynamic_shapes = _tree_map_with_path (
14601460 _produce_dynamic_shapes_for_export , combined_args
14611461 )
0 commit comments