Skip to content

Commit 04bbdc0

Browse files
add llvm implementation for insertion sort
1 parent fc75883 commit 04bbdc0

File tree

5 files changed

+329
-9
lines changed

5 files changed

+329
-9
lines changed

pydatastructs/linear_data_structures/_backend/cpp/algorithms/algorithms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
1414
METH_VARARGS | METH_KEYWORDS, ""},
1515
{"insertion_sort", (PyCFunction) insertion_sort,
1616
METH_VARARGS | METH_KEYWORDS, ""},
17+
{"insertion_sort_llvm", (PyCFunction)insertion_sort_llvm,
18+
METH_VARARGS | METH_KEYWORDS, ""},
1719
{"is_ordered", (PyCFunction) is_ordered,
1820
METH_VARARGS | METH_KEYWORDS, ""},
1921
{"linear_search", (PyCFunction) linear_search,

pydatastructs/linear_data_structures/_backend/cpp/algorithms/llvm_algorithms.py

Lines changed: 122 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_bubble_sort_ptr(dtype: str) -> int:
3939
if dtype not in _SUPPORTED:
4040
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
4141

42-
return _materialize(dtype)
42+
return _materialize(dtype, "bubble")
4343

4444
def _build_bubble_sort_ir(dtype: str) -> str:
4545
if dtype not in _SUPPORTED:
@@ -131,14 +131,127 @@ def _build_bubble_sort_ir(dtype: str) -> str:
131131

132132
return str(mod)
133133

134-
def _materialize(dtype: str) -> int:
134+
def get_insertion_sort_ptr(dtype: str) -> int:
135+
"""Get function pointer for insertion sort with specified dtype."""
136+
dtype = dtype.lower().strip()
137+
if dtype not in _SUPPORTED:
138+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
139+
return _materialize(dtype, "insertion")
140+
141+
def _build_insertion_sort_ir(dtype: str) -> str:
142+
"""Generate LLVM IR for insertion sort for the given dtype."""
143+
if dtype not in _SUPPORTED:
144+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
145+
146+
T, _ = _SUPPORTED[dtype]
147+
i32 = ir.IntType(32)
148+
i64 = ir.IntType(64)
149+
150+
mod = ir.Module(name=f"insertion_sort_{dtype}_module")
151+
fn_name = f"insertion_sort_{dtype}"
152+
153+
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
154+
fn = ir.Function(mod, fn_ty, name=fn_name)
155+
156+
arr, n = fn.args
157+
arr.name, n.name = "arr", "n"
158+
159+
b_entry = fn.append_basic_block("entry")
160+
b_outer = fn.append_basic_block("outer")
161+
b_inner = fn.append_basic_block("inner")
162+
b_latch = fn.append_basic_block("latch")
163+
b_exit = fn.append_basic_block("exit")
164+
165+
b = ir.IRBuilder(b_entry)
166+
167+
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
168+
b.cbranch(cond_trivial, b_exit, b_outer)
169+
170+
b.position_at_end(b_outer)
171+
i_phi = b.phi(i32, name="i")
172+
i_phi.add_incoming(ir.Constant(i32, 1), b_entry)
173+
174+
cond_outer = b.icmp_signed("<", i_phi, n)
175+
b.cbranch(cond_outer, b_inner, b_exit)
176+
177+
b.position_at_end(b_inner)
178+
i64_cast = b.sext(i_phi, i64)
179+
180+
# key = arr[i]
181+
ptr_i = b.gep(arr, [i64_cast])
182+
key = b.load(ptr_i, name="key")
183+
184+
# j = i - 1
185+
j = b.sub(i_phi, ir.Constant(i32, 1), name="j")
186+
187+
# while j >= 0 and arr[j] > key
188+
loop_cond = fn.append_basic_block("loop_cond")
189+
loop_body = fn.append_basic_block("loop_body")
190+
loop_exit = fn.append_basic_block("loop_exit")
191+
192+
b.branch(loop_cond)
193+
b.position_at_end(loop_cond)
194+
195+
j64 = b.sext(j, i64)
196+
ptr_j = b.gep(arr, [j64])
197+
val_j = b.load(ptr_j, name="val_j")
198+
199+
if isinstance(T, ir.IntType):
200+
cmp1 = b.icmp_signed(">=", j, ir.Constant(i32, 0))
201+
cmp2 = b.icmp_signed(">", val_j, key)
202+
else:
203+
cmp1 = b.icmp_signed(">=", j, ir.Constant(i32, 0))
204+
cmp2 = b.fcmp_ordered(">", val_j, key)
205+
206+
cond = b.and_(cmp1, cmp2)
207+
b.cbranch(cond, loop_body, loop_exit)
208+
209+
# loop body
210+
b.position_at_end(loop_body)
211+
jp1 = b.add(j, ir.Constant(i32, 1))
212+
jp1_64 = b.sext(jp1, i64)
213+
ptr_jp1 = b.gep(arr, [jp1_64])
214+
b.store(val_j, ptr_jp1)
215+
216+
j_next = b.sub(j, ir.Constant(i32, 1))
217+
j = j_next
218+
b.branch(loop_cond)
219+
220+
# after loop: arr[j + 1] = key
221+
b.position_at_end(loop_exit)
222+
jp1_final = b.add(j, ir.Constant(i32, 1))
223+
jp1_final_64 = b.sext(jp1_final, i64)
224+
ptr_jp1_final = b.gep(arr, [jp1_final_64])
225+
b.store(key, ptr_jp1_final)
226+
227+
b.branch(b_latch)
228+
229+
# outer latch
230+
b.position_at_end(b_latch)
231+
i_next = b.add(i_phi, ir.Constant(i32, 1))
232+
i_phi.add_incoming(i_next, b_latch)
233+
b.branch(b_outer)
234+
235+
b.position_at_end(b_exit)
236+
b.ret_void()
237+
238+
return str(mod)
239+
240+
def _materialize(dtype: str, algo: str) -> int:
135241
_ensure_target_machine()
136242

137-
if dtype in _fn_ptr_cache:
138-
return _fn_ptr_cache[dtype]
243+
key = f"{algo}_{dtype}"
244+
if key in _fn_ptr_cache:
245+
return _fn_ptr_cache[key]
139246

140247
try:
141-
llvm_ir = _build_bubble_sort_ir(dtype)
248+
if algo == "bubble":
249+
llvm_ir = _build_bubble_sort_ir(dtype)
250+
elif algo == "insertion":
251+
llvm_ir = _build_insertion_sort_ir(dtype)
252+
else:
253+
raise ValueError(f"Unsupported algorithm '{algo}'")
254+
142255
mod = binding.parse_assembly(llvm_ir)
143256
mod.verify()
144257

@@ -156,12 +269,12 @@ def _materialize(dtype: str) -> int:
156269
engine.finalize_object()
157270
engine.run_static_constructors()
158271

159-
addr = engine.get_function_address(f"bubble_sort_{dtype}")
272+
addr = engine.get_function_address(f"{algo}_sort_{dtype}")
160273
if not addr:
161-
raise RuntimeError(f"Failed to get address for bubble_sort_{dtype}")
274+
raise RuntimeError(f"Failed to get address for {algo}_sort_{dtype}")
162275

163-
_fn_ptr_cache[dtype] = addr
164-
_engines[dtype] = engine
276+
_fn_ptr_cache[key] = addr
277+
_engines[key] = engine
165278

166279
return addr
167280

pydatastructs/linear_data_structures/_backend/cpp/algorithms/quadratic_time_sort.hpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,5 +729,207 @@ static PyObject* insertion_sort(PyObject* self, PyObject* args, PyObject* kwds)
729729
return args0;
730730
}
731731

732+
static PyObject* insertion_sort_llvm(PyObject* self, PyObject* args, PyObject* kwds) {
733+
static const char* kwlist[] = {"arr", "start", "end", "comp", "dtype", NULL};
734+
PyObject* arr_obj = NULL;
735+
PyObject* start_obj = NULL;
736+
PyObject* end_obj = NULL;
737+
PyObject* comp_obj = NULL;
738+
const char* dtype_cstr = NULL;
739+
740+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OOOs", (char**)kwlist,
741+
&arr_obj, &start_obj, &end_obj, &comp_obj, &dtype_cstr)) {
742+
return NULL;
743+
}
744+
745+
Py_ssize_t arr_len = PyObject_Length(arr_obj);
746+
if (arr_len < 0) {
747+
return NULL;
748+
}
749+
750+
int start = 0;
751+
int end = (int)arr_len;
752+
if (start_obj && start_obj != Py_None)
753+
start = (int)PyLong_AsLong(start_obj);
754+
if (end_obj && end_obj != Py_None)
755+
end = (int)PyLong_AsLong(end_obj);
756+
757+
if (start < 0 || end > arr_len || start >= end) {
758+
Py_RETURN_NONE;
759+
}
760+
761+
bool is_dynamic = false;
762+
PyObject* buffer_obj = arr_obj;
763+
if (PyObject_HasAttrString(arr_obj, "_data")) {
764+
buffer_obj = PyObject_GetAttrString(arr_obj, "_data");
765+
is_dynamic = true;
766+
} else {
767+
Py_INCREF(buffer_obj);
768+
}
769+
770+
if (!PySequence_Check(buffer_obj)) {
771+
Py_DECREF(buffer_obj);
772+
PyErr_SetString(PyExc_TypeError, "Expected a sequence or dynamic array");
773+
return NULL;
774+
}
775+
776+
PyObject* first = PySequence_GetItem(buffer_obj, 0);
777+
if (!first) {
778+
Py_DECREF(buffer_obj);
779+
return NULL;
780+
}
781+
782+
const char* inferred_dtype = dtype_cstr ? dtype_cstr : (
783+
PyLong_Check(first) ? "int32" :
784+
PyFloat_Check(first) ? "float64" : NULL
785+
);
786+
787+
Py_DECREF(first);
788+
789+
if (!inferred_dtype) {
790+
Py_DECREF(buffer_obj);
791+
PyErr_SetString(PyExc_TypeError, "Unsupported element type");
792+
return NULL;
793+
}
794+
795+
PyObject* sys = PyImport_ImportModule("sys");
796+
PyObject* sys_path = PyObject_GetAttrString(sys, "path");
797+
Py_DECREF(sys);
798+
799+
Py_ssize_t original_len = PyList_Size(sys_path);
800+
if (original_len == -1) {
801+
Py_DECREF(sys_path);
802+
return NULL;
803+
}
804+
805+
PyObject* path = PyUnicode_FromString("pydatastructs/linear_data_structures/_backend/cpp/algorithms");
806+
if (!path) {
807+
Py_DECREF(sys_path);
808+
return NULL;
809+
}
810+
811+
int append_result = PyList_Append(sys_path, path);
812+
Py_DECREF(path);
813+
if (append_result != 0) {
814+
Py_DECREF(sys_path);
815+
return NULL;
816+
}
817+
818+
PyObject* mod = PyImport_ImportModule("llvm_algorithms");
819+
820+
if (PyList_SetSlice(sys_path, original_len, original_len + 1, NULL) != 0) {
821+
PyErr_Clear();
822+
}
823+
Py_DECREF(sys_path);
824+
825+
PyObject* fn = PyObject_GetAttrString(mod, "get_insertion_sort_ptr");
826+
Py_DECREF(mod);
827+
828+
PyObject* arg = PyUnicode_FromString(inferred_dtype);
829+
if (!arg) {
830+
Py_DECREF(fn);
831+
return NULL;
832+
}
833+
834+
PyObject* addr_obj = PyObject_CallFunctionObjArgs(fn, arg, NULL);
835+
Py_DECREF(fn);
836+
Py_DECREF(arg);
837+
if (!addr_obj) {
838+
Py_DECREF(buffer_obj);
839+
return NULL;
840+
}
841+
842+
long long addr = PyLong_AsLongLong(addr_obj);
843+
Py_DECREF(addr_obj);
844+
845+
if (PyErr_Occurred()) {
846+
Py_DECREF(buffer_obj);
847+
return NULL;
848+
}
849+
850+
if (strcmp(inferred_dtype, "int32") == 0) {
851+
typedef void (*fn_t)(int*, int);
852+
fn_t insertion_sort_fn = (fn_t)(intptr_t)addr;
853+
int* arr = (int*)malloc(sizeof(int) * arr_len);
854+
for (Py_ssize_t i = 0; i < arr_len; i++) {
855+
PyObject* item = PySequence_GetItem(buffer_obj, i);
856+
arr[i] = (int)PyLong_AsLong(item);
857+
Py_DECREF(item);
858+
}
859+
insertion_sort_fn(arr, (int)arr_len);
860+
for (Py_ssize_t i = 0; i < arr_len; i++) {
861+
PyObject* v = PyLong_FromLong(arr[i]);
862+
PySequence_SetItem(buffer_obj, i, v);
863+
Py_DECREF(v);
864+
}
865+
free(arr);
866+
} else if (strcmp(inferred_dtype, "int64") == 0) {
867+
typedef void (*fn_t)(long long*, int);
868+
fn_t insertion_sort_fn = (fn_t)(intptr_t)addr;
869+
long long* arr = (long long*)malloc(sizeof(long long) * arr_len);
870+
for (Py_ssize_t i = 0; i < arr_len; i++) {
871+
PyObject* item = PySequence_GetItem(buffer_obj, i);
872+
arr[i] = PyLong_AsLongLong(item);
873+
Py_DECREF(item);
874+
}
875+
insertion_sort_fn(arr, (int)arr_len);
876+
for (Py_ssize_t i = 0; i < arr_len; i++) {
877+
PyObject* v = PyLong_FromLongLong(arr[i]);
878+
PySequence_SetItem(buffer_obj, i, v);
879+
Py_DECREF(v);
880+
}
881+
free(arr);
882+
} else if (strcmp(inferred_dtype, "float32") == 0) {
883+
typedef void (*fn_t)(float*, int);
884+
fn_t insertion_sort_fn = (fn_t)(intptr_t)addr;
885+
float* arr = (float*)malloc(sizeof(float) * arr_len);
886+
for (Py_ssize_t i = 0; i < arr_len; i++) {
887+
PyObject* item = PySequence_GetItem(buffer_obj, i);
888+
arr[i] = (float)PyFloat_AsDouble(item);
889+
Py_DECREF(item);
890+
}
891+
insertion_sort_fn(arr, (int)arr_len);
892+
for (Py_ssize_t i = 0; i < arr_len; i++) {
893+
PyObject* v = PyFloat_FromDouble(arr[i]);
894+
PySequence_SetItem(buffer_obj, i, v);
895+
Py_DECREF(v);
896+
}
897+
free(arr);
898+
} else if (strcmp(inferred_dtype, "float64") == 0) {
899+
typedef void (*fn_t)(double*, int);
900+
fn_t insertion_sort_fn = (fn_t)(intptr_t)addr;
901+
double* arr = (double*)malloc(sizeof(double) * arr_len);
902+
for (Py_ssize_t i = 0; i < arr_len; i++) {
903+
PyObject* item = PySequence_GetItem(buffer_obj, i);
904+
arr[i] = PyFloat_AsDouble(item);
905+
Py_DECREF(item);
906+
}
907+
insertion_sort_fn(arr, (int)arr_len);
908+
for (Py_ssize_t i = 0; i < arr_len; i++) {
909+
PyObject* v = PyFloat_FromDouble(arr[i]);
910+
PySequence_SetItem(buffer_obj, i, v);
911+
Py_DECREF(v);
912+
}
913+
free(arr);
914+
} else {
915+
Py_DECREF(buffer_obj);
916+
PyErr_SetString(PyExc_TypeError, "Unsupported dtype for insertion_sort_llvm");
917+
return NULL;
918+
}
919+
920+
if (is_dynamic && PyObject_HasAttrString(arr_obj, "_modify")) {
921+
PyObject* modify_fn = PyObject_GetAttrString(arr_obj, "_modify");
922+
if (modify_fn) {
923+
PyObject_CallFunctionObjArgs(modify_fn, NULL);
924+
Py_DECREF(modify_fn);
925+
} else {
926+
PyErr_Clear();
927+
}
928+
}
929+
930+
Py_DECREF(buffer_obj);
931+
Py_INCREF(arr_obj);
932+
return arr_obj;
933+
}
732934

733935
#endif

pydatastructs/linear_data_structures/algorithms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,8 @@ def insertion_sort(array, **kwargs):
15131513
backend = kwargs.pop("backend", Backend.PYTHON)
15141514
if backend == Backend.CPP:
15151515
return _algorithms.insertion_sort(array, **kwargs)
1516+
if backend == Backend.LLVM:
1517+
return _algorithms.insertion_sort_llvm(array, **kwargs)
15161518
start = kwargs.get('start', 0)
15171519
end = kwargs.get('end', len(array) - 1)
15181520
comp = kwargs.get('comp', lambda u, v: u <= v)

pydatastructs/linear_data_structures/tests/test_algorithms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def test_selection_sort():
130130
def test_insertion_sort():
131131
_test_common_sort(insertion_sort)
132132
_test_common_sort(insertion_sort, backend=Backend.CPP)
133+
_test_common_sort(insertion_sort, backend=Backend.LLVM)
133134

134135
def test_matrix_multiply_parallel():
135136
ODA = OneDimensionalArray

0 commit comments

Comments
 (0)