@@ -40,7 +40,9 @@ def visit_BinOp(self, node):
4040
4141 try :
4242 original_expression = expression_printer (node )
43- value = eval (original_expression )
43+ globals = {}
44+ locals = {}
45+ value = eval (original_expression , globals , locals )
4446 except Exception as e :
4547 return node
4648
@@ -65,19 +67,28 @@ def visit_BinOp(self, node):
6567 expression_printer = ExpressionPrinter ()
6668 folded_expression = expression_printer (new_node )
6769
68- if len (folded_expression ) > len (original_expression ):
69- # Result is longer than original expression
70+ if len (folded_expression ) >= len (original_expression ):
71+ # Result is not shorter than original expression
7072 return node
7173
72- globals = {'__builtins__' : {}} # Completely empty globals
73- locals = {}
74- assert eval (folded_expression , globals , locals ) == value
74+ try :
75+ globals = {'__builtins__' : {'float' : float }}
76+ locals = {}
77+ folded_value = eval (folded_expression , globals , locals )
78+ except NameError as ne :
79+ if ne .name in ['inf' , 'infj' , 'nan' ]:
80+ # When the value is something like inf+0j the expression printer will print it that way, which is not valid Python.
81+ # In python code it should be '1e999+0j', which parses as a BinOp that the expression printer can handle.
82+ # It's not worth fixing the expression printer to handle this case, since it is unlikely to occur in real code.
83+ return node
84+ raise
85+
86+ if isinstance (value , float ) and math .isnan (value ):
87+ assert math .isnan (folded_value )
88+ else :
89+ assert folded_value == value and type (folded_value ) == type (value )
7590
76- # Some complex number values are parsed as a BinOp
77- # Make sure we represent our AST the same way so it roundtrips correctly
78- parsed_folded_expression = ast .parse (folded_expression , 'folded expression' , 'eval' )
79- assert isinstance (parsed_folded_expression , ast .Expression )
80- if isinstance (parsed_folded_expression .body , ast .BinOp ):
81- new_node = parsed_folded_expression .body
91+ #print(f'{original_expression=}')
92+ #print(f'{folded_expression=}')
8293
8394 return self .add_child (new_node , node .parent , node .namespace )
0 commit comments