Skip to content

Commit 26f9973

Browse files
committed
Be a little more precise about how the temporary reference tracer follows line numbers.
1 parent 159080d commit 26f9973

File tree

8 files changed

+124
-26
lines changed

8 files changed

+124
-26
lines changed

typed_python/PyTemporaryReferenceTracer.cpp

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,33 +50,47 @@ bool PyTemporaryReferenceTracer::isLineNewStatement(PyObject* code, int line) {
5050

5151

5252
int PyTemporaryReferenceTracer::globalTraceFun(PyObject* dummyObj, PyFrameObject* frame, int what, PyObject* arg) {
53-
if (frame != globalTracer.mostRecentEmptyFrame) {
54-
bool shouldProcess = true;
55-
56-
if (what == PyTrace_LINE) {
57-
shouldProcess = globalTracer.isLineNewStatement(
58-
(PyObject*)frame->f_code,
59-
frame->f_lineno
60-
);
61-
}
53+
if (frame != globalTracer.mostRecentEmptyFrame &&
54+
globalTracer.frameToActions.find(frame) != globalTracer.frameToActions.end()) {
55+
// always process exception and return statements
56+
bool forceProcess = (
57+
what == PyTrace_EXCEPTION ||
58+
what == PyTrace_RETURN
59+
);
6260

63-
if (shouldProcess) {
64-
auto it = globalTracer.frameToHandles.find(frame);
61+
// we process any statement on a line that's a new statement
62+
bool shouldProcess = globalTracer.isLineNewStatement(
63+
(PyObject*)frame->f_code,
64+
PyFrame_GetLineNumber(frame)
65+
);
6566

66-
if (it == globalTracer.frameToHandles.end()) {
67+
if (shouldProcess || forceProcess) {
68+
auto it = globalTracer.frameToActions.find(frame);
69+
70+
if (it == globalTracer.frameToActions.end()) {
6771
globalTracer.mostRecentEmptyFrame = frame;
6872
} else {
6973
globalTracer.mostRecentEmptyFrame = nullptr;
7074

71-
for (auto objAndAction: it->second) {
72-
if (objAndAction.second == TraceAction::ConvertTemporaryReference) {
73-
((PyInstance*)objAndAction.first)->resolveTemporaryReference();
74-
}
75+
std::vector<FrameAction> persistingActions;
7576

76-
decref(objAndAction.first);
77+
for (auto& frameAction: it->second) {
78+
if (frameAction.lineNumber != PyFrame_GetLineNumber(frame) || forceProcess) {
79+
if (frameAction.action == TraceAction::ConvertTemporaryReference) {
80+
((PyInstance*)frameAction.obj)->resolveTemporaryReference();
81+
}
82+
83+
decref(frameAction.obj);
84+
} else {
85+
persistingActions.push_back(frameAction);
86+
}
7787
}
7888

79-
globalTracer.frameToHandles.erase(it);
89+
if (persistingActions.size()) {
90+
it->second = persistingActions;
91+
} else {
92+
globalTracer.frameToActions.erase(it);
93+
}
8094
}
8195
}
8296
}
@@ -88,7 +102,7 @@ int PyTemporaryReferenceTracer::globalTraceFun(PyObject* dummyObj, PyFrameObject
88102
);
89103
}
90104

91-
if (globalTracer.frameToHandles.size() == 0) {
105+
if (globalTracer.frameToActions.size() == 0) {
92106
// uninstall ourself
93107
PyEval_SetTrace(globalTracer.priorTraceFunc, globalTracer.priorTraceFuncArg);
94108
decref(globalTracer.priorTraceFuncArg);
@@ -115,7 +129,13 @@ void PyTemporaryReferenceTracer::installGlobalTraceHandlerIfNecessary() {
115129

116130
void PyTemporaryReferenceTracer::traceObject(PyObject* o, PyFrameObject* f) {
117131
// mark that we're going to trace
118-
globalTracer.frameToHandles[f].push_back(std::make_pair(incref(o), TraceAction::ConvertTemporaryReference));
132+
globalTracer.frameToActions[f].push_back(
133+
FrameAction(
134+
incref(o),
135+
TraceAction::ConvertTemporaryReference,
136+
PyFrame_GetLineNumber(f)
137+
)
138+
);
119139

120140
if (globalTracer.mostRecentEmptyFrame == f) {
121141
globalTracer.mostRecentEmptyFrame = nullptr;
@@ -126,7 +146,13 @@ void PyTemporaryReferenceTracer::traceObject(PyObject* o, PyFrameObject* f) {
126146

127147
void PyTemporaryReferenceTracer::keepaliveForCurrentInstruction(PyObject* o, PyFrameObject* f) {
128148
// mark that we're going to trace
129-
globalTracer.frameToHandles[f].push_back(std::make_pair(incref(o), TraceAction::Decref));
149+
globalTracer.frameToActions[f].push_back(
150+
FrameAction(
151+
incref(o),
152+
TraceAction::Decref,
153+
PyFrame_GetLineNumber(f)
154+
)
155+
);
130156

131157
if (globalTracer.mostRecentEmptyFrame == f) {
132158
globalTracer.mostRecentEmptyFrame = nullptr;

typed_python/PyTemporaryReferenceTracer.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,24 @@ class PyTemporaryReferenceTracer {
3232
priorTraceFuncArg(nullptr)
3333
{}
3434

35-
std::unordered_map<PyFrameObject*, std::vector<std::pair<PyObject*, TraceAction> > > frameToHandles;
35+
// perform an action on the first instruction where a
36+
// frame goes out of scope or where it is no longer on the
37+
// given line number
38+
class FrameAction {
39+
public:
40+
FrameAction(PyObject* inO, TraceAction inA, int inLine) :
41+
obj(inO),
42+
action(inA),
43+
lineNumber(inLine)
44+
{
45+
}
46+
47+
PyObject* obj;
48+
TraceAction action;
49+
int lineNumber;
50+
};
51+
52+
std::unordered_map<PyFrameObject*, std::vector<FrameAction> > frameToActions;
3653

3754
std::unordered_map<PyObject*, std::set<int> > codeObjectToExpressionLines;
3855

typed_python/_types.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3204,6 +3204,12 @@ PyObject *getTypePointer(PyObject* nullValue, PyObject* args) {
32043204

32053205
return PyLong_FromLong((uint64_t)type);
32063206
}
3207+
PyObject* _temporaryReferenceTracerActive(PyObject* null, PyObject* args, PyObject* kwargs) {
3208+
return incref(
3209+
PyTemporaryReferenceTracer::globalTracer.frameToActions.size() ?
3210+
Py_True : Py_False
3211+
);
3212+
}
32073213

32083214
PyObject* gilReleaseThreadLoop(PyObject* null, PyObject* args, PyObject* kwargs) {
32093215
PyEnsureGilReleased releaseTheGil;
@@ -3283,6 +3289,7 @@ static PyMethodDef module_methods[] = {
32833289
{"couldConvertObjectToTypeAtLevel", (PyCFunction)couldConvertObjectToTypeAtLevel, METH_VARARGS | METH_KEYWORDS, NULL},
32843290
{"isValidArithmeticUpcast", (PyCFunction)isValidArithmeticUpcast, METH_VARARGS | METH_KEYWORDS, NULL},
32853291
{"isValidArithmeticConversion", (PyCFunction)isValidArithmeticConversion, METH_VARARGS | METH_KEYWORDS, NULL},
3292+
{"_temporaryReferenceTracerActive", (PyCFunction)_temporaryReferenceTracerActive, METH_VARARGS | METH_KEYWORDS, NULL},
32863293
{"gilReleaseThreadLoop", (PyCFunction)gilReleaseThreadLoop, METH_VARARGS | METH_KEYWORDS, NULL},
32873294
{"setModuleDict", (PyCFunction)setModuleDict, METH_VARARGS | METH_KEYWORDS, NULL},
32883295
{NULL, NULL}

typed_python/compiler/python_ast_analysis.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,33 @@ def visit(x):
200200
return closureVars
201201

202202

203+
def extractLineNumbersWithStatements(astNode):
204+
res = set()
205+
206+
def visit(x):
207+
if isinstance(x, Statement):
208+
if x.matches.FunctionDef or x.matches.ClassDef or x.matches.AsyncFunctionDef:
209+
return False
210+
211+
res.add(x.line_number)
212+
213+
if isinstance(x, Expr):
214+
if (
215+
x.matches.Lambda
216+
or x.matches.ListComp
217+
or x.matches.SetComp
218+
or x.matches.DictComp
219+
or x.matches.GeneratorExp
220+
):
221+
return False
222+
223+
return True
224+
225+
visitPyAstChildren(astNode, visit)
226+
227+
return res
228+
229+
203230
def extractFunctionDefsInOrder(astNode):
204231
res = []
205232

typed_python/compiler/tests/held_class_compilation_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import time
3-
3+
import sys
4+
from typed_python import _types
45
from typed_python import Class, Final, ListOf, Held, Member, Entrypoint, Forward
56

67

@@ -88,6 +89,8 @@ def move(c):
8889
self.checkCompiler(lambda c: c.h1.typeOfSelf(), c)
8990

9091
def test_compile_list_of_held_class(self):
92+
assert not _types._temporaryReferenceTracerActive()
93+
9194
class H(Class, Final):
9295
x = Member(int, nonempty=True)
9396
y = Member(float, nonempty=True)
@@ -123,6 +126,14 @@ def incrementAllRange(l):
123126

124127
incrementAllRange(aList)
125128

129+
def testIt(x):
130+
assert _types._temporaryReferenceTracerActive()
131+
return x
132+
133+
testIt(aList[0].increment)
134+
135+
assert not _types._temporaryReferenceTracerActive()
136+
126137
self.assertEqual(getitem(aList, 0).x, 1)
127138
self.assertEqual(getitem(aList, 5).x, 1)
128139

@@ -138,6 +149,8 @@ def incrementViaIterator(l):
138149

139150
incrementViaIterator(aList)
140151

152+
assert not _types._temporaryReferenceTracerActive()
153+
141154
self.assertEqual(getitem(aList, 0).x, 2)
142155
self.assertEqual(getitem(aList, 5).x, 2)
143156

typed_python/compiler/tests/held_class_interpreter_semantics_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,20 @@ def runTest():
5050

5151
def test_construct_and_call_method_multiline(self):
5252
def runTest():
53+
res = H(x=10, y=20).addToX(
54+
# this is deliberately on another line
55+
10
56+
)
57+
return res
58+
59+
def runTest2():
5360
return H(x=10, y=20).addToX(
5461
# this is deliberately on another line
5562
10
5663
)
5764

5865
assert runTest() == 20
66+
assert runTest2() == 20
5967

6068
def test_can_assign_to_held_class_in_list(self):
6169
def runTest():

typed_python/internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,15 +627,15 @@ def Held(T):
627627
def extractCodeObjectNewStatementLineNumbers(codeObject):
628628
"""Return the subset of the line numbers on which codeObject has new statements."""
629629
from typed_python.python_ast import convertFunctionToAlgebraicPyAst
630+
from typed_python.compiler.python_ast_analysis import extractLineNumbersWithStatements
630631

631632
try:
632633
ast = convertFunctionToAlgebraicPyAst(codeObject)
633634

634635
res = set()
635636

636637
if ast.matches.FunctionDef:
637-
for statement in ast.body:
638-
res.add(statement.line_number)
638+
res = extractLineNumbersWithStatements(ast.body)
639639

640640
return res
641641
except Exception:

typed_python/python_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def ExpressionStr(self):
786786
"args": TupleOf(Arg),
787787
"vararg": OneOf(Arg, None),
788788
"kwonlyargs": TupleOf(Arg),
789-
"kw_defaults": TupleOf(Expr),
789+
"kw_defaults": TupleOf(OneOf(None, Expr)),
790790
"kwarg": OneOf(Arg, None),
791791
"defaults": TupleOf(Expr),
792792
},

0 commit comments

Comments
 (0)