diff --git a/.github/workflows/test_corpus.yaml b/.github/workflows/test_corpus.yaml index c1c5b744..c0b4d78d 100644 --- a/.github/workflows/test_corpus.yaml +++ b/.github/workflows/test_corpus.yaml @@ -17,7 +17,7 @@ on: type: boolean description: 'Regenerate results' required: true - default: true + default: false workflow_call: inputs: ref: diff --git a/.github/workflows/xtest.yaml b/.github/workflows/xtest.yaml index 519e7d53..245114de 100644 --- a/.github/workflows/xtest.yaml +++ b/.github/workflows/xtest.yaml @@ -26,7 +26,7 @@ jobs: - name: Run tests run: | - + if [[ "${{ matrix.python }}" == "python3.4" ]]; then (cd /usr/lib64/python3.4/test && python3.4 make_ssl_certs.py) elif [[ "${{ matrix.python }}" == "python3.5" ]]; then diff --git a/.gitignore b/.gitignore index 4a7e4622..a0abb5c0 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ docs/source/transforms/*.min.py .circleci-config.yml .coverage .mypy_cache/ +NOTES.md diff --git a/corpus_test/generate_report.py b/corpus_test/generate_report.py index 8ce94dbf..fccf9beb 100644 --- a/corpus_test/generate_report.py +++ b/corpus_test/generate_report.py @@ -6,7 +6,7 @@ from result import Result, ResultReader -ENHANCED_REPORT = os.environ.get('ENHANCED_REPORT', False) +ENHANCED_REPORT = os.environ.get('ENHANCED_REPORT', True) @dataclass @@ -64,6 +64,9 @@ def mean_percent_of_original(self) -> float: def larger_than_original(self) -> Iterable[Result]: """Return those entries that have a larger minified size than the original size""" for result in self.entries.values(): + if result.outcome != 'Minified': + continue + if result.original_size < result.minified_size: yield result @@ -91,10 +94,18 @@ def compare_size_increase(self, base: 'ResultSet') -> Iterable[Result]: """ for result in self.entries.values(): + if result.outcome != 'Minified': + # This result was not minified, so we can't compare + continue + if result.corpus_entry not in base.entries: continue base_result = base.entries[result.corpus_entry] + if base_result.outcome != 'Minified': + # The base result was not minified, so we can't compare + continue + if result.minified_size > base_result.minified_size: yield result @@ -104,10 +115,17 @@ def compare_size_decrease(self, base: 'ResultSet') -> Iterable[Result]: """ for result in self.entries.values(): + if result.outcome != 'Minified': + continue + if result.corpus_entry not in base.entries: continue base_result = base.entries[result.corpus_entry] + if base_result.outcome != 'Minified': + # The base result was not minified, so we can't compare + continue + if result.minified_size < base_result.minified_size: yield result @@ -164,6 +182,103 @@ def format_difference(compare: Iterable[Result], base: Iterable[Result]) -> str: else: return s +def report_larger_than_original(results_dir: str, python_versions: str, minifier_sha: str) -> str: + yield ''' +## Larger than original + +| Corpus Entry | Original Size | Minified Size | +|--------------|--------------:|--------------:|''' + + for python_version in python_versions: + try: + summary = result_summary(results_dir, python_version, minifier_sha) + except FileNotFoundError: + continue + + larger_than_original = sorted(summary.larger_than_original(), key=lambda result: result.original_size) + + for entry in larger_than_original: + yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} ({entry.minified_size - entry.original_size:+}) |' + +def report_unstable(results_dir: str, python_versions: str, minifier_sha: str) -> str: + yield ''' +## Unstable + +| Corpus Entry | Python Version | Original Size | +|--------------|----------------|--------------:|''' + + for python_version in python_versions: + try: + summary = result_summary(results_dir, python_version, minifier_sha) + except FileNotFoundError: + continue + + unstable = sorted(summary.unstable_minification(), key=lambda result: result.original_size) + + for entry in unstable: + yield f'| {entry.corpus_entry} | {python_version} | {entry.original_size} |' + +def report_exceptions(results_dir: str, python_versions: str, minifier_sha: str) -> str: + yield ''' +## Exceptions + +| Corpus Entry | Python Version | Exception | +|--------------|----------------|-----------|''' + + exceptions_found = False + + for python_version in python_versions: + try: + summary = result_summary(results_dir, python_version, minifier_sha) + except FileNotFoundError: + continue + + exceptions = sorted(summary.exception(), key=lambda result: result.original_size) + + for entry in exceptions: + exceptions_found = True + yield f'| {entry.corpus_entry} | {python_version} | {entry.outcome} |' + + if not exceptions_found: + yield ' None | | |' + +def report_larger_than_base(results_dir: str, python_versions: str, minifier_sha: str, base_sha: str) -> str: + yield ''' +## Top 10 Larger than base + +| Corpus Entry | Original Size | Minified Size | +|--------------|--------------:|--------------:|''' + + there_are_some_larger_than_base = False + + for python_version in python_versions: + try: + summary = result_summary(results_dir, python_version, minifier_sha) + except FileNotFoundError: + continue + + base_summary = result_summary(results_dir, python_version, base_sha) + larger_than_original = sorted(summary.compare_size_increase(base_summary), key=lambda result: result.original_size)[:10] + + for entry in larger_than_original: + there_are_some_larger_than_base = True + yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} ({entry.minified_size - base_summary.entries[entry.corpus_entry].minified_size:+}) |' + + if not there_are_some_larger_than_base: + yield '| N/A | N/A | N/A |' + +def report_slowest(results_dir: str, python_versions: str, minifier_sha: str) -> str: + yield ''' +## Top 10 Slowest + +| Corpus Entry | Original Size | Minified Size | Time | +|--------------|--------------:|--------------:|-----:|''' + + for python_version in python_versions: + summary = result_summary(results_dir, python_version, minifier_sha) + + for entry in sorted(summary.entries.values(), key=lambda entry: entry.time, reverse=True)[:10]: + yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} | {entry.time:.3f} |' def report(results_dir: str, minifier_ref: str, minifier_sha: str, base_ref: str, base_sha: str) -> Iterable[str]: """ @@ -236,50 +351,11 @@ def format_size_change_detail() -> str: ) if ENHANCED_REPORT: - yield ''' -## Larger than original - -| Corpus Entry | Original Size | Minified Size | -|--------------|--------------:|--------------:|''' - - for python_version in ['3.11']: - summary = result_summary(results_dir, python_version, minifier_sha) - larger_than_original = sorted(summary.larger_than_original(), key=lambda result: result.original_size) - - for entry in larger_than_original: - yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} ({entry.minified_size - entry.original_size:+}) |' - - yield ''' -## Top 10 Larger than base - -| Corpus Entry | Original Size | Minified Size | -|--------------|--------------:|--------------:|''' - - there_are_some_larger_than_base = False - - for python_version in ['3.11']: - summary = result_summary(results_dir, python_version, minifier_sha) - base_summary = result_summary(results_dir, python_version, base_sha) - larger_than_original = sorted(summary.compare_size_increase(base_summary), key=lambda result: result.original_size)[:10] - - for entry in larger_than_original: - there_are_some_larger_than_base = True - yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} ({entry.minified_size - base_summary.entries[entry.corpus_entry].minified_size:+}) |' - - if not there_are_some_larger_than_base: - yield '| N/A | N/A | N/A |' - - yield ''' -## Top 10 Slowest - -| Corpus Entry | Original Size | Minified Size | Time | -|--------------|--------------:|--------------:|-----:|''' - - for python_version in ['3.11']: - summary = result_summary(results_dir, python_version, minifier_sha) - - for entry in sorted(summary.entries.values(), key=lambda entry: entry.time, reverse=True)[:10]: - yield f'| {entry.corpus_entry} | {entry.original_size} | {entry.minified_size} | {entry.time:.3f} |' + yield from report_larger_than_original(results_dir, ['3.11'], minifier_sha) + yield from report_larger_than_base(results_dir, ['3.11'], minifier_sha, base_sha) + yield from report_slowest(results_dir, ['3.11'], minifier_sha) + yield from report_unstable(results_dir, ['2.7', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', '3.10', '3.11'], minifier_sha) + yield from report_exceptions(results_dir, ['3.6', '3.7', '3.8', '3.9', '3.10', '3.11'], minifier_sha) def main(): diff --git a/corpus_test/generate_results.py b/corpus_test/generate_results.py index f43d4ad4..a8c3fdf9 100644 --- a/corpus_test/generate_results.py +++ b/corpus_test/generate_results.py @@ -1,8 +1,14 @@ import argparse +import datetime +import gzip import os import sys import time + +import logging + + import python_minifier from result import Result, ResultWriter @@ -23,8 +29,13 @@ def minify_corpus_entry(corpus_path, corpus_entry): :rtype: Result """ - with open(os.path.join(corpus_path, corpus_entry), 'rb') as f: - source = f.read() + if os.path.isfile(os.path.join(corpus_path, corpus_entry + '.py.gz')): + with gzip.open(os.path.join(corpus_path, corpus_entry + '.py.gz'), 'rb') as f: + source = f.read() + else: + with open(os.path.join(corpus_path, corpus_entry), 'rb') as f: + source = f.read() + result = Result(corpus_entry, len(source), 0, 0, '') @@ -72,21 +83,54 @@ def corpus_test(corpus_path, results_path, sha, regenerate_results): :param str sha: The python-minifier sha we are testing :param bool regenerate_results: Regenerate results even if they are present """ - corpus_entries = os.listdir(corpus_path) - python_version = '.'.join([str(s) for s in sys.version_info[:2]]) + + log_path = 'results_' + python_version + '_' + sha + '.log' + print('Logging in GitHub Actions is absolute garbage. Logs are going to ' + log_path) + + logging.basicConfig(filename=os.path.join(results_path, log_path), level=logging.DEBUG) + + corpus_entries = [entry[:-len('.py.gz')] for entry in os.listdir(corpus_path)] + results_file_path = os.path.join(results_path, 'results_' + python_version + '_' + sha + '.csv') - if os.path.isfile(results_file_path) and not regenerate_results: - print('Results file already exists: %s', results_file_path) - return + if os.path.isfile(results_file_path): + logging.info('Results file already exists: %s', results_file_path) + if regenerate_results: + os.remove(results_file_path) + + total_entries = len(corpus_entries) + logging.info('Testing python-minifier on %d entries' % total_entries) + tested_entries = 0 + + start_time = time.time() + next_checkpoint = time.time() + 60 with ResultWriter(results_file_path) as result_writer: + logging.info('%d results already present' % len(result_writer)) + for entry in corpus_entries: - print(entry) + if entry in result_writer: + continue + + logging.debug(entry) + result = minify_corpus_entry(corpus_path, entry) result_writer.write(result) + tested_entries += 1 + + sys.stdout.flush() + + if time.time() > next_checkpoint: + percent = len(result_writer) / total_entries * 100 + time_per_entry = (time.time() - start_time) / tested_entries + entries_remaining = len(corpus_entries) - len(result_writer) + time_remaining = int(entries_remaining * time_per_entry) + logging.info('Tested %d/%d entries (%d%%) %s seconds remaining' % (len(result_writer), total_entries, percent, time_remaining)) + sys.stdout.flush() + next_checkpoint = time.time() + 60 + logging.info('Finished') def bool_parse(value): return value == 'true' diff --git a/corpus_test/result.py b/corpus_test/result.py index 00123dc8..b02e4ac1 100644 --- a/corpus_test/result.py +++ b/corpus_test/result.py @@ -1,3 +1,6 @@ +import os + + class Result(object): def __init__(self, corpus_entry, original_size, minified_size, time, outcome): @@ -21,15 +24,37 @@ def __init__(self, results_path): :param str results_path: The path to the results file """ self._results_path = results_path + self._size = 0 + self._existing_result_set = set() + + if not os.path.isfile(self._results_path): + return + + with open(self._results_path, 'r') as f: + for line in f: + if line != 'corpus_entry,original_size,minified_size,time,result\n': + self._existing_result_set.add(line.split(',')[0]) + + self._size += len(self._existing_result_set) def __enter__(self): - self.results = open(self._results_path, 'w') + self.results = open(self._results_path, 'a') self.results.write('corpus_entry,original_size,minified_size,time,result\n') return self def __exit__(self, exc_type, exc_val, exc_tb): self.results.close() + def __contains__(self, item): + """ + :param str item: The name of the entry in the corpus + :return bool: True if the entry already exists in the results file + """ + return item in self._existing_result_set + + def __len__(self): + return self._size + def write(self, result): """ :param Result result: The result to write to the file @@ -41,6 +66,7 @@ def write(self, result): str(result.time) + ',' + result.outcome + '\n' ) self.results.flush() + self._size += 1 class ResultReader: @@ -66,7 +92,11 @@ def __next__(self): """ :return Result: The next result in the file """ + line = self.results.readline() + while line == 'corpus_entry,original_size,minified_size,time,result\n': + line = self.results.readline() + if line == '': raise StopIteration else: diff --git a/src/python_minifier/expression_printer.py b/src/python_minifier/expression_printer.py index 4544db96..e0adc847 100644 --- a/src/python_minifier/expression_printer.py +++ b/src/python_minifier/expression_printer.py @@ -14,6 +14,7 @@ class ExpressionPrinter(object): def __init__(self): self.precedences = { + 'NamedExpr': 1, # NamedExpr 'Lambda': 2, # Lambda 'IfExp': 3, # IfExp 'comprehension': 3.5, @@ -128,7 +129,7 @@ def visit_Bytes(self, node): def visit_List(self, node): self.printer.delimiter('[') - self._exprlist(node.elts) + self._starred_list(node.elts) self.printer.delimiter(']') def visit_Tuple(self, node): @@ -138,14 +139,25 @@ def visit_Tuple(self, node): self.printer.delimiter(')') return - self._exprlist(node.elts) + if [n for n in node.elts if is_ast_node(n, 'NamedExpr')]: + self.printer.delimiter('(') + delimiter = Delimiter(self.printer) + for expr in node.elts: + delimiter.new_item() + self._expression(expr) + self.printer.delimiter(')') + else: + delimiter = Delimiter(self.printer) + for expr in node.elts: + delimiter.new_item() + self._expression(expr) if len(node.elts) == 1: self.printer.delimiter(',') def visit_Set(self, node): self.printer.delimiter('{') - self._exprlist(node.elts) + self._starred_list(node.elts) self.printer.delimiter('}') def visit_Dict(self, node): @@ -156,17 +168,32 @@ def key_datum(key, datum): if key is None: self.printer.operator('**') - if 0 < self.precedence(datum) <=7: + if 0 < self.precedence(datum) <= 7: self.printer.delimiter('(') self._expression(datum) self.printer.delimiter(')') else: - self._expression(datum) + if is_ast_node(datum, 'NamedExpr'): + self.printer.delimiter('(') + self._expression(datum) + self.printer.delimiter(')') + else: + self._expression(datum) else: - self._expression(key) + if is_ast_node(key, 'NamedExpr'): + self.printer.delimiter('(') + self._expression(key) + self.printer.delimiter(')') + else: + self._expression(key) self.printer.delimiter(':') - self._expression(datum) + if is_ast_node(datum, 'NamedExpr'): + self.printer.delimiter('(') + self._expression(datum) + self.printer.delimiter(')') + else: + self._expression(datum) self.printer.delimiter('{') @@ -407,6 +434,12 @@ def visit_keyword(self, node): if node.arg is None: self.printer.operator('**') self._expression(node.value) + elif is_ast_node(node.value, 'NamedExpr'): + self.printer.identifier(node.arg) + self.printer.delimiter('=') + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') else: self.printer.identifier(node.arg) self.printer.delimiter('=') @@ -414,7 +447,12 @@ def visit_keyword(self, node): def visit_IfExp(self, node): - self._rhs(node.body, node) + if isinstance(node.body, ast.IfExp): + self.printer.delimiter('(') + self._lhs(node.body, node) + self.printer.delimiter(')') + else: + self._lhs(node.body, node) self.printer.keyword('if') @@ -422,7 +460,12 @@ def visit_IfExp(self, node): self.printer.keyword('else') - self._expression(node.orelse) + if is_ast_node(node.orelse, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.orelse) + self.printer.delimiter(')') + else: + self._expression(node.orelse) def visit_Attribute(self, node): value_precedence = self.precedence(node.value) @@ -465,7 +508,25 @@ def visit_Subscript(self, node): elif isinstance(node.slice, ast.Ellipsis): self.visit_Ellipsis(node) elif sys.version_info >= (3, 9) and isinstance(node.slice, ast.Tuple): - self.visit_Tuple(node.slice) + contains_starred = False + if [n for n in node.slice.elts if is_ast_node(n, 'Starred')]: + contains_starred = True + self.printer.delimiter('(') + + with Delimiter(self.printer) as delimiter: + for expr in node.slice.elts: + delimiter.new_item() + self._expression(expr) + + if len(node.slice.elts) == 0: + self.printer.delimiter('(') + self.printer.delimiter(')') + elif len(node.slice.elts) == 1: + self.printer.delimiter(',') + + if contains_starred: + self.printer.delimiter(')') + elif sys.version_info >= (3, 9): self._expression(node.slice) else: @@ -474,25 +535,53 @@ def visit_Subscript(self, node): self.printer.delimiter(']') def visit_Index(self, node): - self._expression(node.value) + if isinstance(node.value, ast.Tuple): + self.printer.delimiter('(') + self.visit_Tuple(node.value) + self.printer.delimiter(')') + else: + self._expression_list(node.value) def visit_Slice(self, node): if node.lower: - self._expression(node.lower) + if is_ast_node(node.lower, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.lower) + self.printer.delimiter(')') + else: + self._expression(node.lower) self.printer.delimiter(':') if node.upper: - self._expression(node.upper) + if is_ast_node(node.upper, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.upper) + self.printer.delimiter(')') + else: + self._expression(node.upper) + if node.step: self.printer.delimiter(':') - self._expression(node.step) + if is_ast_node(node.step, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.step) + self.printer.delimiter(')') + else: + self._expression(node.step) def visit_ExtSlice(self, node): delimiter = Delimiter(self.printer) for s in node.dims: + assert isinstance(s, (ast.Index, ast.Slice, ast.Ellipsis)) + delimiter.new_item() - self._expression(s) + if isinstance(s, ast.Index): + self.visit_Index(s) + elif isinstance(s, ast.Slice): + self.visit_Slice(s) + elif isinstance(s, ast.Ellipsis): + self.visit_Ellipsis(s) if len(node.dims) == 1: self.printer.delimiter(',') @@ -526,7 +615,13 @@ def visit_GeneratorExp(self, node, omit_parens=False): def visit_DictComp(self, node): self.printer.delimiter('{') - self._expression(node.key) + + if 0 < self.precedence(node.key) < 3: + self.printer.delimiter('(') + self._expression(node.key) + self.printer.delimiter(')') + else: + self._expression(node.key) self.printer.delimiter(':') self._expression(node.value) [self.visit_comprehension(x) for x in node.generators] @@ -539,7 +634,7 @@ def visit_comprehension(self, node): self.printer.keyword('async') self.printer.keyword('for') - self._exprlist([node.target]) + self._target_list(node.target) self.printer.keyword('in') self._rhs(node.iter, node) @@ -561,7 +656,12 @@ def visit_Lambda(self, node): self.printer.delimiter(':') - self._expression(node.body) + if is_ast_node(node.body, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.body) + self.printer.delimiter(')') + else: + self._expression(node.body) def visit_arguments(self, node): args = getattr(node, 'posonlyargs', []) + node.args @@ -576,7 +676,13 @@ def visit_arguments(self, node): if i >= count_no_defaults: self.printer.delimiter('=') - self._expression(node.defaults[i - count_no_defaults]) + default = node.defaults[i - count_no_defaults] + if is_ast_node(default, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(default) + self.printer.delimiter(')') + else: + self._expression(node.defaults[i - count_no_defaults]) if hasattr(node, 'posonlyargs') and node.posonlyargs and i + 1 == len(node.posonlyargs): self.printer.delimiter(',') @@ -635,7 +741,12 @@ def visit_arg(self, node): if node.annotation: self.printer.delimiter(':') - self._expression(node.annotation) + if is_ast_node(node.annotation, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.annotation) + self.printer.delimiter(')') + else: + self._expression(node.annotation) def visit_Repr(self, node): self.printer.delimiter('`') @@ -648,49 +759,155 @@ def visit_Expression(self, node): self._expression(node.body) def _expression(self, expression): + """ + An `expression` in the python grammer. + + Tuples must be parenthesized. + Yield/YieldFrom must be parenthesized. + """ + if is_ast_node(expression, (ast.Yield, 'YieldFrom')): self.printer.delimiter('(') - self._yield_expr(expression) + self._yield_expression(expression) self.printer.delimiter(')') elif isinstance(expression, ast.Tuple) and len(expression.elts) > 0: self.printer.delimiter('(') self.visit_Tuple(expression) self.printer.delimiter(')') - elif is_ast_node(expression, 'NamedExpr'): - self.printer.delimiter('(') - self.visit_NamedExpr(expression) - self.printer.delimiter(')') + #elif is_ast_node(expression, 'NamedExpr'): + #self.printer.delimiter('(') + # self.visit_NamedExpr(expression) + #self.printer.delimiter(')') else: self.visit(expression) def _testlist(self, test): if is_ast_node(test, (ast.Yield, 'YieldFrom')): self.printer.delimiter('(') - self._yield_expr(test) - self.printer.delimiter(')') - elif is_ast_node(test, 'NamedExpr'): - self.printer.delimiter('(') - self.visit_NamedExpr(test) + self._yield_expression(test) self.printer.delimiter(')') else: self.visit(test) - def _exprlist(self, exprlist): - delimiter = Delimiter(self.printer) - for expr in exprlist: - delimiter.new_item() - self._expression(expr) + # region Grammar elements + def _expression_list(self, exprlist): + """ + An 'expression_list' in the grammar + + This may be a single expression or a list of expressions. + If it is a list of expressions the exprlist is a Tuple node, which does not need to be enclosed by parentheses. + An empty tuple needs to be printed as '()' + A tuple with a single element needs to have a trailing comma + + If the list contains a starred expression, it needs to be parenthesized. + """ + + if isinstance(exprlist, ast.Tuple): + + contains_starred = False + if [n for n in exprlist.elts if is_ast_node(n, 'Starred')]: + contains_starred = True + + with Delimiter(self.printer, add_parens=contains_starred) as delimiter: + for expr in exprlist.elts: + delimiter.new_item() + self._expression(expr) - def _yield_expr(self, yield_node): + if len(exprlist.elts) == 0: + self.printer.delimiter('(') + self.printer.delimiter(')') + return + + if len(exprlist.elts) == 1: + self.printer.delimiter(',') + + elif isinstance(exprlist, list): + delimiter = Delimiter(self.printer) + for e in exprlist: + delimiter.new_item() + self._expression(e) + else: + if is_ast_node(exprlist, 'Starred'): + self.printer.delimiter('(') + self._expression(exprlist) + self.printer.delimiter(')') + else: + self._expression(exprlist) + + def _starred_list(self, exprlist): + """ + A 'starred_list' in the grammar + + This is very similar to an expression_list, but it may contain a starred expression without being parenthesized. + """ + + if isinstance(exprlist, ast.Tuple): + delimiter = Delimiter(self.printer) + for expr in exprlist.elts: + delimiter.new_item() + self._expression(expr) + + if len(exprlist.elts) == 0: + self.printer.delimiter('(') + self.printer.delimiter(')') + return + + if len(exprlist.elts) == 1: + self.printer.delimiter(',') + + elif isinstance(exprlist, list): + delimiter = Delimiter(self.printer) + for e in exprlist: + delimiter.new_item() + self._expression(e) + else: + self._expression(exprlist) + + def _target_list(self, target_list): + """ + A 'target_list' in the grammar + """ + return self._starred_list(target_list) + + def _starred_expression(self, starred_expression): + """ + A 'starred_expression' in the grammar + """ + return self._expression(starred_expression) + + def _assignment_expression(self, assignment_expression): + """ + An 'assignment_expression' in the grammar + """ + if is_ast_node(assignment_expression, 'NamedExpr'): + self.visit_NamedExpr(assignment_expression) + else: + self._expression(assignment_expression) + + def _yield_expression(self, yield_node): if isinstance(yield_node, ast.Yield): self.printer.keyword('yield') elif isinstance(yield_node, ast.YieldFrom): self.printer.keyword('yield') self.printer.keyword('from') - if yield_node.value is not None: + if yield_node.value is None: + return + + if is_ast_node(yield_node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(yield_node.value) + self.printer.delimiter(')') + elif is_ast_node(yield_node, ast.Yield): + if sys.version_info < (3, 8): + self._expression_list(yield_node.value) + else: + self._starred_list(yield_node.value) + elif is_ast_node(yield_node, 'YieldFrom'): self._expression(yield_node.value) + # endregion + @staticmethod def _is_right_associative(operator): return isinstance(operator, ast.Pow) @@ -740,7 +957,12 @@ def visit_JoinedStr(self, node): def visit_NamedExpr(self, node): self._expression(node.target) self.printer.operator(':=') - self._expression(node.value) + if isinstance(node.value, ast.NamedExpr): + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') + else: + self._expression(node.value) def visit_Await(self, node): assert isinstance(node, ast.Await) diff --git a/src/python_minifier/f_string.py b/src/python_minifier/f_string.py index 2dead030..4f0675de 100644 --- a/src/python_minifier/f_string.py +++ b/src/python_minifier/f_string.py @@ -166,7 +166,13 @@ def get_candidates(self): if self.is_curly(self.node.value): self.printer.delimiter(' ') - self._expression(self.node.value) + if is_ast_node(self.node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(self.node.value) + self.printer.delimiter(')') + + else: + self._expression(self.node.value) if self.node.conversion == 115: self.printer.append('!s', TokenTypes.Delimiter) diff --git a/src/python_minifier/module_printer.py b/src/python_minifier/module_printer.py index cc2ec7d5..e75830cf 100644 --- a/src/python_minifier/module_printer.py +++ b/src/python_minifier/module_printer.py @@ -56,21 +56,46 @@ def visit_Expr(self, node): assert isinstance(node, ast.Expr) if is_ast_node(node.value, (ast.Yield, 'YieldFrom')): - self._yield_expr(node.value) + self._yield_expression(node.value) + elif is_ast_node(node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') + elif isinstance(node.value, ast.Tuple): + self.visit_Tuple(node.value) else: - self._testlist(node.value) + self._starred_expression(node.value) self.printer.end_statement() def visit_Assert(self, node): + """ + Assert statement + + assert_stmt ::= "assert" expression ["," expression] + + https://docs.python.org/3.11/reference/simple_stmts.html#the-assert-statement + """ assert isinstance(node, ast.Assert) self.printer.keyword('assert') - self._expression(node.test) + + if is_ast_node(node.test, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.test) + self.printer.delimiter(')') + else: + self._expression(node.test) if node.msg: self.printer.delimiter(',') - self._expression(node.msg) + + if is_ast_node(node.msg, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.msg) + self.printer.delimiter(')') + else: + self._expression(node.msg) self.printer.end_statement() @@ -78,29 +103,45 @@ def visit_Assign(self, node): assert isinstance(node, ast.Assign) for target_node in node.targets: - self._testlist(target_node) + self._target_list(target_node) self.printer.delimiter('=') # Yield nodes that are the sole node on the right hand side of an assignment do not need parens if is_ast_node(node.value, (ast.Yield, 'YieldFrom')): - self._yield_expr(node.value) + self._yield_expression(node.value) + elif is_ast_node(node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') + elif isinstance(node.value, ast.Tuple): + self.visit_Tuple(node.value) else: - self._testlist(node.value) + self._starred_expression(node.value) self.printer.end_statement() def visit_AugAssign(self, node): assert isinstance(node, ast.AugAssign) - self._testlist(node.target) + self._target_list(node.target) self.visit(node.op) self.printer.delimiter('=') # Yield nodes that are the sole node on the right hand side of an assignment do not need parens if is_ast_node(node.value, (ast.Yield, 'YieldFrom')): - self._yield_expr(node.value) + self._yield_expression(node.value) + + # NamedExpr nodes that are the sole node on the right hand side of an assignment MUST have parens + elif is_ast_node(node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') + else: - self._testlist(node.value) + if sys.version_info >= (3,9): + self._starred_list(node.value) # still documented as expression_list + else: + self._expression_list(node.value) self.printer.end_statement() @@ -116,12 +157,39 @@ def visit_AnnAssign(self, node): if node.annotation: self.printer.delimiter(':') - self._expression(node.annotation) + if is_ast_node(node.annotation, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.annotation) + self.printer.delimiter(')') + else: + self._expression(node.annotation) if node.value: self.printer.delimiter('=') - self._expression(node.value) + if isinstance(node.value, ast.Tuple): + if sys.version_info < (3, 8) and len(node.value.elts) != 0: + self.printer.delimiter('(') + self.visit_Tuple(node.value) + self.printer.delimiter(')') + else: + self.visit_Tuple(node.value) + elif is_ast_node(node.value, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.value) + self.printer.delimiter(')') + elif is_ast_node(node.value, (ast.Yield, 'YieldFrom')): + if sys.version_info >= (3, 8): + self._yield_expression(node.value) + else: + self.printer.delimiter('(') + self._yield_expression(node.value) + self.printer.delimiter(')') + else: + if sys.version_info >= (3, 8): + self._starred_expression(node.value) + else: + self._expression_list(node.value) self.printer.end_statement() @@ -135,22 +203,25 @@ def visit_Delete(self, node): assert isinstance(node, ast.Delete) self.printer.keyword('del') - self._exprlist(node.targets) + self._target_list(node.targets) self.printer.end_statement() def visit_Return(self, node): assert isinstance(node, ast.Return) self.printer.keyword('return') - if isinstance(node.value, ast.Tuple): - if sys.version_info < (3, 8) and [n for n in node.value.elts if is_ast_node(n, 'Starred')]: + + if node.value is not None: + if is_ast_node(node.value, 'NamedExpr'): self.printer.delimiter('(') - self._testlist(node.value) + self.visit_NamedExpr(node.value) self.printer.delimiter(')') else: - self._testlist(node.value) - elif node.value is not None: - self._testlist(node.value) + if sys.version_info < (3, 8): + self._expression_list(node.value) + else: + self._starred_list(node.value) + self.printer.end_statement() def visit_Print(self, node): @@ -177,13 +248,13 @@ def visit_Print(self, node): def visit_Yield(self, node): assert isinstance(node, ast.Yield) - self._yield_expr(node) + self._yield_expression(node) self.printer.end_statement() def visit_YieldFrom(self, node): assert isinstance(node, ast.YieldFrom) - self._yield_expr(node) + self._yield_expression(node) self.printer.end_statement() def visit_Raise(self, node): @@ -207,11 +278,21 @@ def visit_Raise(self, node): # Python3 if node.exc: - self._expression(node.exc) + if is_ast_node(node.exc, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.exc) + self.printer.delimiter(')') + else: + self._expression(node.exc) if node.cause: self.printer.keyword('from') - self._expression(node.cause) + if is_ast_node(node.cause, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.cause) + self.printer.delimiter(')') + else: + self._expression(node.cause) self.printer.end_statement() @@ -309,7 +390,7 @@ def visit_If(self, node, el=False): else: self.printer.keyword('if') - self._expression(node.test) + self._assignment_expression(node.test) self.printer.delimiter(':') self._suite(node.body) @@ -334,9 +415,17 @@ def visit_For(self, node, is_async=False): self.printer.keyword('async') self.printer.keyword('for') - self._exprlist([node.target]) + self._target_list(node.target) self.printer.keyword('in') - self._expression(node.iter) + + if is_ast_node(node.iter, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.iter) + self.printer.delimiter(')') + elif sys.version_info >= (3, 9): + self._starred_list(node.iter) + else: + self._expression_list(node.iter) self.printer.delimiter(':') self._suite(node.body) @@ -352,7 +441,7 @@ def visit_While(self, node): self.printer.newline() self.printer.keyword('while') - self._expression(node.test) + self._assignment_expression(node.test) self.printer.delimiter(':') self._suite(node.body) @@ -425,7 +514,12 @@ def visit_ExceptHandler(self, node, star=False): self.printer.operator('*') if node.type is not None: - self._expression(node.type) + if is_ast_node(node.type, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.type) + self.printer.delimiter(')') + else: + self._expression(node.type) if node.name is not None: self.printer.keyword('as') @@ -462,16 +556,33 @@ def visit_With(self, node, is_async=False): self.printer.delimiter(')') else: self.visit_withitem(item) + + self.printer.delimiter(':') + self._suite(node.body) + else: - self.visit_withitem(node) - self.printer.delimiter(':') - self._suite(node.body) + def python2_nested_with(node): + self.visit_withitem(node) + if len(node.body) == 1 and isinstance(node.body[0], ast.With): + self.printer.delimiter(',') + python2_nested_with(node.body[0]) + else: + self.printer.delimiter(':') + self._suite(node.body) + + python2_nested_with(node) + def visit_withitem(self, node): assert (hasattr(ast, 'withitem') and isinstance(node, ast.withitem)) or isinstance(node, ast.With) - self._expression(node.context_expr) + if is_ast_node(node.context_expr, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.context_expr) + self.printer.delimiter(')') + else: + self._expression(node.context_expr) if node.optional_vars is not None: self.printer.keyword('as') @@ -486,7 +597,7 @@ def visit_FunctionDef(self, node, is_async=False): for d in node.decorator_list: self.printer.operator('@') - self._expression(d) + self._assignment_expression(d) self.printer.newline() if is_async: @@ -500,7 +611,12 @@ def visit_FunctionDef(self, node, is_async=False): if hasattr(node, 'returns') and node.returns is not None: self.printer.delimiter('->') - self._expression(node.returns) + if is_ast_node(node.returns, 'NamedExpr'): + self.printer.delimiter('(') + self.visit_NamedExpr(node.returns) + self.printer.delimiter(')') + else: + self._expression(node.returns) self.printer.delimiter(':') else: self.printer.delimiter(':') @@ -517,7 +633,7 @@ def visit_ClassDef(self, node): for d in node.decorator_list: self.printer.operator('@') - self._expression(d) + self._assignment_expression(d) self.printer.newline() self.printer.keyword('class') diff --git a/src/python_minifier/token_printer.py b/src/python_minifier/token_printer.py index 2a916434..e7f5449d 100644 --- a/src/python_minifier/token_printer.py +++ b/src/python_minifier/token_printer.py @@ -67,6 +67,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def new_item(self): """Add a new item to the delimited group.""" + if self._add_parens and not self._context_manager: + raise ValueError('Cannot use add_parens without using as a context manager') + if self._first: self._first = False if self._context_manager and self._add_parens: diff --git a/test/skip_invalid.py b/test/skip_invalid.py new file mode 100644 index 00000000..dc8270a4 --- /dev/null +++ b/test/skip_invalid.py @@ -0,0 +1,9 @@ +import pytest +def skip_invalid(test): + def wrapper(statement): + if isinstance(statement, tuple): + statement, valid_condition = statement + if valid_condition is False: + pytest.skip('not supported in this version of Python') + test(statement) + return wrapper diff --git a/test/test_assignment_expressions.py b/test/test_assignment_expressions.py index 7f371ec0..cfc48bda 100644 --- a/test/test_assignment_expressions.py +++ b/test/test_assignment_expressions.py @@ -20,3 +20,21 @@ def test_pep(): expected_ast = ast.parse(source) actual_ast = unparse(expected_ast) compare_ast(expected_ast, ast.parse(actual_ast)) + +''' +#(a:=B) +#a=(b:=c) +# foo(h:=6, x=(y := f(x))) +# def foo(answer=(p := 42)):pass +# def foo(answer: (p := 42) = 5, **asd:(c:=6)) -> (z:=1):pass +# a: (p := 42) = 5 +# a += (b := 1) +# (x := lambda: 1) +# lambda: 1 +(x := 1) and 2 +# lambda line: (m := re.match(pattern, line)) and m.group(1) +# f'{(x:=10)}' +# f'{x:=10}' +# with (x := await a, y := await b): pass +def test_named_expression_assignment_05(self): + (x := 1, 2) +''' \ No newline at end of file diff --git a/test/test_expressions.py b/test/test_expressions.py new file mode 100644 index 00000000..1b4c8948 --- /dev/null +++ b/test/test_expressions.py @@ -0,0 +1,292 @@ +import ast +import sys + +import pytest +from python_minifier import unparse +from python_minifier.ast_compare import compare_ast +from skip_invalid import skip_invalid + + +@pytest.mark.parametrize('statement', [ + '1 if 1 else 1', + '1,2 if(1,2)else 1,2', + '(1,)if(1,)else 1,', + '()if()else()', + 'lambda:1 if(lambda:1)else lambda:1', + '(lambda a:1,)if(lambda a:1,)else lambda a:1,', + '1,lambda a:1 if(1,lambda a:1)else 1,lambda a:1', + ('(a:=1)if(b:=1)else(b:=1)', sys.version_info >= (3, 8)), + '(yield)if(yield)else(yield)', + '(yield 1)if(yield 1)else(yield 1)', + ('(yield from 1)if(yield from 1)else(yield from 1)', sys.version_info >= (3, 3)), + 'b.do if b.do else b.do', + "''.join()if''.join()else''.join()", + (('(a if b else a) if (a if b else a) else (a if b else a)', '(a if b else a)if(a if b else a)else a if b else a'), True) +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_if_exp(statement): + if isinstance(statement, tuple): + statement, expected = statement + else: + expected = statement + + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == expected + +@pytest.mark.parametrize('statement', [ + '1+1', + '1,2+1,2', + '1,2+1,2', + '1,+(1,)', + '()+()', + 'lambda:1+(lambda:1)', + 'lambda:1,+(lambda:1,)', + '1,lambda:1+1,lambda:1', + ('(a:=1)+(b:=1)', sys.version_info >= (3, 8)), + 'yield+(yield)', + 'yield 1+(yield 1)', + ('yield from 1+(yield from 1)', sys.version_info >= (3, 3)), + 'b.do+b.do', + "''.join()+''.join()", + 'a if b else c+a if b else c' +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_binop(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + 'a()', + '(1,2)()', + '(1,)()', + '()()', + 'lambda:1()', + '(lambda a:1,)()', + '(1,lambda a:1)()', + ('(a:=1)()', sys.version_info >= (3, 8)), + '(yield)()', + '(yield 1)()', + ('(yield from 1)()', sys.version_info >= (3, 3)), + 'b.do()', + "''.join()()", + '(a if b else a)()' +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_call(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + '1<1<1', + '1,2<1,2<1,2', + '(1,)<(1,)<1,', + '()<()<()', + '(lambda:1)<(lambda:1)<(lambda:1)', + '(lambda a:1,)<(lambda a:1,)<(lambda a:1,)', + '1,lambda a:1<1,lambda a:1<(1,lambda a:1)', + ('(a:=1)<(b:=1)<(c:=1)', sys.version_info >= (3, 8)), + '(yield)<(yield)<(yield)', + '(yield 1)>(yield 1)>(yield 1)', + ('(yield from 1)<(yield from 1)<(yield from 1)', sys.version_info >= (3, 3)), + 'b.do= (3, 0)), + ('(1 for*a in 1)', sys.version_info >= (3, 0)), + ('(1 for*a,b in 1)', sys.version_info >= (3, 0)), + ('(1 for*a,*c in 1)', sys.version_info >= (3, 0)), + '(b.do for b.do in b.do)', + '(lambda:1 for a in(lambda:1))', + '((lambda a:1,)for a in(lambda a:1,))', + '((1,lambda a:1)for a in(1,lambda a:1))', + ('(a:=1 for a in(a:=1))', sys.version_info >= (3, 8)), + '((yield)for a in(yield))', + '((yield 1)for a in(yield 1))', + ('((yield from 1)for a in(yield from 1))', sys.version_info >= (3, 3)), + "(''.join()for a in''.join())" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_comprehension(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + ('await 1', sys.version_info >= (3, 7)), + ('await(1)', sys.version_info < (3, 7)), + + ('await 1,', sys.version_info >= (3, 7)), + ('await(1)', sys.version_info < (3, 7)), + + ('await 1,2', sys.version_info >= (3, 7)), + ('await(1,2)', sys.version_info < (3, 7)), + + 'await()', + 'await(lambda:1)', + + ('await(lambda a:1,)', sys.version_info >= (3, 7)), + ('await(lambda a:1)', sys.version_info < (3, 7)), + + 'await(1,lambda a:1)', + ('await(b:=1)', sys.version_info >= (3, 8)), + + ('await 1 if True else 1', sys.version_info >= (3, 7)), + ('await(1 if True else 1)', sys.version_info < (3, 7)), + + ('await b,1 if True else 1', sys.version_info >= (3, 7)), + ('await(b,1 if True else 1)', sys.version_info < (3, 7)), + + ('await 1 if True else 1,', sys.version_info >= (3, 7)), + ('await(1 if True else 1)', sys.version_info < (3, 7)), + + ('await 1 if True else 1,b', sys.version_info >= (3, 7)), + ('await(1 if True else 1,b)', sys.version_info < (3, 7)), + + ('await b.do', sys.version_info >= (3, 7)), + ('await(b.do)', sys.version_info < (3, 7)), + + ("await''.join()", sys.version_info >= (3, 7)), + ("await(''.join())", sys.version_info < (3, 7)), +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_await(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + '1,2', + ('(a:=1,b:=32)', sys.version_info >= (3, 8)), + ('(1,b:=32)', sys.version_info >= (3, 8)), + 'lambda:1,lambda:2', + '1 if True else 1,2 if True else 2', + '(a for a in a),(b for b in b)', + 'a or b,a and b', + 'a+b,a-b', +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_tuple(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + 'a[1]', + ('a[a:=1]', sys.version_info >= (3, 8)), + 'a[lambda a:1]', + 'a[1 if True else 1]', + 'a[b.do]', + "a[''.join()]", + 'a[1,2]', + 'a[1:1]', + ('a[(a:=1):(b:=1)]', sys.version_info >= (3, 8)), + 'a[lambda:1:lambda:2]', + 'a[1 if True else 1:2 if True else 2]', + 'a[b.do:b.do]', + "a[''.join():''.join()]", + 'a[1,2:1,2]', + 'a[1:1:1]', + ('a[(a:=1):(b:=1):(c:=1)]', sys.version_info >= (3, 8)), + 'a[lambda:1:lambda:2:lambda:3]', + 'a[1 if True else 1:2 if True else 2:3 if True else 3]', + 'a[b.do:b.do:b.do]', + "a[''.join():''.join():''.join()]", + 'a[1,2:1,2:1,2]', + "a[('a','a'),:]", + ('a[(*c,)]', sys.version_info >= (3, 0)), + ('a[(*c,1)]', sys.version_info >= (3, 0)), + ('a[(*a,*b)]', sys.version_info >= (3, 0)), + ('a[(*a,*b):(*a,*b)]', sys.version_info >= (3, 0)), + ('a[(*a,*b):(*a,*b):(*a,*b)]', sys.version_info >= (3, 0)), + 'x[name]', + 'x[1:2]', + 'x[1:2,3]', + 'x[()]', + 'x[1:2,2:2]', + 'x[a,...,b:c]', + 'x[a,...,b]', + 'x[a,b]', + 'x[a:b,]', + 'testme[:42,...,:24:None,24,100]' +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_slice(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + '{1:1}', + '{(1,1):(1,1)}', + '{(1,):(1,)}', + '{():()}', + ('{(a:=1):(a:=1)}', sys.version_info >= (3, 8)), + '{lambda:1:lambda:1}', + '{1 if True else 1:1 if True else 1}', + '{b.do:b.do}', + "{''.join():''.join()}", +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_dict(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + '{1}', + '{1,1}', + '{(1,)}', + '{()}', + ('{a:=1}', sys.version_info >= (3, 8)), + '{lambda:1}', + '{1 if True else 1}', + '{b.do}', + "{''.join()}", +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_set(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + '[1,1]', + '[(1,1),(1,1)]', + '[(1,),(1,)]', + '[(),()]', + ('[a:=1,b:=1]', sys.version_info >= (3, 8)), + '[lambda:1,lambda:1]', + '[1 if True else 1,1 if True else 1]', + '[b.do,b.do]', + "[''.join(),''.join()]", +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_list(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement diff --git a/test/test_slice.py b/test/test_slice.py deleted file mode 100644 index f2f68f3c..00000000 --- a/test/test_slice.py +++ /dev/null @@ -1,23 +0,0 @@ -import ast - -from python_minifier import unparse -from python_minifier.ast_compare import compare_ast - -def test_slice(): - """AST for slices was changed in 3.9""" - - source = ''' -x[name] -x[1:2] -x[1:2, 3] -x[()] -x[1:2, 2:2] -x[a, ..., b:c] -x[a, ..., b] -x[(a, b)] -x[a:b,] -''' - - expected_ast = ast.parse(source) - actual_ast = unparse(expected_ast) - compare_ast(expected_ast, ast.parse(actual_ast)) diff --git a/test/test_statements.py b/test/test_statements.py new file mode 100644 index 00000000..7b4b2590 --- /dev/null +++ b/test/test_statements.py @@ -0,0 +1,533 @@ +""" +Test statements correctly use parentheses when needed + +The important things to test are expressions that might need parentheses: +- lambda +- named expressions +- tuples + - empty + - single element + - multiple elements +- yield +- yield from +- attribute access + +""" + +import ast +import sys + +import pytest + +from python_minifier import unparse +from python_minifier.ast_compare import compare_ast +from skip_invalid import skip_invalid + +@pytest.mark.parametrize('statement', [ + 'a=1', + 'a=b=1', + 'a=1,', + 'a=b=1,', + 'a=1,2', + 'a=b=1,2', + 'a=()', + 'a=b=()', + ('a=*a', sys.version_info >= (3, 0)), + ('a=*a,b', sys.version_info >= (3, 0)), + ('a=b=*a', sys.version_info >= (3, 0)), + ('a=*a,*c', sys.version_info >= (3, 0)), + ('a=b=*a,*c', sys.version_info >= (3, 0)), + 'a=lambda:1', + 'a=lambda a:1,', + 'a=1,lambda a:1', + ('a=*a,1,lambda a:1', sys.version_info >= (3, 0)), + ('a=(b:=1)', sys.version_info >= (3, 8)), + ('a=b=(c:=1)', sys.version_info >= (3, 8)), + 'a=1 if True else 1', + 'a=b,1 if True else 1', + 'a=1 if True else 1,', + 'a=1 if True else 1,b', + 'a=yield', + 'a=yield 1', + ('a=yield from 1', sys.version_info >= (3, 3)), + 'a=b.do', + "a=''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_assign(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'a:int=1', + + ('a:int=1,', sys.version_info >= (3, 8)), + ('a:int=(1,)', sys.version_info < (3, 8)), + + ('a:int=1,2', sys.version_info >= (3, 8)), + ('a:int=(1,2)', sys.version_info < (3, 8)), + + 'a:int=()', + + ('a:int=*a', sys.version_info >= (3, 8)), + ('a:int=(*a)', sys.version_info < (3, 8)), + + ('a:int=*a,b', sys.version_info >= (3, 8)), + ('a:int=(*a,b)', sys.version_info < (3, 8)), + + ('a:int=*a,*c', sys.version_info >= (3, 8)), + ('a:int=(*a,*c)', sys.version_info < (3, 8)), + + 'a:int=lambda:1', + + ('a:int=lambda a:1,', sys.version_info >= (3, 8)), + ('a:int=1,lambda a:1', sys.version_info >= (3, 8)), + + ('a:int=*a,1,lambda a:1', sys.version_info >= (3, 8)), + ('a:int=(*a,1,lambda a:1)', sys.version_info < (3, 8)), + + ('a:int=(b:=1)', sys.version_info >= (3, 8)), + + 'a:int=1 if True else 1', + + ('a:int=b,1 if True else 1', sys.version_info >= (3, 8)), + ('a:int=(b,1 if True else 1)', sys.version_info < (3, 8)), + + ('a:int=1 if True else 1,', sys.version_info >= (3, 8)), + ('a:int=(1 if True else 1,)', sys.version_info < (3, 8)), + + ('a:int=1 if True else 1,b', sys.version_info >= (3, 8)), + ('a:int=(1 if True else 1,b)', sys.version_info < (3, 8)), + + ('a:int=yield', sys.version_info >= (3, 8)), + ('a:int=(yield)', sys.version_info < (3, 8)), + + ('a:int=yield 1', sys.version_info >= (3, 8)), + ('a:int=(yield 1)', sys.version_info < (3, 8)), + + ('a:int=yield from 1', sys.version_info >= (3, 8)), + ('a:int=(yield from 1)', sys.version_info < (3, 8)), + + 'a:int=b.do', + "a:int=''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_annassign(statement): + if sys.version_info < (3, 6): + pytest.skip('annotations not supported') + + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'a+=1', + 'a+=1,', + 'a+=1,2', + 'a+=()', + ('a+=*a', sys.version_info >= (3, 9)), + ('a+=*a,b', sys.version_info >= (3, 9)), + ('a+=*a,*c', sys.version_info >= (3, 9)), + ('a+=(*a)', (3, 0) < sys.version_info < (3, 9)), + ('a+=(*a,b)', (3, 0) < sys.version_info < (3, 9)), + ('a+=(*a,*c)', (3, 0) < sys.version_info < (3, 9)), + 'a+=lambda:1', + 'a+=lambda a:1,', + 'a+=1,lambda a:1', + ('a+=*a,1,lambda a:1', sys.version_info >= (3, 9)), + ('a+=(*a,1,lambda a:1)', (3, 0) <= sys.version_info < (3, 9)), + ('a+=(b:=1)', sys.version_info >= (3, 8)), + 'a+=1 if True else 1', + 'a+=b,1 if True else 1', + 'a+=1 if True else 1,', + 'a+=1 if True else 1,b', + 'a+=yield', + 'a+=yield 1', + ('a+=yield from 1', sys.version_info >= (3, 3)), + 'a+=b.do', + "a+=''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_augassign(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + '1', + '1,', + '1,2', + '()', + ('*a', sys.version_info >= (3, 0)), + ('*a,b', sys.version_info >= (3, 0)), + ('*a,*c', sys.version_info >= (3, 0)), + 'lambda:1', + 'lambda a:1,', + '1,lambda a:1', + ('*a,1,lambda a:1', sys.version_info >= (3, 0)), + ('lambda:(a:=1)', sys.version_info >= (3, 8)), + 'lambda:(yield)', + 'lambda:(yield a)', + 'lambda:(yield a,)', + 'lambda:(yield a,b)', + ('lambda:(yield(b:=1))', sys.version_info >= (3, 8)), + ('lambda:(yield from a)', sys.version_info >= (3, 3)), + ('lambda:(yield from(a,))', sys.version_info >= (3, 3)), + ('lambda:(yield from(a,b))', sys.version_info >= (3, 3)), + ('(b:=1)', sys.version_info >= (3, 8)), + '1 if True else 1', + 'b,1 if True else 1', + '1 if True else 1,', + '1 if True else 1,b', + 'yield', + 'yield 1', + ('yield from 1', sys.version_info >= (3, 3)), + 'b.do', + "''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_expression(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'assert 1', + 'assert 1,msg', + ('assert 1,(a:=1)', sys.version_info >= (3, 8)), + 'assert(1,2)', + 'assert(1,2),msg', + 'assert()', + 'assert(),msg', + 'assert lambda:1', + 'assert lambda a:1,msg', + 'assert(lambda:1,a),msg', + 'assert 1,lambda a:1', + ('assert(b:=1)', sys.version_info >= (3, 8)), + ('assert(b:=1),(c:=1)', sys.version_info >= (3, 8)), + 'assert 1 if True else 1', + 'assert(b,1 if True else 1),msg', + 'assert 1 if True else 1,msg', + 'assert(1 if True else 1,b)', + 'assert(yield)', + 'assert(yield 1)', + ('assert(yield from 1)', sys.version_info >= (3, 3)), + 'assert b.do', + "assert''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_assert(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'del a', + 'del a,b', + ('del()', sys.version_info >= (3, 0)), + 'del b.do', +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_del(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'return', + 'return 1', + 'return 1,', + 'return 1,2', + 'return()', + ('return*a', sys.version_info >= (3, 8)), + ('return*a,b', sys.version_info >= (3, 8)), + ('return*a', sys.version_info >= (3, 8)), + ('return*a,*c', sys.version_info >= (3, 8)), + ('return*a,*c', sys.version_info >= (3, 8)), + ('return(*a)', (3, 0) < sys.version_info < (3, 8)), + ('return(*a,b)', (3, 0) < sys.version_info < (3, 8)), + ('return(*a)', (3, 0) < sys.version_info < (3, 8)), + ('return(*a,*c)', (3, 0) < sys.version_info < (3, 8)), + ('return(*a,*c)', (3, 0) < sys.version_info < (3, 8)), + 'return lambda:1', + 'return lambda a:1,', + 'return 1,lambda a:1', + ('return*a,1,lambda a:1', sys.version_info >= (3, 8)), + ('return(*a,1,lambda a:1)', (3, 0) < sys.version_info < (3, 8)), + ('return(b:=1)', sys.version_info >= (3, 8)), + 'return 1 if True else 1', + 'return b,1 if True else 1', + 'return 1 if True else 1,', + 'return 1 if True else 1,b', + 'return b.do', + "return''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_return(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'yield 1', + 'yield 1,', + 'yield 1,2', + 'yield()', + ('yield*a', sys.version_info >= (3, 8)), + ('yield*a,b', sys.version_info >= (3, 8)), + ('yield*a,*c', sys.version_info >= (3, 8)), + ('yield(*a)', (3, 0) < sys.version_info < (3, 8)), + ('yield(*a,b)', (3, 0) < sys.version_info < (3, 8)), + ('yield(*a,*c)', (3, 0) < sys.version_info < (3, 8)), + 'yield lambda:1', + 'yield lambda a:1,', + 'yield 1,lambda a:1', + ('yield*a,1,lambda a:1', sys.version_info >= (3, 8)), + ('yield(*a,1,lambda a:1)', (3, 0) < sys.version_info < (3, 8)), + ('yield(b:=1)', sys.version_info >= (3, 8)), + 'yield 1 if True else 1', + 'yield b,1 if True else 1', + 'yield 1 if True else 1,', + 'yield 1 if True else 1,b', + ('yield from 1', sys.version_info >= (3, 3)), + ('yield from(1,)', sys.version_info >= (3, 3)), + ('yield from(1,2)', sys.version_info >= (3, 3)), + 'yield b.do', + "yield''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_yield(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'raise 1', + 'raise(1,)', + 'raise(1,2)', + 'raise()', + 'raise lambda:1', + 'raise(lambda a:1,)', + 'raise(1,lambda a:1)', + ('raise(*a,1,lambda a:1)', sys.version_info >= (3,0)), + ('raise(b:=1)', sys.version_info >= (3,8)), + 'raise 1 if True else 1', + 'raise(b,1 if True else 1)', + 'raise(1 if True else 1,)', + 'raise(1 if True else 1,b)', + 'raise b.do', + "raise''.join()", + ('raise 1 from 1', sys.version_info >= (3, 0)), + ('raise(1,)from(1,)', sys.version_info >= (3, 0)), + ('raise(1,2)from(1,2)', sys.version_info >= (3, 0)), + ('raise()from()', sys.version_info >= (3, 0)), + ('raise lambda:1 from lambda:1', sys.version_info >= (3, 0)), + ('raise(lambda a:1,)from(lambda a:1,)', sys.version_info >= (3, 0)), + ('raise(1,lambda a:1)from(1,lambda a:1)', sys.version_info >= (3, 0)), + ('raise(*a,1,lambda a:1)from(*a,1,lambda a:1)', sys.version_info >= (3, 0)), + ('raise(b:=1)from(b:=1)', sys.version_info >= (3,8)), + ('raise 1 if True else 1 from 1 if True else 1', sys.version_info >= (3, 0)), + ('raise(b,1 if True else 1)from(b,1 if True else 1)', sys.version_info >= (3, 0)), + ('raise(1 if True else 1,)from(1 if True else 1,)', sys.version_info >= (3, 0)), + ('raise(1 if True else 1,b)from(1 if True else 1,b)', sys.version_info >= (3, 0)), + ('raise b.do from b.do', sys.version_info >= (3, 0)), + ("raise''.join()from''.join()" , sys.version_info >= (3, 0)), +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_raise(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'if 1:pass', + 'if(1,):pass', + 'if(1,2):pass', + 'if():pass', + ('if(*a,):pass', sys.version_info > (3, 0)), + ('if(*a,b):pass', sys.version_info > (3, 0)), + ('if(*a,*c):pass', sys.version_info > (3, 0)), + 'if lambda:1:pass', + 'if(lambda a:1,):pass', + 'if(1,lambda a:1):pass', + ('if(*a,1,lambda a:1):pass', sys.version_info > (3, 0)), + ('if b:=1:pass', sys.version_info >= (3,8)), + 'if 1 if True else 1:pass', + 'if(b,1 if True else 1):pass', + 'if(1 if True else 1,):pass', + 'if(1 if True else 1,b):pass', + 'if(yield):pass', + 'if(yield 1):pass', + ('if(yield from 1):pass', sys.version_info >= (3, 3)), + 'if b.do:pass', + "if''.join():pass" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_if(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'while 1:pass', + 'while(1,):pass', + 'while(1,2):pass', + 'while():pass', + ('while(*a,):pass', sys.version_info >= (3, 0)), + ('while(*a,b):pass', sys.version_info >= (3, 0)), + ('while(*a,*c):pass', sys.version_info >= (3, 0)), + 'while lambda:1:pass', + 'while(lambda a:1,):pass', + 'while(1,lambda a:1):pass', + ('while(*a,1,lambda a:1):pass', sys.version_info >= (3, 0)), + ('while b:=1:pass', sys.version_info >= (3, 8)), + 'while 1 if True else 1:pass', + 'while(b,1 if True else 1):pass', + 'while(1 if True else 1,):pass', + 'while(1 if True else 1,b):pass', + 'while(yield):pass', + 'while(yield 1):pass', + ('while(yield from 1):pass', sys.version_info >= (3, 3)), + 'while b.do:pass', + "while''.join():pass" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_while(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + 'for a in a:pass', + 'for a,in a:pass', + 'for a,b in a:pass', + ('for()in a:pass', sys.version_info >= (3, 0)), + ('for*a in a:pass', sys.version_info >= (3, 0)), + ('for*a,b in a:pass', sys.version_info >= (3, 0)), + ('for*a,*c in a:pass', sys.version_info >= (3, 0)), + 'for b.do in a:pass', + 'for a in b:pass', + 'for a in b,:pass', + 'for a in b,c:pass', + 'for a in():pass', + ('for a in*a:pass', sys.version_info >= (3, 9)), + ('for a in*a,b:pass', sys.version_info >= (3, 9)), + ('for a in*a,*c:pass', sys.version_info >= (3, 9)), + ('for a in(*a):pass', (3, 0) < sys.version_info < (3, 9)), + ('for a in(*a,b):pass', (3, 0) < sys.version_info < (3, 9)), + ('for a in(*a,*c):pass', (3, 0) < sys.version_info < (3, 9)), + 'for a in lambda:1:pass', + 'for a in lambda a:1,:pass', + 'for a in 1,lambda a:1:pass', + ('for a in*a,1,lambda a:1:pass', sys.version_info >= (3, 9)), + ('for a in(*a,1,lambda a:1):pass', (3, 0) < sys.version_info < (3, 9)), + ('for a in(b:=1):pass', sys.version_info >= (3, 8)), + 'for a in 1 if True else 1:pass', + 'for a in b,1 if True else 1:pass', + 'for a in 1 if True else 1,:pass', + 'for a in 1 if True else 1,b:pass', + 'for a in(yield):pass', + 'for a in(yield 1):pass', + ('for a in(yield from 1):pass', sys.version_info >= (3, 3)), + 'for a in b.do:pass', + "for a in''.join():pass" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_for(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + + +@pytest.mark.parametrize('statement', [ + ' A', + '(A,)', + '(A,A)', + '()', + ('*a', sys.version_info >= (3, 11)), + ('(*a,b)', sys.version_info > (3, 0)), + ('(*a,*c)', sys.version_info > (3, 0)), + ' lambda:A', + '(lambda a:A,)', + '(A,lambda a:A)', + ('(*a,A,lambda a:A)', sys.version_info > (3, 0)), + ('(b:=A)', sys.version_info >= (3, 8)), + ' A if True else A', + '(b,A if True else A)', + '(A if True else A,)', + '(A if True else A,b)', + '(yield)', + '(yield A)', + ('(yield from A)', sys.version_info >= (3, 3)), + ' b.do', + "''.join()" +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_except(statement): + + statement = 'try:pass\nexcept' + statement + ':pass' + + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement + +@pytest.mark.parametrize('statement', [ + 'with 1:pass', + 'with 1,2:pass', + 'with():pass', + 'with lambda:1:pass', + 'with 1,lambda a:1:pass', + ('with(b:=1):pass', sys.version_info >= (3, 8)), + 'with 1 if True else 1:pass', + 'with b,1 if True else 1:pass', + 'with(yield):pass', + 'with(yield 1):pass', + ('with(yield from 1):pass', sys.version_info >= (3, 3)), + 'with b.do:pass', + "with''.join():pass", + 'with 1 as a:pass', + 'with 1,2 as a:pass', + 'with()as a:pass', + 'with lambda:1 as a:pass', + 'with 1,lambda a:1 as a,b:pass', + ('with(b:=1)as a:pass', sys.version_info >= (3, 8)), + 'with 1 if True else 1 as a:pass', + 'with b,1 if True else 1 as a:pass', + 'with(yield)as a:pass', + 'with(yield 1)as a:pass', + ('with(yield from 1)as a:pass', sys.version_info >= (3, 3)), + 'with b.do as a:pass', + "with''.join()as a:pass", +], ids=lambda s: s[0] if isinstance(s, tuple) else s) +@skip_invalid +def test_with(statement): + expected_ast = ast.parse(statement) + minified = unparse(expected_ast) + compare_ast(expected_ast, ast.parse(minified)) + assert minified == statement +