Skip to content

Commit 159080d

Browse files
committed
Ensure that held class temporaries get resolved over multiline expressions correctly.
Basically, we want to ensure that we can write expressions like listOfHeldClass[0].x = 10 and have the assignment modify the copy of the held class that lives within the list. To do this, we need for listOfHeldClass[0] to return a RefTo(H) (where H is whatever the list is holding). Unfortunately, that would mean that we have a reference to a held class living in the python heap, which is very likely to cause a segfault, and which doesn't mirror the semantics of the compiler. The solution we adopted originally was to keep track of all held class temporaries and to convert them to non-temporaries (e.g. make copies of the object) when the current statement has finished executing. This mirrors the way the compiler works: temporaries exist for the duration of the expression we're compiling, which allows chained sequences like this to work. HOWEVER, the mechanism we use is to put a trace on the parent stackframe, and to use that trace to discover when we've hit the next frame. Unfortunately, an expression like listOfHeldClass[0].f( 10 ) causes us to evaluate the '10' on a new line, which triggers the temporary to resolve. This breaks the entire model. This change causes us to cache the subset of line numbers that are the starts of new statements. These line numbers are the ones where the references should be converted to concrete instances of the held class.
1 parent 070213b commit 159080d

File tree

4 files changed

+104
-12
lines changed

4 files changed

+104
-12
lines changed

typed_python/PyTemporaryReferenceTracer.cpp

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,68 @@
1616

1717
#include "PyTemporaryReferenceTracer.hpp"
1818

19+
bool PyTemporaryReferenceTracer::isLineNewStatement(PyObject* code, int line) {
20+
auto it = codeObjectToExpressionLines.find(code);
21+
22+
if (it != codeObjectToExpressionLines.end()) {
23+
return it->second.find(line) != it->second.end();
24+
}
25+
26+
// this permanently memoizes this code object in this global object
27+
// this should be OK because there are (usually) a small and finite number of
28+
// code objects in a given program.
29+
incref(code);
30+
auto& lineNumbers = codeObjectToExpressionLines[code];
31+
32+
static PyObject* internals = internalsModule();
33+
34+
PyObjectStealer res(
35+
PyObject_CallMethod(internals, "extractCodeObjectNewStatementLineNumbers", "O", code, NULL)
36+
);
37+
38+
if (!res) {
39+
PyErr_Clear();
40+
} else {
41+
iterate((PyObject*)res, [&](PyObject* lineNo) {
42+
if (PyLong_Check(lineNo)) {
43+
lineNumbers.insert(PyLong_AsLong(lineNo));
44+
}
45+
});
46+
}
47+
48+
return lineNumbers.find(line) != lineNumbers.end();
49+
}
50+
1951

2052
int PyTemporaryReferenceTracer::globalTraceFun(PyObject* dummyObj, PyFrameObject* frame, int what, PyObject* arg) {
2153
if (frame != globalTracer.mostRecentEmptyFrame) {
22-
auto it = globalTracer.frameToHandles.find(frame);
54+
bool shouldProcess = true;
55+
56+
if (what == PyTrace_LINE) {
57+
shouldProcess = globalTracer.isLineNewStatement(
58+
(PyObject*)frame->f_code,
59+
frame->f_lineno
60+
);
61+
}
62+
63+
if (shouldProcess) {
64+
auto it = globalTracer.frameToHandles.find(frame);
2365

24-
if (it == globalTracer.frameToHandles.end()) {
25-
globalTracer.mostRecentEmptyFrame = frame;
26-
} else {
27-
globalTracer.mostRecentEmptyFrame = nullptr;
66+
if (it == globalTracer.frameToHandles.end()) {
67+
globalTracer.mostRecentEmptyFrame = frame;
68+
} else {
69+
globalTracer.mostRecentEmptyFrame = nullptr;
2870

29-
for (auto objAndAction: it->second) {
30-
if (objAndAction.second == TraceAction::ConvertTemporaryReference) {
31-
((PyInstance*)objAndAction.first)->resolveTemporaryReference();
71+
for (auto objAndAction: it->second) {
72+
if (objAndAction.second == TraceAction::ConvertTemporaryReference) {
73+
((PyInstance*)objAndAction.first)->resolveTemporaryReference();
74+
}
75+
76+
decref(objAndAction.first);
3277
}
3378

34-
decref(objAndAction.first);
79+
globalTracer.frameToHandles.erase(it);
3580
}
36-
37-
globalTracer.frameToHandles.erase(it);
3881
}
3982
}
4083

typed_python/PyTemporaryReferenceTracer.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ class PyTemporaryReferenceTracer {
3434

3535
std::unordered_map<PyFrameObject*, std::vector<std::pair<PyObject*, TraceAction> > > frameToHandles;
3636

37+
std::unordered_map<PyObject*, std::set<int> > codeObjectToExpressionLines;
38+
3739
// the most recent frame we touched that has nothing in it
3840
PyFrameObject* mostRecentEmptyFrame;
3941

4042
Py_tracefunc priorTraceFunc;
4143

4244
PyObject* priorTraceFuncArg;
4345

46+
bool isLineNewStatement(PyObject* code, int line);
47+
4448
static PyTemporaryReferenceTracer globalTracer;
4549

4650
static int globalTraceFun(PyObject* obj, PyFrameObject* frame, int what, PyObject* arg);

typed_python/compiler/tests/held_class_interpreter_semantics_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class H(Class, Final):
1616
def f(self):
1717
return self.x + self.y
1818

19+
def addToX(self, y):
20+
return self.x + y
21+
1922
def increment(self):
2023
self.x += 1
2124
self.y += 1
@@ -45,6 +48,26 @@ def runTest():
4548

4649
assert runTest() == 30
4750

51+
def test_construct_and_call_method_multiline(self):
52+
def runTest():
53+
return H(x=10, y=20).addToX(
54+
# this is deliberately on another line
55+
10
56+
)
57+
58+
assert runTest() == 20
59+
60+
def test_can_assign_to_held_class_in_list(self):
61+
def runTest():
62+
aList = ListOf(H)()
63+
aList.resize(1)
64+
65+
aList[0].increment()
66+
67+
assert aList[0].x == 1
68+
69+
runTest()
70+
4871
def test_list_of_held_class_item_type(self):
4972
assert sys.gettrace() is None
5073

typed_python/internals.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2019 typed_python Authors
1+
# Copyright 2017-2022 typed_python Authors
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -622,3 +622,25 @@ def Held(T):
622622
raise Exception(f"{T} is not a Class")
623623

624624
return T.HeldClass
625+
626+
627+
def extractCodeObjectNewStatementLineNumbers(codeObject):
628+
"""Return the subset of the line numbers on which codeObject has new statements."""
629+
from typed_python.python_ast import convertFunctionToAlgebraicPyAst
630+
631+
try:
632+
ast = convertFunctionToAlgebraicPyAst(codeObject)
633+
634+
res = set()
635+
636+
if ast.matches.FunctionDef:
637+
for statement in ast.body:
638+
res.add(statement.line_number)
639+
640+
return res
641+
except Exception:
642+
# nasty to swallow the exception like this. At least we report it...
643+
import traceback
644+
traceback.print_exc()
645+
646+
return []

0 commit comments

Comments
 (0)