1 /*
2 *
3 * Copyright 2015 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "src/compiler/python_generator.h"
20
21 #include <algorithm>
22 #include <cstddef>
23 #include <map>
24 #include <set>
25 #include <sstream>
26 #include <string>
27
28 #include "codegen/idl_namer.h"
29 #include "codegen/namer.h"
30 #include "codegen/python.h"
31 #include "flatbuffers/idl.h"
32 #include "flatbuffers/util.h"
33
34 namespace flatbuffers {
35 namespace python {
36 namespace grpc {
37 namespace {
ClientStreaming(const RPCCall * method)38 bool ClientStreaming(const RPCCall *method) {
39 const Value *val = method->attributes.Lookup("streaming");
40 return val != nullptr && (val->constant == "client" || val->constant == "bidi");
41 }
42
ServerStreaming(const RPCCall * method)43 bool ServerStreaming(const RPCCall *method) {
44 const Value *val = method->attributes.Lookup("streaming");
45 return val != nullptr && (val->constant == "server" || val->constant == "bidi");
46 }
47
FormatImports(std::stringstream & ss,const Imports & imports)48 void FormatImports(std::stringstream &ss, const Imports &imports) {
49 std::set<std::string> modules;
50 std::map<std::string, std::set<std::string>> names_by_module;
51 for (const Import &import : imports.imports) {
52 if (import.IsLocal()) continue; // skip all local imports
53 if (import.name == "") {
54 modules.insert(import.module);
55 } else {
56 names_by_module[import.module].insert(import.name);
57 }
58 }
59
60 for (const std::string &module : modules) {
61 ss << "import " << module << '\n';
62 }
63 ss << '\n';
64 for (const auto &import : names_by_module) {
65 ss << "from " << import.first << " import ";
66 size_t i = 0;
67 for (const std::string &name : import.second) {
68 if (i > 0) ss << ", ";
69 ss << name;
70 ++i;
71 }
72 ss << '\n';
73 }
74 ss << "\n\n";
75 }
76
SaveStub(const std::string & filename,const Imports & imports,const std::string & content)77 bool SaveStub(const std::string &filename, const Imports &imports,
78 const std::string &content) {
79 std::stringstream ss;
80 ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n"
81 << '\n'
82 << "from __future__ import annotations\n"
83 << '\n';
84 FormatImports(ss, imports);
85 ss << content << '\n';
86
87 EnsureDirExists(StripFileName(filename));
88 return flatbuffers::SaveFile(filename.c_str(), ss.str(), false);
89 }
90
SaveService(const std::string & filename,const Imports & imports,const std::string & content)91 bool SaveService(const std::string &filename, const Imports &imports,
92 const std::string &content) {
93 std::stringstream ss;
94 ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n" << '\n';
95 FormatImports(ss, imports);
96 ss << content << '\n';
97
98 EnsureDirExists(StripFileName(filename));
99 return flatbuffers::SaveFile(filename.c_str(), ss.str(), false);
100 }
101
102 class BaseGenerator {
103 protected:
BaseGenerator(const Parser & parser,const Namer::Config & config,const std::string & path,const Version & version)104 BaseGenerator(const Parser &parser, const Namer::Config &config,
105 const std::string &path, const Version &version)
106 : parser_{parser},
107 namer_{WithFlagOptions(config, parser.opts, path), Keywords(version)},
108 version_{version} {}
109
110 protected:
ModuleForFile(const std::string & file) const111 std::string ModuleForFile(const std::string &file) const {
112 std::string module = parser_.opts.include_prefix + StripExtension(file) +
113 parser_.opts.filename_suffix;
114 std::replace(module.begin(), module.end(), '/', '.');
115 return module;
116 }
117
118 template <typename T>
ModuleFor(const T * def) const119 std::string ModuleFor(const T *def) const {
120 if (parser_.opts.one_file) return ModuleForFile(def->file);
121 return namer_.NamespacedType(*def);
122 }
123
124 const Parser &parser_;
125 const IdlNamer namer_;
126 const Version version_;
127 };
128
129 class StubGenerator : public BaseGenerator {
130 public:
StubGenerator(const Parser & parser,const std::string & path,const Version & version)131 StubGenerator(const Parser &parser, const std::string &path,
132 const Version &version)
133 : BaseGenerator(parser, kStubConfig, path, version) {}
134
Generate()135 bool Generate() {
136 Imports imports;
137 std::stringstream stub;
138 for (const ServiceDef *service : parser_.services_.vec) {
139 Generate(stub, service, &imports);
140 }
141
142 std::string filename =
143 namer_.config_.output_path +
144 StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" +
145 parser_.opts.grpc_filename_suffix + namer_.config_.filename_extension;
146
147 return SaveStub(filename, imports, stub.str());
148 }
149
150 private:
Generate(std::stringstream & ss,const ServiceDef * service,Imports * imports)151 void Generate(std::stringstream &ss, const ServiceDef *service,
152 Imports *imports) {
153 imports->Import("grpc");
154
155 ss << "class " << service->name << "Stub(object):\n"
156 << " def __init__(self, channel: grpc.Channel) -> None: ...\n";
157
158 for (const RPCCall *method : service->calls.vec) {
159 std::string request = "bytes";
160 std::string response = "bytes";
161
162 if (parser_.opts.grpc_python_typed_handlers) {
163 request = namer_.Type(*method->request);
164 response = namer_.Type(*method->response);
165
166 imports->Import(ModuleFor(method->request), request);
167 imports->Import(ModuleFor(method->response), response);
168 }
169
170 ss << " def " << method->name << "(self, ";
171 if (ClientStreaming(method)) {
172 imports->Import("typing");
173 ss << "request_iterator: typing.Iterator[" << request << "]";
174 } else {
175 ss << "request: " << request;
176 }
177 ss << ") -> ";
178 if (ServerStreaming(method)) {
179 imports->Import("typing");
180 ss << "typing.Iterator[" << response << "]";
181 } else {
182 ss << response;
183 }
184 ss << ": ...\n";
185 }
186
187 ss << "\n\n";
188 ss << "class " << service->name << "Servicer(object):\n";
189
190 for (const RPCCall *method : service->calls.vec) {
191 std::string request = "bytes";
192 std::string response = "bytes";
193
194 if (parser_.opts.grpc_python_typed_handlers) {
195 request = namer_.Type(*method->request);
196 response = namer_.Type(*method->response);
197
198 imports->Import(ModuleFor(method->request), request);
199 imports->Import(ModuleFor(method->response), response);
200 }
201
202 ss << " def " << method->name << "(self, ";
203 if (ClientStreaming(method)) {
204 imports->Import("typing");
205 ss << "request_iterator: typing.Iterator[" << request << "]";
206 } else {
207 ss << "request: " << request;
208 }
209 ss << ", context: grpc.ServicerContext) -> ";
210 if (ServerStreaming(method)) {
211 imports->Import("typing");
212 ss << "typing.Iterator[" << response << "]";
213 } else {
214 ss << response;
215 }
216 ss << ": ...\n";
217 }
218
219 ss << '\n'
220 << '\n'
221 << "def add_" << service->name
222 << "Servicer_to_server(servicer: " << service->name
223 << "Servicer, server: grpc.Server) -> None: ...\n";
224 }
225 };
226
227 class ServiceGenerator : public BaseGenerator {
228 public:
ServiceGenerator(const Parser & parser,const std::string & path,const Version & version)229 ServiceGenerator(const Parser &parser, const std::string &path,
230 const Version &version)
231 : BaseGenerator(parser, kConfig, path, version) {}
232
Generate()233 bool Generate() {
234 Imports imports;
235 std::stringstream ss;
236
237 imports.Import("flatbuffers");
238
239 if (parser_.opts.grpc_python_typed_handlers) {
240 ss << "def _serialize_to_bytes(table):\n"
241 << " buf = table._tab.Bytes\n"
242 << " n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)\n"
243 << " if table._tab.Pos != n:\n"
244 << " raise ValueError('must be a top-level table')\n"
245 << " return bytes(buf)\n"
246 << '\n'
247 << '\n';
248 }
249
250 for (const ServiceDef *service : parser_.services_.vec) {
251 GenerateStub(ss, service, &imports);
252 GenerateServicer(ss, service, &imports);
253 GenerateRegister(ss, service, &imports);
254 }
255
256 std::string filename =
257 namer_.config_.output_path +
258 StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" +
259 parser_.opts.grpc_filename_suffix + namer_.config_.filename_extension;
260
261 return SaveService(filename, imports, ss.str());
262 }
263
264 private:
GenerateStub(std::stringstream & ss,const ServiceDef * service,Imports * imports)265 void GenerateStub(std::stringstream &ss, const ServiceDef *service,
266 Imports *imports) {
267 ss << "class " << service->name << "Stub";
268 if (version_.major != 3) ss << "(object)";
269 ss << ":\n"
270 << " '''Interface exported by the server.'''\n"
271 << '\n'
272 << " def __init__(self, channel):\n"
273 << " '''Constructor.\n"
274 << '\n'
275 << " Args:\n"
276 << " channel: A grpc.Channel.\n"
277 << " '''\n"
278 << '\n';
279
280 for (const RPCCall *method : service->calls.vec) {
281 std::string response = namer_.Type(*method->response);
282
283 imports->Import(ModuleFor(method->response), response);
284
285 ss << " self." << method->name << " = channel."
286 << (ClientStreaming(method) ? "stream" : "unary") << "_"
287 << (ServerStreaming(method) ? "stream" : "unary") << "(\n"
288 << " method='/"
289 << service->defined_namespace->GetFullyQualifiedName(service->name)
290 << "/" << method->name << "'";
291
292 if (parser_.opts.grpc_python_typed_handlers) {
293 ss << ",\n"
294 << " request_serializer=_serialize_to_bytes,\n"
295 << " response_deserializer=" << response << ".GetRootAs";
296 }
297 ss << ")\n\n";
298 }
299
300 ss << '\n';
301 }
302
GenerateServicer(std::stringstream & ss,const ServiceDef * service,Imports * imports)303 void GenerateServicer(std::stringstream &ss, const ServiceDef *service,
304 Imports *imports) {
305 imports->Import("grpc");
306
307 ss << "class " << service->name << "Servicer";
308 if (version_.major != 3) ss << "(object)";
309 ss << ":\n"
310 << " '''Interface exported by the server.'''\n"
311 << '\n';
312
313 for (const RPCCall *method : service->calls.vec) {
314 const std::string request_param =
315 ClientStreaming(method) ? "request_iterator" : "request";
316 ss << " def " << method->name << "(self, " << request_param
317 << ", context):\n"
318 << " context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n"
319 << " context.set_details('Method not implemented!')\n"
320 << " raise NotImplementedError('Method not implemented!')\n"
321 << '\n';
322 }
323
324 ss << '\n';
325 }
326
GenerateRegister(std::stringstream & ss,const ServiceDef * service,Imports * imports)327 void GenerateRegister(std::stringstream &ss, const ServiceDef *service,
328 Imports *imports) {
329 imports->Import("grpc");
330
331 ss << "def add_" << service->name
332 << "Servicer_to_server(servicer, server):\n"
333 << " rpc_method_handlers = {\n";
334
335 for (const RPCCall *method : service->calls.vec) {
336 std::string request = namer_.Type(*method->request);
337
338 imports->Import(ModuleFor(method->request), request);
339
340 ss << " '" << method->name << "': grpc."
341 << (ClientStreaming(method) ? "stream" : "unary") << "_"
342 << (ServerStreaming(method) ? "stream" : "unary")
343 << "_rpc_method_handler(\n"
344 << " servicer." << method->name;
345
346 if (parser_.opts.grpc_python_typed_handlers) {
347 ss << ",\n"
348 << " request_deserializer=" << request << ".GetRootAs,\n"
349 << " response_serializer=_serialize_to_bytes";
350 }
351 ss << "),\n";
352 }
353 ss << " }\n"
354 << '\n'
355 << " generic_handler = grpc.method_handlers_generic_handler(\n"
356 << " '"
357 << service->defined_namespace->GetFullyQualifiedName(service->name)
358 << "', rpc_method_handlers)\n"
359 << '\n'
360 << " server.add_generic_rpc_handlers((generic_handler,))\n"
361 << '\n';
362 }
363 };
364 } // namespace
365
Generate(const Parser & parser,const std::string & path,const Version & version)366 bool Generate(const Parser &parser, const std::string &path,
367 const Version &version) {
368 ServiceGenerator generator{parser, path, version};
369 return generator.Generate();
370 }
371
GenerateStub(const Parser & parser,const std::string & path,const Version & version)372 bool GenerateStub(const Parser &parser, const std::string &path,
373 const Version &version) {
374 StubGenerator generator{parser, path, version};
375 return generator.Generate();
376 }
377
378 } // namespace grpc
379 } // namespace python
380 } // namespace flatbuffers
381