Skip to content

Commit 10fc4b0

Browse files
pybind11_protobuf authorscopybara-github
authored andcommitted
Raise "Proto Message has an Unknown Field" only when the extension field is known by the Python side.
PiperOrigin-RevId: 541546258
1 parent 0a22829 commit 10fc4b0

File tree

6 files changed

+26
-6
lines changed

6 files changed

+26
-6
lines changed

pybind11_protobuf/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ cc_library(
9393
"@com_google_absl//absl/meta:type_traits",
9494
"@com_google_absl//absl/strings",
9595
"@com_google_absl//absl/synchronization",
96+
"@com_google_protobuf//:proto_api",
9697
"@com_google_protobuf//:protobuf",
9798
],
9899
)

pybind11_protobuf/check_unknown_fields.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ bool MessageMayContainExtensionsMemoized(const ::google::protobuf::Descriptor* d
6666
}
6767

6868
struct HasUnknownFields {
69-
HasUnknownFields(const ::google::protobuf::Descriptor* root_descriptor)
70-
: root_descriptor(root_descriptor) {}
69+
HasUnknownFields(const ::google::protobuf::python::PyProto_API* py_proto_api,
70+
const ::google::protobuf::Descriptor* root_descriptor)
71+
: py_proto_api(py_proto_api), root_descriptor(root_descriptor) {}
7172

7273
std::string FieldFQN() const { return absl::StrJoin(field_fqn_parts, "."); }
7374
std::string FieldFQNWithFieldNumber() const {
@@ -81,6 +82,7 @@ struct HasUnknownFields {
8182

8283
std::string BuildErrorMessage() const;
8384

85+
const ::google::protobuf::python::PyProto_API* py_proto_api;
8486
const ::google::protobuf::Descriptor* root_descriptor = nullptr;
8587
const ::google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr;
8688
std::vector<std::string> field_fqn_parts;
@@ -97,9 +99,15 @@ bool HasUnknownFields::FindUnknownFieldsRecursive(
9799
reflection.GetUnknownFields(*sub_message);
98100
if (!unknown_field_set.empty()) {
99101
unknown_field_parent_descriptor = sub_message->GetDescriptor();
100-
field_fqn_parts.resize(depth);
101102
unknown_field_number = unknown_field_set.field(0).number();
102-
return true;
103+
104+
// Stop only if the extension is known by Python.
105+
if (py_proto_api->GetDefaultDescriptorPool()->FindExtensionByNumber(
106+
unknown_field_parent_descriptor,
107+
unknown_field_number)) {
108+
field_fqn_parts.resize(depth);
109+
return true;
110+
}
103111
}
104112

105113
// If this message does not include submessages which allow extensions,
@@ -174,9 +182,10 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
174182
}
175183

176184
std::optional<std::string> CheckAndBuildErrorMessageIfAny(
185+
const ::google::protobuf::python::PyProto_API* py_proto_api,
177186
const ::google::protobuf::Message* message) {
178187
const auto* root_descriptor = message->GetDescriptor();
179-
HasUnknownFields search{root_descriptor};
188+
HasUnknownFields search{py_proto_api, root_descriptor};
180189
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
181190
return std::nullopt;
182191
}

pybind11_protobuf/check_unknown_fields.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <optional>
55

66
#include "google/protobuf/message.h"
7+
#include "python/google/protobuf/proto_api.h"
78
#include "absl/strings/string_view.h"
89

910
namespace pybind11_protobuf::check_unknown_fields {
@@ -12,6 +13,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
1213
absl::string_view unknown_field_parent_message_fqn);
1314

1415
std::optional<std::string> CheckAndBuildErrorMessageIfAny(
16+
const ::google::protobuf::python::PyProto_API* py_proto_api,
1517
const ::google::protobuf::Message* top_message);
1618

1719
} // namespace pybind11_protobuf::check_unknown_fields

pybind11_protobuf/native_proto_caster.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ inline void AllowUnknownFieldsFor(
6767
}
6868

6969
} // namespace pybind11_protobuf
70+
7071
namespace pybind11 {
7172
namespace detail {
7273

pybind11_protobuf/proto_cast_util.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,8 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
836836
}
837837

838838
std::optional<std::string> emsg =
839-
check_unknown_fields::CheckAndBuildErrorMessageIfAny(src);
839+
check_unknown_fields::CheckAndBuildErrorMessageIfAny(
840+
GlobalState::instance()->py_proto_api(), src);
840841
if (emsg) {
841842
throw py::value_error(*emsg);
842843
}

pybind11_protobuf/tests/extension_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ def test_reserialize_allow_unknown_outer(self):
197197
extension_in_other_file_pb2.AllowUnknownInnerExtension.hook].value
198198
self.assertEqual(97, b_inner_value)
199199

200+
def test_reserialize_allow_python_unknown_fields(self):
201+
inner = get_allow_unknown_inner(63)
202+
# Creates a message with only unknown fields.
203+
a = extension_pb2.BaseMessage.FromString(inner.SerializeToString())
204+
b = m.reserialize_base_message(a)
205+
self.assertEqual(a.SerializeToString(), b.SerializeToString())
200206

201207
if __name__ == '__main__':
202208
absltest.main()

0 commit comments

Comments
 (0)