Skip to content

Commit 5710696

Browse files
committed
Implement methods using AST
SignatureVariablesCollector: remove attr constructor_source from constructor since it is not used. Add skip_self attr for class static methods. Add ClassVisitor and MethodVisitor classes to astvisitors.py inspectmethod function uses AST to find methods. Add attributes arguments, is_static, is_class, return_type to UmlMethod class. Known limitation: the return type of method is not implemented.
1 parent 0974bf1 commit 5710696

File tree

8 files changed

+124
-68
lines changed

8 files changed

+124
-68
lines changed

py2puml/domain/umlclass.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import List
2-
from dataclasses import dataclass
1+
from typing import List, Dict
2+
from dataclasses import dataclass, field
33

44
from py2puml.domain.umlitem import UmlItem
55

@@ -10,16 +10,23 @@ class UmlAttribute:
1010
type: str
1111
static: bool
1212

13-
1413
@dataclass
1514
class UmlMethod:
1615
name: str
17-
signature: str
16+
arguments: Dict = field(default_factory=dict)
17+
is_static: bool = False
18+
is_class: bool = False
19+
return_type: str = None
20+
21+
@property
22+
def signature(self):
23+
if self.arguments:
24+
return ', '.join([f'{arg_type} {arg_name}' for arg_name, arg_type in self.arguments.items()])
25+
return ''
1826

1927

2028
@dataclass
2129
class UmlClass(UmlItem):
2230
attributes: List[UmlAttribute]
23-
# TODO move to UmlItem?
2431
methods: List[UmlMethod]
2532
is_abstract: bool = False

py2puml/export/puml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
PUML_FILE_END = '@enduml\n'
1313
PUML_ITEM_START_TPL = '{item_type} {item_fqn} {{\n'
1414
PUML_ATTR_TPL = ' {attr_name}: {attr_type}{staticity}\n'
15-
PUML_METHOD_TPL = ' {name}{signature}\n'
15+
PUML_METHOD_TPL = ' {staticity} {name}({signature})\n'
1616
PUML_ITEM_END = '}\n'
1717
PUML_RELATION_TPL = '{source_fqn} {rel_type}-- {target_fqn}\n'
1818

@@ -40,7 +40,7 @@ def to_puml_content(diagram_name: str, uml_items: List[UmlItem], uml_relations:
4040
for uml_attr in uml_class.attributes:
4141
yield PUML_ATTR_TPL.format(attr_name=uml_attr.name, attr_type=uml_attr.type, staticity=FEATURE_STATIC if uml_attr.static else FEATURE_INSTANCE)
4242
for uml_method in uml_class.methods:
43-
yield PUML_METHOD_TPL.format(name=uml_method.name, signature=uml_method.signature)
43+
yield PUML_METHOD_TPL.format(name=uml_method.name, signature=uml_method.signature, staticity=FEATURE_STATIC if uml_method.is_static else FEATURE_INSTANCE)
4444
yield PUML_ITEM_END
4545
else:
4646
raise TypeError(f'cannot process uml_item of type {uml_item.__class__}')

py2puml/inspection/inspectclass.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
from importlib import import_module
2-
from inspect import isabstract, getmembers, signature
2+
from inspect import isabstract, getsource
33
from re import compile as re_compile
44
from typing import Type, List, Dict
5-
5+
from ast import parse, AST
66

77
from py2puml.domain.umlitem import UmlItem
88
from py2puml.domain.umlclass import UmlClass, UmlAttribute, UmlMethod
99
from py2puml.domain.umlrelation import UmlRelation, RelType
10-
from py2puml.parsing.astvisitors import shorten_compound_type_annotation
10+
from py2puml.parsing.astvisitors import shorten_compound_type_annotation, ClassVisitor
1111
from py2puml.parsing.parseclassconstructor import parse_class_constructor
1212
from py2puml.parsing.moduleresolver import ModuleResolver
13-
# from py2puml.utils import investigate_domain_definition
1413

1514

1615
CONCRETE_TYPE_PATTERN = re_compile("^<(?:class|enum) '([\\.|\\w]+)'>$")
1716

17+
1818
def handle_inheritance_relation(
1919
class_type: Type,
2020
class_fqn: str,
@@ -28,6 +28,7 @@ def handle_inheritance_relation(
2828
UmlRelation(base_type_fqn, class_fqn, RelType.INHERITANCE)
2929
)
3030

31+
3132
def inspect_static_attributes(
3233
class_type: Type,
3334
class_type_fqn: str,
@@ -59,7 +60,7 @@ def inspect_static_attributes(
5960
# utility which outputs the fully-qualified name of the attribute types
6061
module_resolver = ModuleResolver(import_module(class_type.__module__))
6162

62-
# builds the definitions of the class attrbutes and their relationships by iterating over the type annotations
63+
# builds the definitions of the class attrbutes and their relationships by iterating over the type annotations
6364
for attr_name, attr_class in type_annotations.items():
6465
attr_raw_type = str(attr_class)
6566
concrete_type_match = CONCRETE_TYPE_PATTERN.search(attr_raw_type)
@@ -89,18 +90,16 @@ def inspect_static_attributes(
8990

9091
return definition_attrs
9192

92-
def inspect_methods(
93-
definition_methods, class_type,
94-
):
95-
no_dunder = lambda method_name: not (method_name[0].startswith('__') or method_name[0].endswith('__'))
96-
methods = filter(no_dunder, getmembers(class_type, callable))
97-
for name, method in methods:
98-
method_signature = signature(method)
99-
uml_method = UmlMethod(
100-
name=name,
101-
signature=str(method_signature),
102-
)
103-
definition_methods.append(uml_method)
93+
94+
def inspect_methods(definition_methods: List, class_type: Type, root_module_name: str):
95+
""" This function parses a class using AST to identify methods. """
96+
print(f'inspecting {class_type.__name__} from {class_type.__module__}')
97+
class_source: str = getsource(class_type)
98+
class_ast: AST = parse(class_source)
99+
visitor = ClassVisitor(class_type, root_module_name)
100+
visitor.visit(class_ast)
101+
for method in visitor.uml_methods:
102+
definition_methods.append(method)
104103

105104

106105
def inspect_class_type(
@@ -117,11 +116,12 @@ def inspect_class_type(
117116
instance_attributes, compositions = parse_class_constructor(class_type, class_type_fqn, root_module_name)
118117
attributes.extend(instance_attributes)
119118
domain_relations.extend(compositions.values())
120-
121-
inspect_methods(domain_items_by_fqn[class_type_fqn].methods, class_type)
119+
120+
inspect_methods(domain_items_by_fqn[class_type_fqn].methods, class_type, root_module_name)
122121

123122
handle_inheritance_relation(class_type, class_type_fqn, root_module_name, domain_relations)
124123

124+
125125
def inspect_dataclass_type(
126126
class_type: Type,
127127
class_type_fqn: str,

py2puml/parsing/astvisitors.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
11

2-
from typing import Dict, List, Tuple
3-
2+
from typing import Dict, List, Tuple, Type
3+
from inspect import getsource, unwrap
44
from ast import (
55
NodeVisitor, arg, expr,
66
FunctionDef, Assign, AnnAssign,
77
Attribute, Name, Subscript, get_source_segment
88
)
99
from collections import namedtuple
10+
from textwrap import dedent
11+
from importlib import import_module
1012

11-
from py2puml.domain.umlclass import UmlAttribute
13+
from py2puml.domain.umlclass import UmlAttribute, UmlMethod
1214
from py2puml.domain.umlrelation import UmlRelation, RelType
1315
from py2puml.parsing.compoundtypesplitter import CompoundTypeSplitter, SPLITTING_CHARACTERS
1416
from py2puml.parsing.moduleresolver import ModuleResolver, NamespacedType
1517

16-
1718
Variable = namedtuple('Variable', ['id', 'type_expr'])
1819

20+
1921
class SignatureVariablesCollector(NodeVisitor):
20-
'''
21-
Collects the variables and their type annotations from the signature of a constructor method
22-
'''
23-
def __init__(self, constructor_source: str, *args, **kwargs):
22+
"""
23+
Collects the variables and their type annotations from the signature of a method
24+
"""
25+
def __init__(self, skip_self=False, *args, **kwargs):
2426
super().__init__(*args, **kwargs)
25-
self.constructor_source = constructor_source
27+
self.skip_self = skip_self
2628
self.class_self_id: str = None
2729
self.variables: List[Variable] = []
2830

2931
def visit_arg(self, node: arg):
3032
variable = Variable(node.arg, node.annotation)
3133

3234
# first constructor variable is the name for the 'self' reference
33-
if self.class_self_id is None:
35+
if self.class_self_id is None and not self.skip_self:
3436
self.class_self_id = variable.id
3537
# other arguments are constructor parameters
3638
else:
@@ -66,6 +68,57 @@ def visit_Subscript(self, node: Subscript):
6668
pass
6769

6870

71+
class ClassVisitor(NodeVisitor):
72+
73+
def __init__(self, class_type: Type, root_fqn: str, *args, **kwargs):
74+
super().__init__(*args, **kwargs)
75+
self.uml_methods: List[UmlMethod] = []
76+
77+
def visit_FunctionDef(self, node: FunctionDef):
78+
method_visitor = MethodVisitor()
79+
method_visitor.visit(node)
80+
self.uml_methods.append(method_visitor.uml_method)
81+
82+
83+
class MethodVisitor(NodeVisitor):
84+
"""
85+
Node visitor subclass used to walk the abstract syntax tree of a method class and identify method arguments.
86+
87+
If the method is the class constructor, instance attributes (and their type) are also identified by looking both at the constructor signature and constructor's body. When searching in the constructor's body, the visitor looks for relevant assignments (with and without type annotation).
88+
"""
89+
90+
def __init__(self, *args, **kwargs):
91+
super().__init__(*args, **kwargs)
92+
self.variables_namespace: List[Variable] = []
93+
self.class_self_id: str
94+
self.uml_method: UmlMethod
95+
96+
def generic_visit(self, node):
97+
NodeVisitor.generic_visit(self, node)
98+
99+
def visit_FunctionDef(self, node: FunctionDef):
100+
decorators = [decorator.id for decorator in node.decorator_list]
101+
is_static = 'staticmethod' in decorators
102+
is_class = 'classmethod' in decorators
103+
variables_collector = SignatureVariablesCollector(skip_self=is_static)
104+
variables_collector.visit(node)
105+
self.variables_namespace = variables_collector.variables
106+
107+
if node.name == '__init__':
108+
self.class_self_id: str = variables_collector.class_self_id
109+
self.generic_visit(node) #Only visit child nodes for constructor
110+
111+
self.uml_method = UmlMethod(name=node.name, is_static=is_static, is_class=is_class)
112+
for argument in variables_collector.variables:
113+
if argument.type_expr:
114+
if hasattr(argument.type_expr, 'id'):
115+
self.uml_method.arguments[argument.id] = argument.type_expr.id
116+
else:
117+
self.uml_method.arguments[argument.id] = f'SUBscript {argument.type_expr.value.id}'
118+
else:
119+
self.uml_method.arguments[argument.id] = None
120+
121+
69122
class ConstructorVisitor(NodeVisitor):
70123
'''
71124
Identifies the attributes (and infer their type) assigned to self in the body of a constructor method
@@ -105,7 +158,7 @@ def generic_visit(self, node):
105158
def visit_FunctionDef(self, node: FunctionDef):
106159
# retrieves constructor arguments ('self' reference and typed arguments)
107160
if node.name == '__init__':
108-
variables_collector = SignatureVariablesCollector(self.constructor_source)
161+
variables_collector = SignatureVariablesCollector()
109162
variables_collector.visit(node)
110163
self.class_self_id: str = variables_collector.class_self_id
111164
self.variables_namespace = variables_collector.variables
@@ -158,7 +211,6 @@ def visit_Assign(self, node: Assign):
158211
# other assignments were done in new variables that can shadow existing ones
159212
self.variables_namespace.extend(variables_collector.variables)
160213

161-
162214
def derive_type_annotation_details(self, annotation: expr) -> Tuple[str, List[str]]:
163215
'''
164216
From a type annotation, derives:
@@ -182,13 +234,12 @@ def derive_type_annotation_details(self, annotation: expr) -> Tuple[str, List[st
182234
return short_type, [full_namespaced_type]
183235
# compound type (List[...], Tuple[Dict[str, float], module.DomainType], etc.)
184236
elif isinstance(annotation, Subscript):
185-
return shorten_compound_type_annotation(
186-
get_source_segment(self.constructor_source, annotation),
187-
self.module_resolver
188-
)
189-
237+
source_segment = get_source_segment(self.constructor_source, annotation)
238+
short_type, associated_types = shorten_compound_type_annotation(source_segment, self.module_resolver)
239+
return short_type, associated_types
190240
return None, []
191241

242+
192243
def shorten_compound_type_annotation(type_annotation: str, module_resolver: ModuleResolver) -> Tuple[str, List[str]]:
193244
'''
194245
In the string representation of a compound type annotation, the elementary types can be prefixed by their packages or sub-packages.

tests/modules/withmethods/withmethods.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from __future__ import annotations
12
from typing import List, Tuple
23
from math import pi
34

45
from tests.modules import withenum
56
from tests import modules
67
from tests.modules.withenum import TimeUnit
78

8-
99
class Coordinates:
1010
def __init__(self, x: float, y: float) -> None:
1111
self.x = x
@@ -30,3 +30,6 @@ def __init__(self, x: int, y: str) -> None:
3030
self.time_resolution: Tuple[str, withenum.TimeUnit] = 'minute', TimeUnit.MINUTE
3131
self.x = x
3232
self.y = y
33+
34+
def do_something(self, posarg_nohint, posarg_hint: str, posarg_default=3) -> int:
35+
return 44

tests/puml_files/with_methods.puml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ class tests.modules.withmethods.withmethods.Point {
1111
time_resolution: Tuple[str, TimeUnit]
1212
x: int
1313
y: str
14-
{static} Point from_values(x: int, y: str)
14+
{static} Point from_values(int x, str y)
1515
Tuple[float, str] get_coordinates(self)
16-
__init__(self, x: int, y: str)
16+
__init__(self, int x, str y)
1717
}
1818
class tests.modules.withmethods.withinheritedmethods.ThreeDimensionalPoint {
1919
z: float
20-
__init__(self, x: int, y: str, z: float)
21-
move(self, offset: int)
20+
__init__(self, int x, str y, float z)
21+
move(self, int offset)
2222
bool check_positive(self)
2323
}
2424
class tests.modules.withmethods.withmethods.Coordinates {
2525
x: float
2626
y: float
27-
__init__(self, x: float, y: float)
27+
__init__(self, float x, float y)
2828
}
2929
tests.modules.withmethods.withmethods.Point *-- tests.modules.withmethods.withmethods.Coordinates
3030
tests.modules.withmethods.withmethods.Point <|-- tests.modules.withmethods.withinheritedmethods.ThreeDimensionalPoint

tests/py2puml/inspection/test_inspectclass.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,9 @@ def test_inspect_module_should_handle_compound_types_with_numbers_in_their_name(
172172
def test_inspect_module_should_find_methods(
173173
domain_items_by_fqn: Dict[str, UmlItem], domain_relations: List[UmlRelation]
174174
):
175-
'''
175+
"""
176176
Test that methods are detected including static methods
177-
178-
This test case assumes that the methods will be sorted by type as follow:
179-
1a - instance methods (special methods aka "dunder")
180-
1b - all other instance methods
181-
2 - static methods
182-
'''
177+
"""
183178

184179
inspect_module(
185180
import_module('tests.modules.withmethods.withmethods'),
@@ -189,16 +184,16 @@ def test_inspect_module_should_find_methods(
189184

190185
# Coordinates UmlClass
191186
coordinates_umlitem: UmlClass = domain_items_by_fqn['tests.modules.withmethods.withmethods.Coordinates']
192-
assert len(coordinates_umlitem.methods) == 0
187+
assert len(coordinates_umlitem.methods) == 1
193188

194189
# Point UmlClass
195190
point_umlitem: UmlClass = domain_items_by_fqn['tests.modules.withmethods.withmethods.Point']
196-
assert len(point_umlitem.methods) == 2
191+
assert len(point_umlitem.methods) == 4
197192

198-
# FIXME dunder methods are filtered out for now
199-
# assert point_umlitem.methods[0].name == '__init__' # 1a - Instance method (special)
200-
assert point_umlitem.methods[0].name == 'from_values' # 2 - Static method
201-
assert point_umlitem.methods[1].name == 'get_coordinates' # 1b - Instance method (regular)
193+
assert point_umlitem.methods[0].name == 'from_values'
194+
assert point_umlitem.methods[1].name == 'get_coordinates'
195+
assert point_umlitem.methods[2].name == '__init__'
196+
assert point_umlitem.methods[3].name == 'do_something'
202197
# FIXME: use 'assert_method' once UmlMethod restructured
203198

204199

@@ -219,10 +214,9 @@ def test_inspect_module_inherited_methods(
219214
coordinates_3d_umlitem: UmlClass = domain_items_by_fqn['tests.modules.withmethods.withinheritedmethods.ThreeDimensionalPoint']
220215

221216
# FIXME inherited methods should not be mentionned
222-
assert len(coordinates_3d_umlitem.methods) == 4
217+
assert len(coordinates_3d_umlitem.methods) == 3
223218

224-
assert coordinates_3d_umlitem.methods[0].name == 'check_positive'
225-
assert coordinates_3d_umlitem.methods[1].name == 'from_values' # inherited method
226-
assert coordinates_3d_umlitem.methods[2].name == 'get_coordinates' # inherited method
227-
assert coordinates_3d_umlitem.methods[3].name == 'move'
219+
assert coordinates_3d_umlitem.methods[2].name == 'check_positive'
220+
assert coordinates_3d_umlitem.methods[0].name == '__init__'
221+
assert coordinates_3d_umlitem.methods[1].name == 'move'
228222
# FIXME: use 'assert_method' once UmlMethod restructured

tests/py2puml/parsing/test_astvisitors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from tests.asserts.variable import assert_Variable
1414
from tests.py2puml.parsing.mockedinstance import MockedInstance
15+
from tests.modules.withmethods import withmethods
1516

1617

1718
class ParseMyConstructorArguments:
@@ -31,7 +32,7 @@ def test_SignatureVariablesCollector_collect_arguments():
3132
constructor_source: str = dedent(getsource(ParseMyConstructorArguments.__init__.__code__))
3233
constructor_ast: AST = parse(constructor_source)
3334

34-
collector = SignatureVariablesCollector(constructor_source)
35+
collector = SignatureVariablesCollector()
3536
collector.visit(constructor_ast)
3637

3738
assert collector.class_self_id == 'me'

0 commit comments

Comments
 (0)