• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include "src/compiler/python_generator.h"
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cstring>
25 #include <fstream>
26 #include <iostream>
27 #include <map>
28 #include <memory>
29 #include <ostream>
30 #include <set>
31 #include <sstream>
32 #include <tuple>
33 #include <vector>
34 
35 #include "src/compiler/config.h"
36 #include "src/compiler/generator_helpers.h"
37 #include "src/compiler/protobuf_plugin.h"
38 #include "src/compiler/python_generator_helpers.h"
39 #include "src/compiler/python_private_generator.h"
40 
41 using grpc::protobuf::FileDescriptor;
42 using grpc::protobuf::compiler::GeneratorContext;
43 using grpc::protobuf::io::CodedOutputStream;
44 using grpc::protobuf::io::ZeroCopyOutputStream;
45 using std::make_pair;
46 using std::map;
47 using std::pair;
48 using std::replace;
49 using std::set;
50 using std::tuple;
51 using std::vector;
52 
53 namespace grpc_python_generator {
54 
55 std::string generator_file_name;
56 
57 namespace {
58 
59 typedef map<std::string, std::string> StringMap;
60 typedef vector<std::string> StringVector;
61 typedef tuple<std::string, std::string> StringPair;
62 typedef set<StringPair> StringPairSet;
63 
64 // Provides RAII indentation handling. Use as:
65 // {
66 //   IndentScope raii_my_indent_var_name_here(my_py_printer);
67 //   // constructor indented my_py_printer
68 //   ...
69 //   // destructor called at end of scope, un-indenting my_py_printer
70 // }
71 class IndentScope {
72  public:
IndentScope(grpc_generator::Printer * printer)73   explicit IndentScope(grpc_generator::Printer* printer) : printer_(printer) {
74     // NOTE(rbellevi): Two-space tabs are hard-coded in the protocol compiler.
75     // Doubling our indents and outdents guarantees compliance with PEP8.
76     printer_->Indent();
77     printer_->Indent();
78   }
79 
~IndentScope()80   ~IndentScope() {
81     printer_->Outdent();
82     printer_->Outdent();
83   }
84 
85  private:
86   grpc_generator::Printer* printer_;
87 };
88 
PrivateGenerator(const GeneratorConfiguration & config,const grpc_generator::File * file)89 PrivateGenerator::PrivateGenerator(const GeneratorConfiguration& config,
90                                    const grpc_generator::File* file)
91     : config(config), file(file) {}
92 
PrintAllComments(StringVector comments,grpc_generator::Printer * out)93 void PrivateGenerator::PrintAllComments(StringVector comments,
94                                         grpc_generator::Printer* out) {
95   if (comments.empty()) {
96     // Python requires code structures like class and def to have
97     // a body, even if it is just "pass" or a docstring.  We need
98     // to ensure not to generate empty bodies. We could do something
99     // smarter and more sophisticated, but at the moment, if there is
100     // no docstring to print, we simply emit "pass" to ensure validity
101     // of the generated code.
102     out->Print(
103         "\"\"\"Missing associated documentation comment in .proto "
104         "file.\"\"\"\n");
105     return;
106   }
107   out->Print("\"\"\"");
108   for (StringVector::iterator it = comments.begin(); it != comments.end();
109        ++it) {
110     size_t start_pos = it->find_first_not_of(' ');
111     if (start_pos != std::string::npos) {
112       out->PrintRaw(it->c_str() + start_pos);
113     }
114     out->Print("\n");
115   }
116   out->Print("\"\"\"\n");
117 }
118 
PrintBetaServicer(const grpc_generator::Service * service,grpc_generator::Printer * out)119 bool PrivateGenerator::PrintBetaServicer(const grpc_generator::Service* service,
120                                          grpc_generator::Printer* out) {
121   StringMap service_dict;
122   service_dict["Service"] = service->name();
123   out->Print("\n\n");
124   out->Print(service_dict, "class Beta$Service$Servicer(object):\n");
125   {
126     IndentScope raii_class_indent(out);
127     out->Print(
128         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
129         "\nIt is recommended to use the GA API (classes and functions in this\n"
130         "file not marked beta) for all further purposes. This class was "
131         "generated\n"
132         "only to ease transition from grpcio<0.15.0 to "
133         "grpcio>=0.15.0.\"\"\"\n");
134     StringVector service_comments = service->GetAllComments();
135     PrintAllComments(service_comments, out);
136     for (int i = 0; i < service->method_count(); ++i) {
137       auto method = service->method(i);
138       std::string arg_name =
139           method->ClientStreaming() ? "request_iterator" : "request";
140       StringMap method_dict;
141       method_dict["Method"] = method->name();
142       method_dict["ArgName"] = arg_name;
143       out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n");
144       {
145         IndentScope raii_method_indent(out);
146         StringVector method_comments = method->GetAllComments();
147         PrintAllComments(method_comments, out);
148         out->Print("context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)\n");
149       }
150     }
151   }
152   return true;
153 }
154 
PrintBetaStub(const grpc_generator::Service * service,grpc_generator::Printer * out)155 bool PrivateGenerator::PrintBetaStub(const grpc_generator::Service* service,
156                                      grpc_generator::Printer* out) {
157   StringMap service_dict;
158   service_dict["Service"] = service->name();
159   out->Print("\n\n");
160   out->Print(service_dict, "class Beta$Service$Stub(object):\n");
161   {
162     IndentScope raii_class_indent(out);
163     out->Print(
164         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
165         "\nIt is recommended to use the GA API (classes and functions in this\n"
166         "file not marked beta) for all further purposes. This class was "
167         "generated\n"
168         "only to ease transition from grpcio<0.15.0 to "
169         "grpcio>=0.15.0.\"\"\"\n");
170     StringVector service_comments = service->GetAllComments();
171     PrintAllComments(service_comments, out);
172     for (int i = 0; i < service->method_count(); ++i) {
173       auto method = service->method(i);
174       std::string arg_name =
175           method->ClientStreaming() ? "request_iterator" : "request";
176       StringMap method_dict;
177       method_dict["Method"] = method->name();
178       method_dict["ArgName"] = arg_name;
179       out->Print(method_dict,
180                  "def $Method$(self, $ArgName$, timeout, metadata=None, "
181                  "with_call=False, protocol_options=None):\n");
182       {
183         IndentScope raii_method_indent(out);
184         StringVector method_comments = method->GetAllComments();
185         PrintAllComments(method_comments, out);
186         out->Print("raise NotImplementedError()\n");
187       }
188       if (!method->ServerStreaming()) {
189         out->Print(method_dict, "$Method$.future = None\n");
190       }
191     }
192   }
193   return true;
194 }
195 
PrintBetaServerFactory(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)196 bool PrivateGenerator::PrintBetaServerFactory(
197     const std::string& package_qualified_service_name,
198     const grpc_generator::Service* service, grpc_generator::Printer* out) {
199   StringMap service_dict;
200   service_dict["Service"] = service->name();
201   out->Print("\n\n");
202   out->Print(service_dict,
203              "def beta_create_$Service$_server(servicer, pool=None, "
204              "pool_size=None, default_timeout=None, maximum_timeout=None):\n");
205   {
206     IndentScope raii_create_server_indent(out);
207     out->Print(
208         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
209         "\nIt is recommended to use the GA API (classes and functions in this\n"
210         "file not marked beta) for all further purposes. This function was\n"
211         "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"
212         "\"\"\"\n");
213     StringMap method_implementation_constructors;
214     StringMap input_message_modules_and_classes;
215     StringMap output_message_modules_and_classes;
216     for (int i = 0; i < service->method_count(); ++i) {
217       auto method = service->method(i);
218       const std::string method_implementation_constructor =
219           std::string(method->ClientStreaming() ? "stream_" : "unary_") +
220           std::string(method->ServerStreaming() ? "stream_" : "unary_") +
221           "inline";
222       std::string input_message_module_and_class;
223       if (!method->get_module_and_message_path_input(
224               &input_message_module_and_class, generator_file_name,
225               generate_in_pb2_grpc, config.import_prefix,
226               config.prefixes_to_filter)) {
227         return false;
228       }
229       std::string output_message_module_and_class;
230       if (!method->get_module_and_message_path_output(
231               &output_message_module_and_class, generator_file_name,
232               generate_in_pb2_grpc, config.import_prefix,
233               config.prefixes_to_filter)) {
234         return false;
235       }
236       method_implementation_constructors.insert(
237           make_pair(method->name(), method_implementation_constructor));
238       input_message_modules_and_classes.insert(
239           make_pair(method->name(), input_message_module_and_class));
240       output_message_modules_and_classes.insert(
241           make_pair(method->name(), output_message_module_and_class));
242     }
243     StringMap method_dict;
244     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
245     out->Print("request_deserializers = {\n");
246     for (StringMap::iterator name_and_input_module_class_pair =
247              input_message_modules_and_classes.begin();
248          name_and_input_module_class_pair !=
249          input_message_modules_and_classes.end();
250          name_and_input_module_class_pair++) {
251       method_dict["MethodName"] = name_and_input_module_class_pair->first;
252       method_dict["InputTypeModuleAndClass"] =
253           name_and_input_module_class_pair->second;
254       IndentScope raii_indent(out);
255       out->Print(method_dict,
256                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
257                  "$InputTypeModuleAndClass$.FromString,\n");
258     }
259     out->Print("}\n");
260     out->Print("response_serializers = {\n");
261     for (StringMap::iterator name_and_output_module_class_pair =
262              output_message_modules_and_classes.begin();
263          name_and_output_module_class_pair !=
264          output_message_modules_and_classes.end();
265          name_and_output_module_class_pair++) {
266       method_dict["MethodName"] = name_and_output_module_class_pair->first;
267       method_dict["OutputTypeModuleAndClass"] =
268           name_and_output_module_class_pair->second;
269       IndentScope raii_indent(out);
270       out->Print(method_dict,
271                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
272                  "$OutputTypeModuleAndClass$.SerializeToString,\n");
273     }
274     out->Print("}\n");
275     out->Print("method_implementations = {\n");
276     for (StringMap::iterator name_and_implementation_constructor =
277              method_implementation_constructors.begin();
278          name_and_implementation_constructor !=
279          method_implementation_constructors.end();
280          name_and_implementation_constructor++) {
281       method_dict["Method"] = name_and_implementation_constructor->first;
282       method_dict["Constructor"] = name_and_implementation_constructor->second;
283       IndentScope raii_descriptions_indent(out);
284       const std::string method_name =
285           name_and_implementation_constructor->first;
286       out->Print(method_dict,
287                  "(\'$PackageQualifiedServiceName$\', \'$Method$\'): "
288                  "face_utilities.$Constructor$(servicer.$Method$),\n");
289     }
290     out->Print("}\n");
291     out->Print(
292         "server_options = beta_implementations.server_options("
293         "request_deserializers=request_deserializers, "
294         "response_serializers=response_serializers, "
295         "thread_pool=pool, thread_pool_size=pool_size, "
296         "default_timeout=default_timeout, "
297         "maximum_timeout=maximum_timeout)\n");
298     out->Print(
299         "return beta_implementations.server(method_implementations, "
300         "options=server_options)\n");
301   }
302   return true;
303 }
304 
PrintBetaStubFactory(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)305 bool PrivateGenerator::PrintBetaStubFactory(
306     const std::string& package_qualified_service_name,
307     const grpc_generator::Service* service, grpc_generator::Printer* out) {
308   StringMap dict;
309   dict["Service"] = service->name();
310   out->Print("\n\n");
311   out->Print(dict,
312              "def beta_create_$Service$_stub(channel, host=None,"
313              " metadata_transformer=None, pool=None, pool_size=None):\n");
314   {
315     IndentScope raii_create_server_indent(out);
316     out->Print(
317         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
318         "\nIt is recommended to use the GA API (classes and functions in this\n"
319         "file not marked beta) for all further purposes. This function was\n"
320         "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"
321         "\"\"\"\n");
322     StringMap method_cardinalities;
323     StringMap input_message_modules_and_classes;
324     StringMap output_message_modules_and_classes;
325     for (int i = 0; i < service->method_count(); ++i) {
326       auto method = service->method(i);
327       const std::string method_cardinality =
328           std::string(method->ClientStreaming() ? "STREAM" : "UNARY") + "_" +
329           std::string(method->ServerStreaming() ? "STREAM" : "UNARY");
330       std::string input_message_module_and_class;
331       if (!method->get_module_and_message_path_input(
332               &input_message_module_and_class, generator_file_name,
333               generate_in_pb2_grpc, config.import_prefix,
334               config.prefixes_to_filter)) {
335         return false;
336       }
337       std::string output_message_module_and_class;
338       if (!method->get_module_and_message_path_output(
339               &output_message_module_and_class, generator_file_name,
340               generate_in_pb2_grpc, config.import_prefix,
341               config.prefixes_to_filter)) {
342         return false;
343       }
344       method_cardinalities.insert(
345           make_pair(method->name(), method_cardinality));
346       input_message_modules_and_classes.insert(
347           make_pair(method->name(), input_message_module_and_class));
348       output_message_modules_and_classes.insert(
349           make_pair(method->name(), output_message_module_and_class));
350     }
351     StringMap method_dict;
352     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
353     out->Print("request_serializers = {\n");
354     for (StringMap::iterator name_and_input_module_class_pair =
355              input_message_modules_and_classes.begin();
356          name_and_input_module_class_pair !=
357          input_message_modules_and_classes.end();
358          name_and_input_module_class_pair++) {
359       method_dict["MethodName"] = name_and_input_module_class_pair->first;
360       method_dict["InputTypeModuleAndClass"] =
361           name_and_input_module_class_pair->second;
362       IndentScope raii_indent(out);
363       out->Print(method_dict,
364                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
365                  "$InputTypeModuleAndClass$.SerializeToString,\n");
366     }
367     out->Print("}\n");
368     out->Print("response_deserializers = {\n");
369     for (StringMap::iterator name_and_output_module_class_pair =
370              output_message_modules_and_classes.begin();
371          name_and_output_module_class_pair !=
372          output_message_modules_and_classes.end();
373          name_and_output_module_class_pair++) {
374       method_dict["MethodName"] = name_and_output_module_class_pair->first;
375       method_dict["OutputTypeModuleAndClass"] =
376           name_and_output_module_class_pair->second;
377       IndentScope raii_indent(out);
378       out->Print(method_dict,
379                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
380                  "$OutputTypeModuleAndClass$.FromString,\n");
381     }
382     out->Print("}\n");
383     out->Print("cardinalities = {\n");
384     for (StringMap::iterator name_and_cardinality =
385              method_cardinalities.begin();
386          name_and_cardinality != method_cardinalities.end();
387          name_and_cardinality++) {
388       method_dict["Method"] = name_and_cardinality->first;
389       method_dict["Cardinality"] = name_and_cardinality->second;
390       IndentScope raii_descriptions_indent(out);
391       out->Print(method_dict,
392                  "\'$Method$\': cardinality.Cardinality.$Cardinality$,\n");
393     }
394     out->Print("}\n");
395     out->Print(
396         "stub_options = beta_implementations.stub_options("
397         "host=host, metadata_transformer=metadata_transformer, "
398         "request_serializers=request_serializers, "
399         "response_deserializers=response_deserializers, "
400         "thread_pool=pool, thread_pool_size=pool_size)\n");
401     out->Print(method_dict,
402                "return beta_implementations.dynamic_stub(channel, "
403                "\'$PackageQualifiedServiceName$\', "
404                "cardinalities, options=stub_options)\n");
405   }
406   return true;
407 }
408 
PrintStub(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)409 bool PrivateGenerator::PrintStub(
410     const std::string& package_qualified_service_name,
411     const grpc_generator::Service* service, grpc_generator::Printer* out) {
412   StringMap dict;
413   dict["Service"] = service->name();
414   out->Print("\n\n");
415   out->Print(dict, "class $Service$Stub(object):\n");
416   {
417     IndentScope raii_class_indent(out);
418     StringVector service_comments = service->GetAllComments();
419     PrintAllComments(service_comments, out);
420     out->Print("\n");
421     out->Print("def __init__(self, channel):\n");
422     {
423       IndentScope raii_init_indent(out);
424       out->Print("\"\"\"Constructor.\n");
425       out->Print("\n");
426       out->Print("Args:\n");
427       {
428         IndentScope raii_args_indent(out);
429         out->Print("channel: A grpc.Channel.\n");
430       }
431       out->Print("\"\"\"\n");
432       for (int i = 0; i < service->method_count(); ++i) {
433         auto method = service->method(i);
434         std::string multi_callable_constructor =
435             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
436             std::string(method->ServerStreaming() ? "stream" : "unary");
437         std::string request_module_and_class;
438         if (!method->get_module_and_message_path_input(
439                 &request_module_and_class, generator_file_name,
440                 generate_in_pb2_grpc, config.import_prefix,
441                 config.prefixes_to_filter)) {
442           return false;
443         }
444         std::string response_module_and_class;
445         if (!method->get_module_and_message_path_output(
446                 &response_module_and_class, generator_file_name,
447                 generate_in_pb2_grpc, config.import_prefix,
448                 config.prefixes_to_filter)) {
449           return false;
450         }
451         StringMap method_dict;
452         method_dict["Method"] = method->name();
453         method_dict["MultiCallableConstructor"] = multi_callable_constructor;
454         out->Print(method_dict,
455                    "self.$Method$ = channel.$MultiCallableConstructor$(\n");
456         {
457           method_dict["PackageQualifiedService"] =
458               package_qualified_service_name;
459           method_dict["RequestModuleAndClass"] = request_module_and_class;
460           method_dict["ResponseModuleAndClass"] = response_module_and_class;
461           IndentScope raii_first_attribute_indent(out);
462           IndentScope raii_second_attribute_indent(out);
463           out->Print(method_dict, "'/$PackageQualifiedService$/$Method$',\n");
464           out->Print(method_dict,
465                      "request_serializer=$RequestModuleAndClass$."
466                      "SerializeToString,\n");
467           out->Print(
468               method_dict,
469               "response_deserializer=$ResponseModuleAndClass$.FromString,\n");
470           out->Print(")\n");
471         }
472       }
473     }
474   }
475   return true;
476 }
477 
PrintServicer(const grpc_generator::Service * service,grpc_generator::Printer * out)478 bool PrivateGenerator::PrintServicer(const grpc_generator::Service* service,
479                                      grpc_generator::Printer* out) {
480   StringMap service_dict;
481   service_dict["Service"] = service->name();
482   out->Print("\n\n");
483   out->Print(service_dict, "class $Service$Servicer(object):\n");
484   {
485     IndentScope raii_class_indent(out);
486     StringVector service_comments = service->GetAllComments();
487     PrintAllComments(service_comments, out);
488     for (int i = 0; i < service->method_count(); ++i) {
489       auto method = service->method(i);
490       std::string arg_name =
491           method->ClientStreaming() ? "request_iterator" : "request";
492       StringMap method_dict;
493       method_dict["Method"] = method->name();
494       method_dict["ArgName"] = arg_name;
495       out->Print("\n");
496       out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n");
497       {
498         IndentScope raii_method_indent(out);
499         StringVector method_comments = method->GetAllComments();
500         PrintAllComments(method_comments, out);
501         out->Print("context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n");
502         out->Print("context.set_details('Method not implemented!')\n");
503         out->Print("raise NotImplementedError('Method not implemented!')\n");
504       }
505     }
506   }
507   return true;
508 }
509 
PrintAddServicerToServer(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)510 bool PrivateGenerator::PrintAddServicerToServer(
511     const std::string& package_qualified_service_name,
512     const grpc_generator::Service* service, grpc_generator::Printer* out) {
513   StringMap service_dict;
514   service_dict["Service"] = service->name();
515   out->Print("\n\n");
516   out->Print(service_dict,
517              "def add_$Service$Servicer_to_server(servicer, server):\n");
518   {
519     IndentScope raii_class_indent(out);
520     out->Print("rpc_method_handlers = {\n");
521     {
522       IndentScope raii_dict_first_indent(out);
523       IndentScope raii_dict_second_indent(out);
524       for (int i = 0; i < service->method_count(); ++i) {
525         auto method = service->method(i);
526         std::string method_handler_constructor =
527             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
528             std::string(method->ServerStreaming() ? "stream" : "unary") +
529             "_rpc_method_handler";
530         std::string request_module_and_class;
531         if (!method->get_module_and_message_path_input(
532                 &request_module_and_class, generator_file_name,
533                 generate_in_pb2_grpc, config.import_prefix,
534                 config.prefixes_to_filter)) {
535           return false;
536         }
537         std::string response_module_and_class;
538         if (!method->get_module_and_message_path_output(
539                 &response_module_and_class, generator_file_name,
540                 generate_in_pb2_grpc, config.import_prefix,
541                 config.prefixes_to_filter)) {
542           return false;
543         }
544         StringMap method_dict;
545         method_dict["Method"] = method->name();
546         method_dict["MethodHandlerConstructor"] = method_handler_constructor;
547         method_dict["RequestModuleAndClass"] = request_module_and_class;
548         method_dict["ResponseModuleAndClass"] = response_module_and_class;
549         out->Print(method_dict,
550                    "'$Method$': grpc.$MethodHandlerConstructor$(\n");
551         {
552           IndentScope raii_call_first_indent(out);
553           IndentScope raii_call_second_indent(out);
554           out->Print(method_dict, "servicer.$Method$,\n");
555           out->Print(
556               method_dict,
557               "request_deserializer=$RequestModuleAndClass$.FromString,\n");
558           out->Print(
559               method_dict,
560               "response_serializer=$ResponseModuleAndClass$.SerializeToString,"
561               "\n");
562         }
563         out->Print("),\n");
564       }
565     }
566     StringMap method_dict;
567     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
568     out->Print("}\n");
569     out->Print("generic_handler = grpc.method_handlers_generic_handler(\n");
570     {
571       IndentScope raii_call_first_indent(out);
572       IndentScope raii_call_second_indent(out);
573       out->Print(method_dict,
574                  "'$PackageQualifiedServiceName$', rpc_method_handlers)\n");
575     }
576     out->Print("server.add_generic_rpc_handlers((generic_handler,))\n");
577   }
578   return true;
579 }
580 
581 /* Prints out a service class used as a container for static methods pertaining
582  * to a class. This class has the exact name of service written in the ".proto"
583  * file, with no suffixes. Since this class merely acts as a namespace, it
584  * should never be instantiated.
585  */
PrintServiceClass(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)586 bool PrivateGenerator::PrintServiceClass(
587     const std::string& package_qualified_service_name,
588     const grpc_generator::Service* service, grpc_generator::Printer* out) {
589   StringMap dict;
590   dict["Service"] = service->name();
591   out->Print("\n\n");
592   out->Print(" # This class is part of an EXPERIMENTAL API.\n");
593   out->Print(dict, "class $Service$(object):\n");
594   {
595     IndentScope class_indent(out);
596     StringVector service_comments = service->GetAllComments();
597     PrintAllComments(service_comments, out);
598     for (int i = 0; i < service->method_count(); ++i) {
599       const auto& method = service->method(i);
600       std::string request_module_and_class;
601       if (!method->get_module_and_message_path_input(
602               &request_module_and_class, generator_file_name,
603               generate_in_pb2_grpc, config.import_prefix,
604               config.prefixes_to_filter)) {
605         return false;
606       }
607       std::string response_module_and_class;
608       if (!method->get_module_and_message_path_output(
609               &response_module_and_class, generator_file_name,
610               generate_in_pb2_grpc, config.import_prefix,
611               config.prefixes_to_filter)) {
612         return false;
613       }
614       out->Print("\n");
615       StringMap method_dict;
616       method_dict["Method"] = method->name();
617       out->Print("@staticmethod\n");
618       out->Print(method_dict, "def $Method$(");
619       std::string request_parameter(
620           method->ClientStreaming() ? "request_iterator" : "request");
621       StringMap args_dict;
622       args_dict["RequestParameter"] = request_parameter;
623       {
624         IndentScope args_indent(out);
625         IndentScope args_double_indent(out);
626         out->Print(args_dict, "$RequestParameter$,\n");
627         out->Print("target,\n");
628         out->Print("options=(),\n");
629         out->Print("channel_credentials=None,\n");
630         out->Print("call_credentials=None,\n");
631         out->Print("insecure=False,\n");
632         out->Print("compression=None,\n");
633         out->Print("wait_for_ready=None,\n");
634         out->Print("timeout=None,\n");
635         out->Print("metadata=None):\n");
636       }
637       {
638         IndentScope method_indent(out);
639         std::string arity_method_name =
640             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
641             std::string(method->ServerStreaming() ? "stream" : "unary");
642         args_dict["ArityMethodName"] = arity_method_name;
643         args_dict["PackageQualifiedService"] = package_qualified_service_name;
644         args_dict["Method"] = method->name();
645         out->Print(args_dict,
646                    "return "
647                    "grpc.experimental.$ArityMethodName$($RequestParameter$, "
648                    "target, '/$PackageQualifiedService$/$Method$',\n");
649         {
650           IndentScope continuation_indent(out);
651           StringMap serializer_dict;
652           serializer_dict["RequestModuleAndClass"] = request_module_and_class;
653           serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
654           out->Print(serializer_dict,
655                      "$RequestModuleAndClass$.SerializeToString,\n");
656           out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
657           out->Print("options, channel_credentials,\n");
658           out->Print(
659               "insecure, call_credentials, compression, wait_for_ready, "
660               "timeout, metadata)\n");
661         }
662       }
663     }
664   }
665   // TODO(rbellevi): Add methods pertinent to the server side as well.
666   return true;
667 }
668 
PrintBetaPreamble(grpc_generator::Printer * out)669 bool PrivateGenerator::PrintBetaPreamble(grpc_generator::Printer* out) {
670   StringMap var;
671   var["Package"] = config.beta_package_root;
672   out->Print(var,
673              "from $Package$ import implementations as beta_implementations\n");
674   out->Print(var, "from $Package$ import interfaces as beta_interfaces\n");
675   out->Print("from grpc.framework.common import cardinality\n");
676   out->Print(
677       "from grpc.framework.interfaces.face import utilities as "
678       "face_utilities\n");
679   return true;
680 }
681 
PrintPreamble(grpc_generator::Printer * out)682 bool PrivateGenerator::PrintPreamble(grpc_generator::Printer* out) {
683   StringMap var;
684   var["Package"] = config.grpc_package_root;
685   out->Print(var, "import $Package$\n");
686   if (generate_in_pb2_grpc) {
687     out->Print("\n");
688     StringPairSet imports_set;
689     for (int i = 0; i < file->service_count(); ++i) {
690       auto service = file->service(i);
691       for (int j = 0; j < service->method_count(); ++j) {
692         auto method = service.get()->method(j);
693 
694         std::string input_type_file_name = method->get_input_type_name();
695         std::string input_module_name =
696             ModuleName(input_type_file_name, config.import_prefix,
697                        config.prefixes_to_filter);
698         std::string input_module_alias =
699             ModuleAlias(input_type_file_name, config.import_prefix,
700                         config.prefixes_to_filter);
701         imports_set.insert(
702             std::make_tuple(input_module_name, input_module_alias));
703 
704         std::string output_type_file_name = method->get_output_type_name();
705         std::string output_module_name =
706             ModuleName(output_type_file_name, config.import_prefix,
707                        config.prefixes_to_filter);
708         std::string output_module_alias =
709             ModuleAlias(output_type_file_name, config.import_prefix,
710                         config.prefixes_to_filter);
711         imports_set.insert(
712             std::make_tuple(output_module_name, output_module_alias));
713       }
714     }
715 
716     for (StringPairSet::iterator it = imports_set.begin();
717          it != imports_set.end(); ++it) {
718       auto module_name = std::get<0>(*it);
719       var["ModuleAlias"] = std::get<1>(*it);
720       const size_t last_dot_pos = module_name.rfind('.');
721       if (last_dot_pos == std::string::npos) {
722         var["ImportStatement"] = "import " + module_name;
723       } else {
724         var["ImportStatement"] = "from " + module_name.substr(0, last_dot_pos) +
725                                  " import " +
726                                  module_name.substr(last_dot_pos + 1);
727       }
728       out->Print(var, "$ImportStatement$ as $ModuleAlias$\n");
729     }
730   }
731   return true;
732 }
733 
PrintGAServices(grpc_generator::Printer * out)734 bool PrivateGenerator::PrintGAServices(grpc_generator::Printer* out) {
735   std::string package = file->package();
736   if (!package.empty()) {
737     package = package.append(".");
738   }
739   for (int i = 0; i < file->service_count(); ++i) {
740     auto service = file->service(i);
741     std::string package_qualified_service_name = package + service->name();
742     if (!(PrintStub(package_qualified_service_name, service.get(), out) &&
743           PrintServicer(service.get(), out) &&
744           PrintAddServicerToServer(package_qualified_service_name,
745                                    service.get(), out) &&
746           PrintServiceClass(package_qualified_service_name, service.get(),
747                             out))) {
748       return false;
749     }
750   }
751   return true;
752 }
753 
PrintBetaServices(grpc_generator::Printer * out)754 bool PrivateGenerator::PrintBetaServices(grpc_generator::Printer* out) {
755   std::string package = file->package();
756   if (!package.empty()) {
757     package = package.append(".");
758   }
759   for (int i = 0; i < file->service_count(); ++i) {
760     auto service = file->service(i);
761     std::string package_qualified_service_name = package + service->name();
762     if (!(PrintBetaServicer(service.get(), out) &&
763           PrintBetaStub(service.get(), out) &&
764           PrintBetaServerFactory(package_qualified_service_name, service.get(),
765                                  out) &&
766           PrintBetaStubFactory(package_qualified_service_name, service.get(),
767                                out))) {
768       return false;
769     }
770   }
771   return true;
772 }
773 
GetGrpcServices()774 pair<bool, std::string> PrivateGenerator::GetGrpcServices() {
775   std::string output;
776   {
777     // Scope the output stream so it closes and finalizes output to the string.
778     auto out = file->CreatePrinter(&output);
779     if (generate_in_pb2_grpc) {
780       out->Print(
781           "# Generated by the gRPC Python protocol compiler plugin. "
782           "DO NOT EDIT!\n\"\"\""
783           "Client and server classes corresponding to protobuf-defined "
784           "services.\"\"\"\n");
785       if (!PrintPreamble(out.get())) {
786         return make_pair(false, "");
787       }
788       if (!PrintGAServices(out.get())) {
789         return make_pair(false, "");
790       }
791     } else {
792       out->Print("try:\n");
793       {
794         IndentScope raii_dict_try_indent(out.get());
795         out->Print(
796             "# THESE ELEMENTS WILL BE DEPRECATED.\n"
797             "# Please use the generated *_pb2_grpc.py files instead.\n");
798         if (!PrintPreamble(out.get())) {
799           return make_pair(false, "");
800         }
801         if (!PrintBetaPreamble(out.get())) {
802           return make_pair(false, "");
803         }
804         if (!PrintGAServices(out.get())) {
805           return make_pair(false, "");
806         }
807         if (!PrintBetaServices(out.get())) {
808           return make_pair(false, "");
809         }
810       }
811       out->Print("except ImportError:\n");
812       {
813         IndentScope raii_dict_except_indent(out.get());
814         out->Print("pass");
815       }
816     }
817   }
818   return make_pair(true, std::move(output));
819 }
820 
821 }  // namespace
822 
GeneratorConfiguration()823 GeneratorConfiguration::GeneratorConfiguration()
824     : grpc_package_root("grpc"),
825       beta_package_root("grpc.beta"),
826       import_prefix("") {}
827 
PythonGrpcGenerator(const GeneratorConfiguration & config)828 PythonGrpcGenerator::PythonGrpcGenerator(const GeneratorConfiguration& config)
829     : config_(config) {}
830 
~PythonGrpcGenerator()831 PythonGrpcGenerator::~PythonGrpcGenerator() {}
832 
GenerateGrpc(GeneratorContext * context,PrivateGenerator & generator,std::string file_name,bool generate_in_pb2_grpc)833 static bool GenerateGrpc(GeneratorContext* context, PrivateGenerator& generator,
834                          std::string file_name, bool generate_in_pb2_grpc) {
835   bool success;
836   std::unique_ptr<ZeroCopyOutputStream> output;
837   std::unique_ptr<CodedOutputStream> coded_output;
838   std::string grpc_code;
839 
840   if (generate_in_pb2_grpc) {
841     output.reset(context->Open(file_name));
842     generator.generate_in_pb2_grpc = true;
843   } else {
844     output.reset(context->OpenForInsert(file_name, "module_scope"));
845     generator.generate_in_pb2_grpc = false;
846   }
847 
848   coded_output.reset(new CodedOutputStream(output.get()));
849   tie(success, grpc_code) = generator.GetGrpcServices();
850 
851   if (success) {
852     coded_output->WriteRaw(grpc_code.data(), grpc_code.size());
853     return true;
854   } else {
855     return false;
856   }
857 }
858 
ParseParameters(const std::string & parameter,std::string * grpc_version,std::vector<std::string> * strip_prefixes,std::string * error)859 static bool ParseParameters(const std::string& parameter,
860                             std::string* grpc_version,
861                             std::vector<std::string>* strip_prefixes,
862                             std::string* error) {
863   std::vector<std::string> comma_delimited_parameters;
864   grpc_python_generator::Split(parameter, ',', &comma_delimited_parameters);
865   if (comma_delimited_parameters.size() == 1 &&
866       comma_delimited_parameters[0].empty()) {
867     *grpc_version = "grpc_2_0";
868   } else if (comma_delimited_parameters.size() == 1) {
869     *grpc_version = comma_delimited_parameters[0];
870   } else if (comma_delimited_parameters.size() == 2) {
871     *grpc_version = comma_delimited_parameters[0];
872     std::copy(comma_delimited_parameters.begin() + 1,
873               comma_delimited_parameters.end(),
874               std::back_inserter(*strip_prefixes));
875   } else {
876     *error = "--grpc_python_out received too many comma-delimited parameters.";
877     return false;
878   }
879   return true;
880 }
881 
GetSupportedFeatures() const882 uint64_t PythonGrpcGenerator::GetSupportedFeatures() const {
883   return FEATURE_PROTO3_OPTIONAL;
884 }
885 
Generate(const FileDescriptor * file,const std::string & parameter,GeneratorContext * context,std::string * error) const886 bool PythonGrpcGenerator::Generate(const FileDescriptor* file,
887                                    const std::string& parameter,
888                                    GeneratorContext* context,
889                                    std::string* error) const {
890   // Get output file name.
891   std::string pb2_file_name;
892   std::string pb2_grpc_file_name;
893   static const int proto_suffix_length = strlen(".proto");
894   if (file->name().size() > static_cast<size_t>(proto_suffix_length) &&
895       file->name().find_last_of(".proto") == file->name().size() - 1) {
896     std::string base =
897         file->name().substr(0, file->name().size() - proto_suffix_length);
898     std::replace(base.begin(), base.end(), '-', '_');
899     pb2_file_name = base + "_pb2.py";
900     pb2_grpc_file_name = base + "_pb2_grpc.py";
901   } else {
902     *error = "Invalid proto file name. Proto file must end with .proto";
903     return false;
904   }
905   generator_file_name = file->name();
906 
907   ProtoBufFile pbfile(file);
908   std::string grpc_version;
909   GeneratorConfiguration extended_config(config_);
910   bool success = ParseParameters(parameter, &grpc_version,
911                                  &(extended_config.prefixes_to_filter), error);
912   PrivateGenerator generator(extended_config, &pbfile);
913   if (!success) return false;
914   if (grpc_version == "grpc_2_0") {
915     return GenerateGrpc(context, generator, pb2_grpc_file_name, true);
916   } else if (grpc_version == "grpc_1_0") {
917     return GenerateGrpc(context, generator, pb2_grpc_file_name, true) &&
918            GenerateGrpc(context, generator, pb2_file_name, false);
919   } else {
920     *error = "Invalid grpc version '" + grpc_version + "'.";
921     return false;
922   }
923 }
924 
925 }  // namespace grpc_python_generator
926