Skip to content

Commit 215be68

Browse files
committed
Implement more nonlinear operators
1 parent 397598e commit 215be68

File tree

8 files changed

+274
-65
lines changed

8 files changed

+274
-65
lines changed

lib/ipopt_model_ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ NB_MODULE(ipopt_model_ext, m)
2424
.value("Restoration_Failed", ApplicationReturnStatus::Restoration_Failed)
2525
.value("Error_In_Step_Computation", ApplicationReturnStatus::Error_In_Step_Computation)
2626
.value("Maximum_CpuTime_Exceeded", ApplicationReturnStatus::Maximum_CpuTime_Exceeded)
27+
.value("Maximum_WallTime_Exceeded", ApplicationReturnStatus::Maximum_WallTime_Exceeded)
2728
.value("Not_Enough_Degrees_Of_Freedom",
2829
ApplicationReturnStatus::Not_Enough_Degrees_Of_Freedom)
2930
.value("Invalid_Problem_Definition", ApplicationReturnStatus::Invalid_Problem_Definition)

lib/main.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,28 @@ auto test_highs() -> void
318318
model.optimize();
319319
}
320320

321-
auto test_highs_capi() -> void
321+
auto debug_highs_passname(int N) -> void
322322
{
323323
void *highs = highs::Highs_create();
324+
325+
auto start = std::chrono::high_resolution_clock::now();
326+
327+
HighsModelMixin model;
328+
for (auto i = 0; i < N; i++)
329+
{
330+
highs::Highs_addCol(highs, 0.0, 0.0, 0.0, 0, nullptr, nullptr);
324331

325-
double c[] = {0.0, 1.0, 0.0};
326-
double lower[] = {0, 0, 0};
327-
double upper[] = {1.0, 1.0, 1.0};
332+
auto col = highs::Highs_getNumCol(highs);
328333

329-
// highs::Highs_addCols(highs, 3, c, lower, upper, 0, nullptr, nullptr, nullptr);
334+
assert(col == i + 1);
330335

331-
auto hessian_nz = highs::Highs_getHessianNumNz(highs);
332-
highs::Highs_run(highs);
333-
hessian_nz = highs::Highs_getHessianNumNz(highs);
336+
auto name = fmt::format("x{}", i);
337+
highs::Highs_passColName(highs, i, name.c_str());
338+
}
339+
340+
auto end = std::chrono::high_resolution_clock::now();
341+
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
342+
fmt::print("N={}, time={} milliseconds\n", N, duration.count());
334343
}
335344

336345
void debug_highs(int N)
@@ -350,16 +359,18 @@ void debug_highs(int N)
350359

351360
void test_highs_add_variable()
352361
{
353-
std::vector<int> Ns{100, 1000, 5000, 10000};
362+
std::vector<int> Ns{1000, 10000, 100000, 1000000};
354363

355364
for (const auto& N : Ns)
356365
{
357-
debug_highs(N);
366+
debug_highs_passname(N);
358367
}
359368
}
360369

361370
auto main() -> int
362371
{
372+
highs::load_library("E:\\HiGHS\\install\\bin\\highs.dll");
373+
363374
test_highs_add_variable();
364375
return 0;
365376
}

lib/nlcore_ext.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
namespace nb = nanobind;
99

1010
#include "pyoptinterface/nlcore.hpp"
11+
#include "cppad/utility/pow_int.hpp"
1112

1213
using a_double = CppAD::AD<double>;
1314
using advec = std::vector<a_double>;
@@ -45,12 +46,6 @@ NB_MODULE(nlcore_ext, m)
4546
.def(nb::self *= double())
4647
.def(nb::self /= double());
4748

48-
m.def("sqrt", [](const a_double &x) { return CppAD::sqrt(x); });
49-
m.def("sin", [](const a_double &x) { return CppAD::sin(x); });
50-
m.def("cos", [](const a_double &x) { return CppAD::cos(x); });
51-
m.def("exp", [](const a_double &x) { return CppAD::exp(x); });
52-
m.def("log", [](const a_double &x) { return CppAD::log(x); });
53-
5449
nb::bind_vector<advec, nb::rv_policy::reference_internal>(m, "advec");
5550

5651
m.def("Independent", nb::overload_cast<advec &>(&CppAD::Independent<advec>), nb::arg("x"));
@@ -129,6 +124,31 @@ NB_MODULE(nlcore_ext, m)
129124
m.def("sparse_jacobian", &sparse_jacobian<double>);
130125
m.def("sparse_hessian", &sparse_hessian<double>);
131126

127+
m.def("abs", [](const a_double &x) { return CppAD::abs(x); });
128+
m.def("acos", [](const a_double &x) { return CppAD::acos(x); });
129+
m.def("acosh", [](const a_double &x) { return CppAD::acosh(x); });
130+
m.def("asin", [](const a_double &x) { return CppAD::asin(x); });
131+
m.def("asinh", [](const a_double &x) { return CppAD::asinh(x); });
132+
m.def("atan", [](const a_double &x) { return CppAD::atan(x); });
133+
m.def("atanh", [](const a_double &x) { return CppAD::atanh(x); });
134+
m.def("cos", [](const a_double &x) { return CppAD::cos(x); });
135+
m.def("cosh", [](const a_double &x) { return CppAD::cosh(x); });
136+
m.def("erf", [](const a_double &x) { return CppAD::erf(x); });
137+
m.def("erfc", [](const a_double &x) { return CppAD::erfc(x); });
138+
m.def("exp", [](const a_double &x) { return CppAD::exp(x); });
139+
m.def("expm1", [](const a_double &x) { return CppAD::expm1(x); });
140+
m.def("log1p", [](const a_double &x) { return CppAD::log1p(x); });
141+
m.def("log", [](const a_double &x) { return CppAD::log(x); });
142+
m.def("pow", [](const a_double &x, const int &y) { return CppAD::pow(x, y); });
143+
m.def("pow", [](const a_double &x, const double &y) { return CppAD::pow(x, y); });
144+
m.def("pow", [](const double &x, const a_double &y) { return CppAD::pow(x, y); });
145+
m.def("pow", [](const a_double &x, const a_double &y) { return CppAD::pow(x, y); });
146+
m.def("sin", [](const a_double &x) { return CppAD::sin(x); });
147+
m.def("sinh", [](const a_double &x) { return CppAD::sinh(x); });
148+
m.def("sqrt", [](const a_double &x) { return CppAD::sqrt(x); });
149+
m.def("tan", [](const a_double &x) { return CppAD::tan(x); });
150+
m.def("tanh", [](const a_double &x) { return CppAD::tanh(x); });
151+
132152
nb::enum_<graph_op_enum>(m, "graph_op")
133153
.value("abs", graph_op_enum::abs_graph_op)
134154
.value("acos", graph_op_enum::acos_graph_op)

src/pyoptinterface/__init__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,29 @@
2626

2727
from pyoptinterface._src.aml import make_nd_variable, quicksum, quicksum_
2828

29-
from pyoptinterface._src.nlcore_ext import sqrt, sin, cos, exp, log
29+
from pyoptinterface._src.nlcore_ext import (
30+
abs,
31+
acos,
32+
acosh,
33+
asin,
34+
asinh,
35+
atan,
36+
atanh,
37+
cos,
38+
cosh,
39+
erf,
40+
erfc,
41+
exp,
42+
expm1,
43+
log1p,
44+
log,
45+
pow,
46+
sin,
47+
sinh,
48+
sqrt,
49+
tan,
50+
tanh,
51+
)
3052

3153
# Alias of ConstraintSense
3254
Eq = ConstraintSense.Equal

src/pyoptinterface/_src/codegen_c.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,31 @@
66
from typing import IO
77

88
op2name = {
9+
graph_op.abs: "fabs",
10+
graph_op.acos: "acos",
11+
graph_op.acosh: "acosh",
12+
graph_op.asin: "asin",
13+
graph_op.asinh: "asinh",
14+
graph_op.atan: "atan",
15+
graph_op.atanh: "atanh",
16+
graph_op.cos: "cos",
17+
graph_op.cosh: "cosh",
18+
graph_op.erf: "erf",
19+
graph_op.erfc: "erfc",
20+
graph_op.exp: "exp",
21+
graph_op.expm1: "expm1",
22+
graph_op.log1p: "log1p",
23+
graph_op.log: "log",
24+
graph_op.pow: "pow",
25+
graph_op.sin: "sin",
26+
graph_op.sinh: "sinh",
27+
graph_op.sqrt: "sqrt",
28+
graph_op.tan: "tan",
29+
graph_op.tanh: "tanh",
930
graph_op.add: "+",
1031
graph_op.sub: "-",
1132
graph_op.mul: "*",
1233
graph_op.div: "/",
13-
graph_op.sqrt: "sqrt",
14-
graph_op.sin: "sin",
15-
graph_op.cos: "cos",
16-
graph_op.exp: "exp",
17-
graph_op.log: "log",
1834
graph_op.azmul: "*",
1935
graph_op.neg: "-",
2036
}

src/pyoptinterface/_src/codegen_llvm.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,46 @@ def create_indirect_load_store(module: ir.Module):
205205
builder.ret_void()
206206

207207

208+
op2name = {
209+
graph_op.abs: "fabs",
210+
graph_op.acos: "acos",
211+
graph_op.acosh: "acosh",
212+
graph_op.asin: "asin",
213+
graph_op.asinh: "asinh",
214+
graph_op.atan: "atan",
215+
graph_op.atanh: "atanh",
216+
graph_op.cos: "cos",
217+
graph_op.cosh: "cosh",
218+
graph_op.erf: "erf",
219+
graph_op.erfc: "erfc",
220+
graph_op.exp: "exp",
221+
graph_op.expm1: "expm1",
222+
graph_op.log1p: "log1p",
223+
graph_op.log: "log",
224+
graph_op.pow: "pow",
225+
graph_op.sin: "sin",
226+
graph_op.sinh: "sinh",
227+
graph_op.sqrt: "sqrt",
228+
graph_op.tan: "tan",
229+
graph_op.tanh: "tanh"
230+
}
231+
232+
math_ops = set(op2name.keys())
233+
234+
binary_ops = set([graph_op.pow])
235+
208236
def create_llvmir_basic_functions(module: ir.Module):
209237
create_azmul(module)
210238
create_sign(module)
211239
create_direct_load_store(module)
212240
create_indirect_load_store(module)
213241

214-
sqrt = ir.Function(module, ir.FunctionType(D, [D]), name="sqrt")
215-
sin = ir.Function(module, ir.FunctionType(D, [D]), name="sin")
216-
cos = ir.Function(module, ir.FunctionType(D, [D]), name="cos")
217-
exp = ir.Function(module, ir.FunctionType(D, [D]), name="exp")
218-
log = ir.Function(module, ir.FunctionType(D, [D]), name="log")
242+
for (op, op_name) in op2name.items():
243+
if op in binary_ops:
244+
func_type = ir.FunctionType(D, [D, D])
245+
else:
246+
func_type = ir.FunctionType(D, [D])
247+
ir.Function(module, func_type, name=op_name)
219248

220249
# sin = module.declare_intrinsic('llvm.sin', [D])
221250
# cos = module.declare_intrinsic('llvm.cos', [D])
@@ -319,11 +348,13 @@ def generate_llvmir_from_graph(
319348

320349
# sin = module.get_global("llvm.sin.f64")
321350
# cos = module.get_global("llvm.cos.f64")
322-
sqrt = module.get_global("sqrt")
323-
sin = module.get_global("sin")
324-
cos = module.get_global("cos")
325-
exp = module.get_global("exp")
326-
log = module.get_global("log")
351+
math_functions = dict()
352+
for op_name in op2name.values():
353+
op_function = module.get_global(op_name)
354+
if op_function is None:
355+
raise ValueError(f"Math function {op_name} not found in module")
356+
math_functions[op_name] = op_function
357+
327358
azmul = module.get_global("azmul")
328359
sign = module.get_global("sign")
329360
load_direct = module.get_global("load_direct")
@@ -403,7 +434,7 @@ def get_node_value(node: int):
403434
return val
404435

405436
for iter in graph_obj:
406-
op_enum = iter.op_enum
437+
op = iter.op_enum
407438
n_result = iter.n_result
408439
arg_node = iter.arg_node
409440

@@ -414,31 +445,30 @@ def get_node_value(node: int):
414445
if len(arg_node) == 2:
415446
arg2 = get_node_value(arg_node[1])
416447

417-
if op_enum == graph_op.add:
448+
if op == graph_op.add:
418449
ret_val = builder.fadd(arg1, arg2)
419-
elif op_enum == graph_op.sub:
450+
elif op == graph_op.sub:
420451
ret_val = builder.fsub(arg1, arg2)
421-
elif op_enum == graph_op.mul:
452+
elif op == graph_op.mul:
422453
ret_val = builder.fmul(arg1, arg2)
423-
elif op_enum == graph_op.div:
454+
elif op == graph_op.div:
424455
ret_val = builder.fdiv(arg1, arg2)
425-
elif op_enum == graph_op.sqrt:
426-
ret_val = builder.call(sqrt, [arg1])
427-
elif op_enum == graph_op.sin:
428-
ret_val = builder.call(sin, [arg1])
429-
elif op_enum == graph_op.cos:
430-
ret_val = builder.call(cos, [arg1])
431-
elif op_enum == graph_op.exp:
432-
ret_val = builder.call(exp, [arg1])
433-
elif op_enum == graph_op.log:
434-
ret_val = builder.call(log, [arg1])
435-
elif op_enum == graph_op.azmul:
456+
elif op == graph_op.azmul:
436457
ret_val = builder.fmul(arg1, arg2)
437458
# ret_val = builder.call(azmul, [arg1, arg2])
438-
elif op_enum == graph_op.neg:
459+
elif op == graph_op.neg:
439460
ret_val = builder.fneg(arg1)
461+
elif op == graph_op.sign:
462+
ret_val = builder.call(sign, [arg1])
463+
elif op in math_ops:
464+
op_name = op2name[op]
465+
op_function = math_functions[op_name]
466+
if op in binary_ops:
467+
ret_val = builder.call(op_function, [arg1, arg2])
468+
else:
469+
ret_val = builder.call(op_function, [arg1])
440470
else:
441-
raise ValueError(f"Unknown op_enum: {op_enum}")
471+
raise ValueError(f"Unknown op_enum: {op}")
442472

443473
ret_val.name = f"v[{result_node}]"
444474
v_dict[result_node] = ret_val

src/pyoptinterface/_src/ipopt.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .jit_llvm import LLJITCompiler
1111
from .tracefun import trace_adfun
1212

13+
from .core_ext import ConstraintIndex
14+
from .nlcore_ext import NLConstraintIndex
15+
1316
from .attributes import (
1417
VariableAttribute,
1518
ConstraintAttribute,
@@ -282,13 +285,33 @@ def get_terminationstatus(model):
282285
),
283286
}
284287

288+
289+
def get_constraint_primal(model, constraint):
290+
if isinstance(constraint, ConstraintIndex):
291+
return model.get_constraint_primal(constraint.index)
292+
elif isinstance(constraint, NLConstraintIndex):
293+
index = constraint.index
294+
dim = constraint.dim
295+
values = [model.get_constraint_primal(index + i) for i in range(dim)]
296+
return values
297+
298+
raise ValueError(f"Unknown constraint type: {type(constraint)}")
299+
300+
def get_constraint_dual(model, constraint):
301+
if isinstance(constraint, ConstraintIndex):
302+
return model.get_constraint_dual(constraint.index)
303+
elif isinstance(constraint, NLConstraintIndex):
304+
index = constraint.index
305+
dim = constraint.dim
306+
values = [model.get_constraint_dual(index + i) for i in range(dim)]
307+
return values
308+
309+
raise ValueError(f"Unknown constraint type: {type(constraint)}")
310+
311+
285312
constraint_attribute_get_func_map = {
286-
ConstraintAttribute.Primal: lambda model, constraint: model.get_constraint_primal(
287-
constraint
288-
),
289-
ConstraintAttribute.Dual: lambda model, constraint: model.get_constraint_dual(
290-
constraint
291-
),
313+
ConstraintAttribute.Primal: get_constraint_primal,
314+
ConstraintAttribute.Dual: get_constraint_dual,
292315
}
293316

294317
constraint_attribute_set_func_map = {}
@@ -332,8 +355,8 @@ def optimize(self, jit_engine="LLVM"):
332355

333356
super().optimize()
334357

335-
def register_function(self, f, /, x, name, p=()):
336-
adfun = trace_adfun(f, x, p)
358+
def register_function(self, f, /, var, name, param=()):
359+
adfun = trace_adfun(f, var, param)
337360
return super().register_function(adfun, name)
338361

339362
def get_variable_attribute(self, variable, attribute: VariableAttribute):

0 commit comments

Comments
 (0)