@@ -167,3 +167,131 @@ def _materialize(dtype: str) -> int:
167167
168168 except Exception as e :
169169 raise RuntimeError (f"Failed to materialize function for dtype { dtype } : { e } " )
170+
171+ def get_selection_sort_ptr (dtype : str ) -> int :
172+ """Get function pointer for selection sort with specified dtype."""
173+ dtype = dtype .lower ().strip ()
174+ if dtype not in _SUPPORTED :
175+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
176+
177+ return _materialize_selection (dtype )
178+
179+
180+ def _build_selection_sort_ir (dtype : str ) -> str :
181+ if dtype not in _SUPPORTED :
182+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
183+
184+ T , _ = _SUPPORTED [dtype ]
185+ i32 = ir .IntType (32 )
186+ i64 = ir .IntType (64 )
187+
188+ mod = ir .Module (name = f"selection_sort_{ dtype } _module" )
189+ fn_name = f"selection_sort_{ dtype } "
190+
191+ fn_ty = ir .FunctionType (ir .VoidType (), [T .as_pointer (), i32 ])
192+ fn = ir .Function (mod , fn_ty , name = fn_name )
193+
194+ arr , n = fn .args
195+ arr .name , n .name = "arr" , "n"
196+
197+ # Basic blocks
198+ b_entry = fn .append_basic_block ("entry" )
199+ b_outer = fn .append_basic_block ("outer" )
200+ b_inner = fn .append_basic_block ("inner" )
201+ b_inner_latch = fn .append_basic_block ("inner.latch" )
202+ b_swap = fn .append_basic_block ("swap" )
203+ b_exit = fn .append_basic_block ("exit" )
204+
205+ b = ir .IRBuilder (b_entry )
206+ cond_trivial = b .icmp_signed ("<=" , n , ir .Constant (i32 , 1 ))
207+ b .cbranch (cond_trivial , b_exit , b_outer )
208+
209+ # Outer loop
210+ b .position_at_end (b_outer )
211+ i_phi = b .phi (i32 , name = "i" )
212+ i_phi .add_incoming (ir .Constant (i32 , 0 ), b_entry ) # start at 0
213+
214+ cond_outer = b .icmp_signed ("<" , i_phi , n )
215+ b .cbranch (cond_outer , b_inner , b_exit )
216+
217+ # Inner loop: find min index
218+ b .position_at_end (b_inner )
219+ min_idx = b_phi = b_phi_i = b .phi (i32 , name = "min_idx" )
220+ min_idx .add_incoming (i_phi , b_outer ) # initial min_idx = i
221+
222+ j_phi = b .phi (i32 , name = "j" )
223+ j_phi .add_incoming (b .add (i_phi , ir .Constant (i32 , 1 )), b_outer )
224+
225+ cond_inner = b .icmp_signed ("<" , j_phi , n )
226+ b .cbranch (cond_inner , b_inner_latch , b_swap )
227+
228+ # Compare and update min_idx
229+ b .position_at_end (b_inner_latch )
230+ j64 = b .sext (j_phi , i64 )
231+ min64 = b .sext (min_idx , i64 )
232+ arr_j_ptr = b .gep (arr , [j64 ], inbounds = True )
233+ arr_min_ptr = b .gep (arr , [min64 ], inbounds = True )
234+ arr_j_val = b .load (arr_j_ptr )
235+ arr_min_val = b .load (arr_min_ptr )
236+
237+ if isinstance (T , ir .IntType ):
238+ cmp = b .icmp_signed ("<" , arr_j_val , arr_min_val )
239+ else :
240+ cmp = b .fcmp_ordered ("<" , arr_j_val , arr_min_val )
241+
242+ with b .if_then (cmp ):
243+ min_idx = j_phi # update min_idx
244+
245+ j_next = b .add (j_phi , ir .Constant (i32 , 1 ))
246+ j_phi .add_incoming (j_next , b_inner_latch )
247+ min_idx .add_incoming (min_idx , b_inner_latch ) # propagate current min_idx
248+ b .branch (b_inner )
249+
250+ # Swap arr[i] and arr[min_idx]
251+ b .position_at_end (b_swap )
252+ i64 = b .sext (i_phi , i64 )
253+ min64 = b .sext (min_idx , i64 )
254+ ptr_i = b .gep (arr , [i64 ], inbounds = True )
255+ ptr_min = b .gep (arr , [min64 ], inbounds = True )
256+ val_i = b .load (ptr_i )
257+ val_min = b .load (ptr_min )
258+ b .store (val_min , ptr_i )
259+ b .store (val_i , ptr_min )
260+
261+ # Increment i
262+ i_next = b .add (i_phi , ir .Constant (i32 , 1 ))
263+ i_phi .add_incoming (i_next , b_swap )
264+ b .branch (b_outer )
265+
266+ # Exit
267+ b .position_at_end (b_exit )
268+ b .ret_void ()
269+
270+ return str (mod )
271+
272+
273+ def _materialize_selection (dtype : str ) -> int :
274+ _ensure_target_machine ()
275+
276+ name = f"selection_sort_{ dtype } "
277+ if dtype in _fn_ptr_cache :
278+ return _fn_ptr_cache [dtype ]
279+
280+ try :
281+ llvm_ir = _build_selection_sort_ir (dtype )
282+ mod = binding .parse_assembly (llvm_ir )
283+ mod .verify ()
284+
285+ engine = binding .create_mcjit_compiler (mod , _target_machine )
286+ engine .finalize_object ()
287+ engine .run_static_constructors ()
288+
289+ addr = engine .get_function_address (name )
290+ if not addr :
291+ raise RuntimeError (f"Failed to get address for { name } " )
292+
293+ _fn_ptr_cache [dtype ] = addr
294+ _engines [dtype ] = engine
295+ return addr
296+ except Exception as e :
297+ raise RuntimeError (f"Failed to materialize function for dtype { dtype } : { e } " )
0 commit comments