Skip to content

Commit 59e373d

Browse files
committed
Cleanup
1 parent b96457c commit 59e373d

File tree

1 file changed

+21
-26
lines changed
  • pybind11_stubgen/parser/mixins

1 file changed

+21
-26
lines changed

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,13 @@ class _NumpyArrayAnnotation:
543543
),
544544
)
545545
)
546-
__DIM_VARS = ["n", "m"]
547-
__DIM_STRING_PATTERN = re.compile(r'"\[(.*?)\]"')
546+
numpy_flags: set[QualifiedName] = {
547+
QualifiedName.from_str("flags.writeable"),
548+
QualifiedName.from_str("flags.c_contiguous"),
549+
QualifiedName.from_str("flags.f_contiguous"),
550+
}
551+
dim_vars: set[str] = {"n", "m"}
552+
__dim_string_pattern = re.compile(r'"\[(.*?)\]"')
548553

549554
def __init__(
550555
self,
@@ -696,7 +701,7 @@ def _from_new_style(
696701
if dims_and_flags:
697702
dims_str_param = dims_and_flags[0]
698703
if isinstance(dims_str_param, Value):
699-
match = cls.__DIM_STRING_PATTERN.search(dims_str_param.repr)
704+
match = cls.__dim_string_pattern.search(dims_str_param.repr)
700705
if match:
701706
dims_str_content = match.group(1)
702707
dims_list = [
@@ -721,22 +726,17 @@ def _to_dims(
721726
return None
722727
elif isinstance(dim_param, ResolvedType):
723728
dim = str(dim_param)
724-
if len(dim) > 1 and dim not in cls.__DIM_VARS:
729+
if len(dim) > 1 and dim not in cls.dim_vars:
725730
return None
726731
else:
727732
return None
728733
result.append(dim)
729734
return result
730735

731-
@staticmethod
732-
def _fix_flags(flags: list[ResolvedType | Value | InvalidExpression]):
733-
__flags: set[QualifiedName] = {
734-
QualifiedName.from_str("flags.writeable"),
735-
QualifiedName.from_str("flags.c_contiguous"),
736-
QualifiedName.from_str("flags.f_contiguous"),
737-
}
736+
@classmethod
737+
def _fix_flags(cls, flags: list[ResolvedType | Value | InvalidExpression]):
738738
for flag in flags:
739-
if isinstance(flag, ResolvedType) and flag.name in __flags:
739+
if isinstance(flag, ResolvedType) and flag.name in cls.numpy_flags:
740740
flag.name = QualifiedName.from_str(f"numpy.ndarray.{flag.name}")
741741

742742
@staticmethod
@@ -756,7 +756,6 @@ class FixNumpyArrayDimAnnotation(IParser):
756756
# `typing_extension.Annotated` in different python versions
757757
# Rely on later fix by `FixTypingTypeNames`
758758
__annotated_name = QualifiedName.from_str("Annotated")
759-
__DIM_VARS = _NumpyArrayAnnotation._NumpyArrayAnnotation__DIM_VARS
760759

761760
def parse_annotation_str(
762761
self, annotation_str: str
@@ -825,20 +824,18 @@ def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicS
825824
return return_t(*dims) # type: ignore[arg-type]
826825

827826
def report_error(self, error: ParserError) -> None:
828-
__flags: set[QualifiedName] = {
829-
QualifiedName.from_str("flags.writeable"),
830-
QualifiedName.from_str("flags.c_contiguous"),
831-
QualifiedName.from_str("flags.f_contiguous"),
832-
}
833827
if (
834828
isinstance(error, NameResolutionError)
835829
and len(error.name) == 1
836830
and len(error.name[0]) == 1
837-
and error.name[0] in self.__DIM_VARS
831+
and error.name[0] in _NumpyArrayAnnotation.dim_vars
838832
):
839833
# Ignores all unknown 'm' and 'n' regardless of the context
840834
return
841-
if isinstance(error, NameResolutionError) and error.name in __flags:
835+
if (
836+
isinstance(error, NameResolutionError)
837+
and error.name in _NumpyArrayAnnotation.numpy_flags
838+
):
842839
return
843840
super().report_error(error)
844841

@@ -903,19 +900,17 @@ def on_dynamic_dim(dim: str) -> None:
903900
)
904901

905902
def report_error(self, error: ParserError) -> None:
906-
__flags: set[QualifiedName] = {
907-
QualifiedName.from_str("flags.writeable"),
908-
QualifiedName.from_str("flags.c_contiguous"),
909-
QualifiedName.from_str("flags.f_contiguous"),
910-
}
911903
if (
912904
isinstance(error, NameResolutionError)
913905
and len(error.name) == 1
914906
and error.name[0] in self.__DIM_VARS
915907
):
916908
# allow type variables, which are manually resolved in `handle_module`
917909
return
918-
if isinstance(error, NameResolutionError) and error.name in __flags:
910+
if (
911+
isinstance(error, NameResolutionError)
912+
and error.name in _NumpyArrayAnnotation.numpy_flags
913+
):
919914
return
920915
super().report_error(error)
921916

0 commit comments

Comments
 (0)