Skip to content

Commit a947b55

Browse files
BUG: use arrow backend for digits references in str.replace (#62872)
Co-authored-by: zishan044 <winchesterfelix007@gmail.com>
1 parent 1c06e34 commit a947b55

File tree

4 files changed

+73
-9
lines changed

4 files changed

+73
-9
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ Conversion
10351035

10361036
Strings
10371037
^^^^^^^
1038+
- Bug in :meth:`Series.str.replace` raising an error on valid group references (``\1``, ``\2``, etc.) on series converted to PyArrow backend dtype (:issue:`62653`)
10381039
- Bug in :meth:`Series.str.zfill` raising ``AttributeError`` for :class:`ArrowDtype` (:issue:`61485`)
10391040
- Bug in :meth:`Series.value_counts` would not respect ``sort=False`` for series having ``string`` dtype (:issue:`55224`)
10401041
- Bug in multiplication with a :class:`StringDtype` incorrectly allowing multiplying by bools; explicitly cast to integers instead (:issue:`62595`)

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,12 @@ def _str_replace(
173173
or callable(repl)
174174
or not case
175175
or flags
176-
or (
177-
isinstance(repl, str)
178-
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
179-
)
176+
or (isinstance(repl, str) and r"\g<" in repl)
180177
):
181178
raise NotImplementedError(
182179
"replace is not supported with a re.Pattern, callable repl, "
183180
"case=False, flags!=0, or when the replacement string contains "
184-
"named group references (\\g<...>, \\d+)"
181+
"named group references (\\g<...>)"
185182
)
186183

187184
func = pc.replace_substring_regex if regex else pc.replace_substring

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ def _str_replace(
425425
or flags
426426
or ( # substitution contains a named group pattern
427427
# https://docs.python.org/3/library/re.html
428-
isinstance(repl, str)
429-
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
428+
isinstance(repl, str) and r"\g<" in repl
430429
)
431430
):
432431
return super()._str_replace(pat, repl, n, case, flags, regex)

pandas/tests/strings/test_find_replace.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
from pandas import (
1111
Series,
12+
StringDtype,
1213
_testing as tm,
1314
)
1415
from pandas.tests.strings import (
@@ -584,6 +585,10 @@ def test_replace_callable_raises(any_string_dtype, repl):
584585
r"\g<three> \g<two> \g<one>",
585586
["Three Two One", "Baz Bar Foo"],
586587
),
588+
(
589+
r"\3 \2 \1",
590+
["Three Two One", "Baz Bar Foo"],
591+
),
587592
(
588593
r"\g<3> \g<2> \g<1>",
589594
["Three Two One", "Baz Bar Foo"],
@@ -599,6 +604,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
599604
],
600605
ids=[
601606
"named_groups_full_swap",
607+
"numbered_groups_no_g_full_swap",
602608
"numbered_groups_full_swap",
603609
"single_group_with_literal",
604610
"mixed_group_reference_with_literal",
@@ -623,22 +629,83 @@ def test_replace_named_groups_regex_swap(
623629
[
624630
r"\g<20>",
625631
r"\20",
632+
r"\40",
633+
r"\4",
626634
],
627635
)
628636
@pytest.mark.parametrize("use_compile", [True, False])
629637
def test_replace_named_groups_regex_swap_expected_fail(
630-
any_string_dtype, repl, use_compile
638+
any_string_dtype, repl, use_compile, request
631639
):
632640
# GH#57636
641+
if (
642+
not use_compile
643+
and r"\g" not in repl
644+
and isinstance(any_string_dtype, StringDtype)
645+
and any_string_dtype.storage == "pyarrow"
646+
):
647+
# calls pyarrow method directly
648+
if repl == r"\20":
649+
mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal")
650+
request.applymarker(mark)
651+
652+
pa = pytest.importorskip("pyarrow")
653+
error_type = pa.ArrowInvalid
654+
error_msg = r"only has \d parenthesized subexpressions"
655+
else:
656+
error_type = re.error
657+
error_msg = "invalid group reference"
658+
633659
pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
634660
if use_compile:
635661
pattern = re.compile(pattern)
636662
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)
637663

638-
with pytest.raises(re.error, match="invalid group reference"):
664+
with pytest.raises(error_type, match=error_msg):
639665
ser.str.replace(pattern, repl, regex=True)
640666

641667

668+
@pytest.mark.parametrize(
669+
"pattern, repl",
670+
[
671+
(r"(\w+) (\w+) (\w+)", r"\20"),
672+
(r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)", r"\20"),
673+
],
674+
)
675+
def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl):
676+
# GH#62653
677+
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype)
678+
679+
result = ser.str.replace(pattern, repl, regex=True)
680+
expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype)
681+
tm.assert_series_equal(result, expected)
682+
683+
684+
@pytest.mark.parametrize(
685+
"pattern, repl, expected_list",
686+
[
687+
(
688+
r"\[(?P<one>\d+)\]",
689+
r"(\1)",
690+
["var.one(0)", "var.two(1)", "var.three(2)"],
691+
),
692+
(
693+
r"\[(\d+)\]",
694+
r"(\1)",
695+
["var.one(0)", "var.two(1)", "var.three(2)"],
696+
),
697+
],
698+
)
699+
@td.skip_if_no("pyarrow")
700+
def test_pyarrow_backend_group_replacement(pattern, repl, expected_list):
701+
ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes(
702+
dtype_backend="pyarrow"
703+
)
704+
result = ser.str.replace(pattern, repl, regex=True)
705+
expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow")
706+
tm.assert_series_equal(result, expected)
707+
708+
642709
def test_replace_callable_named_groups(any_string_dtype):
643710
# test regex named groups
644711
ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)

0 commit comments

Comments
 (0)