• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/descriptor_database.h>
36 
37 #include <set>
38 
39 #include <google/protobuf/descriptor.pb.h>
40 #include <google/protobuf/stubs/strutil.h>
41 #include <google/protobuf/stubs/map_util.h>
42 #include <google/protobuf/stubs/stl_util.h>
43 
44 
45 namespace google {
46 namespace protobuf {
47 
48 namespace {
RecordMessageNames(const DescriptorProto & desc_proto,const std::string & prefix,std::set<std::string> * output)49 void RecordMessageNames(const DescriptorProto& desc_proto,
50                         const std::string& prefix,
51                         std::set<std::string>* output) {
52   GOOGLE_CHECK(desc_proto.has_name());
53   std::string full_name = prefix.empty()
54                               ? desc_proto.name()
55                               : StrCat(prefix, ".", desc_proto.name());
56   output->insert(full_name);
57 
58   for (const auto& d : desc_proto.nested_type()) {
59     RecordMessageNames(d, full_name, output);
60   }
61 }
62 
RecordMessageNames(const FileDescriptorProto & file_proto,std::set<std::string> * output)63 void RecordMessageNames(const FileDescriptorProto& file_proto,
64                         std::set<std::string>* output) {
65   for (const auto& d : file_proto.message_type()) {
66     RecordMessageNames(d, file_proto.package(), output);
67   }
68 }
69 
70 template <typename Fn>
ForAllFileProtos(DescriptorDatabase * db,Fn callback,std::vector<std::string> * output)71 bool ForAllFileProtos(DescriptorDatabase* db, Fn callback,
72                       std::vector<std::string>* output) {
73   std::vector<std::string> file_names;
74   if (!db->FindAllFileNames(&file_names)) {
75     return false;
76   }
77   std::set<std::string> set;
78   FileDescriptorProto file_proto;
79   for (const auto& f : file_names) {
80     file_proto.Clear();
81     if (!db->FindFileByName(f, &file_proto)) {
82       GOOGLE_LOG(ERROR) << "File not found in database (unexpected): " << f;
83       return false;
84     }
85     callback(file_proto, &set);
86   }
87   output->insert(output->end(), set.begin(), set.end());
88   return true;
89 }
90 }  // namespace
91 
~DescriptorDatabase()92 DescriptorDatabase::~DescriptorDatabase() {}
93 
FindAllPackageNames(std::vector<std::string> * output)94 bool DescriptorDatabase::FindAllPackageNames(std::vector<std::string>* output) {
95   return ForAllFileProtos(
96       this,
97       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
98         set->insert(file_proto.package());
99       },
100       output);
101 }
102 
FindAllMessageNames(std::vector<std::string> * output)103 bool DescriptorDatabase::FindAllMessageNames(std::vector<std::string>* output) {
104   return ForAllFileProtos(
105       this,
106       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
107         RecordMessageNames(file_proto, set);
108       },
109       output);
110 }
111 
112 // ===================================================================
113 
SimpleDescriptorDatabase()114 SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
~SimpleDescriptorDatabase()115 SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {}
116 
117 template <typename Value>
AddFile(const FileDescriptorProto & file,Value value)118 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
119     const FileDescriptorProto& file, Value value) {
120   if (!InsertIfNotPresent(&by_name_, file.name(), value)) {
121     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
122     return false;
123   }
124 
125   // We must be careful here -- calling file.package() if file.has_package() is
126   // false could access an uninitialized static-storage variable if we are being
127   // run at startup time.
128   std::string path = file.has_package() ? file.package() : std::string();
129   if (!path.empty()) path += '.';
130 
131   for (int i = 0; i < file.message_type_size(); i++) {
132     if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
133     if (!AddNestedExtensions(file.name(), file.message_type(i), value))
134       return false;
135   }
136   for (int i = 0; i < file.enum_type_size(); i++) {
137     if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
138   }
139   for (int i = 0; i < file.extension_size(); i++) {
140     if (!AddSymbol(path + file.extension(i).name(), value)) return false;
141     if (!AddExtension(file.name(), file.extension(i), value)) return false;
142   }
143   for (int i = 0; i < file.service_size(); i++) {
144     if (!AddSymbol(path + file.service(i).name(), value)) return false;
145   }
146 
147   return true;
148 }
149 
150 namespace {
151 
152 // Returns true if and only if all characters in the name are alphanumerics,
153 // underscores, or periods.
ValidateSymbolName(StringPiece name)154 bool ValidateSymbolName(StringPiece name) {
155   for (char c : name) {
156     // I don't trust ctype.h due to locales.  :(
157     if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') &&
158         (c < 'a' || c > 'z')) {
159       return false;
160     }
161   }
162   return true;
163 }
164 
165 // Find the last key in the container which sorts less than or equal to the
166 // symbol name.  Since upper_bound() returns the *first* key that sorts
167 // *greater* than the input, we want the element immediately before that.
168 template <typename Container, typename Key>
FindLastLessOrEqual(const Container * container,const Key & key)169 typename Container::const_iterator FindLastLessOrEqual(
170     const Container* container, const Key& key) {
171   auto iter = container->upper_bound(key);
172   if (iter != container->begin()) --iter;
173   return iter;
174 }
175 
176 // As above, but using std::upper_bound instead.
177 template <typename Container, typename Key, typename Cmp>
FindLastLessOrEqual(const Container * container,const Key & key,const Cmp & cmp)178 typename Container::const_iterator FindLastLessOrEqual(
179     const Container* container, const Key& key, const Cmp& cmp) {
180   auto iter = std::upper_bound(container->begin(), container->end(), key, cmp);
181   if (iter != container->begin()) --iter;
182   return iter;
183 }
184 
185 // True if either the arguments are equal or super_symbol identifies a
186 // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
187 // "foo.bar.baz", but not a parent of "foo.barbaz").
IsSubSymbol(StringPiece sub_symbol,StringPiece super_symbol)188 bool IsSubSymbol(StringPiece sub_symbol, StringPiece super_symbol) {
189   return sub_symbol == super_symbol ||
190          (HasPrefixString(super_symbol, sub_symbol) &&
191           super_symbol[sub_symbol.size()] == '.');
192 }
193 
194 }  // namespace
195 
196 template <typename Value>
AddSymbol(const std::string & name,Value value)197 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
198     const std::string& name, Value value) {
199   // We need to make sure not to violate our map invariant.
200 
201   // If the symbol name is invalid it could break our lookup algorithm (which
202   // relies on the fact that '.' sorts before all other characters that are
203   // valid in symbol names).
204   if (!ValidateSymbolName(name)) {
205     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name;
206     return false;
207   }
208 
209   // Try to look up the symbol to make sure a super-symbol doesn't already
210   // exist.
211   auto iter = FindLastLessOrEqual(&by_symbol_, name);
212 
213   if (iter == by_symbol_.end()) {
214     // Apparently the map is currently empty.  Just insert and be done with it.
215     by_symbol_.insert(
216         typename std::map<std::string, Value>::value_type(name, value));
217     return true;
218   }
219 
220   if (IsSubSymbol(iter->first, name)) {
221     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
222                << "\" conflicts with the existing "
223                   "symbol \""
224                << iter->first << "\".";
225     return false;
226   }
227 
228   // OK, that worked.  Now we have to make sure that no symbol in the map is
229   // a sub-symbol of the one we are inserting.  The only symbol which could
230   // be so is the first symbol that is greater than the new symbol.  Since
231   // |iter| points at the last symbol that is less than or equal, we just have
232   // to increment it.
233   ++iter;
234 
235   if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
236     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
237                << "\" conflicts with the existing "
238                   "symbol \""
239                << iter->first << "\".";
240     return false;
241   }
242 
243   // OK, no conflicts.
244 
245   // Insert the new symbol using the iterator as a hint, the new entry will
246   // appear immediately before the one the iterator is pointing at.
247   by_symbol_.insert(
248       iter, typename std::map<std::string, Value>::value_type(name, value));
249 
250   return true;
251 }
252 
253 template <typename Value>
AddNestedExtensions(const std::string & filename,const DescriptorProto & message_type,Value value)254 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
255     const std::string& filename, const DescriptorProto& message_type,
256     Value value) {
257   for (int i = 0; i < message_type.nested_type_size(); i++) {
258     if (!AddNestedExtensions(filename, message_type.nested_type(i), value))
259       return false;
260   }
261   for (int i = 0; i < message_type.extension_size(); i++) {
262     if (!AddExtension(filename, message_type.extension(i), value)) return false;
263   }
264   return true;
265 }
266 
267 template <typename Value>
AddExtension(const std::string & filename,const FieldDescriptorProto & field,Value value)268 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
269     const std::string& filename, const FieldDescriptorProto& field,
270     Value value) {
271   if (!field.extendee().empty() && field.extendee()[0] == '.') {
272     // The extension is fully-qualified.  We can use it as a lookup key in
273     // the by_symbol_ table.
274     if (!InsertIfNotPresent(
275             &by_extension_,
276             std::make_pair(field.extendee().substr(1), field.number()),
277             value)) {
278       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
279                     "extend "
280                  << field.extendee() << " { " << field.name() << " = "
281                  << field.number() << " } from:" << filename;
282       return false;
283     }
284   } else {
285     // Not fully-qualified.  We can't really do anything here, unfortunately.
286     // We don't consider this an error, though, because the descriptor is
287     // valid.
288   }
289   return true;
290 }
291 
292 template <typename Value>
FindFile(const std::string & filename)293 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
294     const std::string& filename) {
295   return FindWithDefault(by_name_, filename, Value());
296 }
297 
298 template <typename Value>
FindSymbol(const std::string & name)299 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
300     const std::string& name) {
301   auto iter = FindLastLessOrEqual(&by_symbol_, name);
302 
303   return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name))
304              ? iter->second
305              : Value();
306 }
307 
308 template <typename Value>
FindExtension(const std::string & containing_type,int field_number)309 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
310     const std::string& containing_type, int field_number) {
311   return FindWithDefault(
312       by_extension_, std::make_pair(containing_type, field_number), Value());
313 }
314 
315 template <typename Value>
FindAllExtensionNumbers(const std::string & containing_type,std::vector<int> * output)316 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
317     const std::string& containing_type, std::vector<int>* output) {
318   typename std::map<std::pair<std::string, int>, Value>::const_iterator it =
319       by_extension_.lower_bound(std::make_pair(containing_type, 0));
320   bool success = false;
321 
322   for (; it != by_extension_.end() && it->first.first == containing_type;
323        ++it) {
324     output->push_back(it->first.second);
325     success = true;
326   }
327 
328   return success;
329 }
330 
331 template <typename Value>
FindAllFileNames(std::vector<std::string> * output)332 void SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllFileNames(
333     std::vector<std::string>* output) {
334   output->resize(by_name_.size());
335   int i = 0;
336   for (const auto& kv : by_name_) {
337     (*output)[i] = kv.first;
338     i++;
339   }
340 }
341 
342 // -------------------------------------------------------------------
343 
Add(const FileDescriptorProto & file)344 bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
345   FileDescriptorProto* new_file = new FileDescriptorProto;
346   new_file->CopyFrom(file);
347   return AddAndOwn(new_file);
348 }
349 
AddAndOwn(const FileDescriptorProto * file)350 bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
351   files_to_delete_.emplace_back(file);
352   return index_.AddFile(*file, file);
353 }
354 
FindFileByName(const std::string & filename,FileDescriptorProto * output)355 bool SimpleDescriptorDatabase::FindFileByName(const std::string& filename,
356                                               FileDescriptorProto* output) {
357   return MaybeCopy(index_.FindFile(filename), output);
358 }
359 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)360 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
361     const std::string& symbol_name, FileDescriptorProto* output) {
362   return MaybeCopy(index_.FindSymbol(symbol_name), output);
363 }
364 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)365 bool SimpleDescriptorDatabase::FindFileContainingExtension(
366     const std::string& containing_type, int field_number,
367     FileDescriptorProto* output) {
368   return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
369 }
370 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)371 bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
372     const std::string& extendee_type, std::vector<int>* output) {
373   return index_.FindAllExtensionNumbers(extendee_type, output);
374 }
375 
376 
FindAllFileNames(std::vector<std::string> * output)377 bool SimpleDescriptorDatabase::FindAllFileNames(
378     std::vector<std::string>* output) {
379   index_.FindAllFileNames(output);
380   return true;
381 }
382 
MaybeCopy(const FileDescriptorProto * file,FileDescriptorProto * output)383 bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
384                                          FileDescriptorProto* output) {
385   if (file == NULL) return false;
386   output->CopyFrom(*file);
387   return true;
388 }
389 
390 // -------------------------------------------------------------------
391 
392 class EncodedDescriptorDatabase::DescriptorIndex {
393  public:
394   using Value = std::pair<const void*, int>;
395   // Helpers to recursively add particular descriptors and all their contents
396   // to the index.
397   template <typename FileProto>
398   bool AddFile(const FileProto& file, Value value);
399 
400   Value FindFile(StringPiece filename);
401   Value FindSymbol(StringPiece name);
402   Value FindSymbolOnlyFlat(StringPiece name) const;
403   Value FindExtension(StringPiece containing_type, int field_number);
404   bool FindAllExtensionNumbers(StringPiece containing_type,
405                                std::vector<int>* output);
406   void FindAllFileNames(std::vector<std::string>* output) const;
407 
408  private:
409   friend class EncodedDescriptorDatabase;
410 
411   bool AddSymbol(StringPiece symbol);
412 
413   template <typename DescProto>
414   bool AddNestedExtensions(StringPiece filename,
415                            const DescProto& message_type);
416   template <typename FieldProto>
417   bool AddExtension(StringPiece filename, const FieldProto& field);
418 
419   // All the maps below have two representations:
420   //  - a std::set<> where we insert initially.
421   //  - a std::vector<> where we flatten the structure on demand.
422   // The initial tree helps avoid O(N) behavior of inserting into a sorted
423   // vector, while the vector reduces the heap requirements of the data
424   // structure.
425 
426   void EnsureFlat();
427 
428   using String = std::string;
429 
EncodeString(StringPiece str) const430   String EncodeString(StringPiece str) const { return String(str); }
DecodeString(const String & str,int) const431   StringPiece DecodeString(const String& str, int) const { return str; }
432 
433   struct EncodedEntry {
434     // Do not use `Value` here to avoid the padding of that object.
435     const void* data;
436     int size;
437     // Keep the package here instead of each SymbolEntry to save space.
438     String encoded_package;
439 
valuegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::EncodedEntry440     Value value() const { return {data, size}; }
441   };
442   std::vector<EncodedEntry> all_values_;
443 
444   struct FileEntry {
445     int data_offset;
446     String encoded_name;
447 
namegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileEntry448     StringPiece name(const DescriptorIndex& index) const {
449       return index.DecodeString(encoded_name, data_offset);
450     }
451   };
452   struct FileCompare {
453     const DescriptorIndex& index;
454 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare455     bool operator()(const FileEntry& a, const FileEntry& b) const {
456       return a.name(index) < b.name(index);
457     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare458     bool operator()(const FileEntry& a, StringPiece b) const {
459       return a.name(index) < b;
460     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare461     bool operator()(StringPiece a, const FileEntry& b) const {
462       return a < b.name(index);
463     }
464   };
465   std::set<FileEntry, FileCompare> by_name_{FileCompare{*this}};
466   std::vector<FileEntry> by_name_flat_;
467 
468   struct SymbolEntry {
469     int data_offset;
470     String encoded_symbol;
471 
packagegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry472     StringPiece package(const DescriptorIndex& index) const {
473       return index.DecodeString(index.all_values_[data_offset].encoded_package,
474                                 data_offset);
475     }
symbolgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry476     StringPiece symbol(const DescriptorIndex& index) const {
477       return index.DecodeString(encoded_symbol, data_offset);
478     }
479 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry480     std::string AsString(const DescriptorIndex& index) const {
481       auto p = package(index);
482       return StrCat(p, p.empty() ? "" : ".", symbol(index));
483     }
484   };
485 
486   struct SymbolCompare {
487     const DescriptorIndex& index;
488 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare489     std::string AsString(const SymbolEntry& entry) const {
490       return entry.AsString(index);
491     }
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare492     static StringPiece AsString(StringPiece str) { return str; }
493 
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare494     std::pair<StringPiece, StringPiece> GetParts(
495         const SymbolEntry& entry) const {
496       auto package = entry.package(index);
497       if (package.empty()) return {entry.symbol(index), StringPiece{}};
498       return {package, entry.symbol(index)};
499     }
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare500     std::pair<StringPiece, StringPiece> GetParts(
501         StringPiece str) const {
502       return {str, {}};
503     }
504 
505     template <typename T, typename U>
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare506     bool operator()(const T& lhs, const U& rhs) const {
507       auto lhs_parts = GetParts(lhs);
508       auto rhs_parts = GetParts(rhs);
509 
510       // Fast path to avoid making the whole string for common cases.
511       if (int res =
512               lhs_parts.first.substr(0, rhs_parts.first.size())
513                   .compare(rhs_parts.first.substr(0, lhs_parts.first.size()))) {
514         // If the packages already differ, exit early.
515         return res < 0;
516       } else if (lhs_parts.first.size() == rhs_parts.first.size()) {
517         return lhs_parts.second < rhs_parts.second;
518       }
519       return AsString(lhs) < AsString(rhs);
520     }
521   };
522   std::set<SymbolEntry, SymbolCompare> by_symbol_{SymbolCompare{*this}};
523   std::vector<SymbolEntry> by_symbol_flat_;
524 
525   struct ExtensionEntry {
526     int data_offset;
527     String encoded_extendee;
extendeegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionEntry528     StringPiece extendee(const DescriptorIndex& index) const {
529       return index.DecodeString(encoded_extendee, data_offset).substr(1);
530     }
531     int extension_number;
532   };
533   struct ExtensionCompare {
534     const DescriptorIndex& index;
535 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare536     bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const {
537       return std::make_tuple(a.extendee(index), a.extension_number) <
538              std::make_tuple(b.extendee(index), b.extension_number);
539     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare540     bool operator()(const ExtensionEntry& a,
541                     std::tuple<StringPiece, int> b) const {
542       return std::make_tuple(a.extendee(index), a.extension_number) < b;
543     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare544     bool operator()(std::tuple<StringPiece, int> a,
545                     const ExtensionEntry& b) const {
546       return a < std::make_tuple(b.extendee(index), b.extension_number);
547     }
548   };
549   std::set<ExtensionEntry, ExtensionCompare> by_extension_{
550       ExtensionCompare{*this}};
551   std::vector<ExtensionEntry> by_extension_flat_;
552 };
553 
Add(const void * encoded_file_descriptor,int size)554 bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor,
555                                     int size) {
556   google::protobuf::Arena arena;
557   auto* file = google::protobuf::Arena::CreateMessage<FileDescriptorProto>(&arena);
558   if (file->ParseFromArray(encoded_file_descriptor, size)) {
559     return index_->AddFile(*file,
560                            std::make_pair(encoded_file_descriptor, size));
561   } else {
562     GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to "
563                   "EncodedDescriptorDatabase::Add().";
564     return false;
565   }
566 }
567 
AddCopy(const void * encoded_file_descriptor,int size)568 bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor,
569                                         int size) {
570   void* copy = operator new(size);
571   memcpy(copy, encoded_file_descriptor, size);
572   files_to_delete_.push_back(copy);
573   return Add(copy, size);
574 }
575 
FindFileByName(const std::string & filename,FileDescriptorProto * output)576 bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename,
577                                                FileDescriptorProto* output) {
578   return MaybeParse(index_->FindFile(filename), output);
579 }
580 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)581 bool EncodedDescriptorDatabase::FindFileContainingSymbol(
582     const std::string& symbol_name, FileDescriptorProto* output) {
583   return MaybeParse(index_->FindSymbol(symbol_name), output);
584 }
585 
FindNameOfFileContainingSymbol(const std::string & symbol_name,std::string * output)586 bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
587     const std::string& symbol_name, std::string* output) {
588   auto encoded_file = index_->FindSymbol(symbol_name);
589   if (encoded_file.first == NULL) return false;
590 
591   // Optimization:  The name should be the first field in the encoded message.
592   //   Try to just read it directly.
593   io::CodedInputStream input(static_cast<const uint8*>(encoded_file.first),
594                              encoded_file.second);
595 
596   const uint32 kNameTag = internal::WireFormatLite::MakeTag(
597       FileDescriptorProto::kNameFieldNumber,
598       internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
599 
600   if (input.ReadTagNoLastTag() == kNameTag) {
601     // Success!
602     return internal::WireFormatLite::ReadString(&input, output);
603   } else {
604     // Slow path.  Parse whole message.
605     FileDescriptorProto file_proto;
606     if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
607       return false;
608     }
609     *output = file_proto.name();
610     return true;
611   }
612 }
613 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)614 bool EncodedDescriptorDatabase::FindFileContainingExtension(
615     const std::string& containing_type, int field_number,
616     FileDescriptorProto* output) {
617   return MaybeParse(index_->FindExtension(containing_type, field_number),
618                     output);
619 }
620 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)621 bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
622     const std::string& extendee_type, std::vector<int>* output) {
623   return index_->FindAllExtensionNumbers(extendee_type, output);
624 }
625 
626 template <typename FileProto>
AddFile(const FileProto & file,Value value)627 bool EncodedDescriptorDatabase::DescriptorIndex::AddFile(const FileProto& file,
628                                                          Value value) {
629   // We push `value` into the array first. This is important because the AddXXX
630   // functions below will expect it to be there.
631   all_values_.push_back({value.first, value.second});
632 
633   if (!ValidateSymbolName(file.package())) {
634     GOOGLE_LOG(ERROR) << "Invalid package name: " << file.package();
635     return false;
636   }
637   all_values_.back().encoded_package = EncodeString(file.package());
638 
639   if (!InsertIfNotPresent(
640           &by_name_, FileEntry{static_cast<int>(all_values_.size() - 1),
641                                EncodeString(file.name())}) ||
642       std::binary_search(by_name_flat_.begin(), by_name_flat_.end(),
643                          file.name(), by_name_.key_comp())) {
644     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
645     return false;
646   }
647 
648   for (const auto& message_type : file.message_type()) {
649     if (!AddSymbol(message_type.name())) return false;
650     if (!AddNestedExtensions(file.name(), message_type)) return false;
651   }
652   for (const auto& enum_type : file.enum_type()) {
653     if (!AddSymbol(enum_type.name())) return false;
654   }
655   for (const auto& extension : file.extension()) {
656     if (!AddSymbol(extension.name())) return false;
657     if (!AddExtension(file.name(), extension)) return false;
658   }
659   for (const auto& service : file.service()) {
660     if (!AddSymbol(service.name())) return false;
661   }
662 
663   return true;
664 }
665 
666 template <typename Iter, typename Iter2, typename Index>
CheckForMutualSubsymbols(StringPiece symbol_name,Iter * iter,Iter2 end,const Index & index)667 static bool CheckForMutualSubsymbols(StringPiece symbol_name, Iter* iter,
668                                      Iter2 end, const Index& index) {
669   if (*iter != end) {
670     if (IsSubSymbol((*iter)->AsString(index), symbol_name)) {
671       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
672                  << "\" conflicts with the existing symbol \""
673                  << (*iter)->AsString(index) << "\".";
674       return false;
675     }
676 
677     // OK, that worked.  Now we have to make sure that no symbol in the map is
678     // a sub-symbol of the one we are inserting.  The only symbol which could
679     // be so is the first symbol that is greater than the new symbol.  Since
680     // |iter| points at the last symbol that is less than or equal, we just have
681     // to increment it.
682     ++*iter;
683 
684     if (*iter != end && IsSubSymbol(symbol_name, (*iter)->AsString(index))) {
685       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
686                  << "\" conflicts with the existing symbol \""
687                  << (*iter)->AsString(index) << "\".";
688       return false;
689     }
690   }
691   return true;
692 }
693 
AddSymbol(StringPiece symbol)694 bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol(
695     StringPiece symbol) {
696   SymbolEntry entry = {static_cast<int>(all_values_.size() - 1),
697                        EncodeString(symbol)};
698   std::string entry_as_string = entry.AsString(*this);
699 
700   // We need to make sure not to violate our map invariant.
701 
702   // If the symbol name is invalid it could break our lookup algorithm (which
703   // relies on the fact that '.' sorts before all other characters that are
704   // valid in symbol names).
705   if (!ValidateSymbolName(symbol)) {
706     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << entry_as_string;
707     return false;
708   }
709 
710   auto iter = FindLastLessOrEqual(&by_symbol_, entry);
711   if (!CheckForMutualSubsymbols(entry_as_string, &iter, by_symbol_.end(),
712                                 *this)) {
713     return false;
714   }
715 
716   // Same, but on by_symbol_flat_
717   auto flat_iter =
718       FindLastLessOrEqual(&by_symbol_flat_, entry, by_symbol_.key_comp());
719   if (!CheckForMutualSubsymbols(entry_as_string, &flat_iter,
720                                 by_symbol_flat_.end(), *this)) {
721     return false;
722   }
723 
724   // OK, no conflicts.
725 
726   // Insert the new symbol using the iterator as a hint, the new entry will
727   // appear immediately before the one the iterator is pointing at.
728   by_symbol_.insert(iter, entry);
729 
730   return true;
731 }
732 
733 template <typename DescProto>
AddNestedExtensions(StringPiece filename,const DescProto & message_type)734 bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions(
735     StringPiece filename, const DescProto& message_type) {
736   for (const auto& nested_type : message_type.nested_type()) {
737     if (!AddNestedExtensions(filename, nested_type)) return false;
738   }
739   for (const auto& extension : message_type.extension()) {
740     if (!AddExtension(filename, extension)) return false;
741   }
742   return true;
743 }
744 
745 template <typename FieldProto>
AddExtension(StringPiece filename,const FieldProto & field)746 bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension(
747     StringPiece filename, const FieldProto& field) {
748   if (!field.extendee().empty() && field.extendee()[0] == '.') {
749     // The extension is fully-qualified.  We can use it as a lookup key in
750     // the by_symbol_ table.
751     if (!InsertIfNotPresent(
752             &by_extension_,
753             ExtensionEntry{static_cast<int>(all_values_.size() - 1),
754                            EncodeString(field.extendee()), field.number()}) ||
755         std::binary_search(
756             by_extension_flat_.begin(), by_extension_flat_.end(),
757             std::make_pair(field.extendee().substr(1), field.number()),
758             by_extension_.key_comp())) {
759       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
760                     "extend "
761                  << field.extendee() << " { " << field.name() << " = "
762                  << field.number() << " } from:" << filename;
763       return false;
764     }
765   } else {
766     // Not fully-qualified.  We can't really do anything here, unfortunately.
767     // We don't consider this an error, though, because the descriptor is
768     // valid.
769   }
770   return true;
771 }
772 
773 std::pair<const void*, int>
FindSymbol(StringPiece name)774 EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(StringPiece name) {
775   EnsureFlat();
776   return FindSymbolOnlyFlat(name);
777 }
778 
779 std::pair<const void*, int>
FindSymbolOnlyFlat(StringPiece name) const780 EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat(
781     StringPiece name) const {
782   auto iter =
783       FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp());
784 
785   return iter != by_symbol_flat_.end() &&
786                  IsSubSymbol(iter->AsString(*this), name)
787              ? all_values_[iter->data_offset].value()
788              : Value();
789 }
790 
791 std::pair<const void*, int>
FindExtension(StringPiece containing_type,int field_number)792 EncodedDescriptorDatabase::DescriptorIndex::FindExtension(
793     StringPiece containing_type, int field_number) {
794   EnsureFlat();
795 
796   auto it = std::lower_bound(
797       by_extension_flat_.begin(), by_extension_flat_.end(),
798       std::make_tuple(containing_type, field_number), by_extension_.key_comp());
799   return it == by_extension_flat_.end() ||
800                  it->extendee(*this) != containing_type ||
801                  it->extension_number != field_number
802              ? std::make_pair(nullptr, 0)
803              : all_values_[it->data_offset].value();
804 }
805 
806 template <typename T, typename Less>
MergeIntoFlat(std::set<T,Less> * s,std::vector<T> * flat)807 static void MergeIntoFlat(std::set<T, Less>* s, std::vector<T>* flat) {
808   if (s->empty()) return;
809   std::vector<T> new_flat(s->size() + flat->size());
810   std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0],
811              s->key_comp());
812   *flat = std::move(new_flat);
813   s->clear();
814 }
815 
EnsureFlat()816 void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() {
817   all_values_.shrink_to_fit();
818   // Merge each of the sets into their flat counterpart.
819   MergeIntoFlat(&by_name_, &by_name_flat_);
820   MergeIntoFlat(&by_symbol_, &by_symbol_flat_);
821   MergeIntoFlat(&by_extension_, &by_extension_flat_);
822 }
823 
FindAllExtensionNumbers(StringPiece containing_type,std::vector<int> * output)824 bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers(
825     StringPiece containing_type, std::vector<int>* output) {
826   EnsureFlat();
827 
828   bool success = false;
829   auto it = std::lower_bound(
830       by_extension_flat_.begin(), by_extension_flat_.end(),
831       std::make_tuple(containing_type, 0), by_extension_.key_comp());
832   for (;
833        it != by_extension_flat_.end() && it->extendee(*this) == containing_type;
834        ++it) {
835     output->push_back(it->extension_number);
836     success = true;
837   }
838 
839   return success;
840 }
841 
FindAllFileNames(std::vector<std::string> * output) const842 void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames(
843     std::vector<std::string>* output) const {
844   output->resize(by_name_.size() + by_name_flat_.size());
845   int i = 0;
846   for (const auto& entry : by_name_) {
847     (*output)[i] = std::string(entry.name(*this));
848     i++;
849   }
850   for (const auto& entry : by_name_flat_) {
851     (*output)[i] = std::string(entry.name(*this));
852     i++;
853   }
854 }
855 
856 std::pair<const void*, int>
FindFile(StringPiece filename)857 EncodedDescriptorDatabase::DescriptorIndex::FindFile(
858     StringPiece filename) {
859   EnsureFlat();
860 
861   auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(),
862                              filename, by_name_.key_comp());
863   return it == by_name_flat_.end() || it->name(*this) != filename
864              ? std::make_pair(nullptr, 0)
865              : all_values_[it->data_offset].value();
866 }
867 
868 
FindAllFileNames(std::vector<std::string> * output)869 bool EncodedDescriptorDatabase::FindAllFileNames(
870     std::vector<std::string>* output) {
871   index_->FindAllFileNames(output);
872   return true;
873 }
874 
MaybeParse(std::pair<const void *,int> encoded_file,FileDescriptorProto * output)875 bool EncodedDescriptorDatabase::MaybeParse(
876     std::pair<const void*, int> encoded_file, FileDescriptorProto* output) {
877   if (encoded_file.first == NULL) return false;
878   return output->ParseFromArray(encoded_file.first, encoded_file.second);
879 }
880 
EncodedDescriptorDatabase()881 EncodedDescriptorDatabase::EncodedDescriptorDatabase()
882     : index_(new DescriptorIndex()) {}
883 
~EncodedDescriptorDatabase()884 EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
885   for (void* p : files_to_delete_) {
886     operator delete(p);
887   }
888 }
889 
890 // ===================================================================
891 
DescriptorPoolDatabase(const DescriptorPool & pool)892 DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool)
893     : pool_(pool) {}
~DescriptorPoolDatabase()894 DescriptorPoolDatabase::~DescriptorPoolDatabase() {}
895 
FindFileByName(const std::string & filename,FileDescriptorProto * output)896 bool DescriptorPoolDatabase::FindFileByName(const std::string& filename,
897                                             FileDescriptorProto* output) {
898   const FileDescriptor* file = pool_.FindFileByName(filename);
899   if (file == NULL) return false;
900   output->Clear();
901   file->CopyTo(output);
902   return true;
903 }
904 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)905 bool DescriptorPoolDatabase::FindFileContainingSymbol(
906     const std::string& symbol_name, FileDescriptorProto* output) {
907   const FileDescriptor* file = pool_.FindFileContainingSymbol(symbol_name);
908   if (file == NULL) return false;
909   output->Clear();
910   file->CopyTo(output);
911   return true;
912 }
913 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)914 bool DescriptorPoolDatabase::FindFileContainingExtension(
915     const std::string& containing_type, int field_number,
916     FileDescriptorProto* output) {
917   const Descriptor* extendee = pool_.FindMessageTypeByName(containing_type);
918   if (extendee == NULL) return false;
919 
920   const FieldDescriptor* extension =
921       pool_.FindExtensionByNumber(extendee, field_number);
922   if (extension == NULL) return false;
923 
924   output->Clear();
925   extension->file()->CopyTo(output);
926   return true;
927 }
928 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)929 bool DescriptorPoolDatabase::FindAllExtensionNumbers(
930     const std::string& extendee_type, std::vector<int>* output) {
931   const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
932   if (extendee == NULL) return false;
933 
934   std::vector<const FieldDescriptor*> extensions;
935   pool_.FindAllExtensions(extendee, &extensions);
936 
937   for (int i = 0; i < extensions.size(); ++i) {
938     output->push_back(extensions[i]->number());
939   }
940 
941   return true;
942 }
943 
944 // ===================================================================
945 
MergedDescriptorDatabase(DescriptorDatabase * source1,DescriptorDatabase * source2)946 MergedDescriptorDatabase::MergedDescriptorDatabase(
947     DescriptorDatabase* source1, DescriptorDatabase* source2) {
948   sources_.push_back(source1);
949   sources_.push_back(source2);
950 }
MergedDescriptorDatabase(const std::vector<DescriptorDatabase * > & sources)951 MergedDescriptorDatabase::MergedDescriptorDatabase(
952     const std::vector<DescriptorDatabase*>& sources)
953     : sources_(sources) {}
~MergedDescriptorDatabase()954 MergedDescriptorDatabase::~MergedDescriptorDatabase() {}
955 
FindFileByName(const std::string & filename,FileDescriptorProto * output)956 bool MergedDescriptorDatabase::FindFileByName(const std::string& filename,
957                                               FileDescriptorProto* output) {
958   for (int i = 0; i < sources_.size(); i++) {
959     if (sources_[i]->FindFileByName(filename, output)) {
960       return true;
961     }
962   }
963   return false;
964 }
965 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)966 bool MergedDescriptorDatabase::FindFileContainingSymbol(
967     const std::string& symbol_name, FileDescriptorProto* output) {
968   for (int i = 0; i < sources_.size(); i++) {
969     if (sources_[i]->FindFileContainingSymbol(symbol_name, output)) {
970       // The symbol was found in source i.  However, if one of the previous
971       // sources defines a file with the same name (which presumably doesn't
972       // contain the symbol, since it wasn't found in that source), then we
973       // must hide it from the caller.
974       FileDescriptorProto temp;
975       for (int j = 0; j < i; j++) {
976         if (sources_[j]->FindFileByName(output->name(), &temp)) {
977           // Found conflicting file in a previous source.
978           return false;
979         }
980       }
981       return true;
982     }
983   }
984   return false;
985 }
986 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)987 bool MergedDescriptorDatabase::FindFileContainingExtension(
988     const std::string& containing_type, int field_number,
989     FileDescriptorProto* output) {
990   for (int i = 0; i < sources_.size(); i++) {
991     if (sources_[i]->FindFileContainingExtension(containing_type, field_number,
992                                                  output)) {
993       // The symbol was found in source i.  However, if one of the previous
994       // sources defines a file with the same name (which presumably doesn't
995       // contain the symbol, since it wasn't found in that source), then we
996       // must hide it from the caller.
997       FileDescriptorProto temp;
998       for (int j = 0; j < i; j++) {
999         if (sources_[j]->FindFileByName(output->name(), &temp)) {
1000           // Found conflicting file in a previous source.
1001           return false;
1002         }
1003       }
1004       return true;
1005     }
1006   }
1007   return false;
1008 }
1009 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)1010 bool MergedDescriptorDatabase::FindAllExtensionNumbers(
1011     const std::string& extendee_type, std::vector<int>* output) {
1012   std::set<int> merged_results;
1013   std::vector<int> results;
1014   bool success = false;
1015 
1016   for (int i = 0; i < sources_.size(); i++) {
1017     if (sources_[i]->FindAllExtensionNumbers(extendee_type, &results)) {
1018       std::copy(results.begin(), results.end(),
1019                 std::insert_iterator<std::set<int> >(merged_results,
1020                                                      merged_results.begin()));
1021       success = true;
1022     }
1023     results.clear();
1024   }
1025 
1026   std::copy(merged_results.begin(), merged_results.end(),
1027             std::insert_iterator<std::vector<int> >(*output, output->end()));
1028 
1029   return success;
1030 }
1031 
1032 
1033 }  // namespace protobuf
1034 }  // namespace google
1035