1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 #include <google/protobuf/compiler/python/python_pyi_generator.h>
32
33 #include <string>
34
35 #include <google/protobuf/compiler/python/python_helpers.h>
36 #include <google/protobuf/io/printer.h>
37 #include <google/protobuf/io/zero_copy_stream.h>
38 #include <google/protobuf/descriptor.h>
39 #include <google/protobuf/stubs/strutil.h>
40 #include <google/protobuf/descriptor.pb.h>
41
42 namespace google {
43 namespace protobuf {
44 namespace compiler {
45 namespace python {
46
47 template <typename DescriptorT>
48 struct SortByName {
operator ()google::protobuf::compiler::python::SortByName49 bool operator()(const DescriptorT* l, const DescriptorT* r) const {
50 return l->name() < r->name();
51 }
52 };
53
PyiGenerator()54 PyiGenerator::PyiGenerator() : file_(nullptr) {}
55
~PyiGenerator()56 PyiGenerator::~PyiGenerator() {}
57
PrintItemMap(const std::map<std::string,std::string> & item_map) const58 void PyiGenerator::PrintItemMap(
59 const std::map<std::string, std::string>& item_map) const {
60 for (const auto& entry : item_map) {
61 printer_->Print("$key$: $value$\n", "key", entry.first, "value",
62 entry.second);
63 }
64 }
65
66 template <typename DescriptorT>
ModuleLevelName(const DescriptorT & descriptor) const67 std::string PyiGenerator::ModuleLevelName(const DescriptorT& descriptor) const {
68 std::string name = NamePrefixedWithNestedTypes(descriptor, ".");
69 if (descriptor.file() != file_) {
70 std::string module_name = ModuleName(descriptor.file()->name());
71 std::vector<std::string> tokens = Split(module_name, ".");
72 name = "_" + tokens.back() + "." + name;
73 }
74 return name;
75 }
76
77 struct ImportModules {
78 bool has_repeated = false; // _containers
79 bool has_iterable = false; // typing.Iterable
80 bool has_messages = false; // _message
81 bool has_enums = false; // _enum_type_wrapper
82 bool has_extendable = false; // _python_message
83 bool has_mapping = false; // typing.Mapping
84 bool has_optional = false; // typing.Optional
85 bool has_union = false; // typing.Uion
86 };
87
88 // Checks what modules should be imported for this message
89 // descriptor.
CheckImportModules(const Descriptor * descriptor,ImportModules * import_modules)90 void CheckImportModules(const Descriptor* descriptor,
91 ImportModules* import_modules) {
92 if (descriptor->extension_range_count() > 0) {
93 import_modules->has_extendable = true;
94 }
95 if (descriptor->enum_type_count() > 0) {
96 import_modules->has_enums = true;
97 }
98 for (int i = 0; i < descriptor->field_count(); ++i) {
99 const FieldDescriptor* field = descriptor->field(i);
100 if (IsPythonKeyword(field->name())) {
101 continue;
102 }
103 import_modules->has_optional = true;
104 if (field->is_repeated()) {
105 import_modules->has_repeated = true;
106 }
107 if (field->is_map()) {
108 import_modules->has_mapping = true;
109 const FieldDescriptor* value_des = field->message_type()->field(1);
110 if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
111 value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
112 import_modules->has_union = true;
113 }
114 } else {
115 if (field->is_repeated()) {
116 import_modules->has_iterable = true;
117 }
118 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
119 import_modules->has_union = true;
120 import_modules->has_mapping = true;
121 }
122 if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
123 import_modules->has_union = true;
124 }
125 }
126 }
127 for (int i = 0; i < descriptor->nested_type_count(); ++i) {
128 CheckImportModules(descriptor->nested_type(i), import_modules);
129 }
130 }
131
PrintImports(std::map<std::string,std::string> * item_map) const132 void PyiGenerator::PrintImports(
133 std::map<std::string, std::string>* item_map) const {
134 // Prints imported dependent _pb2 files.
135 for (int i = 0; i < file_->dependency_count(); ++i) {
136 const std::string& filename = file_->dependency(i)->name();
137 std::string module_name = StrippedModuleName(filename);
138 size_t last_dot_pos = module_name.rfind('.');
139 std::string import_statement;
140 if (last_dot_pos == std::string::npos) {
141 import_statement = "import " + module_name;
142 } else {
143 import_statement = "from " + module_name.substr(0, last_dot_pos) +
144 " import " + module_name.substr(last_dot_pos + 1);
145 module_name = module_name.substr(last_dot_pos + 1);
146 }
147 printer_->Print("$statement$ as _$module_name$\n", "statement",
148 import_statement, "module_name", module_name);
149 }
150
151 // Checks what modules should be imported.
152 ImportModules import_modules;
153 if (file_->message_type_count() > 0) {
154 import_modules.has_messages = true;
155 }
156 if (file_->enum_type_count() > 0) {
157 import_modules.has_enums = true;
158 }
159 for (int i = 0; i < file_->message_type_count(); i++) {
160 CheckImportModules(file_->message_type(i), &import_modules);
161 }
162
163 // Prints modules (e.g. _containers, _messages, typing) that are
164 // required in the proto file.
165 if (import_modules.has_repeated) {
166 printer_->Print(
167 "from google.protobuf.internal import containers as "
168 "_containers\n");
169 }
170 if (import_modules.has_enums) {
171 printer_->Print(
172 "from google.protobuf.internal import enum_type_wrapper"
173 " as _enum_type_wrapper\n");
174 }
175 if (import_modules.has_extendable) {
176 printer_->Print(
177 "from google.protobuf.internal import python_message"
178 " as _python_message\n");
179 }
180 printer_->Print(
181 "from google.protobuf import"
182 " descriptor as _descriptor\n");
183 if (import_modules.has_messages) {
184 printer_->Print(
185 "from google.protobuf import message as _message\n");
186 }
187 if (HasGenericServices(file_)) {
188 printer_->Print(
189 "from google.protobuf import service as"
190 " _service\n");
191 }
192 printer_->Print("from typing import ");
193 printer_->Print("ClassVar");
194 if (import_modules.has_iterable) {
195 printer_->Print(", Iterable");
196 }
197 if (import_modules.has_mapping) {
198 printer_->Print(", Mapping");
199 }
200 if (import_modules.has_optional) {
201 printer_->Print(", Optional");
202 }
203 if (file_->service_count() > 0) {
204 printer_->Print(", Text");
205 }
206 if (import_modules.has_union) {
207 printer_->Print(", Union");
208 }
209 printer_->Print("\n\n");
210
211 // Public imports
212 for (int i = 0; i < file_->public_dependency_count(); ++i) {
213 const FileDescriptor* public_dep = file_->public_dependency(i);
214 std::string module_name = StrippedModuleName(public_dep->name());
215 // Top level messages in public imports
216 for (int i = 0; i < public_dep->message_type_count(); ++i) {
217 printer_->Print("from $module$ import $message_class$\n", "module",
218 module_name, "message_class",
219 public_dep->message_type(i)->name());
220 }
221 // Top level enums for public imports
222 for (int i = 0; i < public_dep->enum_type_count(); ++i) {
223 printer_->Print("from $module$ import $enum_class$\n", "module",
224 module_name, "enum_class",
225 public_dep->enum_type(i)->name());
226 }
227 // Enum values for public imports
228 for (int i = 0; i < public_dep->enum_type_count(); ++i) {
229 const EnumDescriptor* enum_descriptor = public_dep->enum_type(i);
230 for (int j = 0; j < enum_descriptor->value_count(); ++j) {
231 (*item_map)[enum_descriptor->value(j)->name()] =
232 ModuleLevelName(*enum_descriptor);
233 }
234 }
235 // Top level extensions for public imports
236 AddExtensions(*public_dep, item_map);
237 }
238 }
239
PrintEnum(const EnumDescriptor & enum_descriptor) const240 void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
241 std::string enum_name = enum_descriptor.name();
242 printer_->Print(
243 "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n"
244 " __slots__ = []\n",
245 "enum_name", enum_name);
246 }
247
248 // Adds enum value to item map which will be ordered and printed later.
AddEnumValue(const EnumDescriptor & enum_descriptor,std::map<std::string,std::string> * item_map) const249 void PyiGenerator::AddEnumValue(
250 const EnumDescriptor& enum_descriptor,
251 std::map<std::string, std::string>* item_map) const {
252 // enum values
253 std::string module_enum_name = ModuleLevelName(enum_descriptor);
254 for (int j = 0; j < enum_descriptor.value_count(); ++j) {
255 const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j);
256 (*item_map)[value_descriptor->name()] = module_enum_name;
257 }
258 }
259
260 // Prints top level enums
PrintTopLevelEnums() const261 void PyiGenerator::PrintTopLevelEnums() const {
262 for (int i = 0; i < file_->enum_type_count(); ++i) {
263 printer_->Print("\n");
264 PrintEnum(*file_->enum_type(i));
265 }
266 }
267
268 // Add top level extensions to item_map which will be ordered and
269 // printed later.
270 template <typename DescriptorT>
AddExtensions(const DescriptorT & descriptor,std::map<std::string,std::string> * item_map) const271 void PyiGenerator::AddExtensions(
272 const DescriptorT& descriptor,
273 std::map<std::string, std::string>* item_map) const {
274 for (int i = 0; i < descriptor.extension_count(); ++i) {
275 const FieldDescriptor* extension_field = descriptor.extension(i);
276 std::string constant_name = extension_field->name() + "_FIELD_NUMBER";
277 ToUpper(&constant_name);
278 (*item_map)[constant_name] = "ClassVar[int]";
279 (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor";
280 }
281 }
282
283 // Returns the string format of a field's cpp_type
GetFieldType(const FieldDescriptor & field_des) const284 std::string PyiGenerator::GetFieldType(const FieldDescriptor& field_des) const {
285 switch (field_des.cpp_type()) {
286 case FieldDescriptor::CPPTYPE_INT32:
287 case FieldDescriptor::CPPTYPE_UINT32:
288 case FieldDescriptor::CPPTYPE_INT64:
289 case FieldDescriptor::CPPTYPE_UINT64:
290 return "int";
291 case FieldDescriptor::CPPTYPE_DOUBLE:
292 case FieldDescriptor::CPPTYPE_FLOAT:
293 return "float";
294 case FieldDescriptor::CPPTYPE_BOOL:
295 return "bool";
296 case FieldDescriptor::CPPTYPE_ENUM:
297 return ModuleLevelName(*field_des.enum_type());
298 case FieldDescriptor::CPPTYPE_STRING:
299 if (field_des.type() == FieldDescriptor::TYPE_STRING) {
300 return "str";
301 } else {
302 return "bytes";
303 }
304 case FieldDescriptor::CPPTYPE_MESSAGE:
305 return ModuleLevelName(*field_des.message_type());
306 default:
307 GOOGLE_LOG(FATAL) << "Unsuppoted field type.";
308 }
309 return "";
310 }
311
PrintMessage(const Descriptor & message_descriptor,bool is_nested) const312 void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
313 bool is_nested) const {
314 if (!is_nested) {
315 printer_->Print("\n");
316 }
317 std::string class_name = message_descriptor.name();
318 printer_->Print("class $class_name$(_message.Message):\n", "class_name",
319 class_name);
320 printer_->Indent();
321 printer_->Indent();
322
323 std::vector<const FieldDescriptor*> fields;
324 fields.reserve(message_descriptor.field_count());
325 for (int i = 0; i < message_descriptor.field_count(); ++i) {
326 fields.push_back(message_descriptor.field(i));
327 }
328 std::sort(fields.begin(), fields.end(), SortByName<FieldDescriptor>());
329
330 // Prints slots
331 printer_->Print("__slots__ = [", "class_name", class_name);
332 bool first_item = true;
333 for (const auto& field_des : fields) {
334 if (IsPythonKeyword(field_des->name())) {
335 continue;
336 }
337 if (first_item) {
338 first_item = false;
339 } else {
340 printer_->Print(", ");
341 }
342 printer_->Print("\"$field_name$\"", "field_name", field_des->name());
343 }
344 printer_->Print("]\n");
345
346 std::map<std::string, std::string> item_map;
347 // Prints Extensions for extendable messages
348 if (message_descriptor.extension_range_count() > 0) {
349 item_map["Extensions"] = "_python_message._ExtensionDict";
350 }
351
352 // Prints nested enums
353 std::vector<const EnumDescriptor*> nested_enums;
354 nested_enums.reserve(message_descriptor.enum_type_count());
355 for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
356 nested_enums.push_back(message_descriptor.enum_type(i));
357 }
358 std::sort(nested_enums.begin(), nested_enums.end(),
359 SortByName<EnumDescriptor>());
360
361 for (const auto& entry : nested_enums) {
362 PrintEnum(*entry);
363 // Adds enum value to item_map which will be ordered and printed later
364 AddEnumValue(*entry, &item_map);
365 }
366
367 // Prints nested messages
368 std::vector<const Descriptor*> nested_messages;
369 nested_messages.reserve(message_descriptor.nested_type_count());
370 for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
371 nested_messages.push_back(message_descriptor.nested_type(i));
372 }
373 std::sort(nested_messages.begin(), nested_messages.end(),
374 SortByName<Descriptor>());
375
376 for (const auto& entry : nested_messages) {
377 PrintMessage(*entry, true);
378 }
379
380 // Adds extensions to item_map which will be ordered and printed later
381 AddExtensions(message_descriptor, &item_map);
382
383 // Adds field number and field descriptor to item_map
384 for (int i = 0; i < message_descriptor.field_count(); ++i) {
385 const FieldDescriptor& field_des = *message_descriptor.field(i);
386 item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] =
387 "ClassVar[int]";
388 if (IsPythonKeyword(field_des.name())) {
389 continue;
390 }
391 std::string field_type = "";
392 if (field_des.is_map()) {
393 const FieldDescriptor* key_des = field_des.message_type()->field(0);
394 const FieldDescriptor* value_des = field_des.message_type()->field(1);
395 field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
396 ? "_containers.MessageMap["
397 : "_containers.ScalarMap[");
398 field_type += GetFieldType(*key_des);
399 field_type += ", ";
400 field_type += GetFieldType(*value_des);
401 } else {
402 if (field_des.is_repeated()) {
403 field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
404 ? "_containers.RepeatedCompositeFieldContainer["
405 : "_containers.RepeatedScalarFieldContainer[");
406 }
407 field_type += GetFieldType(field_des);
408 }
409
410 if (field_des.is_repeated()) {
411 field_type += "]";
412 }
413 item_map[field_des.name()] = field_type;
414 }
415
416 // Prints all items in item_map
417 PrintItemMap(item_map);
418
419 // Prints __init__
420 printer_->Print("def __init__(self");
421 bool has_key_words = false;
422 bool is_first = true;
423 for (int i = 0; i < message_descriptor.field_count(); ++i) {
424 const FieldDescriptor* field_des = message_descriptor.field(i);
425 if (IsPythonKeyword(field_des->name())) {
426 has_key_words = true;
427 continue;
428 }
429 std::string field_name = field_des->name();
430 if (is_first && field_name == "self") {
431 // See b/144146793 for an example of real code that generates a (self,
432 // self) method signature. Since repeating a parameter name is illegal in
433 // Python, we rename the duplicate self.
434 field_name = "self_";
435 }
436 is_first = false;
437 printer_->Print(", $field_name$: ", "field_name", field_name);
438 if (field_des->is_repeated() ||
439 field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
440 printer_->Print("Optional[");
441 }
442 if (field_des->is_map()) {
443 const Descriptor* map_entry = field_des->message_type();
444 printer_->Print("Mapping[$key_type$, $value_type$]", "key_type",
445 GetFieldType(*map_entry->field(0)), "value_type",
446 GetFieldType(*map_entry->field(1)));
447 } else {
448 if (field_des->is_repeated()) {
449 printer_->Print("Iterable[");
450 }
451 if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
452 printer_->Print("Union[$type_name$, Mapping]", "type_name",
453 GetFieldType(*field_des));
454 } else {
455 if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
456 printer_->Print("Union[$type_name$, str]", "type_name",
457 ModuleLevelName(*field_des->enum_type()));
458 } else {
459 printer_->Print("$type_name$", "type_name", GetFieldType(*field_des));
460 }
461 }
462 if (field_des->is_repeated()) {
463 printer_->Print("]");
464 }
465 }
466 if (field_des->is_repeated() ||
467 field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
468 printer_->Print("]");
469 }
470 printer_->Print(" = ...");
471 }
472 if (has_key_words) {
473 printer_->Print(", **kwargs");
474 }
475 printer_->Print(") -> None: ...\n");
476
477 printer_->Outdent();
478 printer_->Outdent();
479 }
480
PrintMessages() const481 void PyiGenerator::PrintMessages() const {
482 // Order the descriptors by name to have same output with proto_to_pyi.py
483 std::vector<const Descriptor*> messages;
484 messages.reserve(file_->message_type_count());
485 for (int i = 0; i < file_->message_type_count(); ++i) {
486 messages.push_back(file_->message_type(i));
487 }
488 std::sort(messages.begin(), messages.end(), SortByName<Descriptor>());
489
490 for (const auto& entry : messages) {
491 PrintMessage(*entry, false);
492 }
493 }
494
PrintServices() const495 void PyiGenerator::PrintServices() const {
496 std::vector<const ServiceDescriptor*> services;
497 services.reserve(file_->service_count());
498 for (int i = 0; i < file_->service_count(); ++i) {
499 services.push_back(file_->service(i));
500 }
501 std::sort(services.begin(), services.end(), SortByName<ServiceDescriptor>());
502
503 // Prints $Service$ and $Service$_Stub classes
504 for (const auto& entry : services) {
505 printer_->Print("\n");
506 printer_->Print(
507 "class $service_name$(_service.service): ...\n\n"
508 "class $service_name$_Stub($service_name$): ...\n",
509 "service_name", entry->name());
510 }
511 }
512
Generate(const FileDescriptor * file,const std::string & parameter,GeneratorContext * context,std::string * error) const513 bool PyiGenerator::Generate(const FileDescriptor* file,
514 const std::string& parameter,
515 GeneratorContext* context,
516 std::string* error) const {
517 MutexLock lock(&mutex_);
518 // Calculate file name.
519 file_ = file;
520 // proto_to_pyi.py may set the output file name directly. To replace
521 // proto_to_pyi.py in google3, protoc also accept --pyi_out to set
522 // the output file name.
523 std::string filename =
524 parameter.empty() ? GetFileName(file, ".pyi") : parameter;
525
526 std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename));
527 GOOGLE_CHECK(output.get());
528 io::Printer printer(output.get(), '$');
529 printer_ = &printer;
530
531 // item map will store "DESCRIPTOR", top level extensions, top level enum
532 // values. The items will be sorted and printed later.
533 std::map<std::string, std::string> item_map;
534
535 // Adds "DESCRIPTOR" into item_map.
536 item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor";
537 PrintImports(&item_map);
538 // Adds top level enum values to item_map.
539 for (int i = 0; i < file_->enum_type_count(); ++i) {
540 AddEnumValue(*file_->enum_type(i), &item_map);
541 }
542 // Adds top level extensions to item_map.
543 AddExtensions(*file_, &item_map);
544 // Prints item map
545 PrintItemMap(item_map);
546
547 PrintMessages();
548 PrintTopLevelEnums();
549 if (HasGenericServices(file)) {
550 PrintServices();
551 }
552 return true;
553 }
554
555 } // namespace python
556 } // namespace compiler
557 } // namespace protobuf
558 } // namespace google
559