• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: kenton@google.com (Kenton Varda)
9 //  Based on original Protocol Buffers design by
10 //  Sanjay Ghemawat, Jeff Dean, and others.
11 
12 #include "google/protobuf/descriptor_database.h"
13 
14 #include <algorithm>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "absl/container/btree_set.h"
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/strings/string_view.h"
24 #include "google/protobuf/descriptor.pb.h"
25 
26 
27 namespace google {
28 namespace protobuf {
29 
30 namespace {
RecordMessageNames(const DescriptorProto & desc_proto,absl::string_view prefix,absl::btree_set<std::string> * output)31 void RecordMessageNames(const DescriptorProto& desc_proto,
32                         absl::string_view prefix,
33                         absl::btree_set<std::string>* output) {
34   ABSL_CHECK(desc_proto.has_name());
35   std::string full_name = prefix.empty()
36                               ? desc_proto.name()
37                               : absl::StrCat(prefix, ".", desc_proto.name());
38   output->insert(full_name);
39 
40   for (const auto& d : desc_proto.nested_type()) {
41     RecordMessageNames(d, full_name, output);
42   }
43 }
44 
RecordMessageNames(const FileDescriptorProto & file_proto,absl::btree_set<std::string> * output)45 void RecordMessageNames(const FileDescriptorProto& file_proto,
46                         absl::btree_set<std::string>* output) {
47   for (const auto& d : file_proto.message_type()) {
48     RecordMessageNames(d, file_proto.package(), output);
49   }
50 }
51 
52 template <typename Fn>
ForAllFileProtos(DescriptorDatabase * db,Fn callback,std::vector<std::string> * output)53 bool ForAllFileProtos(DescriptorDatabase* db, Fn callback,
54                       std::vector<std::string>* output) {
55   std::vector<std::string> file_names;
56   if (!db->FindAllFileNames(&file_names)) {
57     return false;
58   }
59   absl::btree_set<std::string> set;
60   FileDescriptorProto file_proto;
61   for (const auto& f : file_names) {
62     file_proto.Clear();
63     if (!db->FindFileByName(f, &file_proto)) {
64       ABSL_LOG(ERROR) << "File not found in database (unexpected): " << f;
65       return false;
66     }
67     callback(file_proto, &set);
68   }
69   output->insert(output->end(), set.begin(), set.end());
70   return true;
71 }
72 }  // namespace
73 
74 DescriptorDatabase::~DescriptorDatabase() = default;
75 
FindAllPackageNames(std::vector<std::string> * output)76 bool DescriptorDatabase::FindAllPackageNames(std::vector<std::string>* output) {
77   return ForAllFileProtos(
78       this,
79       [](const FileDescriptorProto& file_proto,
80          absl::btree_set<std::string>* set) {
81         set->insert(file_proto.package());
82       },
83       output);
84 }
85 
FindAllMessageNames(std::vector<std::string> * output)86 bool DescriptorDatabase::FindAllMessageNames(std::vector<std::string>* output) {
87   return ForAllFileProtos(
88       this,
89       [](const FileDescriptorProto& file_proto,
90          absl::btree_set<std::string>* set) {
91         RecordMessageNames(file_proto, set);
92       },
93       output);
94 }
95 
96 // ===================================================================
97 
SimpleDescriptorDatabase()98 SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
~SimpleDescriptorDatabase()99 SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {}
100 
101 template <typename Value>
AddFile(const FileDescriptorProto & file,Value value)102 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
103     const FileDescriptorProto& file, Value value) {
104   if (!by_name_.emplace(file.name(), value).second) {
105     ABSL_LOG(ERROR) << "File already exists in database: " << file.name();
106     return false;
107   }
108 
109   // We must be careful here -- calling file.package() if file.has_package() is
110   // false could access an uninitialized static-storage variable if we are being
111   // run at startup time.
112   std::string path = file.has_package() ? file.package() : std::string();
113   if (!path.empty()) path += '.';
114 
115   for (int i = 0; i < file.message_type_size(); i++) {
116     if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
117     if (!AddNestedExtensions(file.name(), file.message_type(i), value))
118       return false;
119   }
120   for (int i = 0; i < file.enum_type_size(); i++) {
121     if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
122   }
123   for (int i = 0; i < file.extension_size(); i++) {
124     if (!AddSymbol(path + file.extension(i).name(), value)) return false;
125     if (!AddExtension(file.name(), file.extension(i), value)) return false;
126   }
127   for (int i = 0; i < file.service_size(); i++) {
128     if (!AddSymbol(path + file.service(i).name(), value)) return false;
129   }
130 
131   return true;
132 }
133 
134 namespace {
135 
136 // Returns true if and only if all characters in the name are alphanumerics,
137 // underscores, or periods.
ValidateSymbolName(absl::string_view name)138 bool ValidateSymbolName(absl::string_view name) {
139   for (char c : name) {
140     // I don't trust ctype.h due to locales.  :(
141     if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') &&
142         (c < 'a' || c > 'z')) {
143       return false;
144     }
145   }
146   return true;
147 }
148 
149 // Find the last key in the container which sorts less than or equal to the
150 // symbol name.  Since upper_bound() returns the *first* key that sorts
151 // *greater* than the input, we want the element immediately before that.
152 template <typename Container, typename Key>
FindLastLessOrEqual(const Container * container,const Key & key)153 typename Container::const_iterator FindLastLessOrEqual(
154     const Container* container, const Key& key) {
155   auto iter = container->upper_bound(key);
156   if (iter != container->begin()) --iter;
157   return iter;
158 }
159 
160 // As above, but using std::upper_bound instead.
161 template <typename Container, typename Key, typename Cmp>
FindLastLessOrEqual(const Container * container,const Key & key,const Cmp & cmp)162 typename Container::const_iterator FindLastLessOrEqual(
163     const Container* container, const Key& key, const Cmp& cmp) {
164   auto iter = std::upper_bound(container->begin(), container->end(), key, cmp);
165   if (iter != container->begin()) --iter;
166   return iter;
167 }
168 
169 // True if either the arguments are equal or super_symbol identifies a
170 // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
171 // "foo.bar.baz", but not a parent of "foo.barbaz").
IsSubSymbol(absl::string_view sub_symbol,absl::string_view super_symbol)172 bool IsSubSymbol(absl::string_view sub_symbol, absl::string_view super_symbol) {
173   return sub_symbol == super_symbol ||
174          (absl::StartsWith(super_symbol, sub_symbol) &&
175           super_symbol[sub_symbol.size()] == '.');
176 }
177 
178 }  // namespace
179 
180 template <typename Value>
AddSymbol(absl::string_view name,Value value)181 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
182     absl::string_view name, Value value) {
183   // We need to make sure not to violate our map invariant.
184 
185   // If the symbol name is invalid it could break our lookup algorithm (which
186   // relies on the fact that '.' sorts before all other characters that are
187   // valid in symbol names).
188   if (!ValidateSymbolName(name)) {
189     ABSL_LOG(ERROR) << "Invalid symbol name: " << name;
190     return false;
191   }
192 
193   // Try to look up the symbol to make sure a super-symbol doesn't already
194   // exist.
195   auto iter = FindLastLessOrEqual(&by_symbol_, name);
196 
197   if (iter == by_symbol_.end()) {
198     // Apparently the map is currently empty.  Just insert and be done with it.
199     by_symbol_.try_emplace(name, value);
200     return true;
201   }
202 
203   if (IsSubSymbol(iter->first, name)) {
204     ABSL_LOG(ERROR) << "Symbol name \"" << name
205                     << "\" conflicts with the existing "
206                        "symbol \""
207                     << iter->first << "\".";
208     return false;
209   }
210 
211   // OK, that worked.  Now we have to make sure that no symbol in the map is
212   // a sub-symbol of the one we are inserting.  The only symbol which could
213   // be so is the first symbol that is greater than the new symbol.  Since
214   // |iter| points at the last symbol that is less than or equal, we just have
215   // to increment it.
216   ++iter;
217 
218   if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
219     ABSL_LOG(ERROR) << "Symbol name \"" << name
220                     << "\" conflicts with the existing "
221                        "symbol \""
222                     << iter->first << "\".";
223     return false;
224   }
225 
226   // OK, no conflicts.
227 
228   // Insert the new symbol using the iterator as a hint, the new entry will
229   // appear immediately before the one the iterator is pointing at.
230   by_symbol_.insert(iter, {std::string(name), value});
231 
232   return true;
233 }
234 
235 template <typename Value>
AddNestedExtensions(const std::string & filename,const DescriptorProto & message_type,Value value)236 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
237     const std::string& filename, const DescriptorProto& message_type,
238     Value value) {
239   for (int i = 0; i < message_type.nested_type_size(); i++) {
240     if (!AddNestedExtensions(filename, message_type.nested_type(i), value))
241       return false;
242   }
243   for (int i = 0; i < message_type.extension_size(); i++) {
244     if (!AddExtension(filename, message_type.extension(i), value)) return false;
245   }
246   return true;
247 }
248 
249 template <typename Value>
AddExtension(const std::string & filename,const FieldDescriptorProto & field,Value value)250 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
251     const std::string& filename, const FieldDescriptorProto& field,
252     Value value) {
253   if (!field.extendee().empty() && field.extendee()[0] == '.') {
254     // The extension is fully-qualified.  We can use it as a lookup key in
255     // the by_symbol_ table.
256     if (!by_extension_
257              .emplace(
258                  std::make_pair(field.extendee().substr(1), field.number()),
259                  value)
260              .second) {
261       ABSL_LOG(ERROR)
262           << "Extension conflicts with extension already in database: "
263              "extend "
264           << field.extendee() << " { " << field.name() << " = "
265           << field.number() << " } from:" << filename;
266       return false;
267     }
268   } else {
269     // Not fully-qualified.  We can't really do anything here, unfortunately.
270     // We don't consider this an error, though, because the descriptor is
271     // valid.
272   }
273   return true;
274 }
275 
276 template <typename Value>
FindFile(const std::string & filename)277 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
278     const std::string& filename) {
279   auto it = by_name_.find(filename);
280   if (it == by_name_.end()) return {};
281   return it->second;
282 }
283 
284 template <typename Value>
FindSymbol(const std::string & name)285 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
286     const std::string& name) {
287   auto iter = FindLastLessOrEqual(&by_symbol_, name);
288 
289   return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name))
290              ? iter->second
291              : Value();
292 }
293 
294 template <typename Value>
FindExtension(const std::string & containing_type,int field_number)295 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
296     const std::string& containing_type, int field_number) {
297   auto it = by_extension_.find({containing_type, field_number});
298   if (it == by_extension_.end()) return {};
299   return it->second;
300 }
301 
302 template <typename Value>
FindAllExtensionNumbers(const std::string & containing_type,std::vector<int> * output)303 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
304     const std::string& containing_type, std::vector<int>* output) {
305   auto it = by_extension_.lower_bound(std::make_pair(containing_type, 0));
306   bool success = false;
307 
308   for (; it != by_extension_.end() && it->first.first == containing_type;
309        ++it) {
310     output->push_back(it->first.second);
311     success = true;
312   }
313 
314   return success;
315 }
316 
317 template <typename Value>
FindAllFileNames(std::vector<std::string> * output)318 void SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllFileNames(
319     std::vector<std::string>* output) {
320   output->resize(by_name_.size());
321   int i = 0;
322   for (const auto& kv : by_name_) {
323     (*output)[i] = kv.first;
324     i++;
325   }
326 }
327 
328 // -------------------------------------------------------------------
329 
Add(const FileDescriptorProto & file)330 bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
331   FileDescriptorProto* new_file = new FileDescriptorProto;
332   new_file->CopyFrom(file);
333   return AddAndOwn(new_file);
334 }
335 
AddAndOwn(const FileDescriptorProto * file)336 bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
337   files_to_delete_.emplace_back(file);
338   return index_.AddFile(*file, file);
339 }
340 
AddUnowned(const FileDescriptorProto * file)341 bool SimpleDescriptorDatabase::AddUnowned(const FileDescriptorProto* file) {
342   return index_.AddFile(*file, file);
343 }
344 
FindFileByName(const std::string & filename,FileDescriptorProto * output)345 bool SimpleDescriptorDatabase::FindFileByName(const std::string& filename,
346                                               FileDescriptorProto* output) {
347   return MaybeCopy(index_.FindFile(filename), output);
348 }
349 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)350 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
351     const std::string& symbol_name, FileDescriptorProto* output) {
352   return MaybeCopy(index_.FindSymbol(symbol_name), output);
353 }
354 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)355 bool SimpleDescriptorDatabase::FindFileContainingExtension(
356     const std::string& containing_type, int field_number,
357     FileDescriptorProto* output) {
358   return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
359 }
360 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)361 bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
362     const std::string& extendee_type, std::vector<int>* output) {
363   return index_.FindAllExtensionNumbers(extendee_type, output);
364 }
365 
366 
FindAllFileNames(std::vector<std::string> * output)367 bool SimpleDescriptorDatabase::FindAllFileNames(
368     std::vector<std::string>* output) {
369   index_.FindAllFileNames(output);
370   return true;
371 }
372 
MaybeCopy(const FileDescriptorProto * file,FileDescriptorProto * output)373 bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
374                                          FileDescriptorProto* output) {
375   if (file == nullptr) return false;
376   output->CopyFrom(*file);
377   return true;
378 }
379 
380 // -------------------------------------------------------------------
381 
382 class EncodedDescriptorDatabase::DescriptorIndex {
383  public:
384   using Value = std::pair<const void*, int>;
385   // Helpers to recursively add particular descriptors and all their contents
386   // to the index.
387   template <typename FileProto>
388   bool AddFile(const FileProto& file, Value value);
389 
390   Value FindFile(absl::string_view filename);
391   Value FindSymbol(absl::string_view name);
392   Value FindSymbolOnlyFlat(absl::string_view name) const;
393   Value FindExtension(absl::string_view containing_type, int field_number);
394   bool FindAllExtensionNumbers(absl::string_view containing_type,
395                                std::vector<int>* output);
396   void FindAllFileNames(std::vector<std::string>* output) const;
397 
398  private:
399   friend class EncodedDescriptorDatabase;
400 
401   bool AddSymbol(absl::string_view symbol);
402 
403   template <typename DescProto>
404   bool AddNestedExtensions(absl::string_view filename,
405                            const DescProto& message_type);
406   template <typename FieldProto>
407   bool AddExtension(absl::string_view filename, const FieldProto& field);
408 
409   // All the maps below have two representations:
410   //  - a absl::btree_set<> where we insert initially.
411   //  - a std::vector<> where we flatten the structure on demand.
412   // The initial tree helps avoid O(N) behavior of inserting into a sorted
413   // vector, while the vector reduces the heap requirements of the data
414   // structure.
415 
416   void EnsureFlat();
417 
418   using String = std::string;
419 
EncodeString(absl::string_view str) const420   String EncodeString(absl::string_view str) const { return String(str); }
DecodeString(const String & str,int) const421   absl::string_view DecodeString(const String& str, int) const { return str; }
422 
423   struct EncodedEntry {
424     // Do not use `Value` here to avoid the padding of that object.
425     const void* data;
426     int size;
427     // Keep the package here instead of each SymbolEntry to save space.
428     String encoded_package;
429 
valuegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::EncodedEntry430     Value value() const { return {data, size}; }
431   };
432   std::vector<EncodedEntry> all_values_;
433 
434   struct FileEntry {
435     int data_offset;
436     String encoded_name;
437 
namegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileEntry438     absl::string_view name(const DescriptorIndex& index) const {
439       return index.DecodeString(encoded_name, data_offset);
440     }
441   };
442   struct FileCompare {
443     const DescriptorIndex& index;
444 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare445     bool operator()(const FileEntry& a, const FileEntry& b) const {
446       return a.name(index) < b.name(index);
447     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare448     bool operator()(const FileEntry& a, absl::string_view b) const {
449       return a.name(index) < b;
450     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare451     bool operator()(absl::string_view a, const FileEntry& b) const {
452       return a < b.name(index);
453     }
454   };
455   absl::btree_set<FileEntry, FileCompare> by_name_{FileCompare{*this}};
456   std::vector<FileEntry> by_name_flat_;
457 
458   struct SymbolEntry {
459     int data_offset;
460     String encoded_symbol;
461 
packagegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry462     absl::string_view package(const DescriptorIndex& index) const {
463       return index.DecodeString(index.all_values_[data_offset].encoded_package,
464                                 data_offset);
465     }
symbolgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry466     absl::string_view symbol(const DescriptorIndex& index) const {
467       return index.DecodeString(encoded_symbol, data_offset);
468     }
469 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry470     std::string AsString(const DescriptorIndex& index) const {
471       auto p = package(index);
472       return absl::StrCat(p, p.empty() ? "" : ".", symbol(index));
473     }
474   };
475 
476   struct SymbolCompare {
477     const DescriptorIndex& index;
478 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare479     std::string AsString(const SymbolEntry& entry) const {
480       return entry.AsString(index);
481     }
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare482     static absl::string_view AsString(absl::string_view str) { return str; }
483 
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare484     std::pair<absl::string_view, absl::string_view> GetParts(
485         const SymbolEntry& entry) const {
486       auto package = entry.package(index);
487       if (package.empty()) return {entry.symbol(index), absl::string_view{}};
488       return {package, entry.symbol(index)};
489     }
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare490     std::pair<absl::string_view, absl::string_view> GetParts(
491         absl::string_view str) const {
492       return {str, {}};
493     }
494 
495     template <typename T, typename U>
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare496     bool operator()(const T& lhs, const U& rhs) const {
497       auto lhs_parts = GetParts(lhs);
498       auto rhs_parts = GetParts(rhs);
499 
500       // Fast path to avoid making the whole string for common cases.
501       if (int res =
502               lhs_parts.first.substr(0, rhs_parts.first.size())
503                   .compare(rhs_parts.first.substr(0, lhs_parts.first.size()))) {
504         // If the packages already differ, exit early.
505         return res < 0;
506       } else if (lhs_parts.first.size() == rhs_parts.first.size()) {
507         return lhs_parts.second < rhs_parts.second;
508       }
509       return AsString(lhs) < AsString(rhs);
510     }
511   };
512   absl::btree_set<SymbolEntry, SymbolCompare> by_symbol_{SymbolCompare{*this}};
513   std::vector<SymbolEntry> by_symbol_flat_;
514 
515   struct ExtensionEntry {
516     int data_offset;
517     String encoded_extendee;
extendeegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionEntry518     absl::string_view extendee(const DescriptorIndex& index) const {
519       return index.DecodeString(encoded_extendee, data_offset).substr(1);
520     }
521     int extension_number;
522   };
523   struct ExtensionCompare {
524     const DescriptorIndex& index;
525 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare526     bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const {
527       return std::make_tuple(a.extendee(index), a.extension_number) <
528              std::make_tuple(b.extendee(index), b.extension_number);
529     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare530     bool operator()(const ExtensionEntry& a,
531                     std::tuple<absl::string_view, int> b) const {
532       return std::make_tuple(a.extendee(index), a.extension_number) < b;
533     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare534     bool operator()(std::tuple<absl::string_view, int> a,
535                     const ExtensionEntry& b) const {
536       return a < std::make_tuple(b.extendee(index), b.extension_number);
537     }
538   };
539   absl::btree_set<ExtensionEntry, ExtensionCompare> by_extension_{
540       ExtensionCompare{*this}};
541   std::vector<ExtensionEntry> by_extension_flat_;
542 };
543 
Add(const void * encoded_file_descriptor,int size)544 bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor,
545                                     int size) {
546   FileDescriptorProto file;
547   if (file.ParseFromArray(encoded_file_descriptor, size)) {
548     return index_->AddFile(file, std::make_pair(encoded_file_descriptor, size));
549   } else {
550     ABSL_LOG(ERROR) << "Invalid file descriptor data passed to "
551                        "EncodedDescriptorDatabase::Add().";
552     return false;
553   }
554 }
555 
AddCopy(const void * encoded_file_descriptor,int size)556 bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor,
557                                         int size) {
558   void* copy = operator new(size);
559   memcpy(copy, encoded_file_descriptor, size);
560   files_to_delete_.push_back(copy);
561   return Add(copy, size);
562 }
563 
FindFileByName(const std::string & filename,FileDescriptorProto * output)564 bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename,
565                                                FileDescriptorProto* output) {
566   return MaybeParse(index_->FindFile(filename), output);
567 }
568 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)569 bool EncodedDescriptorDatabase::FindFileContainingSymbol(
570     const std::string& symbol_name, FileDescriptorProto* output) {
571   return MaybeParse(index_->FindSymbol(symbol_name), output);
572 }
573 
FindNameOfFileContainingSymbol(const std::string & symbol_name,std::string * output)574 bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
575     const std::string& symbol_name, std::string* output) {
576   auto encoded_file = index_->FindSymbol(symbol_name);
577   if (encoded_file.first == nullptr) return false;
578 
579   // Optimization:  The name should be the first field in the encoded message.
580   //   Try to just read it directly.
581   io::CodedInputStream input(static_cast<const uint8_t*>(encoded_file.first),
582                              encoded_file.second);
583 
584   const uint32_t kNameTag = internal::WireFormatLite::MakeTag(
585       FileDescriptorProto::kNameFieldNumber,
586       internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
587 
588   if (input.ReadTagNoLastTag() == kNameTag) {
589     // Success!
590     return internal::WireFormatLite::ReadString(&input, output);
591   } else {
592     // Slow path.  Parse whole message.
593     FileDescriptorProto file_proto;
594     if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
595       return false;
596     }
597     *output = file_proto.name();
598     return true;
599   }
600 }
601 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)602 bool EncodedDescriptorDatabase::FindFileContainingExtension(
603     const std::string& containing_type, int field_number,
604     FileDescriptorProto* output) {
605   return MaybeParse(index_->FindExtension(containing_type, field_number),
606                     output);
607 }
608 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)609 bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
610     const std::string& extendee_type, std::vector<int>* output) {
611   return index_->FindAllExtensionNumbers(extendee_type, output);
612 }
613 
614 template <typename FileProto>
AddFile(const FileProto & file,Value value)615 bool EncodedDescriptorDatabase::DescriptorIndex::AddFile(const FileProto& file,
616                                                          Value value) {
617   // We push `value` into the array first. This is important because the AddXXX
618   // functions below will expect it to be there.
619   all_values_.push_back({value.first, value.second, {}});
620 
621   if (!ValidateSymbolName(file.package())) {
622     ABSL_LOG(ERROR) << "Invalid package name: " << file.package();
623     return false;
624   }
625   all_values_.back().encoded_package = EncodeString(file.package());
626 
627   if (!by_name_
628            .insert({static_cast<int>(all_values_.size() - 1),
629                     EncodeString(file.name())})
630            .second ||
631       std::binary_search(by_name_flat_.begin(), by_name_flat_.end(),
632                          file.name(), by_name_.key_comp())) {
633     ABSL_LOG(ERROR) << "File already exists in database: " << file.name();
634     return false;
635   }
636 
637   for (const auto& message_type : file.message_type()) {
638     if (!AddSymbol(message_type.name())) return false;
639     if (!AddNestedExtensions(file.name(), message_type)) return false;
640   }
641   for (const auto& enum_type : file.enum_type()) {
642     if (!AddSymbol(enum_type.name())) return false;
643   }
644   for (const auto& extension : file.extension()) {
645     if (!AddSymbol(extension.name())) return false;
646     if (!AddExtension(file.name(), extension)) return false;
647   }
648   for (const auto& service : file.service()) {
649     if (!AddSymbol(service.name())) return false;
650   }
651 
652   return true;
653 }
654 
655 template <typename Iter, typename Iter2, typename Index>
CheckForMutualSubsymbols(absl::string_view symbol_name,Iter * iter,Iter2 end,const Index & index)656 static bool CheckForMutualSubsymbols(absl::string_view symbol_name, Iter* iter,
657                                      Iter2 end, const Index& index) {
658   if (*iter != end) {
659     if (IsSubSymbol((*iter)->AsString(index), symbol_name)) {
660       ABSL_LOG(ERROR) << "Symbol name \"" << symbol_name
661                       << "\" conflicts with the existing symbol \""
662                       << (*iter)->AsString(index) << "\".";
663       return false;
664     }
665 
666     // OK, that worked.  Now we have to make sure that no symbol in the map is
667     // a sub-symbol of the one we are inserting.  The only symbol which could
668     // be so is the first symbol that is greater than the new symbol.  Since
669     // |iter| points at the last symbol that is less than or equal, we just have
670     // to increment it.
671     ++*iter;
672 
673     if (*iter != end && IsSubSymbol(symbol_name, (*iter)->AsString(index))) {
674       ABSL_LOG(ERROR) << "Symbol name \"" << symbol_name
675                       << "\" conflicts with the existing symbol \""
676                       << (*iter)->AsString(index) << "\".";
677       return false;
678     }
679   }
680   return true;
681 }
682 
AddSymbol(absl::string_view symbol)683 bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol(
684     absl::string_view symbol) {
685   SymbolEntry entry = {static_cast<int>(all_values_.size() - 1),
686                        EncodeString(symbol)};
687   std::string entry_as_string = entry.AsString(*this);
688 
689   // We need to make sure not to violate our map invariant.
690 
691   // If the symbol name is invalid it could break our lookup algorithm (which
692   // relies on the fact that '.' sorts before all other characters that are
693   // valid in symbol names).
694   if (!ValidateSymbolName(symbol)) {
695     ABSL_LOG(ERROR) << "Invalid symbol name: " << entry_as_string;
696     return false;
697   }
698 
699   auto iter = FindLastLessOrEqual(&by_symbol_, entry);
700   if (!CheckForMutualSubsymbols(entry_as_string, &iter, by_symbol_.end(),
701                                 *this)) {
702     return false;
703   }
704 
705   // Same, but on by_symbol_flat_
706   auto flat_iter =
707       FindLastLessOrEqual(&by_symbol_flat_, entry, by_symbol_.key_comp());
708   if (!CheckForMutualSubsymbols(entry_as_string, &flat_iter,
709                                 by_symbol_flat_.end(), *this)) {
710     return false;
711   }
712 
713   // OK, no conflicts.
714 
715   // Insert the new symbol using the iterator as a hint, the new entry will
716   // appear immediately before the one the iterator is pointing at.
717   by_symbol_.insert(iter, entry);
718 
719   return true;
720 }
721 
722 template <typename DescProto>
AddNestedExtensions(absl::string_view filename,const DescProto & message_type)723 bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions(
724     absl::string_view filename, const DescProto& message_type) {
725   for (const auto& nested_type : message_type.nested_type()) {
726     if (!AddNestedExtensions(filename, nested_type)) return false;
727   }
728   for (const auto& extension : message_type.extension()) {
729     if (!AddExtension(filename, extension)) return false;
730   }
731   return true;
732 }
733 
734 template <typename FieldProto>
AddExtension(absl::string_view filename,const FieldProto & field)735 bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension(
736     absl::string_view filename, const FieldProto& field) {
737   if (!field.extendee().empty() && field.extendee()[0] == '.') {
738     // The extension is fully-qualified.  We can use it as a lookup key in
739     // the by_symbol_ table.
740     if (!by_extension_
741              .insert({static_cast<int>(all_values_.size() - 1),
742                       EncodeString(field.extendee()), field.number()})
743              .second ||
744         std::binary_search(
745             by_extension_flat_.begin(), by_extension_flat_.end(),
746             std::make_pair(field.extendee().substr(1), field.number()),
747             by_extension_.key_comp())) {
748       ABSL_LOG(ERROR)
749           << "Extension conflicts with extension already in database: "
750              "extend "
751           << field.extendee() << " { " << field.name() << " = "
752           << field.number() << " } from:" << filename;
753       return false;
754     }
755   } else {
756     // Not fully-qualified.  We can't really do anything here, unfortunately.
757     // We don't consider this an error, though, because the descriptor is
758     // valid.
759   }
760   return true;
761 }
762 
763 std::pair<const void*, int>
FindSymbol(absl::string_view name)764 EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(absl::string_view name) {
765   EnsureFlat();
766   return FindSymbolOnlyFlat(name);
767 }
768 
769 std::pair<const void*, int>
FindSymbolOnlyFlat(absl::string_view name) const770 EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat(
771     absl::string_view name) const {
772   auto iter =
773       FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp());
774 
775   return iter != by_symbol_flat_.end() &&
776                  IsSubSymbol(iter->AsString(*this), name)
777              ? all_values_[iter->data_offset].value()
778              : Value();
779 }
780 
781 std::pair<const void*, int>
FindExtension(absl::string_view containing_type,int field_number)782 EncodedDescriptorDatabase::DescriptorIndex::FindExtension(
783     absl::string_view containing_type, int field_number) {
784   EnsureFlat();
785 
786   auto it = std::lower_bound(
787       by_extension_flat_.begin(), by_extension_flat_.end(),
788       std::make_tuple(containing_type, field_number), by_extension_.key_comp());
789   return it == by_extension_flat_.end() ||
790                  it->extendee(*this) != containing_type ||
791                  it->extension_number != field_number
792              ? std::make_pair(nullptr, 0)
793              : all_values_[it->data_offset].value();
794 }
795 
796 template <typename T, typename Less>
MergeIntoFlat(absl::btree_set<T,Less> * s,std::vector<T> * flat)797 static void MergeIntoFlat(absl::btree_set<T, Less>* s, std::vector<T>* flat) {
798   if (s->empty()) return;
799   std::vector<T> new_flat(s->size() + flat->size());
800   std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0],
801              s->key_comp());
802   *flat = std::move(new_flat);
803   s->clear();
804 }
805 
EnsureFlat()806 void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() {
807   all_values_.shrink_to_fit();
808   // Merge each of the sets into their flat counterpart.
809   MergeIntoFlat(&by_name_, &by_name_flat_);
810   MergeIntoFlat(&by_symbol_, &by_symbol_flat_);
811   MergeIntoFlat(&by_extension_, &by_extension_flat_);
812 }
813 
FindAllExtensionNumbers(absl::string_view containing_type,std::vector<int> * output)814 bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers(
815     absl::string_view containing_type, std::vector<int>* output) {
816   EnsureFlat();
817 
818   bool success = false;
819   auto it = std::lower_bound(
820       by_extension_flat_.begin(), by_extension_flat_.end(),
821       std::make_tuple(containing_type, 0), by_extension_.key_comp());
822   for (;
823        it != by_extension_flat_.end() && it->extendee(*this) == containing_type;
824        ++it) {
825     output->push_back(it->extension_number);
826     success = true;
827   }
828 
829   return success;
830 }
831 
FindAllFileNames(std::vector<std::string> * output) const832 void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames(
833     std::vector<std::string>* output) const {
834   output->resize(by_name_.size() + by_name_flat_.size());
835   int i = 0;
836   for (const auto& entry : by_name_) {
837     (*output)[i] = std::string(entry.name(*this));
838     i++;
839   }
840   for (const auto& entry : by_name_flat_) {
841     (*output)[i] = std::string(entry.name(*this));
842     i++;
843   }
844 }
845 
846 std::pair<const void*, int>
FindFile(absl::string_view filename)847 EncodedDescriptorDatabase::DescriptorIndex::FindFile(
848     absl::string_view filename) {
849   EnsureFlat();
850 
851   auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(),
852                              filename, by_name_.key_comp());
853   return it == by_name_flat_.end() || it->name(*this) != filename
854              ? std::make_pair(nullptr, 0)
855              : all_values_[it->data_offset].value();
856 }
857 
858 
FindAllFileNames(std::vector<std::string> * output)859 bool EncodedDescriptorDatabase::FindAllFileNames(
860     std::vector<std::string>* output) {
861   index_->FindAllFileNames(output);
862   return true;
863 }
864 
MaybeParse(std::pair<const void *,int> encoded_file,FileDescriptorProto * output)865 bool EncodedDescriptorDatabase::MaybeParse(
866     std::pair<const void*, int> encoded_file, FileDescriptorProto* output) {
867   if (encoded_file.first == nullptr) return false;
868   absl::string_view source(static_cast<const char*>(encoded_file.first),
869                            encoded_file.second);
870   return internal::ParseNoReflection(source, *output);
871 }
872 
EncodedDescriptorDatabase()873 EncodedDescriptorDatabase::EncodedDescriptorDatabase()
874     : index_(new DescriptorIndex()) {}
875 
~EncodedDescriptorDatabase()876 EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
877   for (void* p : files_to_delete_) {
878     operator delete(p);
879   }
880 }
881 
882 // ===================================================================
883 
DescriptorPoolDatabase(const DescriptorPool & pool,DescriptorPoolDatabaseOptions options)884 DescriptorPoolDatabase::DescriptorPoolDatabase(
885     const DescriptorPool& pool, DescriptorPoolDatabaseOptions options)
886     : pool_(pool), options_(std::move(options)) {}
~DescriptorPoolDatabase()887 DescriptorPoolDatabase::~DescriptorPoolDatabase() {}
888 
FindFileByName(const std::string & filename,FileDescriptorProto * output)889 bool DescriptorPoolDatabase::FindFileByName(const std::string& filename,
890                                             FileDescriptorProto* output) {
891   const FileDescriptor* file = pool_.FindFileByName(filename);
892   if (file == nullptr) return false;
893   output->Clear();
894   file->CopyTo(output);
895   if (options_.preserve_source_code_info) {
896     file->CopySourceCodeInfoTo(output);
897   }
898   return true;
899 }
900 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)901 bool DescriptorPoolDatabase::FindFileContainingSymbol(
902     const std::string& symbol_name, FileDescriptorProto* output) {
903   const FileDescriptor* file = pool_.FindFileContainingSymbol(symbol_name);
904   if (file == nullptr) return false;
905   output->Clear();
906   file->CopyTo(output);
907   if (options_.preserve_source_code_info) {
908     file->CopySourceCodeInfoTo(output);
909   }
910   return true;
911 }
912 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)913 bool DescriptorPoolDatabase::FindFileContainingExtension(
914     const std::string& containing_type, int field_number,
915     FileDescriptorProto* output) {
916   const Descriptor* extendee = pool_.FindMessageTypeByName(containing_type);
917   if (extendee == nullptr) return false;
918 
919   const FieldDescriptor* extension =
920       pool_.FindExtensionByNumber(extendee, field_number);
921   if (extension == nullptr) return false;
922 
923   output->Clear();
924   extension->file()->CopyTo(output);
925   if (options_.preserve_source_code_info) {
926     extension->file()->CopySourceCodeInfoTo(output);
927   }
928   return true;
929 }
930 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)931 bool DescriptorPoolDatabase::FindAllExtensionNumbers(
932     const std::string& extendee_type, std::vector<int>* output) {
933   const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
934   if (extendee == nullptr) return false;
935 
936   std::vector<const FieldDescriptor*> extensions;
937   pool_.FindAllExtensions(extendee, &extensions);
938 
939   for (const FieldDescriptor* extension : extensions) {
940     output->push_back(extension->number());
941   }
942 
943   return true;
944 }
945 
946 // ===================================================================
947 
MergedDescriptorDatabase(DescriptorDatabase * source1,DescriptorDatabase * source2)948 MergedDescriptorDatabase::MergedDescriptorDatabase(
949     DescriptorDatabase* source1, DescriptorDatabase* source2) {
950   sources_.push_back(source1);
951   sources_.push_back(source2);
952 }
MergedDescriptorDatabase(const std::vector<DescriptorDatabase * > & sources)953 MergedDescriptorDatabase::MergedDescriptorDatabase(
954     const std::vector<DescriptorDatabase*>& sources)
955     : sources_(sources) {}
~MergedDescriptorDatabase()956 MergedDescriptorDatabase::~MergedDescriptorDatabase() {}
957 
FindFileByName(const std::string & filename,FileDescriptorProto * output)958 bool MergedDescriptorDatabase::FindFileByName(const std::string& filename,
959                                               FileDescriptorProto* output) {
960   for (DescriptorDatabase* source : sources_) {
961     if (source->FindFileByName(filename, output)) {
962       return true;
963     }
964   }
965   return false;
966 }
967 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)968 bool MergedDescriptorDatabase::FindFileContainingSymbol(
969     const std::string& symbol_name, FileDescriptorProto* output) {
970   for (size_t i = 0; i < sources_.size(); i++) {
971     if (sources_[i]->FindFileContainingSymbol(symbol_name, output)) {
972       // The symbol was found in source i.  However, if one of the previous
973       // sources defines a file with the same name (which presumably doesn't
974       // contain the symbol, since it wasn't found in that source), then we
975       // must hide it from the caller.
976       FileDescriptorProto temp;
977       for (size_t j = 0; j < i; j++) {
978         if (sources_[j]->FindFileByName(output->name(), &temp)) {
979           // Found conflicting file in a previous source.
980           return false;
981         }
982       }
983       return true;
984     }
985   }
986   return false;
987 }
988 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)989 bool MergedDescriptorDatabase::FindFileContainingExtension(
990     const std::string& containing_type, int field_number,
991     FileDescriptorProto* output) {
992   for (size_t i = 0; i < sources_.size(); i++) {
993     if (sources_[i]->FindFileContainingExtension(containing_type, field_number,
994                                                  output)) {
995       // The symbol was found in source i.  However, if one of the previous
996       // sources defines a file with the same name (which presumably doesn't
997       // contain the symbol, since it wasn't found in that source), then we
998       // must hide it from the caller.
999       FileDescriptorProto temp;
1000       for (size_t j = 0; j < i; j++) {
1001         if (sources_[j]->FindFileByName(output->name(), &temp)) {
1002           // Found conflicting file in a previous source.
1003           return false;
1004         }
1005       }
1006       return true;
1007     }
1008   }
1009   return false;
1010 }
1011 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)1012 bool MergedDescriptorDatabase::FindAllExtensionNumbers(
1013     const std::string& extendee_type, std::vector<int>* output) {
1014   // NOLINTNEXTLINE(google3-runtime-rename-unnecessary-ordering)
1015   absl::btree_set<int> merged_results;
1016   std::vector<int> results;
1017   bool success = false;
1018   for (DescriptorDatabase* source : sources_) {
1019     if (source->FindAllExtensionNumbers(extendee_type, &results)) {
1020       for (int r : results) merged_results.insert(r);
1021       success = true;
1022     }
1023     results.clear();
1024   }
1025   for (int r : merged_results) output->push_back(r);
1026   return success;
1027 }
1028 
1029 
FindAllFileNames(std::vector<std::string> * output)1030 bool MergedDescriptorDatabase::FindAllFileNames(
1031     std::vector<std::string>* output) {
1032   bool implemented = false;
1033   for (DescriptorDatabase* source : sources_) {
1034     std::vector<std::string> source_output;
1035     if (source->FindAllFileNames(&source_output)) {
1036       output->reserve(output->size() + source_output.size());
1037       for (auto& source_out : source_output) {
1038         output->push_back(std::move(source_out));
1039       }
1040       implemented = true;
1041     }
1042   }
1043   return implemented;
1044 }
1045 
1046 }  // namespace protobuf
1047 }  // namespace google
1048