Skip to content

Commit 0116cce

Browse files
rwgkcopybara-github
authored andcommitted
Raise "Proto Message has an Unknown Field" in certain situations.
The motivation for this change is to flag accidental loss of protobuf extensions when `use_fast_cpp_protos` is in use, to turn silent but potentially critical failures into noisy failures, for example: ``` Proto Message of type pybind11.test.NestRepeated has an Unknown Field with parent of type pybind11.test.BaseMessage: base_msgs.1003 (pybind11_protobuf/tests/extension_nest_repeated.proto, pybind11_protobuf/tests/extension.proto). Please add the required `cc_proto_library` `deps`. Only if there is no alternative to suppressing this error, use `pybind11_protobuf::AllowUnknownFieldsFor("pybind11.test.NestRepeated", "base_msgs");` (Warning: suppressions may mask critical bugs.) ``` See the updated "Protobuf Extensions" section in README.md for background and details. check_unknown_fields.cc is mostly the work of @laramiel, with an initial implementation only by @rwgk. `MessageMayContainExtensionsMemoized()` is for speed-up, but note that it also limits the scope of the Unknown Field detection (as explained in the updated README.md section). PiperOrigin-RevId: 468553207
1 parent 179a3ec commit 0116cce

13 files changed

+473
-34
lines changed

README.md

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,58 @@ C++ native protobuf object with C++ when passed by `const &` or `const *`.
7070

7171
### Protobuf Extensions
7272

73-
With `use_fast_cpp_protos`, any `py_proto_library` becomes implicitly dependent
74-
on the corresponding `cc_proto_library`, but this is currently not handled
75-
automatically, nor clearly diagnosed or formally enforced. In the absence of
73+
When `use_fast_cpp_protos` is in use, and
7674
protobuf extensions
77-
usually there is no problem: Python code tends to depend organically on
78-
a given `py_proto_library`, and pybind11-wrapped C++ code tends to depend
79-
organically on the corresponding `cc_proto_library`. However, when extensions
80-
are involved, a well-known pitfall is to accidentally omit the corresponding
81-
`cc_proto_library`. Currently this needs to be kept in mind as a pitfall
82-
(sorry), but the usual best-practice unit testing is likely to catch such
83-
situations. Once discovered, the fix is easy: add the `cc_proto_library`
84-
to the `deps` of the relevant `pybind_library` or `pybind_extension`.
75+
are involved, a well-known pitfall is that extensions are silently moved
76+
to the `proto2::UnknownFieldSet` when a message is deserialized in C++,
77+
but the `cc_proto_library` for the extensions is not linked in. The root
78+
cause is an asymmetry in the handling of Python protos vs C++ protos: when
79+
a Python proto is deserialized, both the Python descriptor pool and the C++
80+
descriptor pool are inspected, but when a C++ proto is deserialized, only
81+
the C++ descriptor pool is inspected. Until this asymmetry is resolved, the
82+
`cc_proto_library` for all extensions involved must be added to the `deps` of
83+
the relevant `pybind_library` or `pybind_extension`, but this is sufficiently
84+
unobvious to be a setup for regular accidents, potentially with critical
85+
consequences.
86+
87+
To guard against the most common type of accident, native_proto_caster.h
88+
includes a safety mechanism that raises "Proto Message has an Unknown Field"
89+
in certain situations:
90+
91+
* When `use_fast_cpp_protos` is in use,
92+
* a protobuf message is returned from C++ to Python,
93+
* the message involves protobuf extensions (recursively),
94+
* and the `proto2::UnknownFieldSet` for the message or any of its submessages
95+
is not empty.
96+
97+
`pybind11_protobuf::AllowUnknownFieldsFor` is an escape hatch for situations in
98+
which
99+
100+
* unknown fields existed before the safety mechanism was
101+
introduced.
102+
* unknown fields are needed in the future.
103+
104+
An example of a full error message (with lines breaks here for readability):
105+
106+
```
107+
Proto Message of type pybind11.test.NestRepeated has an Unknown Field with
108+
parent of type pybind11.test.BaseMessage: base_msgs.1003
109+
(pybind11_protobuf/tests/extension_nest_repeated.proto,
110+
pybind11_protobuf/tests/extension.proto).
111+
Please add the required `cc_proto_library` `deps`.
112+
Only if there is no alternative to suppressing this error, use
113+
`pybind11_protobuf::AllowUnknownFieldsFor("pybind11.test.NestRepeated", "base_msgs");`
114+
(Warning: suppressions may mask critical bugs.)
115+
```
116+
117+
The current implementation is a compromise solution, trading off simplicity
118+
of implementation, runtime performance, and precision. Generally, the runtime
119+
overhead is expected to be very small, but fields flagged as unknown may not
120+
necessarily be in extensions.
121+
Alerting developers of new code to unknown fields is assumed to be generally
122+
helpful, but the unknown fields detection is limited to messages with
123+
extensions, to avoid the runtime overhead for the presumably much more common
124+
case that no extensions are involved.
85125

86126
### Enumerations
87127

pybind11_protobuf/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pybind_library(
2525
"//visibility:public",
2626
],
2727
deps = [
28+
":check_unknown_fields",
2829
":enum_type_caster",
2930
":proto_cast_util",
3031
"@com_google_protobuf//:protobuf",
@@ -58,6 +59,7 @@ pybind_library(
5859
"//conditions:default": [],
5960
}),
6061
deps = [
62+
":check_unknown_fields",
6163
"@com_google_absl//absl/container:flat_hash_map",
6264
"@com_google_absl//absl/strings",
6365
"@com_google_absl//absl/types:optional",
@@ -79,3 +81,20 @@ pybind_library(
7981
"@com_google_protobuf//:protobuf",
8082
],
8183
)
84+
85+
cc_library(
86+
name = "check_unknown_fields",
87+
srcs = ["check_unknown_fields.cc"],
88+
hdrs = ["check_unknown_fields.h"],
89+
visibility = [
90+
"//visibility:private",
91+
],
92+
deps = [
93+
"@com_google_absl//absl/container:flat_hash_map",
94+
"@com_google_absl//absl/container:flat_hash_set",
95+
"@com_google_absl//absl/meta:type_traits",
96+
"@com_google_absl//absl/strings",
97+
"@com_google_absl//absl/synchronization",
98+
"@com_google_protobuf//:protobuf",
99+
],
100+
)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
#include "pybind11_protobuf/check_unknown_fields.h"
2+
3+
#include <cassert>
4+
#include <cstdint>
5+
#include <optional>
6+
#include <string>
7+
#include <vector>
8+
9+
#include "google/protobuf/descriptor.h"
10+
#include "google/protobuf/message.h"
11+
#include "google/protobuf/unknown_field_set.h"
12+
#include "absl/container/flat_hash_map.h"
13+
#include "absl/container/flat_hash_set.h"
14+
#include "absl/strings/str_cat.h"
15+
#include "absl/strings/str_join.h"
16+
#include "absl/strings/string_view.h"
17+
#include "absl/synchronization/mutex.h"
18+
19+
namespace pybind11_protobuf::check_unknown_fields {
20+
namespace {
21+
22+
using AllowListSet = absl::flat_hash_set<std::string>;
23+
using MayContainExtensionsMap =
24+
absl::flat_hash_map<const ::google::protobuf::Descriptor*, bool>;
25+
26+
AllowListSet* GetAllowList() {
27+
static auto* allow_list = new AllowListSet();
28+
return allow_list;
29+
}
30+
31+
std::string MakeAllowListKey(
32+
absl::string_view top_message_descriptor_full_name,
33+
absl::string_view unknown_field_parent_message_fqn) {
34+
return absl::StrCat(top_message_descriptor_full_name, ":",
35+
unknown_field_parent_message_fqn);
36+
}
37+
38+
/// Recurses through the message Descriptor class looking for valid extensions.
39+
/// Stores the result to `memoized`.
40+
bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor* descriptor,
41+
MayContainExtensionsMap* memoized) {
42+
if (descriptor->extension_range_count() > 0) return true;
43+
44+
auto [it, inserted] = memoized->try_emplace(descriptor, false);
45+
if (!inserted) {
46+
return it->second;
47+
}
48+
49+
for (int i = 0; i < descriptor->field_count(); i++) {
50+
auto* fd = descriptor->field(i);
51+
if (fd->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) continue;
52+
if (MessageMayContainExtensionsRecursive(fd->message_type(), memoized)) {
53+
(*memoized)[descriptor] = true;
54+
return true;
55+
}
56+
}
57+
58+
return false;
59+
}
60+
61+
bool MessageMayContainExtensionsMemoized(const ::google::protobuf::Descriptor* descriptor) {
62+
static auto* memoized = new MayContainExtensionsMap();
63+
static absl::Mutex lock;
64+
absl::MutexLock l(&lock);
65+
return MessageMayContainExtensionsRecursive(descriptor, memoized);
66+
}
67+
68+
struct HasUnknownFields {
69+
HasUnknownFields(const ::google::protobuf::Descriptor* root_descriptor)
70+
: root_descriptor(root_descriptor) {}
71+
72+
std::string FieldFQN() const { return absl::StrJoin(field_fqn_parts, "."); }
73+
std::string FieldFQNWithFieldNumber() const {
74+
return field_fqn_parts.empty()
75+
? absl::StrCat(unknown_field_number)
76+
: absl::StrCat(FieldFQN(), ".", unknown_field_number);
77+
}
78+
79+
bool FindUnknownFieldsRecursive(const ::google::protobuf::Message* sub_message,
80+
uint32_t depth);
81+
82+
std::string BuildErrorMessage() const;
83+
84+
const ::google::protobuf::Descriptor* root_descriptor = nullptr;
85+
const ::google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr;
86+
std::vector<std::string> field_fqn_parts;
87+
int unknown_field_number;
88+
};
89+
90+
/// Recurses through the message fields class looking for UnknownFields.
91+
bool HasUnknownFields::FindUnknownFieldsRecursive(
92+
const ::google::protobuf::Message* sub_message, uint32_t depth) {
93+
const ::google::protobuf::Reflection& reflection = *sub_message->GetReflection();
94+
95+
// If there are unknown fields, stop searching.
96+
const ::google::protobuf::UnknownFieldSet& unknown_field_set =
97+
reflection.GetUnknownFields(*sub_message);
98+
if (!unknown_field_set.empty()) {
99+
unknown_field_parent_descriptor = sub_message->GetDescriptor();
100+
field_fqn_parts.resize(depth);
101+
unknown_field_number = unknown_field_set.field(0).number();
102+
return true;
103+
}
104+
105+
// If this message does not include submessages which allow extensions,
106+
// then it cannot include unknown fields.
107+
if (!MessageMayContainExtensionsMemoized(sub_message->GetDescriptor())) {
108+
return false;
109+
}
110+
111+
// Otherwise the method has to check all present fields, including
112+
// extensions to determine if they include unknown fields.
113+
std::vector<const ::google::protobuf::FieldDescriptor*> present_fields;
114+
reflection.ListFields(*sub_message, &present_fields);
115+
116+
for (const auto* field : present_fields) {
117+
if (field->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
118+
continue;
119+
}
120+
if (field->is_repeated()) {
121+
int field_size = reflection.FieldSize(*sub_message, field);
122+
for (int i = 0; i != field_size; ++i) {
123+
if (FindUnknownFieldsRecursive(
124+
&reflection.GetRepeatedMessage(*sub_message, field, i),
125+
depth + 1U)) {
126+
field_fqn_parts[depth] = field->name();
127+
return true;
128+
}
129+
}
130+
} else if (FindUnknownFieldsRecursive(
131+
&reflection.GetMessage(*sub_message, field), depth + 1U)) {
132+
field_fqn_parts[depth] = field->name();
133+
return true;
134+
}
135+
}
136+
137+
return false;
138+
}
139+
140+
std::string HasUnknownFields::BuildErrorMessage() const {
141+
assert(unknown_field_parent_descriptor != nullptr);
142+
assert(root_descriptor != nullptr);
143+
144+
std::string emsg = absl::StrCat( //
145+
"Proto Message of type ", root_descriptor->full_name(),
146+
" has an Unknown Field");
147+
if (root_descriptor != unknown_field_parent_descriptor) {
148+
absl::StrAppend(&emsg, " with parent of type ",
149+
unknown_field_parent_descriptor->full_name());
150+
}
151+
absl::StrAppend(&emsg, ": ", FieldFQNWithFieldNumber(), " (",
152+
root_descriptor->file()->name());
153+
if (root_descriptor->file() != unknown_field_parent_descriptor->file()) {
154+
absl::StrAppend(&emsg, ", ",
155+
unknown_field_parent_descriptor->file()->name());
156+
}
157+
absl::StrAppend(
158+
&emsg,
159+
"). Please add the required `cc_proto_library` `deps`. "
160+
"Only if there is no alternative to suppressing this error, use "
161+
"`pybind11_protobuf::AllowUnknownFieldsFor(\"",
162+
root_descriptor->full_name(), "\", \"", FieldFQN(),
163+
"\");` (Warning: suppressions may mask critical bugs.)");
164+
165+
return emsg;
166+
}
167+
168+
} // namespace
169+
170+
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
171+
absl::string_view unknown_field_parent_message_fqn) {
172+
GetAllowList()->insert(MakeAllowListKey(top_message_descriptor_full_name,
173+
unknown_field_parent_message_fqn));
174+
}
175+
176+
std::optional<std::string> CheckAndBuildErrorMessageIfAny(
177+
const ::google::protobuf::Message* message) {
178+
const auto* root_descriptor = message->GetDescriptor();
179+
HasUnknownFields search{root_descriptor};
180+
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
181+
return std::nullopt;
182+
}
183+
if (GetAllowList()->count(MakeAllowListKey(root_descriptor->full_name(),
184+
search.FieldFQN())) != 0) {
185+
return std::nullopt;
186+
}
187+
return search.BuildErrorMessage();
188+
}
189+
190+
} // namespace pybind11_protobuf::check_unknown_fields
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef PYBIND11_PROTOBUF_CHECK_UNKNOWN_FIELDS_H_
2+
#define PYBIND11_PROTOBUF_CHECK_UNKNOWN_FIELDS_H_
3+
4+
#include <optional>
5+
6+
#include "google/protobuf/message.h"
7+
#include "absl/strings/string_view.h"
8+
9+
namespace pybind11_protobuf::check_unknown_fields {
10+
11+
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
12+
absl::string_view unknown_field_parent_message_fqn);
13+
14+
std::optional<std::string> CheckAndBuildErrorMessageIfAny(
15+
const ::google::protobuf::Message* top_message);
16+
17+
} // namespace pybind11_protobuf::check_unknown_fields
18+
19+
#endif // PYBIND11_PROTOBUF_CHECK_UNKNOWN_FIELDS_H_

pybind11_protobuf/native_proto_caster.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "google/protobuf/message.h"
2020
#include "absl/strings/string_view.h"
21+
#include "pybind11_protobuf/check_unknown_fields.h"
2122
#include "pybind11_protobuf/enum_type_caster.h"
2223
#include "pybind11_protobuf/proto_caster_impl.h"
2324

@@ -61,8 +62,8 @@ inline void ImportNativeProtoCasters() { InitializePybindProtoCastUtil(); }
6162
inline void AllowUnknownFieldsFor(
6263
absl::string_view top_message_descriptor_full_name,
6364
absl::string_view unknown_field_parent_message_fqn) {
64-
// Preparation for cl/467099971.
65-
// TODO(240452999): Submit cl/464678844.
65+
check_unknown_fields::AllowUnknownFieldsFor(top_message_descriptor_full_name,
66+
unknown_field_parent_message_fqn);
6667
}
6768

6869
} // namespace pybind11_protobuf

pybind11_protobuf/proto_cast_util.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "absl/strings/str_replace.h"
2323
#include "absl/strings/str_split.h"
2424
#include "absl/types/optional.h"
25+
#include "pybind11_protobuf/check_unknown_fields.h"
2526

2627
namespace py = pybind11;
2728

@@ -801,6 +802,12 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
801802
return GenericPyProtoCast(src, policy, parent, is_const);
802803
}
803804

805+
std::optional<std::string> emsg =
806+
check_unknown_fields::CheckAndBuildErrorMessageIfAny(src);
807+
if (emsg) {
808+
throw py::value_error(*emsg);
809+
}
810+
804811
// If this is a dynamically generated proto, then we're going to need to
805812
// construct a mapping between C++ pool() and python pool(), and then
806813
// use the PyProto_API to make it work.

pybind11_protobuf/tests/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ py_proto_library(
4545
deps = [":extension_proto"],
4646
)
4747

48+
proto_library(
49+
name = "extension_nest_repeated_proto",
50+
srcs = ["extension_nest_repeated.proto"],
51+
deps = [":extension_proto"],
52+
)
53+
54+
cc_proto_library(
55+
name = "extension_nest_repeated_cc_proto",
56+
deps = [":extension_nest_repeated_proto"],
57+
)
58+
59+
py_proto_library(
60+
name = "extension_nest_repeated_py_pb2",
61+
deps = [":extension_nest_repeated_proto"],
62+
)
63+
4864
proto_library(
4965
name = "extension_in_other_file_in_deps_proto",
5066
srcs = ["extension_in_other_file_in_deps.proto"],
@@ -135,6 +151,7 @@ pybind_extension(
135151
":extension_cc_proto",
136152
# Intentionally omitted: ":extension_in_other_file_cc_proto",
137153
":extension_in_other_file_in_deps_cc_proto",
154+
":extension_nest_repeated_cc_proto",
138155
":test_cc_proto",
139156
"@com_google_protobuf//:protobuf",
140157
"//pybind11_protobuf:native_proto_caster",
@@ -153,6 +170,7 @@ py_test(
153170
deps = [
154171
":extension_in_other_file_in_deps_py_pb2",
155172
":extension_in_other_file_py_pb2",
173+
":extension_nest_repeated_py_pb2",
156174
":extension_py_pb2",
157175
":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811
158176
"@com_google_absl_py//absl/testing:absltest",

0 commit comments

Comments
 (0)