Skip to content

Commit 70c4120

Browse files
authored
Improved log messages for autotuning (#817)
1 parent 5354356 commit 70c4120

File tree

9 files changed

+123
-28
lines changed

9 files changed

+123
-28
lines changed

docs/conf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,11 @@ def connect(self, event: str, callback: Callable[..., None]) -> None:
129129
}
130130

131131
theme_variables = pytorch_sphinx_theme2.get_theme_variables()
132-
templates_path = [
133-
"_templates",
134-
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
135-
]
132+
templates_path = ["_templates"]
133+
if pytorch_sphinx_theme2.__file__ is not None:
134+
templates_path.append(
135+
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates")
136+
)
136137

137138
html_context = {
138139
"theme_variables": theme_variables,

helion/autotuner/base_cache.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,28 @@ def get(self) -> Config | None:
153153
def put(self, config: Config) -> None:
154154
raise NotImplementedError
155155

156+
def _get_cache_info_message(self) -> str:
157+
"""Return a message describing where the cache is and how to clear it."""
158+
return ""
159+
156160
def autotune(self) -> Config:
157161
if os.environ.get("HELION_SKIP_CACHE", "") not in {"", "0", "false", "False"}:
158162
return self.autotuner.autotune()
159163

160164
if (config := self.get()) is not None:
161165
counters["autotune"]["cache_hit"] += 1
162166
log.debug("cache hit: %s", str(config))
167+
cache_info = self._get_cache_info_message()
168+
self.autotuner.log(
169+
f"Found cached config for {self.kernel.kernel.name}, skipping autotuning.\n{cache_info}"
170+
)
163171
return config
164172

165173
counters["autotune"]["cache_miss"] += 1
166174
log.debug("cache miss")
167175

176+
self.autotuner.log("Starting autotuning process, this may take a while...")
177+
168178
config = self.autotuner.autotune()
169179

170180
self.put(config)

helion/autotuner/base_search.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import torch.multiprocessing as mp
3030
from torch.utils._pytree import tree_flatten
3131
from torch.utils._pytree import tree_map
32-
from tqdm.rich import tqdm
3332
from triton.testing import do_bench
3433

3534
from .. import exc
@@ -40,6 +39,7 @@
4039
from .logger import LambdaLogger
4140
from .logger import classify_triton_exception
4241
from .logger import format_triton_compile_failure
42+
from .progress_bar import iter_with_progress
4343

4444
log = logging.getLogger(__name__)
4545

@@ -321,15 +321,14 @@ def parallel_benchmark(
321321
else:
322322
is_workings = [True] * len(configs)
323323
results = []
324-
iterator = zip(configs, fns, is_workings, strict=True)
325-
if self.settings.autotune_progress_bar:
326-
iterator = tqdm(
327-
iterator,
328-
total=len(configs),
329-
desc=desc,
330-
unit="config",
331-
disable=not self.settings.autotune_progress_bar,
332-
)
324+
325+
# Render a progress bar only when the user requested it.
326+
iterator = iter_with_progress(
327+
zip(configs, fns, is_workings, strict=True),
328+
total=len(configs),
329+
description=desc,
330+
enabled=self.settings.autotune_progress_bar,
331+
)
333332
for config, fn, is_working in iterator:
334333
if is_working:
335334
# benchmark one-by-one to avoid noisy results

helion/autotuner/benchmarking.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import statistics
55
from typing import Callable
66

7-
from tqdm.rich import tqdm
87
from triton import runtime
98

9+
from .progress_bar import iter_with_progress
10+
1011

1112
def interleaved_bench(
1213
fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None
@@ -38,9 +39,15 @@ def interleaved_bench(
3839
]
3940

4041
di.synchronize()
41-
iterator = range(repeat)
42-
if desc is not None:
43-
iterator = tqdm(iterator, desc=desc, total=repeat, unit="round")
42+
43+
# When a description is supplied we show a progress bar so the user can
44+
# track the repeated benchmarking loop.
45+
iterator = iter_with_progress(
46+
range(repeat),
47+
total=repeat,
48+
description=desc,
49+
enabled=desc is not None,
50+
)
4451
for i in iterator:
4552
for j in range(len(fns)):
4653
clear_cache()

helion/autotuner/differential_evolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def _autotune(self) -> Config:
9595
)
9696
self.initial_two_generations()
9797
for i in range(2, self.max_generations):
98+
self.log(f"Generation {i} starting")
9899
replaced = self.evolve_population()
99-
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
100+
self.log(f"Generation {i} complete: replaced={replaced}", self.statistics)
100101
self.rebenchmark_population()
101102
return self.best.config

helion/autotuner/local_cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def put(self, config: Config) -> None:
9494
path = self._get_local_cache_path()
9595
config.save(path)
9696

97+
def _get_cache_info_message(self) -> str:
98+
cache_dir = self._get_local_cache_path().parent
99+
return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1."
100+
97101

98102
class StrictLocalAutotuneCache(LocalAutotuneCache):
99103
"""

helion/autotuner/pattern_search.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646

4747
def _autotune(self) -> Config:
4848
self.log(
49-
f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}"
49+
f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}"
5050
)
5151
visited = set()
5252
self.population = []
@@ -59,7 +59,7 @@ def _autotune(self) -> Config:
5959
self.population.append(member)
6060
self.parallel_benchmark_population(self.population, desc="Initial population")
6161
# again with higher accuracy
62-
self.rebenchmark_population(self.population, desc="Initial rebench")
62+
self.rebenchmark_population(self.population, desc="Verifying initial results")
6363
self.population.sort(key=performance)
6464
starting_points = []
6565
for member in self.population[: self.copies]:
@@ -88,21 +88,25 @@ def _autotune(self) -> Config:
8888
new_population[id(member)] = member
8989
if num_active == 0:
9090
break
91+
92+
# Log generation header before compiling/benchmarking
93+
self.log(
94+
f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)"
95+
)
96+
9197
self.population = [*new_population.values()]
9298
# compile any unbenchmarked members in parallel
9399
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
94100
if unbenchmarked:
95101
self.parallel_benchmark_population(
96-
unbenchmarked, desc=f"Gen {generation} neighbors"
102+
unbenchmarked, desc=f"Generation {generation}: Exploring neighbors"
97103
)
98104
# higher-accuracy rebenchmark
99105
self.rebenchmark_population(
100-
self.population, desc=f"Gen {generation} rebench"
101-
)
102-
self.log(
103-
f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:",
104-
self.statistics,
106+
self.population, desc=f"Generation {generation}: Verifying top configs"
105107
)
108+
# Log final statistics for this generation
109+
self.log(f"Generation {generation} complete:", self.statistics)
106110
return self.best.config
107111

108112
def _pattern_search_from(

helion/autotuner/progress_bar.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Progress-bar utilities used by the autotuner.
2+
3+
We rely on `rich` to render colored, full-width progress bars that
4+
show the description, percentage complete, and how many items have been
5+
processed.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import TYPE_CHECKING
11+
from typing import TypeVar
12+
13+
from rich.progress import BarColumn
14+
from rich.progress import MofNCompleteColumn
15+
from rich.progress import Progress
16+
from rich.progress import ProgressColumn
17+
from rich.progress import TextColumn
18+
from rich.text import Text
19+
20+
if TYPE_CHECKING:
21+
from collections.abc import Iterable
22+
from collections.abc import Iterator
23+
24+
from rich.progress import Task
25+
26+
T = TypeVar("T")
27+
28+
29+
class SpeedColumn(ProgressColumn):
30+
"""Render the processing speed in configs per second."""
31+
32+
def render(self, task: Task) -> Text:
33+
return Text(
34+
f"{task.speed:.1f} configs/s" if task.speed is not None else "- configs/s",
35+
style="magenta",
36+
)
37+
38+
39+
def iter_with_progress(
40+
iterable: Iterable[T], *, total: int, description: str | None = None, enabled: bool
41+
) -> Iterator[T]:
42+
"""Yield items from *iterable*, optionally showing a progress bar.
43+
44+
Parameters
45+
----------
46+
iterable:
47+
Any iterable whose items should be yielded.
48+
total:
49+
Total number of items expected from the iterable.
50+
description:
51+
Text displayed on the left side of the bar. Defaults to ``"Progress"``.
52+
enabled:
53+
When ``False`` the iterable is returned unchanged so there is zero
54+
overhead; when ``True`` a Rich progress bar is rendered.
55+
"""
56+
if not enabled:
57+
yield from iterable
58+
return
59+
60+
if description is None:
61+
description = "Progress"
62+
63+
with Progress(
64+
TextColumn("[progress.description]{task.description}"),
65+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
66+
BarColumn(bar_width=None, complete_style="yellow", finished_style="green"),
67+
MofNCompleteColumn(),
68+
SpeedColumn(),
69+
) as progress:
70+
yield from progress.track(iterable, total=total, description=description)

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@ pre-commit
44
filecheck
55
expecttest
66
numpy
7-
tqdm
87
rich

0 commit comments

Comments
 (0)