Skip to content

Commit a83bdbc

Browse files
committed
Update interface
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
1 parent cb6ed1b commit a83bdbc

File tree

7 files changed

+12
-6
lines changed

7 files changed

+12
-6
lines changed

examples/layer_wise_benchmarks/run_single.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
99
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
1010
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
11-
from tensorrt_llm.tools.layer_wise_benchmarks.runner_base import BalanceMethod
12-
from tensorrt_llm.tools.layer_wise_benchmarks.runner_factory import get_runner_cls
11+
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls
1312

1413

1514
def comma_separated_ints(s):
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .runner_factory import get_runner_cls
2+
from .runner_interface import BalanceMethod
3+
4+
__all__ = [
5+
"BalanceMethod",
6+
"get_runner_cls",
7+
]

tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tensorrt_llm.functional import AllReduceStrategy
1313
from tensorrt_llm.mapping import Mapping
1414

15-
from .runner_base import BalanceMethod, RunnerBase
15+
from .runner_interface import BalanceMethod, RunnerBase
1616
from .runner_utils import RunnerMixin, ceil_div
1717

1818

tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tensorrt_llm.functional import AllReduceStrategy
99
from tensorrt_llm.mapping import Mapping
1010

11-
from .runner_base import RunnerBase
11+
from .runner_interface import RunnerBase
1212
from .runner_utils import RunnerMixin
1313

1414

tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .qwen3_next_runner import Qwen3NextRunner
55

66

7-
def get_runner_cls(pretrained_model_name_or_path: str):
7+
def get_runner_cls(pretrained_model_name_or_path: str) -> type:
88
pretrained_config = load_pretrained_config(pretrained_model_name_or_path)
99
return {
1010
"deepseek_v3": DeepSeekV3Runner,

tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tensorrt_llm.mapping import Mapping
2323
from tensorrt_llm.models.modeling_utils import QuantConfig
2424

25-
from .runner_base import BalanceMethod
25+
from .runner_interface import BalanceMethod
2626

2727

2828
def ceil_div(a, b):

0 commit comments

Comments
 (0)