Skip to content

Commit 3260d58

Browse files
update python wrapper to current state (#42)
1 parent dc8fb53 commit 3260d58

File tree

5 files changed

+11
-18
lines changed

5 files changed

+11
-18
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,4 @@ jobs:
7272
[ -z "$files" ] || clang-format --dry-run --Werror $files
7373
7474
- name: Bazel tests
75-
run: bazel test src:all
75+
run: bazel test src/...

src/py/simplex_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_create_simplex_decoder():
4545
decoder.decode_to_errors([1])
4646
assert decoder.mask_from_errors([1]) == 0
4747
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
48-
assert decoder.decode([1, 2]) == 0
48+
assert decoder.decode([1]) == 0
4949

5050

5151
if __name__ == "__main__":

src/py/tesseract_test.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,17 @@ def test_create_config():
3535

3636

3737
def test_create_node():
38-
node = tesseract_decoder.tesseract.Node(dets=["a"])
39-
assert node.dets == ["a"]
40-
41-
42-
def test_create_qnode():
43-
qnode = tesseract_decoder.tesseract.QNode(num_dets=5, errs=[42])
44-
assert qnode.num_dets == 5
45-
assert str(qnode) == "QNode(cost=0, num_dets=5, errs=[42])"
38+
node = tesseract_decoder.tesseract.Node(errors=[1, 0])
39+
assert node.errors == [1, 0]
4640

4741

4842
def test_create_decoder():
4943
config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)
5044
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
5145
decoder.decode_to_errors([0])
52-
decoder.decode_to_errors([0], 0)
46+
decoder.decode_to_errors(detections=[0], det_order=0, det_beam=0)
5347
assert decoder.mask_from_errors([1]) == 0
54-
assert decoder.cost_from_errors([1]) == pytest.approx(1.609438)
48+
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
5549
assert decoder.decode([0]) == 0
5650

5751

src/simplex.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,7 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
299299
}
300300

301301
// Get the model status
302-
[[maybe_unused]] const HighsModelStatus& model_status =
303-
highs->getModelStatus();
302+
[[maybe_unused]] const HighsModelStatus& model_status = highs->getModelStatus();
304303
assert(model_status == HighsModelStatus::kOptimal);
305304
}
306305

src/tesseract.pybind.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ void add_tesseract_module(py::module &root) {
4646
.def("__str__", &TesseractConfig::str);
4747

4848
py::class_<Node>(m, "Node")
49-
.def(py::init<double, size_t, std::vector<size_t>>(), py::arg("errs") = std::vector<size_t>(),
50-
py::arg("cost") = 0.0, py::arg("num_dets") = 0)
51-
.def_readwrite("errs", &Node::errors)
49+
.def(py::init<double, size_t, std::vector<size_t>>(), py::arg("cost") = 0.0,
50+
py::arg("num_detectors") = 0, py::arg("errors") = std::vector<size_t>())
51+
.def_readwrite("errors", &Node::errors)
5252
.def_readwrite("cost", &Node::cost)
53-
.def_readwrite("num_dets", &Node::num_detectors)
53+
.def_readwrite("num_detectors", &Node::num_detectors)
5454
.def(py::self > py::self)
5555
.def("__str__", &Node::str);
5656

0 commit comments

Comments
 (0)