Skip to content

Commit 0074dc0

Browse files
committed
Fix the performance issue of HiGHS by avoiding some slow API functions
1 parent 32abc16 commit 0074dc0

File tree

5 files changed

+101
-62
lines changed

5 files changed

+101
-62
lines changed

include/pyoptinterface/highs_model.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
B(Highs_addCol); \
1717
B(Highs_getNumCol); \
1818
B(Highs_changeColIntegrality); \
19-
B(Highs_passColName); \
2019
B(Highs_deleteColsBySet); \
2120
B(Highs_addRow); \
2221
B(Highs_getNumRow); \
23-
B(Highs_passRowName); \
2422
B(Highs_deleteRowsBySet); \
2523
B(Highs_passHessian); \
2624
B(Highs_changeColsCostByRange); \
@@ -50,11 +48,9 @@
5048
B(Highs_getInfoType); \
5149
B(Highs_getInt64InfoValue); \
5250
B(Highs_getDoubleInfoValue); \
53-
B(Highs_getColName); \
5451
B(Highs_getColIntegrality); \
5552
B(Highs_changeColsBoundsBySet); \
5653
B(Highs_getColsBySet); \
57-
B(Highs_getRowName); \
5854
B(Highs_getObjectiveSense); \
5955
B(Highs_getObjectiveValue); \
6056
B(Highs_getColsByRange); \
@@ -217,12 +213,19 @@ class POIHighsModel
217213
// So we need to keep track of binary variables
218214
Hashset<IndexT> binary_variables;
219215

216+
// Store the names internally because use HiGHS API to set them is very expensive
217+
Hashmap<IndexT, std::string> m_var_names, m_con_names;
218+
220219
/* Highs part */
221220
std::unique_ptr<void, HighsfreemodelT> m_model;
222221

223222
public:
224223
// cache the solution
225224
POIHighsSolution m_solution;
225+
// cache the number of variable and constraints because querying them via HiGHS API is very
226+
// expensive
227+
HighsInt m_n_variables = 0;
228+
HighsInt m_n_constraints = 0;
226229
};
227230

228-
using HighsModelMixin = CommercialSolverMixin<POIHighsModel>;
231+
using HighsModelMixin = CommercialSolverMixin<POIHighsModel>;

lib/highs_model.cpp

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,7 @@ VariableIndex POIHighsModel::add_variable(VariableDomain domain, double lb, doub
140140
auto error = highs::Highs_addCol(m_model.get(), 0.0, lb, ub, 0, nullptr, nullptr);
141141
check_error(error);
142142

143-
auto column = highs::Highs_getNumCol(m_model.get());
144-
// 0-based indexing
145-
column -= 1;
143+
auto column = m_n_variables;
146144

147145
if (domain != VariableDomain::Continuous)
148146
{
@@ -155,12 +153,16 @@ VariableIndex POIHighsModel::add_variable(VariableDomain domain, double lb, doub
155153
check_error(error);
156154
}
157155

156+
if (name != nullptr && name[0] == '\0')
157+
{
158+
name = nullptr;
159+
}
158160
if (name)
159161
{
160-
error = highs::Highs_passColName(m_model.get(), column, name);
161-
check_error(error);
162+
m_var_names.insert({variable.index, name});
162163
}
163164

165+
m_n_variables++;
164166
return variable;
165167
}
166168

@@ -177,6 +179,9 @@ void POIHighsModel::delete_variable(const VariableIndex &variable)
177179

178180
m_variable_index.delete_index(variable.index);
179181
binary_variables.erase(variable.index);
182+
183+
m_n_variables--;
184+
m_var_names.erase(variable.index);
180185
}
181186

182187
void POIHighsModel::delete_variables(const Vector<VariableIndex> &variables)
@@ -203,7 +208,9 @@ void POIHighsModel::delete_variables(const Vector<VariableIndex> &variables)
203208
for (int i = 0; i < n_variables; i++)
204209
{
205210
m_variable_index.delete_index(variables[i].index);
211+
m_var_names.erase(variables[i].index);
206212
}
213+
m_n_variables -= columns.size();
207214
}
208215

209216
bool POIHighsModel::is_variable_active(const VariableIndex &variable)
@@ -231,7 +238,7 @@ ConstraintIndex POIHighsModel::add_linear_constraint(const ScalarAffineFunction
231238
const char *name)
232239
{
233240
IndexT index = m_linear_constraint_index.add_index();
234-
ConstraintIndex constraint_index(ConstraintType::Linear, index);
241+
ConstraintIndex constraint(ConstraintType::Linear, index);
235242

236243
AffineFunctionPtrForm<HighsInt, HighsInt, double> ptr_form;
237244
ptr_form.make(this, function);
@@ -259,20 +266,20 @@ ConstraintIndex POIHighsModel::add_linear_constraint(const ScalarAffineFunction
259266
auto error = highs::Highs_addRow(m_model.get(), lb, ub, numnz, cind, cval);
260267
check_error(error);
261268

262-
HighsInt row = highs::Highs_getNumRow(m_model.get());
263-
// 0-based indexing
264-
row -= 1;
269+
HighsInt row = m_n_constraints;
270+
265271
if (name != nullptr && name[0] == '\0')
266272
{
267273
name = nullptr;
268274
}
269275
if (name)
270276
{
271-
error = highs::Highs_passRowName(m_model.get(), row, name);
272-
check_error(error);
277+
m_con_names.insert({constraint.index, name});
273278
}
274279

275-
return constraint_index;
280+
m_n_constraints++;
281+
282+
return constraint;
276283
}
277284

278285
ConstraintIndex POIHighsModel::add_quadratic_constraint(const ScalarQuadraticFunction &function,
@@ -294,21 +301,20 @@ void POIHighsModel::delete_constraint(const ConstraintIndex &constraint)
294301
check_error(error);
295302

296303
m_linear_constraint_index.delete_index(constraint.index);
304+
m_con_names.erase(constraint.index);
297305
}
298306

299307
bool POIHighsModel::is_constraint_active(const ConstraintIndex &constraint)
300308
{
301309
return m_linear_constraint_index.has_index(constraint.index);
302310
}
303311

304-
// #define private public
305-
// #include "Highs.h"
306312
void POIHighsModel::_set_affine_objective(const ScalarAffineFunction &function,
307313
ObjectiveSense sense, bool clear_quadratic)
308314
{
309315
HighsInt error;
310316

311-
HighsInt n_variables = highs::Highs_getNumCol(m_model.get());
317+
HighsInt n_variables = m_n_variables;
312318
if (clear_quadratic)
313319
{
314320
// First delete all quadratic terms
@@ -317,9 +323,6 @@ void POIHighsModel::_set_affine_objective(const ScalarAffineFunction &function,
317323
highs::Highs_passHessian(m_model.get(), n_variables, 0, kHighsHessianFormatTriangular,
318324
colstarts.data(), nullptr, nullptr);
319325

320-
// HighsModel &model = ((Highs *)m_model.get())->model_;
321-
// auto &hessian = model.hessian_;
322-
323326
check_error(error);
324327
}
325328

@@ -358,7 +361,7 @@ void POIHighsModel::set_objective(const ScalarQuadraticFunction &function, Objec
358361

359362
// Add quadratic term
360363
int numqnz = function.size();
361-
HighsInt n_variables = highs::Highs_getNumCol(m_model.get());
364+
HighsInt n_variables = m_n_variables;
362365
if (numqnz > 0)
363366
{
364367
CSCMatrix<HighsInt, HighsInt, double> csc;
@@ -432,8 +435,8 @@ void POIHighsModel::optimize()
432435
x.dual_solution_status = kHighsSolutionStatusNone;
433436
x.has_dual_ray = false;
434437
x.has_primal_ray = false;
435-
auto numCols = highs::Highs_getNumCols(model);
436-
auto numRows = highs::Highs_getNumRows(model);
438+
auto numCols = m_n_variables;
439+
auto numRows = m_n_constraints;
437440
x.model_status = highs::Highs_getModelStatus(model);
438441

439442
HighsInt status;
@@ -604,18 +607,20 @@ double POIHighsModel::get_raw_info_double(const char *info_name)
604607

605608
std::string POIHighsModel::get_variable_name(const VariableIndex &variable)
606609
{
607-
auto column = _checked_variable_index(variable);
608-
char name[kHighsMaximumStringLength];
609-
auto error = highs::Highs_getColName(m_model.get(), column, name);
610-
check_error(error);
611-
return std::string(name);
610+
auto iter = m_var_names.find(variable.index);
611+
if (iter != m_var_names.end())
612+
{
613+
return iter->second;
614+
}
615+
else
616+
{
617+
return fmt::format("x{}", variable.index);
618+
}
612619
}
613620

614621
void POIHighsModel::set_variable_name(const VariableIndex &variable, const char *name)
615622
{
616-
auto column = _checked_variable_index(variable);
617-
auto error = highs::Highs_passColName(m_model.get(), column, name);
618-
check_error(error);
623+
m_var_names[variable.index] = name;
619624
}
620625

621626
VariableDomain POIHighsModel::get_variable_type(const VariableIndex &variable)
@@ -702,18 +707,20 @@ void POIHighsModel::set_variable_upper_bound(const VariableIndex &variable, doub
702707

703708
std::string POIHighsModel::get_constraint_name(const ConstraintIndex &constraint)
704709
{
705-
auto row = _checked_constraint_index(constraint);
706-
char name[kHighsMaximumStringLength];
707-
auto error = highs::Highs_getRowName(m_model.get(), row, name);
708-
check_error(error);
709-
return std::string(name);
710+
auto iter = m_con_names.find(constraint.index);
711+
if (iter != m_con_names.end())
712+
{
713+
return iter->second;
714+
}
715+
else
716+
{
717+
return fmt::format("con{}", constraint.index);
718+
}
710719
}
711720

712721
void POIHighsModel::set_constraint_name(const ConstraintIndex &constraint, const char *name)
713722
{
714-
auto row = _checked_constraint_index(constraint);
715-
auto error = highs::Highs_passRowName(m_model.get(), row, name);
716-
check_error(error);
723+
m_con_names[constraint.index] = name;
717724
}
718725

719726
double POIHighsModel::get_constraint_primal(const ConstraintIndex &constraint)
@@ -805,7 +812,7 @@ void POIHighsModel::set_primal_start(const Vector<VariableIndex> &variables,
805812
if (numnz == 0)
806813
return;
807814

808-
auto numcol = highs::Highs_getNumCol(m_model.get());
815+
auto numcol = m_n_variables;
809816
if (numcol == 0)
810817
return;
811818

lib/highs_model_ext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ NB_MODULE(highs_model_ext, m)
3333
#define BIND_F(f) .def(#f, &HighsModelMixin::f)
3434
nb::class_<HighsModelMixin, POIHighsModel>(m, "RawModel")
3535
.def(nb::init<>())
36+
.def_ro("m_n_variables", &HighsModelMixin::m_n_variables)
37+
.def_ro("m_n_constraints", &HighsModelMixin::m_n_constraints)
3638
// clang-format off
3739
BIND_F(init)
3840
BIND_F(write)

0 commit comments

Comments
 (0)