44from typing import Any , Callable , Self
55
66
7- type TransformFunc = Callable [[Node ], None ]
8- type VisitFunc = Callable [[Node ], Any ]
7+ type TransformFunc [ T ] = Callable [[Node [ T ] ], None ]
8+ type VisitFunc [ T ] = Callable [[Node [ T ] ], Any ]
99type AggrFunc = Callable [[Any , Any , Any ], Any ]
1010
1111
12- class Node :
12+ class Node [ T ] :
1313 _left : Self | None
1414 _right : Self | None
15- _data : int
15+ _data : T
1616
17- def __init__ (self , data : int ):
17+ def __init__ (self , data : T ):
1818 self ._left = None
1919 self ._right = None
2020 self ._data = data
@@ -36,11 +36,11 @@ def right(self, right: Self) -> None:
3636 self ._right = right
3737
3838 @property
39- def data (self ) -> int :
39+ def data (self ) -> T :
4040 return self ._data
4141
4242 @data .setter
43- def data (self , data : int ) -> None :
43+ def data (self , data : T ) -> None :
4444 self ._data = data
4545
4646 @property
@@ -64,7 +64,7 @@ def nr_descendants(self) -> int:
6464 count += 1 + self ._right .nr_descendants
6565 return count
6666
67- def transformn (self , func : TransformFunc ) -> None :
67+ def transformn (self , func : TransformFunc [ T ] ) -> None :
6868 func (self )
6969 if self ._left is not None :
7070 self ._left .transformn (func )
@@ -77,10 +77,12 @@ def __str__(self) -> str:
7777 def __repr__ (self ) -> str :
7878 return f"{ self .data } "
7979
80- def visit (self , visit_func : VisitFunc , aggr_func : AggrFunc ) -> Any :
80+ def visit (self , visit_func : VisitFunc [ T ] , aggr_func : AggrFunc ) -> Any :
8181 self_value = visit_func (self )
8282 left_value = (
83- self ._left .visit (visit_func , aggr_func ) if self ._left is not None else None
83+ self ._left .visit (visit_func , aggr_func )
84+ if self ._left is not None
85+ else None
8486 )
8587 right_value = (
8688 self ._right .visit (visit_func , aggr_func )
@@ -90,7 +92,7 @@ def visit(self, visit_func: VisitFunc, aggr_func: AggrFunc) -> Any:
9092 return aggr_func (self_value , left_value , right_value )
9193
9294
93- def str_visit (node : Node ) -> str :
95+ def str_visit [ T ] (node : Node [ T ] ) -> str :
9496 return str (node .data )
9597
9698
@@ -104,34 +106,38 @@ def str_aggr(self_value: str, left_value: str, right_value: str) -> str:
104106 return aggr
105107
106108
107- def double_value (node : Node ) -> None :
109+ def double_value (node : Node [ int ] ) -> None :
108110 node .data = 2 * node .data
109111
110112
111- class Tree :
112- _root : Node | None
113+ class Tree [ T ] :
114+ _root : Node [ T ] | None
113115
114- def __init__ (self , root : Node | None = None ):
116+ def __init__ (self , root : Node [ T ] | None = None ):
115117 self ._root = root
116118
117119 @property
118- def root (self ) -> Node | None :
120+ def root (self ) -> Node [ T ] | None :
119121 return self ._root
120122
121123 @root .setter
122- def root (self , root : Node ) -> None :
124+ def root (self , root : Node [ T ] ) -> None :
123125 self ._root = root
124126
125127 @property
126128 def nr_of_nodes (self ) -> int :
127129 return 0 if self ._root is None else 1 + self ._root .nr_descendants
128130
129- def transformn (self , func : TransformFunc ) -> None :
131+ def transformn (self , func : TransformFunc [ T ] ) -> None :
130132 if self ._root is not None :
131133 self ._root .transformn (func )
132134
133135 def __str__ (self ) -> str :
134- return "" if self ._root is None else self ._root .visit (str_visit , str_aggr )
136+ return (
137+ ""
138+ if self ._root is None
139+ else self ._root .visit (str_visit , str_aggr )
140+ )
135141
136142 def __repr__ (self ) -> str :
137143 return f"{ self ._root } "
0 commit comments