Skip to content

Commit a6fbed7

Browse files
REGR: fix string contains/match methods with compiled regex with flags (pandas-dev#62251)
1 parent d1c8ce6 commit a6fbed7

File tree

5 files changed

+158
-24
lines changed

5 files changed

+158
-24
lines changed

doc/source/whatsnew/v2.3.3.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ become the default string dtype in pandas 3.0. See
2222

2323
Bug fixes
2424
^^^^^^^^^
25-
-
25+
- Fix regression in ``~Series.str.contains``, ``~Series.str.match`` and ``~Series.str.fullmatch``
26+
with a compiled regex and custom flags (:issue:`62240`)
2627

2728
.. ---------------------------------------------------------------------------
2829
.. _whatsnew_233.contributors:

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,29 +301,23 @@ def _str_contains(
301301

302302
def _str_match(
303303
self,
304-
pat: str | re.Pattern,
304+
pat: str,
305305
case: bool = True,
306306
flags: int = 0,
307307
na: Scalar | lib.NoDefault = lib.no_default,
308308
):
309-
if isinstance(pat, re.Pattern):
310-
# GH#61952
311-
pat = pat.pattern
312-
if isinstance(pat, str) and not pat.startswith("^"):
309+
if not pat.startswith("^"):
313310
pat = f"^{pat}"
314311
return self._str_contains(pat, case, flags, na, regex=True)
315312

316313
def _str_fullmatch(
317314
self,
318-
pat: str | re.Pattern,
315+
pat: str,
319316
case: bool = True,
320317
flags: int = 0,
321318
na: Scalar | lib.NoDefault = lib.no_default,
322319
):
323-
if isinstance(pat, re.Pattern):
324-
# GH#61952
325-
pat = pat.pattern
326-
if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")):
320+
if not pat.endswith("$") or pat.endswith("\\$"):
327321
pat = f"{pat}$"
328322
return self._str_match(pat, case, flags, na)
329323

pandas/core/arrays/string_arrow.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pandas._typing import (
5252
ArrayLike,
5353
Dtype,
54+
Scalar,
5455
Self,
5556
npt,
5657
)
@@ -329,8 +330,6 @@ def _data(self):
329330
_str_startswith = ArrowStringArrayMixin._str_startswith
330331
_str_endswith = ArrowStringArrayMixin._str_endswith
331332
_str_pad = ArrowStringArrayMixin._str_pad
332-
_str_match = ArrowStringArrayMixin._str_match
333-
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
334333
_str_lower = ArrowStringArrayMixin._str_lower
335334
_str_upper = ArrowStringArrayMixin._str_upper
336335
_str_strip = ArrowStringArrayMixin._str_strip
@@ -345,6 +344,28 @@ def _data(self):
345344
_str_len = ArrowStringArrayMixin._str_len
346345
_str_slice = ArrowStringArrayMixin._str_slice
347346

347+
@staticmethod
348+
def _is_re_pattern_with_flags(pat: str | re.Pattern) -> bool:
349+
# check if `pat` is a compiled regex pattern with flags that are not
350+
# supported by pyarrow
351+
return (
352+
isinstance(pat, re.Pattern)
353+
and (pat.flags & ~(re.IGNORECASE | re.UNICODE)) != 0
354+
)
355+
356+
@staticmethod
357+
def _preprocess_re_pattern(pat: re.Pattern, case: bool) -> tuple[str, bool, int]:
358+
pattern = pat.pattern
359+
flags = pat.flags
360+
# flags is not supported by pyarrow, but `case` is -> extract and remove
361+
if flags & re.IGNORECASE:
362+
case = False
363+
flags = flags & ~re.IGNORECASE
364+
# when creating a pattern with re.compile and a string, it automatically
365+
# gets a UNICODE flag, while pyarrow assumes unicode for strings anyway
366+
flags = flags & ~re.UNICODE
367+
return pattern, case, flags
368+
348369
def _str_contains(
349370
self,
350371
pat,
@@ -353,13 +374,42 @@ def _str_contains(
353374
na=lib.no_default,
354375
regex: bool = True,
355376
):
356-
if flags:
377+
if flags or self._is_re_pattern_with_flags(pat):
357378
return super()._str_contains(pat, case, flags, na, regex)
358379
if isinstance(pat, re.Pattern):
359-
pat = pat.pattern
380+
# TODO flags passed separately by user are ignored
381+
pat, case, flags = self._preprocess_re_pattern(pat, case)
360382

361383
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
362384

385+
def _str_match(
386+
self,
387+
pat: str | re.Pattern,
388+
case: bool = True,
389+
flags: int = 0,
390+
na: Scalar | lib.NoDefault = lib.no_default,
391+
):
392+
if flags or self._is_re_pattern_with_flags(pat):
393+
return super()._str_match(pat, case, flags, na)
394+
if isinstance(pat, re.Pattern):
395+
pat, case, flags = self._preprocess_re_pattern(pat, case)
396+
397+
return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)
398+
399+
def _str_fullmatch(
400+
self,
401+
pat: str | re.Pattern,
402+
case: bool = True,
403+
flags: int = 0,
404+
na: Scalar | lib.NoDefault = lib.no_default,
405+
):
406+
if flags or self._is_re_pattern_with_flags(pat):
407+
return super()._str_fullmatch(pat, case, flags, na)
408+
if isinstance(pat, re.Pattern):
409+
pat, case, flags = self._preprocess_re_pattern(pat, case)
410+
411+
return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na)
412+
363413
def _str_replace(
364414
self,
365415
pat: str | re.Pattern,

pandas/core/strings/object_array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ def _str_match(
252252
):
253253
if not case:
254254
flags |= re.IGNORECASE
255-
if isinstance(pat, re.Pattern):
256-
pat = pat.pattern
255+
257256
regex = re.compile(pat, flags=flags)
258257

259258
f = lambda x: regex.match(x) is not None
@@ -268,8 +267,7 @@ def _str_fullmatch(
268267
):
269268
if not case:
270269
flags |= re.IGNORECASE
271-
if isinstance(pat, re.Pattern):
272-
pat = pat.pattern
270+
273271
regex = re.compile(pat, flags=flags)
274272

275273
f = lambda x: regex.fullmatch(x) is not None

pandas/tests/strings/test_find_replace.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,60 @@ def test_contains_nan(any_string_dtype):
292292

293293
def test_contains_compiled_regex(any_string_dtype):
294294
# GH#61942
295-
ser = Series(["foo", "bar", "baz"], dtype=any_string_dtype)
295+
expected_dtype = (
296+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
297+
)
298+
299+
ser = Series(["foo", "bar", "Baz"], dtype=any_string_dtype)
300+
296301
pat = re.compile("ba.")
297302
result = ser.str.contains(pat)
303+
expected = Series([False, True, False], dtype=expected_dtype)
304+
tm.assert_series_equal(result, expected)
305+
306+
# TODO this currently works for pyarrow-backed dtypes but raises for python
307+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
308+
result = ser.str.contains(pat, case=False)
309+
expected = Series([False, True, True], dtype=expected_dtype)
310+
tm.assert_series_equal(result, expected)
311+
else:
312+
with pytest.raises(
313+
ValueError, match="cannot process flags argument with a compiled pattern"
314+
):
315+
ser.str.contains(pat, case=False)
316+
317+
pat = re.compile("ba.", flags=re.IGNORECASE)
318+
result = ser.str.contains(pat)
319+
expected = Series([False, True, True], dtype=expected_dtype)
320+
tm.assert_series_equal(result, expected)
298321

322+
# TODO should this be supported?
323+
with pytest.raises(
324+
ValueError, match="cannot process flags argument with a compiled pattern"
325+
):
326+
ser.str.contains(pat, flags=re.IGNORECASE)
327+
328+
329+
def test_contains_compiled_regex_flags(any_string_dtype):
330+
# ensure other (than ignorecase) flags are respected
299331
expected_dtype = (
300332
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
301333
)
334+
335+
ser = Series(["foobar", "foo\nbar", "Baz"], dtype=any_string_dtype)
336+
337+
pat = re.compile("^ba")
338+
result = ser.str.contains(pat)
339+
expected = Series([False, False, False], dtype=expected_dtype)
340+
tm.assert_series_equal(result, expected)
341+
342+
pat = re.compile("^ba", flags=re.MULTILINE)
343+
result = ser.str.contains(pat)
344+
expected = Series([False, True, False], dtype=expected_dtype)
345+
tm.assert_series_equal(result, expected)
346+
347+
pat = re.compile("^ba", flags=re.MULTILINE | re.IGNORECASE)
348+
result = ser.str.contains(pat)
302349
expected = Series([False, True, True], dtype=expected_dtype)
303350
tm.assert_series_equal(result, expected)
304351

@@ -837,14 +884,36 @@ def test_match_case_kwarg(any_string_dtype):
837884

838885
def test_match_compiled_regex(any_string_dtype):
839886
# GH#61952
840-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
841-
result = values.str.match(re.compile(r"ab"), case=False)
842887
expected_dtype = (
843888
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
844889
)
890+
891+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
892+
893+
result = values.str.match(re.compile("ab"))
894+
expected = Series([True, False, True, False], dtype=expected_dtype)
895+
tm.assert_series_equal(result, expected)
896+
897+
# TODO this currently works for pyarrow-backed dtypes but raises for python
898+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
899+
result = values.str.match(re.compile("ab"), case=False)
900+
expected = Series([True, True, True, True], dtype=expected_dtype)
901+
tm.assert_series_equal(result, expected)
902+
else:
903+
with pytest.raises(
904+
ValueError, match="cannot process flags argument with a compiled pattern"
905+
):
906+
values.str.match(re.compile("ab"), case=False)
907+
908+
result = values.str.match(re.compile("ab", flags=re.IGNORECASE))
845909
expected = Series([True, True, True, True], dtype=expected_dtype)
846910
tm.assert_series_equal(result, expected)
847911

912+
with pytest.raises(
913+
ValueError, match="cannot process flags argument with a compiled pattern"
914+
):
915+
values.str.match(re.compile("ab"), flags=re.IGNORECASE)
916+
848917

849918
# --------------------------------------------------------------------------------------
850919
# str.fullmatch
@@ -917,14 +986,36 @@ def test_fullmatch_case_kwarg(any_string_dtype):
917986

918987
def test_fullmatch_compiled_regex(any_string_dtype):
919988
# GH#61952
920-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
921-
result = values.str.fullmatch(re.compile(r"ab"), case=False)
922989
expected_dtype = (
923990
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
924991
)
992+
993+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
994+
995+
result = values.str.fullmatch(re.compile("ab"))
996+
expected = Series([True, False, False, False], dtype=expected_dtype)
997+
tm.assert_series_equal(result, expected)
998+
999+
# TODO this currently works for pyarrow-backed dtypes but raises for python
1000+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
1001+
result = values.str.fullmatch(re.compile("ab"), case=False)
1002+
expected = Series([True, True, False, False], dtype=expected_dtype)
1003+
tm.assert_series_equal(result, expected)
1004+
else:
1005+
with pytest.raises(
1006+
ValueError, match="cannot process flags argument with a compiled pattern"
1007+
):
1008+
values.str.fullmatch(re.compile("ab"), case=False)
1009+
1010+
result = values.str.fullmatch(re.compile("ab", flags=re.IGNORECASE))
9251011
expected = Series([True, True, False, False], dtype=expected_dtype)
9261012
tm.assert_series_equal(result, expected)
9271013

1014+
with pytest.raises(
1015+
ValueError, match="cannot process flags argument with a compiled pattern"
1016+
):
1017+
values.str.fullmatch(re.compile("ab"), flags=re.IGNORECASE)
1018+
9281019

9291020
# --------------------------------------------------------------------------------------
9301021
# str.findall

0 commit comments

Comments
 (0)