Skip to content

Commit 546728e

Browse files
committed
Compress non-simple alternative getattr calls into a function call.
1 parent 42af319 commit 546728e

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

typed_python/compiler/tests/alternative_compilation_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,3 +1754,34 @@ def eq(a1: AB, a2: AB):
17541754
assert eq(AB.A(a=1), AB.A(a=1))
17551755
assert eqMixed(AB.A(a=1), AB.A(a=1))
17561756
assert eqConcrete(AB.A(a=1), AB.A(a=1))
1757+
1758+
def test_compiled_attribute_access(self):
1759+
A = Alternative(
1760+
"A",
1761+
X=dict(x=int, y=str, z=str),
1762+
Y=dict(x=int, y=str, z=float),
1763+
Z=dict(),
1764+
)
1765+
1766+
def getx(a: A):
1767+
try:
1768+
return a.x
1769+
except Exception:
1770+
return 'None'
1771+
1772+
def gety(a: A):
1773+
try:
1774+
return a.y
1775+
except Exception:
1776+
return 'None'
1777+
1778+
def getz(a: A):
1779+
try:
1780+
return a.z
1781+
except Exception:
1782+
return 'None'
1783+
1784+
for a in [A.X(x=1, y='2', z='3'), A.Y(x=1, y='2', z=3), A.Z()]:
1785+
assert getx(a) == Entrypoint(getx)(a)
1786+
assert gety(a) == Entrypoint(gety)(a)
1787+
assert getz(a) == Entrypoint(getz)(a)

typed_python/compiler/type_wrappers/alternative_wrapper.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,40 @@ def convert_attribute(self, context, instance, attribute, nocheck=False):
364364
else:
365365
outputType = mergeTypeWrappers(possibleTypes)
366366

367-
output = context.allocateUninitializedSlot(outputType)
367+
native = context.converter.defineNativeFunction(
368+
'getattr(' + self.typeRepresentation.__name__ + ", " + attribute + ")",
369+
('getattr', self, attribute),
370+
[self],
371+
outputType,
372+
lambda context, out, instance: self.generateNativeGetattr(
373+
context, outputType, validIndices, out, instance, attribute
374+
)
375+
)
376+
377+
if outputType.is_pass_by_ref:
378+
return context.push(
379+
outputType,
380+
lambda out: native.call(out, instance)
381+
)
382+
else:
383+
return context.pushPod(
384+
outputType,
385+
native.call(instance)
386+
)
387+
388+
def generateNativeGetattr(self, context, outputType, validIndices, out, instance, attr):
389+
which = instance.nonref_expr.ElementPtrIntegers(0, 1).load()
390+
for ix in validIndices:
391+
with context.ifelse(which.eq(ix)) as (ifTrue, ifFalse):
392+
with ifTrue:
393+
res = self.refAs(context, instance, ix).convert_attribute(attr)
394+
if res is not None:
395+
res = res.convert_to_type(outputType, ConversionLevel.Signature)
368396

369-
with context.switch(instance.nonref_expr.ElementPtrIntegers(0, 1).load(), validIndices, False) as indicesAndContexts:
370-
for ix, subcontext in indicesAndContexts:
371-
with subcontext:
372-
attr = self.refAs(context, instance, ix).convert_attribute(attribute)
373-
attr = attr.convert_to_type(outputType, ConversionLevel.Signature)
374-
output.convert_copy_initialize(attr)
375-
context.markUninitializedSlotInitialized(output)
397+
context.pushReturnValue(res)
376398

377-
return output
399+
if len(validIndices) != len(self.alternatives):
400+
context.pushException(AttributeError, attr)
378401

379402
def convert_check_matches(self, context, instance, typename):
380403
index = -1

0 commit comments

Comments
 (0)