Skip to content

Commit ba9b679

Browse files
Xiaofei Wangcopybara-github
authored andcommitted
Use GetMessageClass instead of MessageFactory.GetPrototype.
PiperOrigin-RevId: 521874847
1 parent d824311 commit ba9b679

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

pybind11_protobuf/proto_cast_util.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class GlobalState {
184184
py::object factory_;
185185
py::object find_message_type_by_name_;
186186
py::object get_prototype_;
187+
py::object get_message_class_;
187188

188189
absl::flat_hash_map<std::string, py::module_> import_cache_;
189190
};
@@ -199,9 +200,15 @@ GlobalState::GlobalState() {
199200
auto message_factory =
200201
ImportCached("google.protobuf.message_factory");
201202
global_pool_ = descriptor_pool.attr("Default")();
202-
factory_ = message_factory.attr("MessageFactory")(global_pool_);
203203
find_message_type_by_name_ = global_pool_.attr("FindMessageTypeByName");
204-
get_prototype_ = factory_.attr("GetPrototype");
204+
if (hasattr(message_factory, "GetMessageClass")) {
205+
get_message_class_ = message_factory.attr("GetMessageClass");
206+
} else {
207+
// TODO(pybind11-infra): Cleanup `MessageFactory.GetProtoType` after it
208+
// is deprecated. See b/258832141.
209+
factory_ = message_factory.attr("MessageFactory")(global_pool_);
210+
get_prototype_ = factory_.attr("GetPrototype");
211+
}
205212
} catch (py::error_already_set& e) {
206213
if (IsImportError(e)) {
207214
std::cerr << "Add a python dependency on "
@@ -216,6 +223,7 @@ GlobalState::GlobalState() {
216223
factory_ = {};
217224
find_message_type_by_name_ = {};
218225
get_prototype_ = {};
226+
get_message_class_ = {};
219227
}
220228

221229
// determine the proto implementation.
@@ -296,7 +304,14 @@ py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) {
296304
if (global_pool_) {
297305
try {
298306
auto d = find_message_type_by_name_(descriptor->full_name());
299-
auto p = get_prototype_(d);
307+
py::object p;
308+
if (get_message_class_.check()) {
309+
p = get_message_class_(d);
310+
} else {
311+
// TODO(pybind11-infra): Cleanup `MessageFactory.GetProtoType` after it
312+
// is deprecated. See b/258832141.
313+
p = get_prototype_(d);
314+
}
300315
return p();
301316
} catch (...) {
302317
// TODO(pybind11-infra): narrow down to expected exception(s).

0 commit comments

Comments
 (0)