Skip to content

Commit 60d6c95

Browse files
committed
test: add tests for pyarrow \d replacement
Also tests for ambiguous reference \d + another literal as digit.
1 parent a4f483f commit 60d6c95

File tree

1 file changed

+63
-11
lines changed

1 file changed

+63
-11
lines changed

pandas/tests/strings/test_find_replace.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
from pandas import (
1111
Series,
12+
StringDtype,
1213
_testing as tm,
1314
)
1415
from pandas.tests.strings import (
@@ -584,6 +585,10 @@ def test_replace_callable_raises(any_string_dtype, repl):
584585
r"\g<three> \g<two> \g<one>",
585586
["Three Two One", "Baz Bar Foo"],
586587
),
588+
(
589+
r"\3 \2 \1",
590+
["Three Two One", "Baz Bar Foo"],
591+
),
587592
(
588593
r"\g<3> \g<2> \g<1>",
589594
["Three Two One", "Baz Bar Foo"],
@@ -599,6 +604,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
599604
],
600605
ids=[
601606
"named_groups_full_swap",
607+
"numbered_groups_no_g_full_swap",
602608
"numbered_groups_full_swap",
603609
"single_group_with_literal",
604610
"mixed_group_reference_with_literal",
@@ -623,33 +629,79 @@ def test_replace_named_groups_regex_swap(
623629
[
624630
r"\g<20>",
625631
r"\20",
632+
r"\40",
633+
r"\4",
626634
],
627635
)
628636
@pytest.mark.parametrize("use_compile", [True, False])
629637
def test_replace_named_groups_regex_swap_expected_fail(
630-
any_string_dtype, repl, use_compile
638+
any_string_dtype, repl, use_compile, request
631639
):
632640
# GH#57636
641+
if (
642+
not use_compile
643+
and r"\g" not in repl
644+
and isinstance(any_string_dtype, StringDtype)
645+
and any_string_dtype.storage == "pyarrow"
646+
):
647+
# calls pyarrow method directly
648+
if repl == r"\20":
649+
mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal")
650+
request.applymarker(mark)
651+
652+
pa = pytest.importorskip("pyarrow")
653+
error_type = pa.ArrowInvalid
654+
error_msg = r"only has \d parenthesized subexpressions"
655+
else:
656+
error_type = re.error
657+
error_msg = "invalid group reference"
658+
633659
pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
634660
if use_compile:
635661
pattern = re.compile(pattern)
636662
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)
637663

638-
with pytest.raises(re.error, match="invalid group reference"):
664+
with pytest.raises(error_type, match=error_msg):
639665
ser.str.replace(pattern, repl, regex=True)
640666

641667

642-
@pytest.mark.parametrize("use_compile", [True, False])
643-
def test_replace_non_named_group(any_string_dtype, use_compile):
644-
ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"], dtype=any_string_dtype)
645-
pattern = r"\[(\d+)\]"
646-
if use_compile:
647-
pattern = re.compile(pattern)
648-
repl = r"(\1)"
668+
@pytest.mark.parametrize(
669+
"pattern, repl",
670+
[
671+
(r"(\w+) (\w+) (\w+)", r"\20"),
672+
(r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)", r"\20"),
673+
],
674+
)
675+
def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl):
676+
# GH#62653
677+
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype)
678+
649679
result = ser.str.replace(pattern, repl, regex=True)
650-
expected = Series(
651-
["var.one(0)", "var.two(1)", "var.three(2)"], dtype=any_string_dtype
680+
expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype)
681+
tm.assert_series_equal(result, expected)
682+
683+
684+
@pytest.mark.parametrize(
685+
"pattern, repl, expected_list",
686+
[
687+
(
688+
r"\[(?P<one>\d+)\]",
689+
r"(\1)",
690+
["var.one(0)", "var.two(1)", "var.three(2)"],
691+
),
692+
(
693+
r"\[(\d+)\]",
694+
r"(\1)",
695+
["var.one(0)", "var.two(1)", "var.three(2)"],
696+
),
697+
],
698+
)
699+
def test_pyarrow_backend_group_replacement(pattern, repl, expected_list):
700+
ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes(
701+
dtype_backend="pyarrow"
652702
)
703+
result = ser.str.replace(pattern, repl, regex=True)
704+
expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow")
653705
tm.assert_series_equal(result, expected)
654706

655707

0 commit comments

Comments
 (0)