• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include "src/compiler/python_generator.h"
20 
21 #include <algorithm>
22 #include <cstddef>
23 #include <map>
24 #include <set>
25 #include <sstream>
26 #include <string>
27 
28 #include "codegen/idl_namer.h"
29 #include "codegen/namer.h"
30 #include "codegen/python.h"
31 #include "flatbuffers/idl.h"
32 #include "flatbuffers/util.h"
33 
34 namespace flatbuffers {
35 namespace python {
36 namespace grpc {
37 namespace {
ClientStreaming(const RPCCall * method)38 bool ClientStreaming(const RPCCall *method) {
39   const Value *val = method->attributes.Lookup("streaming");
40   return val != nullptr && (val->constant == "client" || val->constant == "bidi");
41 }
42 
ServerStreaming(const RPCCall * method)43 bool ServerStreaming(const RPCCall *method) {
44   const Value *val = method->attributes.Lookup("streaming");
45   return val != nullptr && (val->constant == "server" || val->constant == "bidi");
46 }
47 
FormatImports(std::stringstream & ss,const Imports & imports)48 void FormatImports(std::stringstream &ss, const Imports &imports) {
49   std::set<std::string> modules;
50   std::map<std::string, std::set<std::string>> names_by_module;
51   for (const Import &import : imports.imports) {
52     if (import.IsLocal()) continue;  // skip all local imports
53     if (import.name == "") {
54       modules.insert(import.module);
55     } else {
56       names_by_module[import.module].insert(import.name);
57     }
58   }
59 
60   for (const std::string &module : modules) {
61     ss << "import " << module << '\n';
62   }
63   ss << '\n';
64   for (const auto &import : names_by_module) {
65     ss << "from " << import.first << " import ";
66     size_t i = 0;
67     for (const std::string &name : import.second) {
68       if (i > 0) ss << ", ";
69       ss << name;
70       ++i;
71     }
72     ss << '\n';
73   }
74   ss << "\n\n";
75 }
76 
SaveStub(const std::string & filename,const Imports & imports,const std::string & content)77 bool SaveStub(const std::string &filename, const Imports &imports,
78               const std::string &content) {
79   std::stringstream ss;
80   ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n"
81      << '\n'
82      << "from __future__ import annotations\n"
83      << '\n';
84   FormatImports(ss, imports);
85   ss << content << '\n';
86 
87   EnsureDirExists(StripFileName(filename));
88   return flatbuffers::SaveFile(filename.c_str(), ss.str(), false);
89 }
90 
SaveService(const std::string & filename,const Imports & imports,const std::string & content)91 bool SaveService(const std::string &filename, const Imports &imports,
92                  const std::string &content) {
93   std::stringstream ss;
94   ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n" << '\n';
95   FormatImports(ss, imports);
96   ss << content << '\n';
97 
98   EnsureDirExists(StripFileName(filename));
99   return flatbuffers::SaveFile(filename.c_str(), ss.str(), false);
100 }
101 
102 class BaseGenerator {
103  protected:
BaseGenerator(const Parser & parser,const Namer::Config & config,const std::string & path,const Version & version)104   BaseGenerator(const Parser &parser, const Namer::Config &config,
105                 const std::string &path, const Version &version)
106       : parser_{parser},
107         namer_{WithFlagOptions(config, parser.opts, path), Keywords(version)},
108         version_{version} {}
109 
110  protected:
ModuleForFile(const std::string & file) const111   std::string ModuleForFile(const std::string &file) const {
112     std::string module = parser_.opts.include_prefix + StripExtension(file) +
113                          parser_.opts.filename_suffix;
114     std::replace(module.begin(), module.end(), '/', '.');
115     return module;
116   }
117 
118   template <typename T>
ModuleFor(const T * def) const119   std::string ModuleFor(const T *def) const {
120     if (parser_.opts.one_file) return ModuleForFile(def->file);
121     return namer_.NamespacedType(*def);
122   }
123 
124   const Parser &parser_;
125   const IdlNamer namer_;
126   const Version version_;
127 };
128 
129 class StubGenerator : public BaseGenerator {
130  public:
StubGenerator(const Parser & parser,const std::string & path,const Version & version)131   StubGenerator(const Parser &parser, const std::string &path,
132                 const Version &version)
133       : BaseGenerator(parser, kStubConfig, path, version) {}
134 
Generate()135   bool Generate() {
136     Imports imports;
137     std::stringstream stub;
138     for (const ServiceDef *service : parser_.services_.vec) {
139       Generate(stub, service, &imports);
140     }
141 
142     std::string filename =
143         namer_.config_.output_path +
144         StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" +
145         parser_.opts.grpc_filename_suffix + namer_.config_.filename_extension;
146 
147     return SaveStub(filename, imports, stub.str());
148   }
149 
150  private:
Generate(std::stringstream & ss,const ServiceDef * service,Imports * imports)151   void Generate(std::stringstream &ss, const ServiceDef *service,
152                 Imports *imports) {
153     imports->Import("grpc");
154 
155     ss << "class " << service->name << "Stub(object):\n"
156        << "  def __init__(self, channel: grpc.Channel) -> None: ...\n";
157 
158     for (const RPCCall *method : service->calls.vec) {
159       std::string request = "bytes";
160       std::string response = "bytes";
161 
162       if (parser_.opts.grpc_python_typed_handlers) {
163         request = namer_.Type(*method->request);
164         response = namer_.Type(*method->response);
165 
166         imports->Import(ModuleFor(method->request), request);
167         imports->Import(ModuleFor(method->response), response);
168       }
169 
170       ss << "  def " << method->name << "(self, ";
171       if (ClientStreaming(method)) {
172         imports->Import("typing");
173         ss << "request_iterator: typing.Iterator[" << request << "]";
174       } else {
175         ss << "request: " << request;
176       }
177       ss << ") -> ";
178       if (ServerStreaming(method)) {
179         imports->Import("typing");
180         ss << "typing.Iterator[" << response << "]";
181       } else {
182         ss << response;
183       }
184       ss << ": ...\n";
185     }
186 
187     ss << "\n\n";
188     ss << "class " << service->name << "Servicer(object):\n";
189 
190     for (const RPCCall *method : service->calls.vec) {
191       std::string request = "bytes";
192       std::string response = "bytes";
193 
194       if (parser_.opts.grpc_python_typed_handlers) {
195         request = namer_.Type(*method->request);
196         response = namer_.Type(*method->response);
197 
198         imports->Import(ModuleFor(method->request), request);
199         imports->Import(ModuleFor(method->response), response);
200       }
201 
202       ss << "  def " << method->name << "(self, ";
203       if (ClientStreaming(method)) {
204         imports->Import("typing");
205         ss << "request_iterator: typing.Iterator[" << request << "]";
206       } else {
207         ss << "request: " << request;
208       }
209       ss << ", context: grpc.ServicerContext) -> ";
210       if (ServerStreaming(method)) {
211         imports->Import("typing");
212         ss << "typing.Iterator[" << response << "]";
213       } else {
214         ss << response;
215       }
216       ss << ": ...\n";
217     }
218 
219     ss << '\n'
220        << '\n'
221        << "def add_" << service->name
222        << "Servicer_to_server(servicer: " << service->name
223        << "Servicer, server: grpc.Server) -> None: ...\n";
224   }
225 };
226 
227 class ServiceGenerator : public BaseGenerator {
228  public:
ServiceGenerator(const Parser & parser,const std::string & path,const Version & version)229   ServiceGenerator(const Parser &parser, const std::string &path,
230                    const Version &version)
231       : BaseGenerator(parser, kConfig, path, version) {}
232 
Generate()233   bool Generate() {
234     Imports imports;
235     std::stringstream ss;
236 
237     imports.Import("flatbuffers");
238 
239     if (parser_.opts.grpc_python_typed_handlers) {
240       ss << "def _serialize_to_bytes(table):\n"
241          << "  buf = table._tab.Bytes\n"
242          << "  n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)\n"
243          << "  if table._tab.Pos != n:\n"
244          << "    raise ValueError('must be a top-level table')\n"
245          << "  return bytes(buf)\n"
246          << '\n'
247          << '\n';
248     }
249 
250     for (const ServiceDef *service : parser_.services_.vec) {
251       GenerateStub(ss, service, &imports);
252       GenerateServicer(ss, service, &imports);
253       GenerateRegister(ss, service, &imports);
254     }
255 
256     std::string filename =
257         namer_.config_.output_path +
258         StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" +
259         parser_.opts.grpc_filename_suffix + namer_.config_.filename_extension;
260 
261     return SaveService(filename, imports, ss.str());
262   }
263 
264  private:
GenerateStub(std::stringstream & ss,const ServiceDef * service,Imports * imports)265   void GenerateStub(std::stringstream &ss, const ServiceDef *service,
266                     Imports *imports) {
267     ss << "class " << service->name << "Stub";
268     if (version_.major != 3) ss << "(object)";
269     ss << ":\n"
270        << "  '''Interface exported by the server.'''\n"
271        << '\n'
272        << "  def __init__(self, channel):\n"
273        << "    '''Constructor.\n"
274        << '\n'
275        << "    Args:\n"
276        << "      channel: A grpc.Channel.\n"
277        << "    '''\n"
278        << '\n';
279 
280     for (const RPCCall *method : service->calls.vec) {
281       std::string response = namer_.Type(*method->response);
282 
283       imports->Import(ModuleFor(method->response), response);
284 
285       ss << "    self." << method->name << " = channel."
286          << (ClientStreaming(method) ? "stream" : "unary") << "_"
287          << (ServerStreaming(method) ? "stream" : "unary") << "(\n"
288          << "      method='/"
289          << service->defined_namespace->GetFullyQualifiedName(service->name)
290          << "/" << method->name << "'";
291 
292       if (parser_.opts.grpc_python_typed_handlers) {
293         ss << ",\n"
294            << "      request_serializer=_serialize_to_bytes,\n"
295            << "      response_deserializer=" << response << ".GetRootAs";
296       }
297       ss << ")\n\n";
298     }
299 
300     ss << '\n';
301   }
302 
GenerateServicer(std::stringstream & ss,const ServiceDef * service,Imports * imports)303   void GenerateServicer(std::stringstream &ss, const ServiceDef *service,
304                         Imports *imports) {
305     imports->Import("grpc");
306 
307     ss << "class " << service->name << "Servicer";
308     if (version_.major != 3) ss << "(object)";
309     ss << ":\n"
310        << "  '''Interface exported by the server.'''\n"
311        << '\n';
312 
313     for (const RPCCall *method : service->calls.vec) {
314       const std::string request_param =
315           ClientStreaming(method) ? "request_iterator" : "request";
316       ss << "  def " << method->name << "(self, " << request_param
317          << ", context):\n"
318          << "    context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n"
319          << "    context.set_details('Method not implemented!')\n"
320          << "    raise NotImplementedError('Method not implemented!')\n"
321          << '\n';
322     }
323 
324     ss << '\n';
325   }
326 
GenerateRegister(std::stringstream & ss,const ServiceDef * service,Imports * imports)327   void GenerateRegister(std::stringstream &ss, const ServiceDef *service,
328                         Imports *imports) {
329     imports->Import("grpc");
330 
331     ss << "def add_" << service->name
332        << "Servicer_to_server(servicer, server):\n"
333        << "  rpc_method_handlers = {\n";
334 
335     for (const RPCCall *method : service->calls.vec) {
336       std::string request = namer_.Type(*method->request);
337 
338       imports->Import(ModuleFor(method->request), request);
339 
340       ss << "    '" << method->name << "': grpc."
341          << (ClientStreaming(method) ? "stream" : "unary") << "_"
342          << (ServerStreaming(method) ? "stream" : "unary")
343          << "_rpc_method_handler(\n"
344          << "      servicer." << method->name;
345 
346       if (parser_.opts.grpc_python_typed_handlers) {
347         ss << ",\n"
348            << "      request_deserializer=" << request << ".GetRootAs,\n"
349            << "      response_serializer=_serialize_to_bytes";
350       }
351       ss << "),\n";
352     }
353     ss << "  }\n"
354        << '\n'
355        << "  generic_handler = grpc.method_handlers_generic_handler(\n"
356        << "    '"
357        << service->defined_namespace->GetFullyQualifiedName(service->name)
358        << "', rpc_method_handlers)\n"
359        << '\n'
360        << "  server.add_generic_rpc_handlers((generic_handler,))\n"
361        << '\n';
362   }
363 };
364 }  // namespace
365 
Generate(const Parser & parser,const std::string & path,const Version & version)366 bool Generate(const Parser &parser, const std::string &path,
367               const Version &version) {
368   ServiceGenerator generator{parser, path, version};
369   return generator.Generate();
370 }
371 
GenerateStub(const Parser & parser,const std::string & path,const Version & version)372 bool GenerateStub(const Parser &parser, const std::string &path,
373                   const Version &version) {
374   StubGenerator generator{parser, path, version};
375   return generator.Generate();
376 }
377 
378 }  // namespace grpc
379 }  // namespace python
380 }  // namespace flatbuffers
381