• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <algorithm>
20 #include <cassert>
21 #include <cctype>
22 #include <cstring>
23 #include <fstream>
24 #include <iostream>
25 #include <map>
26 #include <memory>
27 #include <ostream>
28 #include <set>
29 #include <sstream>
30 #include <tuple>
31 #include <vector>
32 
33 #include "src/compiler/config.h"
34 #include "src/compiler/generator_helpers.h"
35 #include "src/compiler/protobuf_plugin.h"
36 #include "src/compiler/python_generator.h"
37 #include "src/compiler/python_generator_helpers.h"
38 #include "src/compiler/python_private_generator.h"
39 
40 using grpc::protobuf::FileDescriptor;
41 using grpc::protobuf::compiler::GeneratorContext;
42 using grpc::protobuf::io::CodedOutputStream;
43 using grpc::protobuf::io::ZeroCopyOutputStream;
44 using std::make_pair;
45 using std::map;
46 using std::pair;
47 using std::replace;
48 using std::set;
49 using std::tuple;
50 using std::vector;
51 
52 namespace grpc_python_generator {
53 
54 std::string generator_file_name;
55 
56 namespace {
57 
58 typedef map<std::string, std::string> StringMap;
59 typedef vector<std::string> StringVector;
60 typedef tuple<std::string, std::string> StringPair;
61 typedef set<StringPair> StringPairSet;
62 
63 // Provides RAII indentation handling. Use as:
64 // {
65 //   IndentScope raii_my_indent_var_name_here(my_py_printer);
66 //   // constructor indented my_py_printer
67 //   ...
68 //   // destructor called at end of scope, un-indenting my_py_printer
69 // }
70 class IndentScope {
71  public:
IndentScope(grpc_generator::Printer * printer)72   explicit IndentScope(grpc_generator::Printer* printer) : printer_(printer) {
73     // NOTE(rbellevi): Two-space tabs are hard-coded in the protocol compiler.
74     // Doubling our indents and outdents guarantees compliance with PEP8.
75     printer_->Indent();
76     printer_->Indent();
77   }
78 
~IndentScope()79   ~IndentScope() {
80     printer_->Outdent();
81     printer_->Outdent();
82   }
83 
84  private:
85   grpc_generator::Printer* printer_;
86 };
87 
PrivateGenerator(const GeneratorConfiguration & config,const grpc_generator::File * file)88 PrivateGenerator::PrivateGenerator(const GeneratorConfiguration& config,
89                                    const grpc_generator::File* file)
90     : config(config), file(file) {}
91 
PrintAllComments(StringVector comments,grpc_generator::Printer * out)92 void PrivateGenerator::PrintAllComments(StringVector comments,
93                                         grpc_generator::Printer* out) {
94   if (comments.empty()) {
95     // Python requires code structures like class and def to have
96     // a body, even if it is just "pass" or a docstring.  We need
97     // to ensure not to generate empty bodies. We could do something
98     // smarter and more sophisticated, but at the moment, if there is
99     // no docstring to print, we simply emit "pass" to ensure validity
100     // of the generated code.
101     out->Print(
102         "\"\"\"Missing associated documentation comment in .proto "
103         "file.\"\"\"\n");
104     return;
105   }
106   out->Print("\"\"\"");
107   for (StringVector::iterator it = comments.begin(); it != comments.end();
108        ++it) {
109     size_t start_pos = it->find_first_not_of(' ');
110     if (start_pos != std::string::npos) {
111       out->PrintRaw(it->c_str() + start_pos);
112     }
113     out->Print("\n");
114   }
115   out->Print("\"\"\"\n");
116 }
117 
PrintBetaServicer(const grpc_generator::Service * service,grpc_generator::Printer * out)118 bool PrivateGenerator::PrintBetaServicer(const grpc_generator::Service* service,
119                                          grpc_generator::Printer* out) {
120   StringMap service_dict;
121   service_dict["Service"] = service->name();
122   out->Print("\n\n");
123   out->Print(service_dict, "class Beta$Service$Servicer(object):\n");
124   {
125     IndentScope raii_class_indent(out);
126     out->Print(
127         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
128         "\nIt is recommended to use the GA API (classes and functions in this\n"
129         "file not marked beta) for all further purposes. This class was "
130         "generated\n"
131         "only to ease transition from grpcio<0.15.0 to "
132         "grpcio>=0.15.0.\"\"\"\n");
133     StringVector service_comments = service->GetAllComments();
134     PrintAllComments(service_comments, out);
135     for (int i = 0; i < service->method_count(); ++i) {
136       auto method = service->method(i);
137       std::string arg_name =
138           method->ClientStreaming() ? "request_iterator" : "request";
139       StringMap method_dict;
140       method_dict["Method"] = method->name();
141       method_dict["ArgName"] = arg_name;
142       out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n");
143       {
144         IndentScope raii_method_indent(out);
145         StringVector method_comments = method->GetAllComments();
146         PrintAllComments(method_comments, out);
147         out->Print("context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)\n");
148       }
149     }
150   }
151   return true;
152 }
153 
PrintBetaStub(const grpc_generator::Service * service,grpc_generator::Printer * out)154 bool PrivateGenerator::PrintBetaStub(const grpc_generator::Service* service,
155                                      grpc_generator::Printer* out) {
156   StringMap service_dict;
157   service_dict["Service"] = service->name();
158   out->Print("\n\n");
159   out->Print(service_dict, "class Beta$Service$Stub(object):\n");
160   {
161     IndentScope raii_class_indent(out);
162     out->Print(
163         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
164         "\nIt is recommended to use the GA API (classes and functions in this\n"
165         "file not marked beta) for all further purposes. This class was "
166         "generated\n"
167         "only to ease transition from grpcio<0.15.0 to "
168         "grpcio>=0.15.0.\"\"\"\n");
169     StringVector service_comments = service->GetAllComments();
170     PrintAllComments(service_comments, out);
171     for (int i = 0; i < service->method_count(); ++i) {
172       auto method = service->method(i);
173       std::string arg_name =
174           method->ClientStreaming() ? "request_iterator" : "request";
175       StringMap method_dict;
176       method_dict["Method"] = method->name();
177       method_dict["ArgName"] = arg_name;
178       out->Print(method_dict,
179                  "def $Method$(self, $ArgName$, timeout, metadata=None, "
180                  "with_call=False, protocol_options=None):\n");
181       {
182         IndentScope raii_method_indent(out);
183         StringVector method_comments = method->GetAllComments();
184         PrintAllComments(method_comments, out);
185         out->Print("raise NotImplementedError()\n");
186       }
187       if (!method->ServerStreaming()) {
188         out->Print(method_dict, "$Method$.future = None\n");
189       }
190     }
191   }
192   return true;
193 }
194 
PrintBetaServerFactory(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)195 bool PrivateGenerator::PrintBetaServerFactory(
196     const std::string& package_qualified_service_name,
197     const grpc_generator::Service* service, grpc_generator::Printer* out) {
198   StringMap service_dict;
199   service_dict["Service"] = service->name();
200   out->Print("\n\n");
201   out->Print(service_dict,
202              "def beta_create_$Service$_server(servicer, pool=None, "
203              "pool_size=None, default_timeout=None, maximum_timeout=None):\n");
204   {
205     IndentScope raii_create_server_indent(out);
206     out->Print(
207         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
208         "\nIt is recommended to use the GA API (classes and functions in this\n"
209         "file not marked beta) for all further purposes. This function was\n"
210         "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"
211         "\"\"\"\n");
212     StringMap method_implementation_constructors;
213     StringMap input_message_modules_and_classes;
214     StringMap output_message_modules_and_classes;
215     for (int i = 0; i < service->method_count(); ++i) {
216       auto method = service->method(i);
217       const std::string method_implementation_constructor =
218           std::string(method->ClientStreaming() ? "stream_" : "unary_") +
219           std::string(method->ServerStreaming() ? "stream_" : "unary_") +
220           "inline";
221       std::string input_message_module_and_class;
222       if (!method->get_module_and_message_path_input(
223               &input_message_module_and_class, generator_file_name,
224               generate_in_pb2_grpc, config.import_prefix,
225               config.prefixes_to_filter)) {
226         return false;
227       }
228       std::string output_message_module_and_class;
229       if (!method->get_module_and_message_path_output(
230               &output_message_module_and_class, generator_file_name,
231               generate_in_pb2_grpc, config.import_prefix,
232               config.prefixes_to_filter)) {
233         return false;
234       }
235       method_implementation_constructors.insert(
236           make_pair(method->name(), method_implementation_constructor));
237       input_message_modules_and_classes.insert(
238           make_pair(method->name(), input_message_module_and_class));
239       output_message_modules_and_classes.insert(
240           make_pair(method->name(), output_message_module_and_class));
241     }
242     StringMap method_dict;
243     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
244     out->Print("request_deserializers = {\n");
245     for (StringMap::iterator name_and_input_module_class_pair =
246              input_message_modules_and_classes.begin();
247          name_and_input_module_class_pair !=
248          input_message_modules_and_classes.end();
249          name_and_input_module_class_pair++) {
250       method_dict["MethodName"] = name_and_input_module_class_pair->first;
251       method_dict["InputTypeModuleAndClass"] =
252           name_and_input_module_class_pair->second;
253       IndentScope raii_indent(out);
254       out->Print(method_dict,
255                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
256                  "$InputTypeModuleAndClass$.FromString,\n");
257     }
258     out->Print("}\n");
259     out->Print("response_serializers = {\n");
260     for (StringMap::iterator name_and_output_module_class_pair =
261              output_message_modules_and_classes.begin();
262          name_and_output_module_class_pair !=
263          output_message_modules_and_classes.end();
264          name_and_output_module_class_pair++) {
265       method_dict["MethodName"] = name_and_output_module_class_pair->first;
266       method_dict["OutputTypeModuleAndClass"] =
267           name_and_output_module_class_pair->second;
268       IndentScope raii_indent(out);
269       out->Print(method_dict,
270                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
271                  "$OutputTypeModuleAndClass$.SerializeToString,\n");
272     }
273     out->Print("}\n");
274     out->Print("method_implementations = {\n");
275     for (StringMap::iterator name_and_implementation_constructor =
276              method_implementation_constructors.begin();
277          name_and_implementation_constructor !=
278          method_implementation_constructors.end();
279          name_and_implementation_constructor++) {
280       method_dict["Method"] = name_and_implementation_constructor->first;
281       method_dict["Constructor"] = name_and_implementation_constructor->second;
282       IndentScope raii_descriptions_indent(out);
283       const std::string method_name =
284           name_and_implementation_constructor->first;
285       out->Print(method_dict,
286                  "(\'$PackageQualifiedServiceName$\', \'$Method$\'): "
287                  "face_utilities.$Constructor$(servicer.$Method$),\n");
288     }
289     out->Print("}\n");
290     out->Print(
291         "server_options = beta_implementations.server_options("
292         "request_deserializers=request_deserializers, "
293         "response_serializers=response_serializers, "
294         "thread_pool=pool, thread_pool_size=pool_size, "
295         "default_timeout=default_timeout, "
296         "maximum_timeout=maximum_timeout)\n");
297     out->Print(
298         "return beta_implementations.server(method_implementations, "
299         "options=server_options)\n");
300   }
301   return true;
302 }
303 
PrintBetaStubFactory(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)304 bool PrivateGenerator::PrintBetaStubFactory(
305     const std::string& package_qualified_service_name,
306     const grpc_generator::Service* service, grpc_generator::Printer* out) {
307   StringMap dict;
308   dict["Service"] = service->name();
309   out->Print("\n\n");
310   out->Print(dict,
311              "def beta_create_$Service$_stub(channel, host=None,"
312              " metadata_transformer=None, pool=None, pool_size=None):\n");
313   {
314     IndentScope raii_create_server_indent(out);
315     out->Print(
316         "\"\"\"The Beta API is deprecated for 0.15.0 and later.\n"
317         "\nIt is recommended to use the GA API (classes and functions in this\n"
318         "file not marked beta) for all further purposes. This function was\n"
319         "generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"
320         "\"\"\"\n");
321     StringMap method_cardinalities;
322     StringMap input_message_modules_and_classes;
323     StringMap output_message_modules_and_classes;
324     for (int i = 0; i < service->method_count(); ++i) {
325       auto method = service->method(i);
326       const std::string method_cardinality =
327           std::string(method->ClientStreaming() ? "STREAM" : "UNARY") + "_" +
328           std::string(method->ServerStreaming() ? "STREAM" : "UNARY");
329       std::string input_message_module_and_class;
330       if (!method->get_module_and_message_path_input(
331               &input_message_module_and_class, generator_file_name,
332               generate_in_pb2_grpc, config.import_prefix,
333               config.prefixes_to_filter)) {
334         return false;
335       }
336       std::string output_message_module_and_class;
337       if (!method->get_module_and_message_path_output(
338               &output_message_module_and_class, generator_file_name,
339               generate_in_pb2_grpc, config.import_prefix,
340               config.prefixes_to_filter)) {
341         return false;
342       }
343       method_cardinalities.insert(
344           make_pair(method->name(), method_cardinality));
345       input_message_modules_and_classes.insert(
346           make_pair(method->name(), input_message_module_and_class));
347       output_message_modules_and_classes.insert(
348           make_pair(method->name(), output_message_module_and_class));
349     }
350     StringMap method_dict;
351     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
352     out->Print("request_serializers = {\n");
353     for (StringMap::iterator name_and_input_module_class_pair =
354              input_message_modules_and_classes.begin();
355          name_and_input_module_class_pair !=
356          input_message_modules_and_classes.end();
357          name_and_input_module_class_pair++) {
358       method_dict["MethodName"] = name_and_input_module_class_pair->first;
359       method_dict["InputTypeModuleAndClass"] =
360           name_and_input_module_class_pair->second;
361       IndentScope raii_indent(out);
362       out->Print(method_dict,
363                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
364                  "$InputTypeModuleAndClass$.SerializeToString,\n");
365     }
366     out->Print("}\n");
367     out->Print("response_deserializers = {\n");
368     for (StringMap::iterator name_and_output_module_class_pair =
369              output_message_modules_and_classes.begin();
370          name_and_output_module_class_pair !=
371          output_message_modules_and_classes.end();
372          name_and_output_module_class_pair++) {
373       method_dict["MethodName"] = name_and_output_module_class_pair->first;
374       method_dict["OutputTypeModuleAndClass"] =
375           name_and_output_module_class_pair->second;
376       IndentScope raii_indent(out);
377       out->Print(method_dict,
378                  "(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): "
379                  "$OutputTypeModuleAndClass$.FromString,\n");
380     }
381     out->Print("}\n");
382     out->Print("cardinalities = {\n");
383     for (StringMap::iterator name_and_cardinality =
384              method_cardinalities.begin();
385          name_and_cardinality != method_cardinalities.end();
386          name_and_cardinality++) {
387       method_dict["Method"] = name_and_cardinality->first;
388       method_dict["Cardinality"] = name_and_cardinality->second;
389       IndentScope raii_descriptions_indent(out);
390       out->Print(method_dict,
391                  "\'$Method$\': cardinality.Cardinality.$Cardinality$,\n");
392     }
393     out->Print("}\n");
394     out->Print(
395         "stub_options = beta_implementations.stub_options("
396         "host=host, metadata_transformer=metadata_transformer, "
397         "request_serializers=request_serializers, "
398         "response_deserializers=response_deserializers, "
399         "thread_pool=pool, thread_pool_size=pool_size)\n");
400     out->Print(method_dict,
401                "return beta_implementations.dynamic_stub(channel, "
402                "\'$PackageQualifiedServiceName$\', "
403                "cardinalities, options=stub_options)\n");
404   }
405   return true;
406 }
407 
PrintStub(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)408 bool PrivateGenerator::PrintStub(
409     const std::string& package_qualified_service_name,
410     const grpc_generator::Service* service, grpc_generator::Printer* out) {
411   StringMap dict;
412   dict["Service"] = service->name();
413   out->Print("\n\n");
414   out->Print(dict, "class $Service$Stub(object):\n");
415   {
416     IndentScope raii_class_indent(out);
417     StringVector service_comments = service->GetAllComments();
418     PrintAllComments(service_comments, out);
419     out->Print("\n");
420     out->Print("def __init__(self, channel):\n");
421     {
422       IndentScope raii_init_indent(out);
423       out->Print("\"\"\"Constructor.\n");
424       out->Print("\n");
425       out->Print("Args:\n");
426       {
427         IndentScope raii_args_indent(out);
428         out->Print("channel: A grpc.Channel.\n");
429       }
430       out->Print("\"\"\"\n");
431       for (int i = 0; i < service->method_count(); ++i) {
432         auto method = service->method(i);
433         std::string multi_callable_constructor =
434             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
435             std::string(method->ServerStreaming() ? "stream" : "unary");
436         std::string request_module_and_class;
437         if (!method->get_module_and_message_path_input(
438                 &request_module_and_class, generator_file_name,
439                 generate_in_pb2_grpc, config.import_prefix,
440                 config.prefixes_to_filter)) {
441           return false;
442         }
443         std::string response_module_and_class;
444         if (!method->get_module_and_message_path_output(
445                 &response_module_and_class, generator_file_name,
446                 generate_in_pb2_grpc, config.import_prefix,
447                 config.prefixes_to_filter)) {
448           return false;
449         }
450         StringMap method_dict;
451         method_dict["Method"] = method->name();
452         method_dict["MultiCallableConstructor"] = multi_callable_constructor;
453         out->Print(method_dict,
454                    "self.$Method$ = channel.$MultiCallableConstructor$(\n");
455         {
456           method_dict["PackageQualifiedService"] =
457               package_qualified_service_name;
458           method_dict["RequestModuleAndClass"] = request_module_and_class;
459           method_dict["ResponseModuleAndClass"] = response_module_and_class;
460           IndentScope raii_first_attribute_indent(out);
461           IndentScope raii_second_attribute_indent(out);
462           out->Print(method_dict, "'/$PackageQualifiedService$/$Method$',\n");
463           out->Print(method_dict,
464                      "request_serializer=$RequestModuleAndClass$."
465                      "SerializeToString,\n");
466           out->Print(
467               method_dict,
468               "response_deserializer=$ResponseModuleAndClass$.FromString,\n");
469           out->Print(")\n");
470         }
471       }
472     }
473   }
474   return true;
475 }
476 
PrintServicer(const grpc_generator::Service * service,grpc_generator::Printer * out)477 bool PrivateGenerator::PrintServicer(const grpc_generator::Service* service,
478                                      grpc_generator::Printer* out) {
479   StringMap service_dict;
480   service_dict["Service"] = service->name();
481   out->Print("\n\n");
482   out->Print(service_dict, "class $Service$Servicer(object):\n");
483   {
484     IndentScope raii_class_indent(out);
485     StringVector service_comments = service->GetAllComments();
486     PrintAllComments(service_comments, out);
487     for (int i = 0; i < service->method_count(); ++i) {
488       auto method = service->method(i);
489       std::string arg_name =
490           method->ClientStreaming() ? "request_iterator" : "request";
491       StringMap method_dict;
492       method_dict["Method"] = method->name();
493       method_dict["ArgName"] = arg_name;
494       out->Print("\n");
495       out->Print(method_dict, "def $Method$(self, $ArgName$, context):\n");
496       {
497         IndentScope raii_method_indent(out);
498         StringVector method_comments = method->GetAllComments();
499         PrintAllComments(method_comments, out);
500         out->Print("context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n");
501         out->Print("context.set_details('Method not implemented!')\n");
502         out->Print("raise NotImplementedError('Method not implemented!')\n");
503       }
504     }
505   }
506   return true;
507 }
508 
PrintAddServicerToServer(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)509 bool PrivateGenerator::PrintAddServicerToServer(
510     const std::string& package_qualified_service_name,
511     const grpc_generator::Service* service, grpc_generator::Printer* out) {
512   StringMap service_dict;
513   service_dict["Service"] = service->name();
514   out->Print("\n\n");
515   out->Print(service_dict,
516              "def add_$Service$Servicer_to_server(servicer, server):\n");
517   {
518     IndentScope raii_class_indent(out);
519     out->Print("rpc_method_handlers = {\n");
520     {
521       IndentScope raii_dict_first_indent(out);
522       IndentScope raii_dict_second_indent(out);
523       for (int i = 0; i < service->method_count(); ++i) {
524         auto method = service->method(i);
525         std::string method_handler_constructor =
526             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
527             std::string(method->ServerStreaming() ? "stream" : "unary") +
528             "_rpc_method_handler";
529         std::string request_module_and_class;
530         if (!method->get_module_and_message_path_input(
531                 &request_module_and_class, generator_file_name,
532                 generate_in_pb2_grpc, config.import_prefix,
533                 config.prefixes_to_filter)) {
534           return false;
535         }
536         std::string response_module_and_class;
537         if (!method->get_module_and_message_path_output(
538                 &response_module_and_class, generator_file_name,
539                 generate_in_pb2_grpc, config.import_prefix,
540                 config.prefixes_to_filter)) {
541           return false;
542         }
543         StringMap method_dict;
544         method_dict["Method"] = method->name();
545         method_dict["MethodHandlerConstructor"] = method_handler_constructor;
546         method_dict["RequestModuleAndClass"] = request_module_and_class;
547         method_dict["ResponseModuleAndClass"] = response_module_and_class;
548         out->Print(method_dict,
549                    "'$Method$': grpc.$MethodHandlerConstructor$(\n");
550         {
551           IndentScope raii_call_first_indent(out);
552           IndentScope raii_call_second_indent(out);
553           out->Print(method_dict, "servicer.$Method$,\n");
554           out->Print(
555               method_dict,
556               "request_deserializer=$RequestModuleAndClass$.FromString,\n");
557           out->Print(
558               method_dict,
559               "response_serializer=$ResponseModuleAndClass$.SerializeToString,"
560               "\n");
561         }
562         out->Print("),\n");
563       }
564     }
565     StringMap method_dict;
566     method_dict["PackageQualifiedServiceName"] = package_qualified_service_name;
567     out->Print("}\n");
568     out->Print("generic_handler = grpc.method_handlers_generic_handler(\n");
569     {
570       IndentScope raii_call_first_indent(out);
571       IndentScope raii_call_second_indent(out);
572       out->Print(method_dict,
573                  "'$PackageQualifiedServiceName$', rpc_method_handlers)\n");
574     }
575     out->Print("server.add_generic_rpc_handlers((generic_handler,))\n");
576   }
577   return true;
578 }
579 
580 /* Prints out a service class used as a container for static methods pertaining
581  * to a class. This class has the exact name of service written in the ".proto"
582  * file, with no suffixes. Since this class merely acts as a namespace, it
583  * should never be instantiated.
584  */
PrintServiceClass(const std::string & package_qualified_service_name,const grpc_generator::Service * service,grpc_generator::Printer * out)585 bool PrivateGenerator::PrintServiceClass(
586     const std::string& package_qualified_service_name,
587     const grpc_generator::Service* service, grpc_generator::Printer* out) {
588   StringMap dict;
589   dict["Service"] = service->name();
590   out->Print("\n\n");
591   out->Print(" # This class is part of an EXPERIMENTAL API.\n");
592   out->Print(dict, "class $Service$(object):\n");
593   {
594     IndentScope class_indent(out);
595     StringVector service_comments = service->GetAllComments();
596     PrintAllComments(service_comments, out);
597     for (int i = 0; i < service->method_count(); ++i) {
598       const auto& method = service->method(i);
599       std::string request_module_and_class;
600       if (!method->get_module_and_message_path_input(
601               &request_module_and_class, generator_file_name,
602               generate_in_pb2_grpc, config.import_prefix,
603               config.prefixes_to_filter)) {
604         return false;
605       }
606       std::string response_module_and_class;
607       if (!method->get_module_and_message_path_output(
608               &response_module_and_class, generator_file_name,
609               generate_in_pb2_grpc, config.import_prefix,
610               config.prefixes_to_filter)) {
611         return false;
612       }
613       out->Print("\n");
614       StringMap method_dict;
615       method_dict["Method"] = method->name();
616       out->Print("@staticmethod\n");
617       out->Print(method_dict, "def $Method$(");
618       std::string request_parameter(
619           method->ClientStreaming() ? "request_iterator" : "request");
620       StringMap args_dict;
621       args_dict["RequestParameter"] = request_parameter;
622       {
623         IndentScope args_indent(out);
624         IndentScope args_double_indent(out);
625         out->Print(args_dict, "$RequestParameter$,\n");
626         out->Print("target,\n");
627         out->Print("options=(),\n");
628         out->Print("channel_credentials=None,\n");
629         out->Print("call_credentials=None,\n");
630         out->Print("insecure=False,\n");
631         out->Print("compression=None,\n");
632         out->Print("wait_for_ready=None,\n");
633         out->Print("timeout=None,\n");
634         out->Print("metadata=None):\n");
635       }
636       {
637         IndentScope method_indent(out);
638         std::string arity_method_name =
639             std::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
640             std::string(method->ServerStreaming() ? "stream" : "unary");
641         args_dict["ArityMethodName"] = arity_method_name;
642         args_dict["PackageQualifiedService"] = package_qualified_service_name;
643         args_dict["Method"] = method->name();
644         out->Print(args_dict,
645                    "return "
646                    "grpc.experimental.$ArityMethodName$($RequestParameter$, "
647                    "target, '/$PackageQualifiedService$/$Method$',\n");
648         {
649           IndentScope continuation_indent(out);
650           StringMap serializer_dict;
651           serializer_dict["RequestModuleAndClass"] = request_module_and_class;
652           serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
653           out->Print(serializer_dict,
654                      "$RequestModuleAndClass$.SerializeToString,\n");
655           out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
656           out->Print("options, channel_credentials,\n");
657           out->Print(
658               "insecure, call_credentials, compression, wait_for_ready, "
659               "timeout, metadata)\n");
660         }
661       }
662     }
663   }
664   // TODO(rbellevi): Add methods pertinent to the server side as well.
665   return true;
666 }
667 
PrintBetaPreamble(grpc_generator::Printer * out)668 bool PrivateGenerator::PrintBetaPreamble(grpc_generator::Printer* out) {
669   StringMap var;
670   var["Package"] = config.beta_package_root;
671   out->Print(var,
672              "from $Package$ import implementations as beta_implementations\n");
673   out->Print(var, "from $Package$ import interfaces as beta_interfaces\n");
674   out->Print("from grpc.framework.common import cardinality\n");
675   out->Print(
676       "from grpc.framework.interfaces.face import utilities as "
677       "face_utilities\n");
678   return true;
679 }
680 
PrintPreamble(grpc_generator::Printer * out)681 bool PrivateGenerator::PrintPreamble(grpc_generator::Printer* out) {
682   StringMap var;
683   var["Package"] = config.grpc_package_root;
684   out->Print(var, "import $Package$\n");
685   if (generate_in_pb2_grpc) {
686     out->Print("\n");
687     StringPairSet imports_set;
688     for (int i = 0; i < file->service_count(); ++i) {
689       auto service = file->service(i);
690       for (int j = 0; j < service->method_count(); ++j) {
691         auto method = service.get()->method(j);
692 
693         std::string input_type_file_name = method->get_input_type_name();
694         std::string input_module_name =
695             ModuleName(input_type_file_name, config.import_prefix,
696                        config.prefixes_to_filter);
697         std::string input_module_alias =
698             ModuleAlias(input_type_file_name, config.import_prefix,
699                         config.prefixes_to_filter);
700         imports_set.insert(
701             std::make_tuple(input_module_name, input_module_alias));
702 
703         std::string output_type_file_name = method->get_output_type_name();
704         std::string output_module_name =
705             ModuleName(output_type_file_name, config.import_prefix,
706                        config.prefixes_to_filter);
707         std::string output_module_alias =
708             ModuleAlias(output_type_file_name, config.import_prefix,
709                         config.prefixes_to_filter);
710         imports_set.insert(
711             std::make_tuple(output_module_name, output_module_alias));
712       }
713     }
714 
715     for (StringPairSet::iterator it = imports_set.begin();
716          it != imports_set.end(); ++it) {
717       auto module_name = std::get<0>(*it);
718       var["ModuleAlias"] = std::get<1>(*it);
719       const size_t last_dot_pos = module_name.rfind('.');
720       if (last_dot_pos == std::string::npos) {
721         var["ImportStatement"] = "import " + module_name;
722       } else {
723         var["ImportStatement"] = "from " + module_name.substr(0, last_dot_pos) +
724                                  " import " +
725                                  module_name.substr(last_dot_pos + 1);
726       }
727       out->Print(var, "$ImportStatement$ as $ModuleAlias$\n");
728     }
729   }
730   return true;
731 }
732 
PrintGAServices(grpc_generator::Printer * out)733 bool PrivateGenerator::PrintGAServices(grpc_generator::Printer* out) {
734   std::string package = file->package();
735   if (!package.empty()) {
736     package = package.append(".");
737   }
738   for (int i = 0; i < file->service_count(); ++i) {
739     auto service = file->service(i);
740     std::string package_qualified_service_name = package + service->name();
741     if (!(PrintStub(package_qualified_service_name, service.get(), out) &&
742           PrintServicer(service.get(), out) &&
743           PrintAddServicerToServer(package_qualified_service_name,
744                                    service.get(), out) &&
745           PrintServiceClass(package_qualified_service_name, service.get(),
746                             out))) {
747       return false;
748     }
749   }
750   return true;
751 }
752 
PrintBetaServices(grpc_generator::Printer * out)753 bool PrivateGenerator::PrintBetaServices(grpc_generator::Printer* out) {
754   std::string package = file->package();
755   if (!package.empty()) {
756     package = package.append(".");
757   }
758   for (int i = 0; i < file->service_count(); ++i) {
759     auto service = file->service(i);
760     std::string package_qualified_service_name = package + service->name();
761     if (!(PrintBetaServicer(service.get(), out) &&
762           PrintBetaStub(service.get(), out) &&
763           PrintBetaServerFactory(package_qualified_service_name, service.get(),
764                                  out) &&
765           PrintBetaStubFactory(package_qualified_service_name, service.get(),
766                                out))) {
767       return false;
768     }
769   }
770   return true;
771 }
772 
GetGrpcServices()773 pair<bool, std::string> PrivateGenerator::GetGrpcServices() {
774   std::string output;
775   {
776     // Scope the output stream so it closes and finalizes output to the string.
777     auto out = file->CreatePrinter(&output);
778     if (generate_in_pb2_grpc) {
779       out->Print(
780           "# Generated by the gRPC Python protocol compiler plugin. "
781           "DO NOT EDIT!\n\"\"\""
782           "Client and server classes corresponding to protobuf-defined "
783           "services.\"\"\"\n");
784       if (!PrintPreamble(out.get())) {
785         return make_pair(false, "");
786       }
787       if (!PrintGAServices(out.get())) {
788         return make_pair(false, "");
789       }
790     } else {
791       out->Print("try:\n");
792       {
793         IndentScope raii_dict_try_indent(out.get());
794         out->Print(
795             "# THESE ELEMENTS WILL BE DEPRECATED.\n"
796             "# Please use the generated *_pb2_grpc.py files instead.\n");
797         if (!PrintPreamble(out.get())) {
798           return make_pair(false, "");
799         }
800         if (!PrintBetaPreamble(out.get())) {
801           return make_pair(false, "");
802         }
803         if (!PrintGAServices(out.get())) {
804           return make_pair(false, "");
805         }
806         if (!PrintBetaServices(out.get())) {
807           return make_pair(false, "");
808         }
809       }
810       out->Print("except ImportError:\n");
811       {
812         IndentScope raii_dict_except_indent(out.get());
813         out->Print("pass");
814       }
815     }
816   }
817   return make_pair(true, std::move(output));
818 }
819 
820 }  // namespace
821 
GeneratorConfiguration()822 GeneratorConfiguration::GeneratorConfiguration()
823     : grpc_package_root("grpc"),
824       beta_package_root("grpc.beta"),
825       import_prefix("") {}
826 
PythonGrpcGenerator(const GeneratorConfiguration & config)827 PythonGrpcGenerator::PythonGrpcGenerator(const GeneratorConfiguration& config)
828     : config_(config) {}
829 
~PythonGrpcGenerator()830 PythonGrpcGenerator::~PythonGrpcGenerator() {}
831 
GenerateGrpc(GeneratorContext * context,PrivateGenerator & generator,std::string file_name,bool generate_in_pb2_grpc)832 static bool GenerateGrpc(GeneratorContext* context, PrivateGenerator& generator,
833                          std::string file_name, bool generate_in_pb2_grpc) {
834   bool success;
835   std::unique_ptr<ZeroCopyOutputStream> output;
836   std::unique_ptr<CodedOutputStream> coded_output;
837   std::string grpc_code;
838 
839   if (generate_in_pb2_grpc) {
840     output.reset(context->Open(file_name));
841     generator.generate_in_pb2_grpc = true;
842   } else {
843     output.reset(context->OpenForInsert(file_name, "module_scope"));
844     generator.generate_in_pb2_grpc = false;
845   }
846 
847   coded_output.reset(new CodedOutputStream(output.get()));
848   tie(success, grpc_code) = generator.GetGrpcServices();
849 
850   if (success) {
851     coded_output->WriteRaw(grpc_code.data(), grpc_code.size());
852     return true;
853   } else {
854     return false;
855   }
856 }
857 
ParseParameters(const std::string & parameter,std::string * grpc_version,std::vector<std::string> * strip_prefixes,std::string * error)858 static bool ParseParameters(const std::string& parameter,
859                             std::string* grpc_version,
860                             std::vector<std::string>* strip_prefixes,
861                             std::string* error) {
862   std::vector<std::string> comma_delimited_parameters;
863   grpc_python_generator::Split(parameter, ',', &comma_delimited_parameters);
864   if (comma_delimited_parameters.size() == 1 &&
865       comma_delimited_parameters[0].empty()) {
866     *grpc_version = "grpc_2_0";
867   } else if (comma_delimited_parameters.size() == 1) {
868     *grpc_version = comma_delimited_parameters[0];
869   } else if (comma_delimited_parameters.size() == 2) {
870     *grpc_version = comma_delimited_parameters[0];
871     std::copy(comma_delimited_parameters.begin() + 1,
872               comma_delimited_parameters.end(),
873               std::back_inserter(*strip_prefixes));
874   } else {
875     *error = "--grpc_python_out received too many comma-delimited parameters.";
876     return false;
877   }
878   return true;
879 }
880 
GetSupportedFeatures() const881 uint64_t PythonGrpcGenerator::GetSupportedFeatures() const {
882   return FEATURE_PROTO3_OPTIONAL;
883 }
884 
Generate(const FileDescriptor * file,const std::string & parameter,GeneratorContext * context,std::string * error) const885 bool PythonGrpcGenerator::Generate(const FileDescriptor* file,
886                                    const std::string& parameter,
887                                    GeneratorContext* context,
888                                    std::string* error) const {
889   // Get output file name.
890   std::string pb2_file_name;
891   std::string pb2_grpc_file_name;
892   static const int proto_suffix_length = strlen(".proto");
893   if (file->name().size() > static_cast<size_t>(proto_suffix_length) &&
894       file->name().find_last_of(".proto") == file->name().size() - 1) {
895     std::string base =
896         file->name().substr(0, file->name().size() - proto_suffix_length);
897     std::replace(base.begin(), base.end(), '-', '_');
898     pb2_file_name = base + "_pb2.py";
899     pb2_grpc_file_name = base + "_pb2_grpc.py";
900   } else {
901     *error = "Invalid proto file name. Proto file must end with .proto";
902     return false;
903   }
904   generator_file_name = file->name();
905 
906   ProtoBufFile pbfile(file);
907   std::string grpc_version;
908   GeneratorConfiguration extended_config(config_);
909   bool success = ParseParameters(parameter, &grpc_version,
910                                  &(extended_config.prefixes_to_filter), error);
911   PrivateGenerator generator(extended_config, &pbfile);
912   if (!success) return false;
913   if (grpc_version == "grpc_2_0") {
914     return GenerateGrpc(context, generator, pb2_grpc_file_name, true);
915   } else if (grpc_version == "grpc_1_0") {
916     return GenerateGrpc(context, generator, pb2_grpc_file_name, true) &&
917            GenerateGrpc(context, generator, pb2_file_name, false);
918   } else {
919     *error = "Invalid grpc version '" + grpc_version + "'.";
920     return false;
921   }
922 }
923 
924 }  // namespace grpc_python_generator
925