|
| 1 | +#include "pybind11_protobuf/check_unknown_fields.h" |
| 2 | + |
| 3 | +#include <cassert> |
| 4 | +#include <cstdint> |
| 5 | +#include <optional> |
| 6 | +#include <string> |
| 7 | +#include <vector> |
| 8 | + |
| 9 | +#include "google/protobuf/descriptor.h" |
| 10 | +#include "google/protobuf/message.h" |
| 11 | +#include "google/protobuf/unknown_field_set.h" |
| 12 | +#include "absl/container/flat_hash_map.h" |
| 13 | +#include "absl/container/flat_hash_set.h" |
| 14 | +#include "absl/strings/str_cat.h" |
| 15 | +#include "absl/strings/str_join.h" |
| 16 | +#include "absl/strings/string_view.h" |
| 17 | +#include "absl/synchronization/mutex.h" |
| 18 | + |
| 19 | +namespace pybind11_protobuf::check_unknown_fields { |
| 20 | +namespace { |
| 21 | + |
| 22 | +using AllowListSet = absl::flat_hash_set<std::string>; |
| 23 | +using MayContainExtensionsMap = |
| 24 | + absl::flat_hash_map<const ::google::protobuf::Descriptor*, bool>; |
| 25 | + |
| 26 | +AllowListSet* GetAllowList() { |
| 27 | + static auto* allow_list = new AllowListSet(); |
| 28 | + return allow_list; |
| 29 | +} |
| 30 | + |
| 31 | +std::string MakeAllowListKey( |
| 32 | + absl::string_view top_message_descriptor_full_name, |
| 33 | + absl::string_view unknown_field_parent_message_fqn) { |
| 34 | + return absl::StrCat(top_message_descriptor_full_name, ":", |
| 35 | + unknown_field_parent_message_fqn); |
| 36 | +} |
| 37 | + |
| 38 | +/// Recurses through the message Descriptor class looking for valid extensions. |
| 39 | +/// Stores the result to `memoized`. |
| 40 | +bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor* descriptor, |
| 41 | + MayContainExtensionsMap* memoized) { |
| 42 | + if (descriptor->extension_range_count() > 0) return true; |
| 43 | + |
| 44 | + auto [it, inserted] = memoized->try_emplace(descriptor, false); |
| 45 | + if (!inserted) { |
| 46 | + return it->second; |
| 47 | + } |
| 48 | + |
| 49 | + for (int i = 0; i < descriptor->field_count(); i++) { |
| 50 | + auto* fd = descriptor->field(i); |
| 51 | + if (fd->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) continue; |
| 52 | + if (MessageMayContainExtensionsRecursive(fd->message_type(), memoized)) { |
| 53 | + (*memoized)[descriptor] = true; |
| 54 | + return true; |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + return false; |
| 59 | +} |
| 60 | + |
| 61 | +bool MessageMayContainExtensionsMemoized(const ::google::protobuf::Descriptor* descriptor) { |
| 62 | + static auto* memoized = new MayContainExtensionsMap(); |
| 63 | + static absl::Mutex lock; |
| 64 | + absl::MutexLock l(&lock); |
| 65 | + return MessageMayContainExtensionsRecursive(descriptor, memoized); |
| 66 | +} |
| 67 | + |
| 68 | +struct HasUnknownFields { |
| 69 | + HasUnknownFields(const ::google::protobuf::Descriptor* root_descriptor) |
| 70 | + : root_descriptor(root_descriptor) {} |
| 71 | + |
| 72 | + std::string FieldFQN() const { return absl::StrJoin(field_fqn_parts, "."); } |
| 73 | + std::string FieldFQNWithFieldNumber() const { |
| 74 | + return field_fqn_parts.empty() |
| 75 | + ? absl::StrCat(unknown_field_number) |
| 76 | + : absl::StrCat(FieldFQN(), ".", unknown_field_number); |
| 77 | + } |
| 78 | + |
| 79 | + bool FindUnknownFieldsRecursive(const ::google::protobuf::Message* sub_message, |
| 80 | + uint32_t depth); |
| 81 | + |
| 82 | + std::string BuildErrorMessage() const; |
| 83 | + |
| 84 | + const ::google::protobuf::Descriptor* root_descriptor = nullptr; |
| 85 | + const ::google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr; |
| 86 | + std::vector<std::string> field_fqn_parts; |
| 87 | + int unknown_field_number; |
| 88 | +}; |
| 89 | + |
| 90 | +/// Recurses through the message fields class looking for UnknownFields. |
| 91 | +bool HasUnknownFields::FindUnknownFieldsRecursive( |
| 92 | + const ::google::protobuf::Message* sub_message, uint32_t depth) { |
| 93 | + const ::google::protobuf::Reflection& reflection = *sub_message->GetReflection(); |
| 94 | + |
| 95 | + // If there are unknown fields, stop searching. |
| 96 | + const ::google::protobuf::UnknownFieldSet& unknown_field_set = |
| 97 | + reflection.GetUnknownFields(*sub_message); |
| 98 | + if (!unknown_field_set.empty()) { |
| 99 | + unknown_field_parent_descriptor = sub_message->GetDescriptor(); |
| 100 | + field_fqn_parts.resize(depth); |
| 101 | + unknown_field_number = unknown_field_set.field(0).number(); |
| 102 | + return true; |
| 103 | + } |
| 104 | + |
| 105 | + // If this message does not include submessages which allow extensions, |
| 106 | + // then it cannot include unknown fields. |
| 107 | + if (!MessageMayContainExtensionsMemoized(sub_message->GetDescriptor())) { |
| 108 | + return false; |
| 109 | + } |
| 110 | + |
| 111 | + // Otherwise the method has to check all present fields, including |
| 112 | + // extensions to determine if they include unknown fields. |
| 113 | + std::vector<const ::google::protobuf::FieldDescriptor*> present_fields; |
| 114 | + reflection.ListFields(*sub_message, &present_fields); |
| 115 | + |
| 116 | + for (const auto* field : present_fields) { |
| 117 | + if (field->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { |
| 118 | + continue; |
| 119 | + } |
| 120 | + if (field->is_repeated()) { |
| 121 | + int field_size = reflection.FieldSize(*sub_message, field); |
| 122 | + for (int i = 0; i != field_size; ++i) { |
| 123 | + if (FindUnknownFieldsRecursive( |
| 124 | + &reflection.GetRepeatedMessage(*sub_message, field, i), |
| 125 | + depth + 1U)) { |
| 126 | + field_fqn_parts[depth] = field->name(); |
| 127 | + return true; |
| 128 | + } |
| 129 | + } |
| 130 | + } else if (FindUnknownFieldsRecursive( |
| 131 | + &reflection.GetMessage(*sub_message, field), depth + 1U)) { |
| 132 | + field_fqn_parts[depth] = field->name(); |
| 133 | + return true; |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + return false; |
| 138 | +} |
| 139 | + |
| 140 | +std::string HasUnknownFields::BuildErrorMessage() const { |
| 141 | + assert(unknown_field_parent_descriptor != nullptr); |
| 142 | + assert(root_descriptor != nullptr); |
| 143 | + |
| 144 | + std::string emsg = absl::StrCat( // |
| 145 | + "Proto Message of type ", root_descriptor->full_name(), |
| 146 | + " has an Unknown Field"); |
| 147 | + if (root_descriptor != unknown_field_parent_descriptor) { |
| 148 | + absl::StrAppend(&emsg, " with parent of type ", |
| 149 | + unknown_field_parent_descriptor->full_name()); |
| 150 | + } |
| 151 | + absl::StrAppend(&emsg, ": ", FieldFQNWithFieldNumber(), " (", |
| 152 | + root_descriptor->file()->name()); |
| 153 | + if (root_descriptor->file() != unknown_field_parent_descriptor->file()) { |
| 154 | + absl::StrAppend(&emsg, ", ", |
| 155 | + unknown_field_parent_descriptor->file()->name()); |
| 156 | + } |
| 157 | + absl::StrAppend( |
| 158 | + &emsg, |
| 159 | + "). Please add the required `cc_proto_library` `deps`. " |
| 160 | + "Only if there is no alternative to suppressing this error, use " |
| 161 | + "`pybind11_protobuf::AllowUnknownFieldsFor(\"", |
| 162 | + root_descriptor->full_name(), "\", \"", FieldFQN(), |
| 163 | + "\");` (Warning: suppressions may mask critical bugs.)"); |
| 164 | + |
| 165 | + return emsg; |
| 166 | +} |
| 167 | + |
| 168 | +} // namespace |
| 169 | + |
| 170 | +void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, |
| 171 | + absl::string_view unknown_field_parent_message_fqn) { |
| 172 | + GetAllowList()->insert(MakeAllowListKey(top_message_descriptor_full_name, |
| 173 | + unknown_field_parent_message_fqn)); |
| 174 | +} |
| 175 | + |
| 176 | +std::optional<std::string> CheckAndBuildErrorMessageIfAny( |
| 177 | + const ::google::protobuf::Message* message) { |
| 178 | + const auto* root_descriptor = message->GetDescriptor(); |
| 179 | + HasUnknownFields search{root_descriptor}; |
| 180 | + if (!search.FindUnknownFieldsRecursive(message, 0u)) { |
| 181 | + return std::nullopt; |
| 182 | + } |
| 183 | + if (GetAllowList()->count(MakeAllowListKey(root_descriptor->full_name(), |
| 184 | + search.FieldFQN())) != 0) { |
| 185 | + return std::nullopt; |
| 186 | + } |
| 187 | + return search.BuildErrorMessage(); |
| 188 | +} |
| 189 | + |
| 190 | +} // namespace pybind11_protobuf::check_unknown_fields |
0 commit comments