@@ -205,17 +205,46 @@ def create_indirect_load_store(module: ir.Module):
205205 builder .ret_void ()
206206
207207
208+ op2name = {
209+ graph_op .abs : "fabs" ,
210+ graph_op .acos : "acos" ,
211+ graph_op .acosh : "acosh" ,
212+ graph_op .asin : "asin" ,
213+ graph_op .asinh : "asinh" ,
214+ graph_op .atan : "atan" ,
215+ graph_op .atanh : "atanh" ,
216+ graph_op .cos : "cos" ,
217+ graph_op .cosh : "cosh" ,
218+ graph_op .erf : "erf" ,
219+ graph_op .erfc : "erfc" ,
220+ graph_op .exp : "exp" ,
221+ graph_op .expm1 : "expm1" ,
222+ graph_op .log1p : "log1p" ,
223+ graph_op .log : "log" ,
224+ graph_op .pow : "pow" ,
225+ graph_op .sin : "sin" ,
226+ graph_op .sinh : "sinh" ,
227+ graph_op .sqrt : "sqrt" ,
228+ graph_op .tan : "tan" ,
229+ graph_op .tanh : "tanh"
230+ }
231+
232+ math_ops = set (op2name .keys ())
233+
234+ binary_ops = set ([graph_op .pow ])
235+
208236def create_llvmir_basic_functions (module : ir .Module ):
209237 create_azmul (module )
210238 create_sign (module )
211239 create_direct_load_store (module )
212240 create_indirect_load_store (module )
213241
214- sqrt = ir .Function (module , ir .FunctionType (D , [D ]), name = "sqrt" )
215- sin = ir .Function (module , ir .FunctionType (D , [D ]), name = "sin" )
216- cos = ir .Function (module , ir .FunctionType (D , [D ]), name = "cos" )
217- exp = ir .Function (module , ir .FunctionType (D , [D ]), name = "exp" )
218- log = ir .Function (module , ir .FunctionType (D , [D ]), name = "log" )
242+ for (op , op_name ) in op2name .items ():
243+ if op in binary_ops :
244+ func_type = ir .FunctionType (D , [D , D ])
245+ else :
246+ func_type = ir .FunctionType (D , [D ])
247+ ir .Function (module , func_type , name = op_name )
219248
220249 # sin = module.declare_intrinsic('llvm.sin', [D])
221250 # cos = module.declare_intrinsic('llvm.cos', [D])
@@ -319,11 +348,13 @@ def generate_llvmir_from_graph(
319348
320349 # sin = module.get_global("llvm.sin.f64")
321350 # cos = module.get_global("llvm.cos.f64")
322- sqrt = module .get_global ("sqrt" )
323- sin = module .get_global ("sin" )
324- cos = module .get_global ("cos" )
325- exp = module .get_global ("exp" )
326- log = module .get_global ("log" )
351+ math_functions = dict ()
352+ for op_name in op2name .values ():
353+ op_function = module .get_global (op_name )
354+ if op_function is None :
355+ raise ValueError (f"Math function { op_name } not found in module" )
356+ math_functions [op_name ] = op_function
357+
327358 azmul = module .get_global ("azmul" )
328359 sign = module .get_global ("sign" )
329360 load_direct = module .get_global ("load_direct" )
@@ -403,7 +434,7 @@ def get_node_value(node: int):
403434 return val
404435
405436 for iter in graph_obj :
406- op_enum = iter .op_enum
437+ op = iter .op_enum
407438 n_result = iter .n_result
408439 arg_node = iter .arg_node
409440
@@ -414,31 +445,30 @@ def get_node_value(node: int):
414445 if len (arg_node ) == 2 :
415446 arg2 = get_node_value (arg_node [1 ])
416447
417- if op_enum == graph_op .add :
448+ if op == graph_op .add :
418449 ret_val = builder .fadd (arg1 , arg2 )
419- elif op_enum == graph_op .sub :
450+ elif op == graph_op .sub :
420451 ret_val = builder .fsub (arg1 , arg2 )
421- elif op_enum == graph_op .mul :
452+ elif op == graph_op .mul :
422453 ret_val = builder .fmul (arg1 , arg2 )
423- elif op_enum == graph_op .div :
454+ elif op == graph_op .div :
424455 ret_val = builder .fdiv (arg1 , arg2 )
425- elif op_enum == graph_op .sqrt :
426- ret_val = builder .call (sqrt , [arg1 ])
427- elif op_enum == graph_op .sin :
428- ret_val = builder .call (sin , [arg1 ])
429- elif op_enum == graph_op .cos :
430- ret_val = builder .call (cos , [arg1 ])
431- elif op_enum == graph_op .exp :
432- ret_val = builder .call (exp , [arg1 ])
433- elif op_enum == graph_op .log :
434- ret_val = builder .call (log , [arg1 ])
435- elif op_enum == graph_op .azmul :
456+ elif op == graph_op .azmul :
436457 ret_val = builder .fmul (arg1 , arg2 )
437458 # ret_val = builder.call(azmul, [arg1, arg2])
438- elif op_enum == graph_op .neg :
459+ elif op == graph_op .neg :
439460 ret_val = builder .fneg (arg1 )
461+ elif op == graph_op .sign :
462+ ret_val = builder .call (sign , [arg1 ])
463+ elif op in math_ops :
464+ op_name = op2name [op ]
465+ op_function = math_functions [op_name ]
466+ if op in binary_ops :
467+ ret_val = builder .call (op_function , [arg1 , arg2 ])
468+ else :
469+ ret_val = builder .call (op_function , [arg1 ])
440470 else :
441- raise ValueError (f"Unknown op_enum: { op_enum } " )
471+ raise ValueError (f"Unknown op_enum: { op } " )
442472
443473 ret_val .name = f"v[{ result_node } ]"
444474 v_dict [result_node ] = ret_val
0 commit comments