• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 #include <google/protobuf/compiler/python/pyi_generator.h>
32 
33 #include <string>
34 
35 #include <google/protobuf/stubs/strutil.h>
36 #include <google/protobuf/compiler/python/helpers.h>
37 #include <google/protobuf/descriptor.h>
38 #include <google/protobuf/descriptor.pb.h>
39 #include <google/protobuf/io/printer.h>
40 #include <google/protobuf/io/zero_copy_stream.h>
41 
42 namespace google {
43 namespace protobuf {
44 namespace compiler {
45 namespace python {
46 
47 template <typename DescriptorT>
48 struct SortByName {
operator ()google::protobuf::compiler::python::SortByName49   bool operator()(const DescriptorT* l, const DescriptorT* r) const {
50     return l->name() < r->name();
51   }
52 };
53 
PyiGenerator()54 PyiGenerator::PyiGenerator() : file_(nullptr) {}
55 
~PyiGenerator()56 PyiGenerator::~PyiGenerator() {}
57 
PrintItemMap(const std::map<std::string,std::string> & item_map) const58 void PyiGenerator::PrintItemMap(
59     const std::map<std::string, std::string>& item_map) const {
60   for (const auto& entry : item_map) {
61     printer_->Print("$key$: $value$\n", "key", entry.first, "value",
62                     entry.second);
63   }
64 }
65 
66 template <typename DescriptorT>
ModuleLevelName(const DescriptorT & descriptor,const std::map<std::string,std::string> & import_map) const67 std::string PyiGenerator::ModuleLevelName(
68     const DescriptorT& descriptor,
69     const std::map<std::string, std::string>& import_map) const {
70   std::string name = NamePrefixedWithNestedTypes(descriptor, ".");
71   if (descriptor.file() != file_) {
72     std::string module_alias;
73     std::string filename = descriptor.file()->name();
74     if (import_map.find(filename) == import_map.end()) {
75       std::string module_name = ModuleName(descriptor.file()->name());
76       std::vector<std::string> tokens = Split(module_name, ".");
77       module_alias = "_" + tokens.back();
78     } else {
79       module_alias = import_map.at(filename);
80     }
81     name = module_alias + "." + name;
82   }
83   return name;
84 }
85 
86 struct ImportModules {
87   bool has_repeated = false;    // _containers
88   bool has_iterable = false;    // typing.Iterable
89   bool has_messages = false;    // _message
90   bool has_enums = false;       // _enum_type_wrapper
91   bool has_extendable = false;  // _python_message
92   bool has_mapping = false;     // typing.Mapping
93   bool has_optional = false;    // typing.Optional
94   bool has_union = false;       // typing.Union
95   bool has_well_known_type = false;
96 };
97 
98 // Checks whether a descriptor name matches a well-known type.
IsWellKnownType(const std::string & name)99 bool IsWellKnownType(const std::string& name) {
100   // LINT.IfChange(wktbases)
101   return (name == "google.protobuf.Any" ||
102           name == "google.protobuf.Duration" ||
103           name == "google.protobuf.FieldMask" ||
104           name == "google.protobuf.ListValue" ||
105           name == "google.protobuf.Struct" ||
106           name == "google.protobuf.Timestamp");
107   // LINT.ThenChange(//depot/google3/net/proto2/python/internal/well_known_types.py:wktbases)
108 }
109 
110 // Checks what modules should be imported for this message
111 // descriptor.
CheckImportModules(const Descriptor * descriptor,ImportModules * import_modules)112 void CheckImportModules(const Descriptor* descriptor,
113                         ImportModules* import_modules) {
114   if (descriptor->extension_range_count() > 0) {
115     import_modules->has_extendable = true;
116   }
117   if (descriptor->enum_type_count() > 0) {
118     import_modules->has_enums = true;
119   }
120   if (IsWellKnownType(descriptor->full_name())) {
121     import_modules->has_well_known_type = true;
122   }
123   for (int i = 0; i < descriptor->field_count(); ++i) {
124     const FieldDescriptor* field = descriptor->field(i);
125     if (IsPythonKeyword(field->name())) {
126       continue;
127     }
128     import_modules->has_optional = true;
129     if (field->is_repeated()) {
130       import_modules->has_repeated = true;
131     }
132     if (field->is_map()) {
133       import_modules->has_mapping = true;
134       const FieldDescriptor* value_des = field->message_type()->field(1);
135       if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
136           value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
137         import_modules->has_union = true;
138       }
139     } else {
140       if (field->is_repeated()) {
141         import_modules->has_iterable = true;
142       }
143       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
144         import_modules->has_union = true;
145         import_modules->has_mapping = true;
146       }
147       if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
148         import_modules->has_union = true;
149       }
150     }
151   }
152   for (int i = 0; i < descriptor->nested_type_count(); ++i) {
153     CheckImportModules(descriptor->nested_type(i), import_modules);
154   }
155 }
156 
PrintImportForDescriptor(const FileDescriptor & desc,std::map<std::string,std::string> * import_map,std::set<std::string> * seen_aliases) const157 void PyiGenerator::PrintImportForDescriptor(
158     const FileDescriptor& desc,
159     std::map<std::string, std::string>* import_map,
160     std::set<std::string>* seen_aliases) const {
161   const std::string& filename = desc.name();
162   std::string module_name = StrippedModuleName(filename);
163   size_t last_dot_pos = module_name.rfind('.');
164   std::string import_statement;
165   if (last_dot_pos == std::string::npos) {
166     import_statement = "import " + module_name;
167   } else {
168     import_statement = "from " + module_name.substr(0, last_dot_pos) +
169                        " import " + module_name.substr(last_dot_pos + 1);
170     module_name = module_name.substr(last_dot_pos + 1);
171   }
172   std::string alias = "_" + module_name;
173   // Generate a unique alias by adding _1 suffixes until we get an unused alias.
174   while (seen_aliases->find(alias) != seen_aliases->end()) {
175     alias = alias + "_1";
176   }
177   printer_->Print("$statement$ as $alias$\n", "statement",
178                   import_statement, "alias", alias);
179   (*import_map)[filename] = alias;
180   seen_aliases->insert(alias);
181 }
182 
PrintImports(std::map<std::string,std::string> * item_map,std::map<std::string,std::string> * import_map) const183 void PyiGenerator::PrintImports(
184     std::map<std::string, std::string>* item_map,
185     std::map<std::string, std::string>* import_map) const {
186   // Prints imported dependent _pb2 files.
187   std::set<std::string> seen_aliases;
188   for (int i = 0; i < file_->dependency_count(); ++i) {
189     const FileDescriptor* dep = file_->dependency(i);
190     PrintImportForDescriptor(*dep, import_map, &seen_aliases);
191     for (int j = 0; j < dep->public_dependency_count(); ++j) {
192       PrintImportForDescriptor(
193           *dep->public_dependency(j), import_map, &seen_aliases);
194     }
195   }
196 
197   // Checks what modules should be imported.
198   ImportModules import_modules;
199   if (file_->message_type_count() > 0) {
200     import_modules.has_messages = true;
201   }
202   if (file_->enum_type_count() > 0) {
203     import_modules.has_enums = true;
204   }
205   for (int i = 0; i < file_->message_type_count(); i++) {
206     CheckImportModules(file_->message_type(i), &import_modules);
207   }
208 
209   // Prints modules (e.g. _containers, _messages, typing) that are
210   // required in the proto file.
211   if (import_modules.has_repeated) {
212     printer_->Print(
213         "from google.protobuf.internal import containers as "
214         "_containers\n");
215   }
216   if (import_modules.has_enums) {
217     printer_->Print(
218         "from google.protobuf.internal import enum_type_wrapper"
219         " as _enum_type_wrapper\n");
220   }
221   if (import_modules.has_extendable) {
222     printer_->Print(
223         "from google.protobuf.internal import python_message"
224         " as _python_message\n");
225   }
226   if (import_modules.has_well_known_type) {
227     printer_->Print(
228         "from google.protobuf.internal import well_known_types"
229         " as _well_known_types\n");
230   }
231   printer_->Print(
232       "from google.protobuf import"
233       " descriptor as _descriptor\n");
234   if (import_modules.has_messages) {
235     printer_->Print(
236         "from google.protobuf import message as _message\n");
237   }
238   if (HasGenericServices(file_)) {
239     printer_->Print(
240         "from google.protobuf import service as"
241         " _service\n");
242   }
243   printer_->Print("from typing import ");
244   printer_->Print("ClassVar as _ClassVar");
245   if (import_modules.has_iterable) {
246     printer_->Print(", Iterable as _Iterable");
247   }
248   if (import_modules.has_mapping) {
249     printer_->Print(", Mapping as _Mapping");
250   }
251   if (import_modules.has_optional) {
252     printer_->Print(", Optional as _Optional");
253   }
254   if (import_modules.has_union) {
255     printer_->Print(", Union as _Union");
256   }
257   printer_->Print("\n\n");
258 
259   // Public imports
260   for (int i = 0; i < file_->public_dependency_count(); ++i) {
261     const FileDescriptor* public_dep = file_->public_dependency(i);
262     std::string module_name = StrippedModuleName(public_dep->name());
263     // Top level messages in public imports
264     for (int i = 0; i < public_dep->message_type_count(); ++i) {
265       printer_->Print("from $module$ import $message_class$\n", "module",
266                       module_name, "message_class",
267                       public_dep->message_type(i)->name());
268     }
269     // Top level enums for public imports
270     for (int i = 0; i < public_dep->enum_type_count(); ++i) {
271       printer_->Print("from $module$ import $enum_class$\n", "module",
272                       module_name, "enum_class",
273                       public_dep->enum_type(i)->name());
274     }
275     // Enum values for public imports
276     for (int i = 0; i < public_dep->enum_type_count(); ++i) {
277       const EnumDescriptor* enum_descriptor = public_dep->enum_type(i);
278       for (int j = 0; j < enum_descriptor->value_count(); ++j) {
279         (*item_map)[enum_descriptor->value(j)->name()] =
280             ModuleLevelName(*enum_descriptor, *import_map);
281       }
282     }
283     // Top level extensions for public imports
284     AddExtensions(*public_dep, item_map);
285   }
286 }
287 
PrintEnum(const EnumDescriptor & enum_descriptor) const288 void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
289   std::string enum_name = enum_descriptor.name();
290   printer_->Print(
291       "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n"
292       "    __slots__ = []\n",
293       "enum_name", enum_name);
294 }
295 
296 // Adds enum value to item map which will be ordered and printed later.
AddEnumValue(const EnumDescriptor & enum_descriptor,std::map<std::string,std::string> * item_map,const std::map<std::string,std::string> & import_map) const297 void PyiGenerator::AddEnumValue(
298     const EnumDescriptor& enum_descriptor,
299     std::map<std::string, std::string>* item_map,
300     const std::map<std::string, std::string>& import_map) const {
301   // enum values
302   std::string module_enum_name = ModuleLevelName(enum_descriptor, import_map);
303   for (int j = 0; j < enum_descriptor.value_count(); ++j) {
304     const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j);
305     (*item_map)[value_descriptor->name()] = module_enum_name;
306   }
307 }
308 
309 // Prints top level enums
PrintTopLevelEnums() const310 void PyiGenerator::PrintTopLevelEnums() const {
311   for (int i = 0; i < file_->enum_type_count(); ++i) {
312     printer_->Print("\n");
313     PrintEnum(*file_->enum_type(i));
314   }
315 }
316 
317 // Add top level extensions to item_map which will be ordered and
318 // printed later.
319 template <typename DescriptorT>
AddExtensions(const DescriptorT & descriptor,std::map<std::string,std::string> * item_map) const320 void PyiGenerator::AddExtensions(
321     const DescriptorT& descriptor,
322     std::map<std::string, std::string>* item_map) const {
323   for (int i = 0; i < descriptor.extension_count(); ++i) {
324     const FieldDescriptor* extension_field = descriptor.extension(i);
325     std::string constant_name = extension_field->name() + "_FIELD_NUMBER";
326     ToUpper(&constant_name);
327     (*item_map)[constant_name] = "_ClassVar[int]";
328     (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor";
329   }
330 }
331 
332 // Returns the string format of a field's cpp_type
GetFieldType(const FieldDescriptor & field_des,const Descriptor & containing_des,const std::map<std::string,std::string> & import_map) const333 std::string PyiGenerator::GetFieldType(
334     const FieldDescriptor& field_des, const Descriptor& containing_des,
335     const std::map<std::string, std::string>& import_map) const {
336   switch (field_des.cpp_type()) {
337     case FieldDescriptor::CPPTYPE_INT32:
338     case FieldDescriptor::CPPTYPE_UINT32:
339     case FieldDescriptor::CPPTYPE_INT64:
340     case FieldDescriptor::CPPTYPE_UINT64:
341       return "int";
342     case FieldDescriptor::CPPTYPE_DOUBLE:
343     case FieldDescriptor::CPPTYPE_FLOAT:
344       return "float";
345     case FieldDescriptor::CPPTYPE_BOOL:
346       return "bool";
347     case FieldDescriptor::CPPTYPE_ENUM:
348       return ModuleLevelName(*field_des.enum_type(), import_map);
349     case FieldDescriptor::CPPTYPE_STRING:
350       if (field_des.type() == FieldDescriptor::TYPE_STRING) {
351         return "str";
352       } else {
353         return "bytes";
354       }
355     case FieldDescriptor::CPPTYPE_MESSAGE: {
356       // If the field is inside a nested message and the nested message has the
357       // same name as a top-level message, then we need to prefix the field type
358       // with the module name for disambiguation.
359       std::string name = ModuleLevelName(*field_des.message_type(), import_map);
360       if ((containing_des.containing_type() != nullptr &&
361            name == containing_des.name())) {
362         std::string module = ModuleName(field_des.file()->name());
363         name = module + "." + name;
364       }
365       return name;
366     }
367     default:
368       GOOGLE_LOG(FATAL) << "Unsupported field type.";
369   }
370   return "";
371 }
372 
PrintMessage(const Descriptor & message_descriptor,bool is_nested,const std::map<std::string,std::string> & import_map) const373 void PyiGenerator::PrintMessage(
374     const Descriptor& message_descriptor, bool is_nested,
375     const std::map<std::string, std::string>& import_map) const {
376   if (!is_nested) {
377     printer_->Print("\n");
378   }
379   std::string class_name = message_descriptor.name();
380   std::string extra_base;
381   // A well-known type needs to inherit from its corresponding base class in
382   // net/proto2/python/internal/well_known_types.
383   if (IsWellKnownType(message_descriptor.full_name())) {
384     extra_base = ", _well_known_types." + message_descriptor.name();
385   } else {
386     extra_base = "";
387   }
388   printer_->Print("class $class_name$(_message.Message$extra_base$):\n",
389                   "class_name", class_name, "extra_base", extra_base);
390   printer_->Indent();
391   printer_->Indent();
392 
393   std::vector<const FieldDescriptor*> fields;
394   fields.reserve(message_descriptor.field_count());
395   for (int i = 0; i < message_descriptor.field_count(); ++i) {
396     fields.push_back(message_descriptor.field(i));
397   }
398   std::sort(fields.begin(), fields.end(), SortByName<FieldDescriptor>());
399 
400   // Prints slots
401   printer_->Print("__slots__ = [", "class_name", class_name);
402   bool first_item = true;
403   for (const auto& field_des : fields) {
404     if (IsPythonKeyword(field_des->name())) {
405       continue;
406     }
407     if (first_item) {
408       first_item = false;
409     } else {
410       printer_->Print(", ");
411     }
412     printer_->Print("\"$field_name$\"", "field_name", field_des->name());
413   }
414   printer_->Print("]\n");
415 
416   std::map<std::string, std::string> item_map;
417   // Prints Extensions for extendable messages
418   if (message_descriptor.extension_range_count() > 0) {
419     item_map["Extensions"] = "_python_message._ExtensionDict";
420   }
421 
422   // Prints nested enums
423   std::vector<const EnumDescriptor*> nested_enums;
424   nested_enums.reserve(message_descriptor.enum_type_count());
425   for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
426     nested_enums.push_back(message_descriptor.enum_type(i));
427   }
428   std::sort(nested_enums.begin(), nested_enums.end(),
429             SortByName<EnumDescriptor>());
430 
431   for (const auto& entry : nested_enums) {
432     PrintEnum(*entry);
433     // Adds enum value to item_map which will be ordered and printed later
434     AddEnumValue(*entry, &item_map, import_map);
435   }
436 
437   // Prints nested messages
438   std::vector<const Descriptor*> nested_messages;
439   nested_messages.reserve(message_descriptor.nested_type_count());
440   for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
441     nested_messages.push_back(message_descriptor.nested_type(i));
442   }
443   std::sort(nested_messages.begin(), nested_messages.end(),
444             SortByName<Descriptor>());
445 
446   for (const auto& entry : nested_messages) {
447     PrintMessage(*entry, true, import_map);
448   }
449 
450   // Adds extensions to item_map which will be ordered and printed later
451   AddExtensions(message_descriptor, &item_map);
452 
453   // Adds field number and field descriptor to item_map
454   for (int i = 0; i < message_descriptor.field_count(); ++i) {
455     const FieldDescriptor& field_des = *message_descriptor.field(i);
456     item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] =
457         "_ClassVar[int]";
458     if (IsPythonKeyword(field_des.name())) {
459       continue;
460     }
461     std::string field_type = "";
462     if (field_des.is_map()) {
463       const FieldDescriptor* key_des = field_des.message_type()->field(0);
464       const FieldDescriptor* value_des = field_des.message_type()->field(1);
465       field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
466                         ? "_containers.MessageMap["
467                         : "_containers.ScalarMap[");
468       field_type += GetFieldType(*key_des, message_descriptor, import_map);
469       field_type += ", ";
470       field_type += GetFieldType(*value_des, message_descriptor, import_map);
471     } else {
472       if (field_des.is_repeated()) {
473         field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
474                           ? "_containers.RepeatedCompositeFieldContainer["
475                           : "_containers.RepeatedScalarFieldContainer[");
476       }
477       field_type += GetFieldType(field_des, message_descriptor, import_map);
478     }
479 
480     if (field_des.is_repeated()) {
481       field_type += "]";
482     }
483     item_map[field_des.name()] = field_type;
484   }
485 
486   // Prints all items in item_map
487   PrintItemMap(item_map);
488 
489   // Prints __init__
490   printer_->Print("def __init__(self");
491   bool has_key_words = false;
492   bool is_first = true;
493   for (int i = 0; i < message_descriptor.field_count(); ++i) {
494     const FieldDescriptor* field_des = message_descriptor.field(i);
495     if (IsPythonKeyword(field_des->name())) {
496       has_key_words = true;
497       continue;
498     }
499     std::string field_name = field_des->name();
500     if (is_first && field_name == "self") {
501       // See b/144146793 for an example of real code that generates a (self,
502       // self) method signature. Since repeating a parameter name is illegal in
503       // Python, we rename the duplicate self.
504       field_name = "self_";
505     }
506     is_first = false;
507     printer_->Print(", $field_name$: ", "field_name", field_name);
508     if (field_des->is_repeated() ||
509         field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
510       printer_->Print("_Optional[");
511     }
512     if (field_des->is_map()) {
513       const Descriptor* map_entry = field_des->message_type();
514       printer_->Print(
515           "_Mapping[$key_type$, $value_type$]", "key_type",
516           GetFieldType(*map_entry->field(0), message_descriptor, import_map),
517           "value_type",
518           GetFieldType(*map_entry->field(1), message_descriptor, import_map));
519     } else {
520       if (field_des->is_repeated()) {
521         printer_->Print("_Iterable[");
522       }
523       if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
524         printer_->Print(
525             "_Union[$type_name$, _Mapping]", "type_name",
526             GetFieldType(*field_des, message_descriptor, import_map));
527       } else {
528         if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
529           printer_->Print("_Union[$type_name$, str]", "type_name",
530                           ModuleLevelName(*field_des->enum_type(), import_map));
531         } else {
532           printer_->Print(
533               "$type_name$", "type_name",
534               GetFieldType(*field_des, message_descriptor, import_map));
535         }
536       }
537       if (field_des->is_repeated()) {
538         printer_->Print("]");
539       }
540     }
541     if (field_des->is_repeated() ||
542         field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
543       printer_->Print("]");
544     }
545     printer_->Print(" = ...");
546   }
547   if (has_key_words) {
548     printer_->Print(", **kwargs");
549   }
550   printer_->Print(") -> None: ...\n");
551 
552   printer_->Outdent();
553   printer_->Outdent();
554 }
555 
PrintMessages(const std::map<std::string,std::string> & import_map) const556 void PyiGenerator::PrintMessages(
557     const std::map<std::string, std::string>& import_map) const {
558   // Deterministically order the descriptors.
559   std::vector<const Descriptor*> messages;
560   messages.reserve(file_->message_type_count());
561   for (int i = 0; i < file_->message_type_count(); ++i) {
562     messages.push_back(file_->message_type(i));
563   }
564   std::sort(messages.begin(), messages.end(), SortByName<Descriptor>());
565 
566   for (const auto& entry : messages) {
567     PrintMessage(*entry, false, import_map);
568   }
569 }
570 
PrintServices() const571 void PyiGenerator::PrintServices() const {
572   std::vector<const ServiceDescriptor*> services;
573   services.reserve(file_->service_count());
574   for (int i = 0; i < file_->service_count(); ++i) {
575     services.push_back(file_->service(i));
576   }
577   std::sort(services.begin(), services.end(), SortByName<ServiceDescriptor>());
578 
579   // Prints $Service$ and $Service$_Stub classes
580   for (const auto& entry : services) {
581     printer_->Print("\n");
582     printer_->Print(
583         "class $service_name$(_service.service): ...\n\n"
584         "class $service_name$_Stub($service_name$): ...\n",
585         "service_name", entry->name());
586   }
587 }
588 
Generate(const FileDescriptor * file,const std::string & parameter,GeneratorContext * context,std::string * error) const589 bool PyiGenerator::Generate(const FileDescriptor* file,
590                             const std::string& parameter,
591                             GeneratorContext* context,
592                             std::string* error) const {
593   MutexLock lock(&mutex_);
594   // Calculate file name.
595   file_ = file;
596   std::string filename =
597       parameter.empty() ? GetFileName(file, ".pyi") : parameter;
598 
599   std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename));
600   GOOGLE_CHECK(output.get());
601   io::Printer printer(output.get(), '$');
602   printer_ = &printer;
603 
604   // item map will store "DESCRIPTOR", top level extensions, top level enum
605   // values. The items will be sorted and printed later.
606   std::map<std::string, std::string> item_map;
607 
608   // Adds "DESCRIPTOR" into item_map.
609   item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor";
610 
611   // import_map will be a mapping from filename to module alias, e.g.
612   // "google3/foo/bar.py" -> "_bar"
613   std::map<std::string, std::string> import_map;
614 
615   PrintImports(&item_map, &import_map);
616   // Adds top level enum values to item_map.
617   for (int i = 0; i < file_->enum_type_count(); ++i) {
618     AddEnumValue(*file_->enum_type(i), &item_map, import_map);
619   }
620   // Adds top level extensions to item_map.
621   AddExtensions(*file_, &item_map);
622   // Prints item map
623   PrintItemMap(item_map);
624 
625   PrintMessages(import_map);
626   PrintTopLevelEnums();
627   if (HasGenericServices(file)) {
628     PrintServices();
629   }
630   return true;
631 }
632 
633 }  // namespace python
634 }  // namespace compiler
635 }  // namespace protobuf
636 }  // namespace google
637