Skip to content

Commit 9c9454a

Browse files
Switch from using stim_py to use py::object in python wrapper (#48)
following the discussion in quantumlib/Stim#963, this will unblock #41
1 parent 0471d4e commit 9c9454a

File tree

10 files changed

+166
-98
lines changed

10 files changed

+166
-98
lines changed

WORKSPACE

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,3 @@ http_archive(
6666
urls = ["https://github.com/bazelbuild/platforms/archive/refs/tags/0.0.6.zip"],
6767
strip_prefix = "platforms-0.0.6",
6868
)
69-
70-
http_archive(
71-
name = "stim_py",
72-
build_file = "//external:stim_py.BUILD",
73-
sha256 = "95236006859d6754be99629d4fb44788e742e962ac8c59caad421ca088f7350e",
74-
strip_prefix = "stim-1.15.0",
75-
urls = ["https://github.com/quantumlib/Stim/releases/download/v1.15.0/stim-1.15.0.tar.gz"],
76-
)

external/stim_py.BUILD

Lines changed: 0 additions & 64 deletions
This file was deleted.

src/BUILD

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ cc_library(
6767
pybind_library(
6868
name = "tesseract_decoder_pybind",
6969
srcs = [
70+
"stim_utils.pybind.h",
7071
"common.pybind.h",
7172
"utils.pybind.h",
7273
"simplex.pybind.h",
@@ -77,7 +78,6 @@ pybind_library(
7778
":libutils",
7879
":libsimplex",
7980
":libtesseract",
80-
"@stim_py//:stim_pybind_lib",
8181
],
8282
)
8383

@@ -88,7 +88,6 @@ pybind_extension(
8888
],
8989
deps = [
9090
":tesseract_decoder_pybind",
91-
"@stim_py//:stim",
9291
],
9392
)
9493

@@ -213,4 +212,3 @@ cc_binary(
213212
"@nlohmann_json//:json",
214213
],
215214
)
216-

src/common.pybind.h

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
#include <vector>
2323

2424
#include "common.h"
25-
#include "src/stim/dem/dem_instruction.pybind.h"
26-
#include "stim/dem/detector_error_model_target.pybind.h"
25+
#include "stim_utils.pybind.h"
2726

2827
namespace py = pybind11;
2928

@@ -39,8 +38,9 @@ void add_common_module(py::module &root) {
3938
.def(py::self == py::self)
4039
.def(py::self != py::self)
4140
.def("as_dem_instruction_targets", [](common::Symptom s) {
42-
std::vector<stim_pybind::ExposedDemTarget> ret;
43-
for (auto &t : s.as_dem_instruction_targets()) ret.emplace_back(t);
41+
std::vector<py::object> ret;
42+
for (auto &t : s.as_dem_instruction_targets())
43+
ret.push_back(make_py_object(t, "DemTarget"));
4444
return ret;
4545
});
4646

@@ -57,15 +57,38 @@ void add_common_module(py::module &root) {
5757
std::vector<bool> &>(),
5858
py::arg("likelihood_cost"), py::arg("probability"), py::arg("detectors"),
5959
py::arg("observables"), py::arg("dets_array"))
60-
.def(py::init([](stim_pybind::ExposedDemInstruction edi) {
61-
return new common::Error(edi.as_dem_instruction());
60+
.def(py::init([](py::object edi) {
61+
std::vector<double> args;
62+
std::vector<stim::DemTarget> targets;
63+
auto di = parse_py_dem_instruction(edi, args, targets);
64+
return new common::Error(di);
6265
}),
6366
py::arg("error"));
6467

65-
m.def("merge_identical_errors", &common::merge_identical_errors, py::arg("dem"));
66-
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors, py::arg("dem"));
67-
m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"), py::arg("error_counts"),
68-
py::arg("num_shots"));
68+
m.def(
69+
"merge_identical_errors",
70+
[](py::object dem) {
71+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
72+
auto res = common::merge_identical_errors(input_dem);
73+
return make_py_object(res, "DetectorErrorModel");
74+
},
75+
py::arg("dem"));
76+
m.def(
77+
"remove_zero_probability_errors",
78+
[](py::object dem) {
79+
return make_py_object(
80+
common::remove_zero_probability_errors(parse_py_object<stim::DetectorErrorModel>(dem)),
81+
"DetectorErrorModel");
82+
},
83+
py::arg("dem"));
84+
m.def(
85+
"dem_from_counts",
86+
[](py::object orig_dem, const std::vector<size_t> error_counts, size_t num_shots) {
87+
auto dem = parse_py_object<stim::DetectorErrorModel>(orig_dem);
88+
return make_py_object(common::dem_from_counts(dem, error_counts, num_shots),
89+
"DetectorErrorModel");
90+
},
91+
py::arg("orig_dem"), py::arg("error_counts"), py::arg("num_shots"));
6992
}
7093

7194
#endif

src/py/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_test(
77
visibility = ["//:__subpackages__"],
88
deps = [
99
"@pypi//pytest",
10+
"@pypi//stim",
1011
"//src:lib_tesseract_decoder",
1112
],
1213
)
@@ -17,6 +18,7 @@ py_test(
1718
visibility = ["//:__subpackages__"],
1819
deps = [
1920
"@pypi//pytest",
21+
"@pypi//stim",
2022
"//src:lib_tesseract_decoder",
2123
],
2224
)
@@ -27,6 +29,7 @@ py_test(
2729
visibility = ["//:__subpackages__"],
2830
deps = [
2931
"@pypi//pytest",
32+
"@pypi//stim",
3033
"//src:lib_tesseract_decoder",
3134
],
3235
)
@@ -37,6 +40,7 @@ py_test(
3740
visibility = ["//:__subpackages__"],
3841
deps = [
3942
"@pypi//pytest",
43+
"@pypi//stim",
4044
"//src:lib_tesseract_decoder",
4145
],
4246
)

src/py/tesseract_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def test_create_config():
3232
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
3333
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
3434
)
35+
assert (
36+
tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem
37+
== _DETECTOR_ERROR_MODEL
38+
)
3539

3640

3741
def test_create_node():

src/simplex.pybind.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,29 @@
2121

2222
#include "common.h"
2323
#include "simplex.h"
24+
#include "stim_utils.pybind.h"
2425

2526
namespace py = pybind11;
2627

28+
namespace {
29+
SimplexConfig simplex_config_maker(py::object dem, bool parallelize = false,
30+
size_t window_length = 0, size_t window_slide_length = 0,
31+
bool verbose = false) {
32+
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
33+
return SimplexConfig({input_dem, parallelize, window_length, window_slide_length, verbose});
34+
}
35+
36+
}; // namespace
37+
2738
void add_simplex_module(py::module &root) {
2839
auto m =
2940
root.def_submodule("simplex", "Module containing the SimplexDecoder and related methods");
3041

3142
py::class_<SimplexConfig>(m, "SimplexConfig")
32-
.def(py::init<stim::DetectorErrorModel, bool, size_t, size_t, bool>(), py::arg("dem"),
33-
py::arg("parallelize") = false, py::arg("window_length") = 0,
34-
py::arg("window_slide_length") = 0, py::arg("verbose") = false)
35-
.def_readwrite("dem", &SimplexConfig::dem)
43+
.def(py::init(&simplex_config_maker), py::arg("dem"), py::arg("parallelize") = false,
44+
py::arg("window_length") = 0, py::arg("window_slide_length") = 0,
45+
py::arg("verbose") = false)
46+
.def_property("dem", &dem_getter<SimplexConfig>, &dem_setter<SimplexConfig>)
3647
.def_readwrite("parallelize", &SimplexConfig::parallelize)
3748
.def_readwrite("window_length", &SimplexConfig::window_length)
3849
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)

src/stim_utils.pybind.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#ifndef _STIM_UTILS_PYBIND_H
2+
#define _STIM_UTILS_PYBIND_H
3+
4+
#include <pybind11/operators.h>
5+
#include <pybind11/pybind11.h>
6+
#include <pybind11/stl.h>
7+
8+
#include "stim.h"
9+
10+
namespace {
11+
namespace py = pybind11;
12+
}
13+
14+
template <typename T>
15+
py::object make_py_object(const T cpp_obj, const char* py_name) {
16+
auto stim_lib = py::module::import("stim");
17+
return stim_lib.attr(py_name)(cpp_obj.str());
18+
}
19+
20+
template <typename T>
21+
T parse_py_object(py::object py_obj) {
22+
std::string obj_str = py::cast<std::string>(py_obj.attr("__str__")());
23+
return T(obj_str);
24+
}
25+
26+
stim::DemInstructionType parse_dit(std::string dit_str) {
27+
if (dit_str == "error") return stim::DemInstructionType::DEM_ERROR;
28+
if (dit_str == "detector") return stim::DemInstructionType::DEM_DETECTOR;
29+
if (dit_str == "logical_observable") return stim::DemInstructionType::DEM_LOGICAL_OBSERVABLE;
30+
if (dit_str == "shift_detectors") return stim::DemInstructionType::DEM_SHIFT_DETECTORS;
31+
if (dit_str == "repeat") return stim::DemInstructionType::DEM_REPEAT_BLOCK;
32+
throw std::invalid_argument("unknown dem instruction type: " + dit_str);
33+
return stim::DemInstructionType::DEM_DETECTOR;
34+
}
35+
36+
stim::DemTarget parse_py_dem_target(py::object py_obj) {
37+
return stim::DemTarget::from_text(py::cast<std::string>(py_obj.attr("__str__")()));
38+
}
39+
40+
stim::DemInstruction parse_py_dem_instruction(py::object py_obj, std::vector<double>& args,
41+
std::vector<stim::DemTarget>& targets) {
42+
for (auto t : py_obj.attr("args_copy")()) args.push_back(t.cast<double>());
43+
stim::SpanRef args_ref(args);
44+
45+
for (auto t : py_obj.attr("targets_copy")())
46+
targets.push_back(parse_py_dem_target(t.cast<py::object>()));
47+
48+
stim::SpanRef targets_ref(targets);
49+
auto ty = parse_dit(py::cast<std::string>(py_obj.attr("type")));
50+
std::string tag = py::cast<std::string>(py_obj.attr("tag"));
51+
52+
auto di = stim::DemInstruction();
53+
di.arg_data = args_ref;
54+
di.target_data = targets_ref;
55+
di.tag = tag;
56+
di.type = ty;
57+
return di;
58+
}
59+
60+
template <typename T>
61+
py::object dem_getter(const T& config) {
62+
return make_py_object(config.dem, "DetectorErrorModel");
63+
}
64+
template <typename T>
65+
void dem_setter(T& config, py::object dem) {
66+
config.dem = parse_py_object<stim::DetectorErrorModel>(dem);
67+
}
68+
69+
#endif

src/tesseract.pybind.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,35 @@
1919
#include <pybind11/pybind11.h>
2020
#include <pybind11/stl.h>
2121

22+
#include "stim_utils.pybind.h"
2223
#include "tesseract.h"
2324

2425
namespace py = pybind11;
2526

27+
namespace {
28+
TesseractConfig tesseract_config_maker(
29+
py::object dem, int det_beam = INF_DET_BEAM, bool beam_climbing = false,
30+
bool no_revisit_dets = false, bool at_most_two_errors_per_detector = false,
31+
bool verbose = false, size_t pqlimit = std::numeric_limits<size_t>::max(),
32+
std::vector<std::vector<size_t>> det_orders = std::vector<std::vector<size_t>>(),
33+
double det_penalty = 0.0) {
34+
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
35+
return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets,
36+
at_most_two_errors_per_detector, verbose, pqlimit, det_orders,
37+
det_penalty});
38+
}
39+
}; // namespace
2640
void add_tesseract_module(py::module &root) {
2741
auto m = root.def_submodule("tesseract", "Module containing the tesseract algorithm");
2842

2943
m.attr("INF_DET_BEAM") = INF_DET_BEAM;
3044
py::class_<TesseractConfig>(m, "TesseractConfig")
31-
.def(py::init<stim::DetectorErrorModel, int, bool, bool, bool, bool, size_t,
32-
std::vector<std::vector<size_t>>, double>(),
33-
py::arg("dem"), py::arg("det_beam") = INF_DET_BEAM, py::arg("beam_climbing") = false,
34-
py::arg("no_revisit_dets") = false, py::arg("at_most_two_errors_per_detector") = false,
35-
py::arg("verbose") = false, py::arg("pqlimit") = std::numeric_limits<size_t>::max(),
45+
.def(py::init(&tesseract_config_maker), py::arg("dem"), py::arg("det_beam") = INF_DET_BEAM,
46+
py::arg("beam_climbing") = false, py::arg("no_revisit_dets") = false,
47+
py::arg("at_most_two_errors_per_detector") = false, py::arg("verbose") = false,
48+
py::arg("pqlimit") = std::numeric_limits<size_t>::max(),
3649
py::arg("det_orders") = std::vector<std::vector<size_t>>(), py::arg("det_penalty") = 0.0)
37-
.def_readwrite("dem", &TesseractConfig::dem)
50+
.def_property("dem", &dem_getter<TesseractConfig>, &dem_setter<TesseractConfig>)
3851
.def_readwrite("det_beam", &TesseractConfig::det_beam)
3952
.def_readwrite("no_revisit_dets", &TesseractConfig::no_revisit_dets)
4053
.def_readwrite("at_most_two_errors_per_detector",

src/utils.pybind.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,27 @@ void add_utils_module(py::module &root) {
2828

2929
m.attr("EPSILON") = EPSILON;
3030
m.attr("INF") = INF;
31-
m.def("get_detector_coords", &get_detector_coords, py::arg("dem"));
32-
m.def("build_detector_graph", &build_detector_graph, py::arg("dem"));
33-
m.def("get_errors_from_dem", &get_errors_from_dem, py::arg("dem"));
31+
m.def(
32+
"get_detector_coords",
33+
[](py::object dem) {
34+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
35+
return get_detector_coords(input_dem);
36+
},
37+
py::arg("dem"));
38+
m.def(
39+
"build_detector_graph",
40+
[](py::object dem) {
41+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
42+
return build_detector_graph(input_dem);
43+
},
44+
py::arg("dem"));
45+
m.def(
46+
"get_errors_from_dem",
47+
[](py::object dem) {
48+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
49+
return get_errors_from_dem(input_dem);
50+
},
51+
py::arg("dem"));
3452

3553
// Not exposing sampling_from_dem and sample_shots because they depend on
3654
// stim::SparseShot which stim doesn't expose to python.

0 commit comments

Comments
 (0)