Skip to content

Commit 6cf2950

Browse files
authored
Merge pull request #680 from onekey-sec/refactor-progress-reporting
processing: extract progress reporting from business logic
2 parents 861d76b + 9d940c8 commit 6cf2950

File tree

4 files changed

+115
-34
lines changed

4 files changed

+115
-34
lines changed

tests/test_cli.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import List, Optional
2+
from typing import List, Optional, Type
33
from unittest import mock
44

55
import pytest
@@ -11,6 +11,11 @@
1111
from unblob.handlers import BUILTIN_HANDLERS
1212
from unblob.models import DirectoryHandler, Glob, Handler, HexString, MultiFile
1313
from unblob.processing import DEFAULT_DEPTH, DEFAULT_PROCESS_NUM, ExtractionConfig
14+
from unblob.ui import (
15+
NullProgressReporter,
16+
ProgressReporter,
17+
RichConsoleProgressReporter,
18+
)
1419

1520

1621
class TestHandler(Handler):
@@ -174,18 +179,50 @@ def test_dir_for_file(tmp_path: Path):
174179

175180

176181
@pytest.mark.parametrize(
177-
"params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity",
182+
"params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity, expected_progress_reporter",
178183
[
179-
pytest.param([], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 0, id="empty"),
180184
pytest.param(
181-
["--verbose"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 1, id="verbose-1"
185+
[],
186+
DEFAULT_DEPTH,
187+
1,
188+
DEFAULT_PROCESS_NUM,
189+
0,
190+
RichConsoleProgressReporter,
191+
id="empty",
192+
),
193+
pytest.param(
194+
["--verbose"],
195+
DEFAULT_DEPTH,
196+
1,
197+
DEFAULT_PROCESS_NUM,
198+
1,
199+
NullProgressReporter,
200+
id="verbose-1",
201+
),
202+
pytest.param(
203+
["-vv"],
204+
DEFAULT_DEPTH,
205+
1,
206+
DEFAULT_PROCESS_NUM,
207+
2,
208+
NullProgressReporter,
209+
id="verbose-2",
210+
),
211+
pytest.param(
212+
["-vvv"],
213+
DEFAULT_DEPTH,
214+
1,
215+
DEFAULT_PROCESS_NUM,
216+
3,
217+
NullProgressReporter,
218+
id="verbose-3",
219+
),
220+
pytest.param(
221+
["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, mock.ANY, id="depth"
182222
),
183-
pytest.param(["-vv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 2, id="verbose-2"),
184223
pytest.param(
185-
["-vvv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 3, id="verbose-3"
224+
["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, mock.ANY, id="process-num"
186225
),
187-
pytest.param(["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, id="depth"),
188-
pytest.param(["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, id="process-num"),
189226
],
190227
)
191228
def test_archive_success(
@@ -194,6 +231,7 @@ def test_archive_success(
194231
expected_entropy_depth: int,
195232
expected_process_num: int,
196233
expected_verbosity: int,
234+
expected_progress_reporter: Type[ProgressReporter],
197235
tmp_path: Path,
198236
):
199237
runner = CliRunner()
@@ -225,6 +263,7 @@ def test_archive_success(
225263
process_num=expected_process_num,
226264
handlers=BUILTIN_HANDLERS,
227265
verbose=expected_verbosity,
266+
progress_reporter=expected_progress_reporter,
228267
)
229268
process_file_mock.assert_called_once_with(config, in_path, None)
230269
logger_config_mock.assert_called_once_with(expected_verbosity, tmp_path, log_path)

unblob/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ExtractionConfig,
2727
process_file,
2828
)
29+
from .ui import NullProgressReporter, RichConsoleProgressReporter
2930

3031
logger = get_logger()
3132

@@ -258,6 +259,9 @@ def cli(
258259
dir_handlers=dir_handlers,
259260
keep_extracted_chunks=keep_extracted_chunks,
260261
verbose=verbose,
262+
progress_reporter=NullProgressReporter
263+
if verbose
264+
else RichConsoleProgressReporter,
261265
)
262266

263267
logger.info("Start processing file", file=file)

unblob/processing.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import shutil
33
from operator import attrgetter
44
from pathlib import Path
5-
from typing import Iterable, List, Optional, Sequence, Set, Tuple
5+
from typing import Iterable, List, Optional, Sequence, Set, Tuple, Type
66

77
import attr
88
import magic
99
import plotext as plt
10-
from rich import progress
11-
from rich.style import Style
1210
from structlog import get_logger
1311
from unblob_native import math_tools as mt
1412

@@ -45,6 +43,7 @@
4543
UnknownError,
4644
)
4745
from .signals import terminate_gracefully
46+
from .ui import NullProgressReporter, ProgressReporter
4847

4948
logger = get_logger()
5049

@@ -94,6 +93,7 @@ class ExtractionConfig:
9493
handlers: Handlers = BUILTIN_HANDLERS
9594
dir_handlers: DirectoryHandlers = BUILTIN_DIR_HANDLERS
9695
verbose: int = 1
96+
progress_reporter: Type[ProgressReporter] = NullProgressReporter
9797

9898
def get_extract_dir_for(self, path: Path) -> Path:
9999
"""Return extraction dir under root with the name of path."""
@@ -146,26 +146,11 @@ def _process_task(config: ExtractionConfig, task: Task) -> ProcessResult:
146146
processor = Processor(config)
147147
aggregated_result = ProcessResult()
148148

149-
if not config.verbose:
150-
progress_display = progress.Progress(
151-
progress.TextColumn(
152-
"Extraction progress: {task.percentage:>3.0f}%",
153-
style=Style(color="#00FFC8"),
154-
),
155-
progress.BarColumn(
156-
complete_style=Style(color="#00FFC8"), style=Style(color="#002060")
157-
),
158-
)
159-
progress_display.start()
160-
overall_progress_task = progress_display.add_task("Extraction progress:")
149+
progress_reporter = config.progress_reporter()
161150

162151
def process_result(pool, result):
163-
if config.verbose == 0 and progress_display.tasks[0].total is not None:
164-
progress_display.update(
165-
overall_progress_task,
166-
advance=1,
167-
total=progress_display.tasks[0].total + len(result.subtasks),
168-
)
152+
progress_reporter.update(result)
153+
169154
for new_task in result.subtasks:
170155
pool.submit(new_task)
171156
aggregated_result.register(result)
@@ -176,14 +161,10 @@ def process_result(pool, result):
176161
result_callback=process_result,
177162
)
178163

179-
with pool:
164+
with pool, progress_reporter:
180165
pool.submit(task)
181166
pool.process_until_done()
182167

183-
if not config.verbose:
184-
progress_display.remove_task(overall_progress_task)
185-
progress_display.stop()
186-
187168
return aggregated_result
188169

189170

unblob/ui.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Protocol
2+
3+
from rich import progress
4+
from rich.style import Style
5+
6+
from .models import TaskResult
7+
8+
9+
class ProgressReporter(Protocol):
10+
def __enter__(self):
11+
...
12+
13+
def __exit__(self, _exc_type, _exc_value, _tb):
14+
...
15+
16+
def update(self, result: TaskResult):
17+
...
18+
19+
20+
class NullProgressReporter:
21+
def __enter__(self):
22+
pass
23+
24+
def __exit__(self, _exc_type, _exc_value, _tb):
25+
pass
26+
27+
def update(self, result: TaskResult):
28+
pass
29+
30+
31+
class RichConsoleProgressReporter:
32+
def __init__(self):
33+
self._progress = progress.Progress(
34+
progress.TextColumn(
35+
"Extraction progress: {task.percentage:>3.0f}%",
36+
style=Style(color="#00FFC8"),
37+
),
38+
progress.BarColumn(
39+
complete_style=Style(color="#00FFC8"), style=Style(color="#002060")
40+
),
41+
)
42+
self._overall_progress_task = self._progress.add_task("Extraction progress:")
43+
44+
def __enter__(self):
45+
self._progress.start()
46+
47+
def __exit__(self, _exc_type, _exc_value, _tb):
48+
self._progress.remove_task(self._overall_progress_task)
49+
self._progress.stop()
50+
51+
def update(self, result: TaskResult):
52+
if (total := self._progress.tasks[0].total) is not None:
53+
self._progress.update(
54+
self._overall_progress_task,
55+
advance=1,
56+
total=total + len(result.subtasks),
57+
)

0 commit comments

Comments
 (0)