• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 #include <google/protobuf/compiler/python/python_pyi_generator.h>
32 
33 #include <string>
34 
35 #include <google/protobuf/compiler/python/python_helpers.h>
36 #include <google/protobuf/io/printer.h>
37 #include <google/protobuf/io/zero_copy_stream.h>
38 #include <google/protobuf/descriptor.h>
39 #include <google/protobuf/stubs/strutil.h>
40 #include <google/protobuf/descriptor.pb.h>
41 
42 namespace google {
43 namespace protobuf {
44 namespace compiler {
45 namespace python {
46 
47 template <typename DescriptorT>
48 struct SortByName {
operator ()google::protobuf::compiler::python::SortByName49   bool operator()(const DescriptorT* l, const DescriptorT* r) const {
50     return l->name() < r->name();
51   }
52 };
53 
PyiGenerator()54 PyiGenerator::PyiGenerator() : file_(nullptr) {}
55 
~PyiGenerator()56 PyiGenerator::~PyiGenerator() {}
57 
PrintItemMap(const std::map<std::string,std::string> & item_map) const58 void PyiGenerator::PrintItemMap(
59     const std::map<std::string, std::string>& item_map) const {
60   for (const auto& entry : item_map) {
61     printer_->Print("$key$: $value$\n", "key", entry.first, "value",
62                     entry.second);
63   }
64 }
65 
66 template <typename DescriptorT>
ModuleLevelName(const DescriptorT & descriptor) const67 std::string PyiGenerator::ModuleLevelName(const DescriptorT& descriptor) const {
68   std::string name = NamePrefixedWithNestedTypes(descriptor, ".");
69   if (descriptor.file() != file_) {
70     std::string module_name = ModuleName(descriptor.file()->name());
71     std::vector<std::string> tokens = Split(module_name, ".");
72     name = "_" + tokens.back() + "." + name;
73   }
74   return name;
75 }
76 
77 struct ImportModules {
78   bool has_repeated = false;    // _containers
79   bool has_iterable = false;    // typing.Iterable
80   bool has_messages = false;    // _message
81   bool has_enums = false;       // _enum_type_wrapper
82   bool has_extendable = false;  // _python_message
83   bool has_mapping = false;     // typing.Mapping
84   bool has_optional = false;    // typing.Optional
85   bool has_union = false;       // typing.Uion
86 };
87 
88 // Checks what modules should be imported for this message
89 // descriptor.
CheckImportModules(const Descriptor * descriptor,ImportModules * import_modules)90 void CheckImportModules(const Descriptor* descriptor,
91                         ImportModules* import_modules) {
92   if (descriptor->extension_range_count() > 0) {
93     import_modules->has_extendable = true;
94   }
95   if (descriptor->enum_type_count() > 0) {
96     import_modules->has_enums = true;
97   }
98   for (int i = 0; i < descriptor->field_count(); ++i) {
99     const FieldDescriptor* field = descriptor->field(i);
100     if (IsPythonKeyword(field->name())) {
101       continue;
102     }
103     import_modules->has_optional = true;
104     if (field->is_repeated()) {
105       import_modules->has_repeated = true;
106     }
107     if (field->is_map()) {
108       import_modules->has_mapping = true;
109       const FieldDescriptor* value_des = field->message_type()->field(1);
110       if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
111           value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
112         import_modules->has_union = true;
113       }
114     } else {
115       if (field->is_repeated()) {
116         import_modules->has_iterable = true;
117       }
118       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
119         import_modules->has_union = true;
120         import_modules->has_mapping = true;
121       }
122       if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
123         import_modules->has_union = true;
124       }
125     }
126   }
127   for (int i = 0; i < descriptor->nested_type_count(); ++i) {
128     CheckImportModules(descriptor->nested_type(i), import_modules);
129   }
130 }
131 
PrintImports(std::map<std::string,std::string> * item_map) const132 void PyiGenerator::PrintImports(
133     std::map<std::string, std::string>* item_map) const {
134   // Prints imported dependent _pb2 files.
135   for (int i = 0; i < file_->dependency_count(); ++i) {
136     const std::string& filename = file_->dependency(i)->name();
137     std::string module_name = StrippedModuleName(filename);
138     size_t last_dot_pos = module_name.rfind('.');
139     std::string import_statement;
140     if (last_dot_pos == std::string::npos) {
141       import_statement = "import " + module_name;
142     } else {
143       import_statement = "from " + module_name.substr(0, last_dot_pos) +
144                          " import " + module_name.substr(last_dot_pos + 1);
145       module_name = module_name.substr(last_dot_pos + 1);
146     }
147     printer_->Print("$statement$ as _$module_name$\n", "statement",
148                     import_statement, "module_name", module_name);
149   }
150 
151   // Checks what modules should be imported.
152   ImportModules import_modules;
153   if (file_->message_type_count() > 0) {
154     import_modules.has_messages = true;
155   }
156   if (file_->enum_type_count() > 0) {
157     import_modules.has_enums = true;
158   }
159   for (int i = 0; i < file_->message_type_count(); i++) {
160     CheckImportModules(file_->message_type(i), &import_modules);
161   }
162 
163   // Prints modules (e.g. _containers, _messages, typing) that are
164   // required in the proto file.
165   if (import_modules.has_repeated) {
166     printer_->Print(
167         "from google.protobuf.internal import containers as "
168         "_containers\n");
169   }
170   if (import_modules.has_enums) {
171     printer_->Print(
172         "from google.protobuf.internal import enum_type_wrapper"
173         " as _enum_type_wrapper\n");
174   }
175   if (import_modules.has_extendable) {
176     printer_->Print(
177         "from google.protobuf.internal import python_message"
178         " as _python_message\n");
179   }
180   printer_->Print(
181       "from google.protobuf import"
182       " descriptor as _descriptor\n");
183   if (import_modules.has_messages) {
184     printer_->Print(
185         "from google.protobuf import message as _message\n");
186   }
187   if (HasGenericServices(file_)) {
188     printer_->Print(
189         "from google.protobuf import service as"
190         " _service\n");
191   }
192   printer_->Print("from typing import ");
193   printer_->Print("ClassVar");
194   if (import_modules.has_iterable) {
195     printer_->Print(", Iterable");
196   }
197   if (import_modules.has_mapping) {
198     printer_->Print(", Mapping");
199   }
200   if (import_modules.has_optional) {
201     printer_->Print(", Optional");
202   }
203   if (file_->service_count() > 0) {
204     printer_->Print(", Text");
205   }
206   if (import_modules.has_union) {
207     printer_->Print(", Union");
208   }
209   printer_->Print("\n\n");
210 
211   // Public imports
212   for (int i = 0; i < file_->public_dependency_count(); ++i) {
213     const FileDescriptor* public_dep = file_->public_dependency(i);
214     std::string module_name = StrippedModuleName(public_dep->name());
215     // Top level messages in public imports
216     for (int i = 0; i < public_dep->message_type_count(); ++i) {
217       printer_->Print("from $module$ import $message_class$\n", "module",
218                       module_name, "message_class",
219                       public_dep->message_type(i)->name());
220     }
221     // Top level enums for public imports
222     for (int i = 0; i < public_dep->enum_type_count(); ++i) {
223       printer_->Print("from $module$ import $enum_class$\n", "module",
224                       module_name, "enum_class",
225                       public_dep->enum_type(i)->name());
226     }
227     // Enum values for public imports
228     for (int i = 0; i < public_dep->enum_type_count(); ++i) {
229       const EnumDescriptor* enum_descriptor = public_dep->enum_type(i);
230       for (int j = 0; j < enum_descriptor->value_count(); ++j) {
231         (*item_map)[enum_descriptor->value(j)->name()] =
232             ModuleLevelName(*enum_descriptor);
233       }
234     }
235     // Top level extensions for public imports
236     AddExtensions(*public_dep, item_map);
237   }
238 }
239 
PrintEnum(const EnumDescriptor & enum_descriptor) const240 void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
241   std::string enum_name = enum_descriptor.name();
242   printer_->Print(
243       "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n"
244       "    __slots__ = []\n",
245       "enum_name", enum_name);
246 }
247 
248 // Adds enum value to item map which will be ordered and printed later.
AddEnumValue(const EnumDescriptor & enum_descriptor,std::map<std::string,std::string> * item_map) const249 void PyiGenerator::AddEnumValue(
250     const EnumDescriptor& enum_descriptor,
251     std::map<std::string, std::string>* item_map) const {
252   // enum values
253   std::string module_enum_name = ModuleLevelName(enum_descriptor);
254   for (int j = 0; j < enum_descriptor.value_count(); ++j) {
255     const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j);
256     (*item_map)[value_descriptor->name()] = module_enum_name;
257   }
258 }
259 
260 // Prints top level enums
PrintTopLevelEnums() const261 void PyiGenerator::PrintTopLevelEnums() const {
262   for (int i = 0; i < file_->enum_type_count(); ++i) {
263     printer_->Print("\n");
264     PrintEnum(*file_->enum_type(i));
265   }
266 }
267 
268 // Add top level extensions to item_map which will be ordered and
269 // printed later.
270 template <typename DescriptorT>
AddExtensions(const DescriptorT & descriptor,std::map<std::string,std::string> * item_map) const271 void PyiGenerator::AddExtensions(
272     const DescriptorT& descriptor,
273     std::map<std::string, std::string>* item_map) const {
274   for (int i = 0; i < descriptor.extension_count(); ++i) {
275     const FieldDescriptor* extension_field = descriptor.extension(i);
276     std::string constant_name = extension_field->name() + "_FIELD_NUMBER";
277     ToUpper(&constant_name);
278     (*item_map)[constant_name] = "ClassVar[int]";
279     (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor";
280   }
281 }
282 
283 // Returns the string format of a field's cpp_type
GetFieldType(const FieldDescriptor & field_des) const284 std::string PyiGenerator::GetFieldType(const FieldDescriptor& field_des) const {
285   switch (field_des.cpp_type()) {
286     case FieldDescriptor::CPPTYPE_INT32:
287     case FieldDescriptor::CPPTYPE_UINT32:
288     case FieldDescriptor::CPPTYPE_INT64:
289     case FieldDescriptor::CPPTYPE_UINT64:
290       return "int";
291     case FieldDescriptor::CPPTYPE_DOUBLE:
292     case FieldDescriptor::CPPTYPE_FLOAT:
293       return "float";
294     case FieldDescriptor::CPPTYPE_BOOL:
295       return "bool";
296     case FieldDescriptor::CPPTYPE_ENUM:
297       return ModuleLevelName(*field_des.enum_type());
298     case FieldDescriptor::CPPTYPE_STRING:
299       if (field_des.type() == FieldDescriptor::TYPE_STRING) {
300         return "str";
301       } else {
302         return "bytes";
303       }
304     case FieldDescriptor::CPPTYPE_MESSAGE:
305       return ModuleLevelName(*field_des.message_type());
306     default:
307       GOOGLE_LOG(FATAL) << "Unsuppoted field type.";
308   }
309   return "";
310 }
311 
PrintMessage(const Descriptor & message_descriptor,bool is_nested) const312 void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
313                                 bool is_nested) const {
314   if (!is_nested) {
315     printer_->Print("\n");
316   }
317   std::string class_name = message_descriptor.name();
318   printer_->Print("class $class_name$(_message.Message):\n", "class_name",
319                   class_name);
320   printer_->Indent();
321   printer_->Indent();
322 
323   std::vector<const FieldDescriptor*> fields;
324   fields.reserve(message_descriptor.field_count());
325   for (int i = 0; i < message_descriptor.field_count(); ++i) {
326     fields.push_back(message_descriptor.field(i));
327   }
328   std::sort(fields.begin(), fields.end(), SortByName<FieldDescriptor>());
329 
330   // Prints slots
331   printer_->Print("__slots__ = [", "class_name", class_name);
332   bool first_item = true;
333   for (const auto& field_des : fields) {
334     if (IsPythonKeyword(field_des->name())) {
335       continue;
336     }
337     if (first_item) {
338       first_item = false;
339     } else {
340       printer_->Print(", ");
341     }
342     printer_->Print("\"$field_name$\"", "field_name", field_des->name());
343   }
344   printer_->Print("]\n");
345 
346   std::map<std::string, std::string> item_map;
347   // Prints Extensions for extendable messages
348   if (message_descriptor.extension_range_count() > 0) {
349     item_map["Extensions"] = "_python_message._ExtensionDict";
350   }
351 
352   // Prints nested enums
353   std::vector<const EnumDescriptor*> nested_enums;
354   nested_enums.reserve(message_descriptor.enum_type_count());
355   for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
356     nested_enums.push_back(message_descriptor.enum_type(i));
357   }
358   std::sort(nested_enums.begin(), nested_enums.end(),
359             SortByName<EnumDescriptor>());
360 
361   for (const auto& entry : nested_enums) {
362     PrintEnum(*entry);
363     // Adds enum value to item_map which will be ordered and printed later
364     AddEnumValue(*entry, &item_map);
365   }
366 
367   // Prints nested messages
368   std::vector<const Descriptor*> nested_messages;
369   nested_messages.reserve(message_descriptor.nested_type_count());
370   for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
371     nested_messages.push_back(message_descriptor.nested_type(i));
372   }
373   std::sort(nested_messages.begin(), nested_messages.end(),
374             SortByName<Descriptor>());
375 
376   for (const auto& entry : nested_messages) {
377     PrintMessage(*entry, true);
378   }
379 
380   // Adds extensions to item_map which will be ordered and printed later
381   AddExtensions(message_descriptor, &item_map);
382 
383   // Adds field number and field descriptor to item_map
384   for (int i = 0; i < message_descriptor.field_count(); ++i) {
385     const FieldDescriptor& field_des = *message_descriptor.field(i);
386     item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] =
387         "ClassVar[int]";
388     if (IsPythonKeyword(field_des.name())) {
389       continue;
390     }
391     std::string field_type = "";
392     if (field_des.is_map()) {
393       const FieldDescriptor* key_des = field_des.message_type()->field(0);
394       const FieldDescriptor* value_des = field_des.message_type()->field(1);
395       field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
396                         ? "_containers.MessageMap["
397                         : "_containers.ScalarMap[");
398       field_type += GetFieldType(*key_des);
399       field_type += ", ";
400       field_type += GetFieldType(*value_des);
401     } else {
402       if (field_des.is_repeated()) {
403         field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
404                           ? "_containers.RepeatedCompositeFieldContainer["
405                           : "_containers.RepeatedScalarFieldContainer[");
406       }
407       field_type += GetFieldType(field_des);
408     }
409 
410     if (field_des.is_repeated()) {
411       field_type += "]";
412     }
413     item_map[field_des.name()] = field_type;
414   }
415 
416   // Prints all items in item_map
417   PrintItemMap(item_map);
418 
419   // Prints __init__
420   printer_->Print("def __init__(self");
421   bool has_key_words = false;
422   bool is_first = true;
423   for (int i = 0; i < message_descriptor.field_count(); ++i) {
424     const FieldDescriptor* field_des = message_descriptor.field(i);
425     if (IsPythonKeyword(field_des->name())) {
426       has_key_words = true;
427       continue;
428     }
429     std::string field_name = field_des->name();
430     if (is_first && field_name == "self") {
431       // See b/144146793 for an example of real code that generates a (self,
432       // self) method signature. Since repeating a parameter name is illegal in
433       // Python, we rename the duplicate self.
434       field_name = "self_";
435     }
436     is_first = false;
437     printer_->Print(", $field_name$: ", "field_name", field_name);
438     if (field_des->is_repeated() ||
439         field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
440       printer_->Print("Optional[");
441     }
442     if (field_des->is_map()) {
443       const Descriptor* map_entry = field_des->message_type();
444       printer_->Print("Mapping[$key_type$, $value_type$]", "key_type",
445                       GetFieldType(*map_entry->field(0)), "value_type",
446                       GetFieldType(*map_entry->field(1)));
447     } else {
448       if (field_des->is_repeated()) {
449         printer_->Print("Iterable[");
450       }
451       if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
452         printer_->Print("Union[$type_name$, Mapping]", "type_name",
453                         GetFieldType(*field_des));
454       } else {
455         if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
456           printer_->Print("Union[$type_name$, str]", "type_name",
457                           ModuleLevelName(*field_des->enum_type()));
458         } else {
459           printer_->Print("$type_name$", "type_name", GetFieldType(*field_des));
460         }
461       }
462       if (field_des->is_repeated()) {
463         printer_->Print("]");
464       }
465     }
466     if (field_des->is_repeated() ||
467         field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
468       printer_->Print("]");
469     }
470     printer_->Print(" = ...");
471   }
472   if (has_key_words) {
473     printer_->Print(", **kwargs");
474   }
475   printer_->Print(") -> None: ...\n");
476 
477   printer_->Outdent();
478   printer_->Outdent();
479 }
480 
PrintMessages() const481 void PyiGenerator::PrintMessages() const {
482   // Order the descriptors by name to have same output with proto_to_pyi.py
483   std::vector<const Descriptor*> messages;
484   messages.reserve(file_->message_type_count());
485   for (int i = 0; i < file_->message_type_count(); ++i) {
486     messages.push_back(file_->message_type(i));
487   }
488   std::sort(messages.begin(), messages.end(), SortByName<Descriptor>());
489 
490   for (const auto& entry : messages) {
491     PrintMessage(*entry, false);
492   }
493 }
494 
PrintServices() const495 void PyiGenerator::PrintServices() const {
496   std::vector<const ServiceDescriptor*> services;
497   services.reserve(file_->service_count());
498   for (int i = 0; i < file_->service_count(); ++i) {
499     services.push_back(file_->service(i));
500   }
501   std::sort(services.begin(), services.end(), SortByName<ServiceDescriptor>());
502 
503   // Prints $Service$ and $Service$_Stub classes
504   for (const auto& entry : services) {
505     printer_->Print("\n");
506     printer_->Print(
507         "class $service_name$(_service.service): ...\n\n"
508         "class $service_name$_Stub($service_name$): ...\n",
509         "service_name", entry->name());
510   }
511 }
512 
Generate(const FileDescriptor * file,const std::string & parameter,GeneratorContext * context,std::string * error) const513 bool PyiGenerator::Generate(const FileDescriptor* file,
514                             const std::string& parameter,
515                             GeneratorContext* context,
516                             std::string* error) const {
517   MutexLock lock(&mutex_);
518   // Calculate file name.
519   file_ = file;
520   // proto_to_pyi.py may set the output file name directly. To replace
521   // proto_to_pyi.py in google3, protoc also accept --pyi_out to set
522   // the output file name.
523   std::string filename =
524       parameter.empty() ? GetFileName(file, ".pyi") : parameter;
525 
526   std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename));
527   GOOGLE_CHECK(output.get());
528   io::Printer printer(output.get(), '$');
529   printer_ = &printer;
530 
531   // item map will store "DESCRIPTOR", top level extensions, top level enum
532   // values. The items will be sorted and printed later.
533   std::map<std::string, std::string> item_map;
534 
535   // Adds "DESCRIPTOR" into item_map.
536   item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor";
537   PrintImports(&item_map);
538   // Adds top level enum values to item_map.
539   for (int i = 0; i < file_->enum_type_count(); ++i) {
540     AddEnumValue(*file_->enum_type(i), &item_map);
541   }
542   // Adds top level extensions to item_map.
543   AddExtensions(*file_, &item_map);
544   // Prints item map
545   PrintItemMap(item_map);
546 
547   PrintMessages();
548   PrintTopLevelEnums();
549   if (HasGenericServices(file)) {
550     PrintServices();
551   }
552   return true;
553 }
554 
555 }  // namespace python
556 }  // namespace compiler
557 }  // namespace protobuf
558 }  // namespace google
559