Skip to content

Commit 81b856f

Browse files
committed
Working AVX through explicit vectorization.
Its not ideal that llvm can't figure this out for us, but its better than not having it.
1 parent 7c1e277 commit 81b856f

File tree

4 files changed

+125
-5
lines changed

4 files changed

+125
-5
lines changed

typed_python/compiler/native_compiler/binary_shared_object.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def fromDisk(path, globalVariableDefinitions, functionNameToType, usedExternalFu
8181
def fromModule(module, globalVariableDefinitions, functionNameToType, usedExternalFunctions, functionDefinitions):
8282
target_triple = llvm.get_process_triple()
8383
target = llvm.Target.from_triple(target_triple)
84-
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
84+
features = llvm.get_host_cpu_features()
85+
86+
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default', features=features.flatten())
8587

8688
# returns the contents of a '.o' file coming out of a c++ compiler like clang
8789
o_file_contents = target_machine_shared_object.emit_object(module)

typed_python/compiler/native_compiler/llvm_execution_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
llvm.initialize_native_target()
2626
llvm.initialize_native_asmprinter() # yes, even this one
2727

28+
features = llvm.get_host_cpu_features()
29+
30+
2831
target_triple = llvm.get_process_triple()
2932
target = llvm.Target.from_triple(target_triple)
30-
target_machine = target.create_target_machine()
31-
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
33+
target_machine = target.create_target_machine(features=features.flatten())
34+
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default', features=features.flatten())
3235

3336
ctypes.CDLL(_types.__file__, mode=ctypes.RTLD_GLOBAL)
3437

typed_python/compiler/tests/compilable_builtin_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
class InlineLlvmFunc(CompilableBuiltin):
1616
def __eq__(self, other):
17-
return isinstance(other, inlineLlvmFunc)
17+
return isinstance(other, InlineLlvmFunc)
1818

1919
def __hash__(self):
20-
return hash("inlineLlvmFunc")
20+
return hash("InlineLlvmFunc")
2121

2222
def convert_call(self, context, instance, args, kwargs):
2323
return context.pushPod(
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import time
2+
3+
from typed_python import Entrypoint, ListOf
4+
from typed_python.compiler.type_wrappers.compilable_builtin import CompilableBuiltin
5+
from typed_python.compiler.type_wrappers.runtime_functions import externalCallTarget, Float64, Void
6+
7+
8+
tp_llvm_vecMultAdd = externalCallTarget("tp_llvm_vecMultAdd", Void, Float64.pointer(), Float64.pointer(), Float64.pointer(), inlineLlvmDefinition="""
9+
define external void @"tp_llvm_vecMultAdd"(double* %p1, double* %p2, double* %p3) {
10+
entry:
11+
%p1_vec_ptr = bitcast double* %p1 to <8 x double>*
12+
%p2_vec_ptr = bitcast double* %p2 to <8 x double>*
13+
%p3_vec_ptr = bitcast double* %p3 to <8 x double>*
14+
15+
; note that we have to have 'align 1' here because we don't make any guarantees
16+
; about alignment in TP internals at all, mostly due to laziness. As a result, you
17+
; can get a segfault if your memory is not aligned to the native alignment of the
18+
; vector type - this is never a problem with loading a primitive like int64, but the
19+
; avx instructions generated by the load here will crash if you leave off the
20+
; alignment because they'll assume 64 which is not always the case, and then
21+
; the resulting aligned processor read will crash.
22+
%p1_vec = load <8 x double>, <8 x double>* %p1_vec_ptr, align 1
23+
%p2_vec = load <8 x double>, <8 x double>* %p2_vec_ptr, align 1
24+
%p3_vec = fmul <8 x double> %p1_vec, %p2_vec
25+
26+
store <8 x double> %p3_vec, <8 x double>* %p3_vec_ptr, align 1
27+
28+
ret void
29+
}
30+
""")
31+
32+
33+
class TpLlvmVecMultAdd(CompilableBuiltin):
34+
def __eq__(self, other):
35+
return isinstance(other, TpLlvmVecMultAdd)
36+
37+
def __hash__(self):
38+
return hash("TpLlvmVecMultAdd")
39+
40+
def convert_call(self, context, instance, args, kwargs):
41+
context.pushEffect(
42+
tp_llvm_vecMultAdd.call(
43+
args[0],
44+
args[1],
45+
args[2]
46+
)
47+
)
48+
return context.constant(None)
49+
50+
51+
@Entrypoint
52+
def fmultAdd1(l, p1, p2, p3):
53+
i = 0
54+
55+
while i < l:
56+
p3[i] = p1[i] * p2[i]
57+
i += 1
58+
59+
60+
@Entrypoint
61+
def fmultAdd2(l, p1, p2, p3):
62+
i = 0
63+
64+
while i + 8 < l:
65+
TpLlvmVecMultAdd()(p1 + i, p2 + i, p3 + i)
66+
i += 8
67+
68+
while i < l:
69+
p3[i] = p1[i] * p2[i]
70+
i += 1
71+
72+
73+
@Entrypoint
74+
def fmultAdd1Times(ct, l, p1, p2, p3):
75+
for i in range(ct):
76+
fmultAdd1(l, p1, p2, p3)
77+
78+
79+
@Entrypoint
80+
def fmultAdd2Times(ct, l, p1, p2, p3):
81+
for i in range(ct):
82+
fmultAdd2(l, p1, p2, p3)
83+
84+
85+
def test_inline_vectorization_working():
86+
l1 = ListOf(float)()
87+
l2 = ListOf(float)()
88+
l3 = ListOf(float)()
89+
90+
N = 1024
91+
92+
l1.resize(N)
93+
l2.resize(N)
94+
l3.resize(N)
95+
96+
fmultAdd1Times(1, N, l1.pointerUnsafe(0), l2.pointerUnsafe(0), l3.pointerUnsafe(0))
97+
fmultAdd2Times(1, N, l1.pointerUnsafe(0), l2.pointerUnsafe(0), l3.pointerUnsafe(0))
98+
99+
t0 = time.time()
100+
fmultAdd1Times(1000000, N, l1.pointerUnsafe(0), l2.pointerUnsafe(0), l3.pointerUnsafe(0))
101+
t1 = time.time()
102+
103+
t2 = time.time()
104+
fmultAdd2Times(1000000, N, l1.pointerUnsafe(0), l2.pointerUnsafe(0), l3.pointerUnsafe(0))
105+
t3 = time.time()
106+
107+
print(t1 - t0)
108+
print(t3 - t2)
109+
110+
speedup = (t1 - t0) / (t3 - t2)
111+
112+
# I get about 4x because of AVX instructions. LLVM can't figure out to do this
113+
# directly for whatever reason, but the inlined primitive does it. I don't get the same
114+
# speedup on the MacOS workers on gitlab so the threshold is set quite low
115+
assert speedup > 1.25

0 commit comments

Comments
 (0)