• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * NOTE: The following implementation is a translation for the Swift-grpc
19  * generator since flatbuffers doesnt allow plugins for now. if an issue arises
20  * please open an issue in the flatbuffers repository. This file should always
21  * be maintained according to the Swift-grpc repository
22  */
23 #include <map>
24 #include <sstream>
25 
26 #include "flatbuffers/util.h"
27 #include "src/compiler/schema_interface.h"
28 #include "src/compiler/swift_generator.h"
29 
30 namespace grpc_swift_generator {
31 
WrapInNameSpace(const std::vector<std::string> & components,const grpc::string & name)32 std::string WrapInNameSpace(const std::vector<std::string> &components,
33                             const grpc::string &name) {
34   std::string qualified_name;
35   for (auto it = components.begin(); it != components.end(); ++it)
36     qualified_name += *it + "_";
37   return qualified_name + name;
38 }
39 
GenerateMessage(const std::vector<std::string> & components,const grpc::string & name)40 grpc::string GenerateMessage(const std::vector<std::string> &components,
41                              const grpc::string &name) {
42   return "Message<" + WrapInNameSpace(components, name) + ">";
43 }
44 
45 // MARK: - Client
46 
GenerateClientFuncName(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)47 void GenerateClientFuncName(const grpc_generator::Method *method,
48                             grpc_generator::Printer *printer,
49                             std::map<grpc::string, grpc::string> *dictonary) {
50   auto vars = *dictonary;
51   if (method->NoStreaming()) {
52     printer->Print(vars,
53                    "  $GenAccess$func $MethodName$(\n"
54                    "    _ request: $Input$\n"
55                    "    , callOptions: CallOptions?$isNil$\n"
56                    "  ) -> UnaryCall<$Input$, $Output$>");
57     return;
58   }
59 
60   if (method->ServerStreaming()) {
61     printer->Print(vars,
62                    "  $GenAccess$func $MethodName$(\n"
63                    "    _ request: $Input$\n"
64                    "    , callOptions: CallOptions?$isNil$,\n"
65                    "    handler: @escaping ($Output$) -> Void\n"
66                    "  ) -> ServerStreamingCall<$Input$, $Output$>");
67     return;
68   }
69 
70   if (method->ClientStreaming()) {
71     printer->Print(vars,
72                    "  $GenAccess$func $MethodName$(\n"
73                    "    callOptions: CallOptions?$isNil$\n"
74                    "  ) -> ClientStreamingCall<$Input$, $Output$>");
75     return;
76   }
77 
78   printer->Print(vars,
79                  "  $GenAccess$func $MethodName$(\n"
80                  "    callOptions: CallOptions?$isNil$,\n"
81                  "    handler: @escaping ($Output$ ) -> Void\n"
82                  "  ) -> BidirectionalStreamingCall<$Input$, $Output$>");
83 }
84 
GenerateClientFuncBody(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)85 void GenerateClientFuncBody(const grpc_generator::Method *method,
86                             grpc_generator::Printer *printer,
87                             std::map<grpc::string, grpc::string> *dictonary) {
88   auto vars = *dictonary;
89   vars["Interceptor"] =
90       "interceptors: self.interceptors?.make$MethodName$Interceptors() ?? []";
91   if (method->NoStreaming()) {
92     printer->Print(
93         vars,
94         "    return self.makeUnaryCall(\n"
95         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
96         "      request: request,\n"
97         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
98         "      $Interceptor$\n"
99         "    )\n");
100     return;
101   }
102 
103   if (method->ServerStreaming()) {
104     printer->Print(
105         vars,
106         "    return self.makeServerStreamingCall(\n"
107         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
108         "      request: request,\n"
109         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
110         "      $Interceptor$,\n"
111         "      handler: handler\n"
112         "    )\n");
113     return;
114   }
115 
116   if (method->ClientStreaming()) {
117     printer->Print(
118         vars,
119         "    return self.makeClientStreamingCall(\n"
120         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
121         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
122         "      $Interceptor$\n"
123         "    )\n");
124     return;
125   }
126   printer->Print(vars,
127                  "    return self.makeBidirectionalStreamingCall(\n"
128                  "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
129                  "      callOptions: callOptions ?? self.defaultCallOptions,\n"
130                  "      $Interceptor$,\n"
131                  "      handler: handler\n"
132                  "    )\n");
133 }
134 
GenerateClientProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)135 void GenerateClientProtocol(const grpc_generator::Service *service,
136                             grpc_generator::Printer *printer,
137                             std::map<grpc::string, grpc::string> *dictonary) {
138   auto vars = *dictonary;
139   printer->Print(
140       vars,
141       "$ACCESS$ protocol $ServiceQualifiedName$ClientProtocol: GRPCClient {");
142   printer->Print("\n\n");
143   printer->Print("  var serviceName: String { get }");
144   printer->Print("\n\n");
145   printer->Print(
146       vars,
147       "  var interceptors: "
148       "$ServiceQualifiedName$ClientInterceptorFactoryProtocol? { get }");
149   printer->Print("\n\n");
150 
151   vars["GenAccess"] = "";
152   for (auto it = 0; it < service->method_count(); it++) {
153     auto method = service->method(it);
154     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
155                                     method->get_input_type_name());
156     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
157                                      method->get_output_type_name());
158     vars["MethodName"] = method->name();
159     vars["isNil"] = "";
160     GenerateClientFuncName(method.get(), &*printer, &vars);
161     printer->Print("\n\n");
162   }
163   printer->Print("}\n\n");
164 
165   printer->Print(vars, "extension $ServiceQualifiedName$ClientProtocol {");
166   printer->Print("\n\n");
167   printer->Print(vars,
168                  "  $ACCESS$ var serviceName: String { "
169                  "\"$PATH$$ServiceName$\" }\n");
170 
171   vars["GenAccess"] = service->is_internal() ? "internal " : "public ";
172   for (auto it = 0; it < service->method_count(); it++) {
173     auto method = service->method(it);
174     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
175                                     method->get_input_type_name());
176     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
177                                      method->get_output_type_name());
178     vars["MethodName"] = method->name();
179     vars["isNil"] = " = nil";
180     printer->Print("\n");
181     GenerateClientFuncName(method.get(), &*printer, &vars);
182     printer->Print(" {\n");
183     GenerateClientFuncBody(method.get(), &*printer, &vars);
184     printer->Print("  }\n");
185   }
186   printer->Print("}\n\n");
187 
188   printer->Print(vars,
189                  "$ACCESS$ protocol "
190                  "$ServiceQualifiedName$ClientInterceptorFactoryProtocol {\n");
191 
192   for (auto it = 0; it < service->method_count(); it++) {
193     auto method = service->method(it);
194     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
195                                     method->get_input_type_name());
196     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
197                                      method->get_output_type_name());
198     vars["MethodName"] = method->name();
199     printer->Print(
200         vars,
201         "  /// - Returns: Interceptors to use when invoking '$MethodName$'.\n");
202     printer->Print(vars,
203                    "  func make$MethodName$Interceptors() -> "
204                    "[ClientInterceptor<$Input$, $Output$>]\n\n");
205   }
206   printer->Print("}\n\n");
207 }
208 
GenerateClientClass(grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)209 void GenerateClientClass(grpc_generator::Printer *printer,
210                          std::map<grpc::string, grpc::string> *dictonary) {
211   auto vars = *dictonary;
212   printer->Print(vars,
213                  "$ACCESS$ final class $ServiceQualifiedName$ServiceClient: "
214                  "$ServiceQualifiedName$ClientProtocol {\n");
215   printer->Print(vars, "  $ACCESS$ let channel: GRPCChannel\n");
216   printer->Print(vars, "  $ACCESS$ var defaultCallOptions: CallOptions\n");
217   printer->Print(vars,
218                  "  $ACCESS$ var interceptors: "
219                  "$ServiceQualifiedName$ClientInterceptorFactoryProtocol?\n");
220   printer->Print("\n");
221   printer->Print(
222       vars,
223       "  $ACCESS$ init(\n"
224       "    channel: GRPCChannel,\n"
225       "    defaultCallOptions: CallOptions = CallOptions(),\n"
226       "    interceptors: "
227       "$ServiceQualifiedName$ClientInterceptorFactoryProtocol? = nil\n"
228       "  ) {\n");
229   printer->Print("    self.channel = channel\n");
230   printer->Print("    self.defaultCallOptions = defaultCallOptions\n");
231   printer->Print("    self.interceptors = interceptors\n");
232   printer->Print("  }");
233   printer->Print("\n");
234   printer->Print("}\n");
235 }
236 
237 // MARK: - Server
238 
GenerateServerFuncName(const grpc_generator::Method * method)239 grpc::string GenerateServerFuncName(const grpc_generator::Method *method) {
240   if (method->NoStreaming()) {
241     return "func $MethodName$(request: $Input$"
242            ", context: StatusOnlyCallContext) -> EventLoopFuture<$Output$>";
243   }
244 
245   if (method->ClientStreaming()) {
246     return "func $MethodName$(context: UnaryResponseCallContext<$Output$>) -> "
247            "EventLoopFuture<(StreamEvent<$Input$"
248            ">) -> Void>";
249   }
250 
251   if (method->ServerStreaming()) {
252     return "func $MethodName$(request: $Input$"
253            ", context: StreamingResponseCallContext<$Output$>) -> "
254            "EventLoopFuture<GRPCStatus>";
255   }
256   return "func $MethodName$(context: StreamingResponseCallContext<$Output$>) "
257          "-> EventLoopFuture<(StreamEvent<$Input$>) -> Void>";
258 }
259 
GenerateServerExtensionBody(const grpc_generator::Method * method)260 grpc::string GenerateServerExtensionBody(const grpc_generator::Method *method) {
261   grpc::string start = "    case \"$MethodName$\":\n    ";
262   grpc::string interceptors =
263       "      interceptors: self.interceptors?.make$MethodName$Interceptors() "
264       "?? [],\n";
265   if (method->NoStreaming()) {
266     return start +
267            "return UnaryServerHandler(\n"
268            "      context: context,\n"
269            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
270            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
271            interceptors +
272            "      userFunction: self.$MethodName$(request:context:))\n";
273   }
274   if (method->ServerStreaming()) {
275     return start +
276            "return ServerStreamingServerHandler(\n"
277            "      context: context,\n"
278            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
279            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
280            interceptors +
281            "      userFunction: self.$MethodName$(request:context:))\n";
282   }
283   if (method->ClientStreaming()) {
284     return start +
285            "return ClientStreamingServerHandler(\n"
286            "      context: context,\n"
287            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
288            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
289            interceptors +
290            "      observerFactory: self.$MethodName$(context:))\n";
291   }
292   if (method->BidiStreaming()) {
293     return start +
294            "return BidirectionalStreamingServerHandler(\n"
295            "      context: context,\n"
296            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
297            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
298            interceptors +
299            "      observerFactory: self.$MethodName$(context:))\n";
300   }
301   return "";
302 }
303 
GenerateServerProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)304 void GenerateServerProtocol(const grpc_generator::Service *service,
305                             grpc_generator::Printer *printer,
306                             std::map<grpc::string, grpc::string> *dictonary) {
307   auto vars = *dictonary;
308   printer->Print(vars,
309                  "$ACCESS$ protocol $ServiceQualifiedName$Provider: "
310                  "CallHandlerProvider {\n");
311   printer->Print(
312       vars,
313       "  var interceptors: "
314       "$ServiceQualifiedName$ServerInterceptorFactoryProtocol? { get }\n");
315   for (auto it = 0; it < service->method_count(); it++) {
316     auto method = service->method(it);
317     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
318                                     method->get_input_type_name());
319     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
320                                      method->get_output_type_name());
321     vars["MethodName"] = method->name();
322     printer->Print("  ");
323     auto func = GenerateServerFuncName(method.get());
324     printer->Print(vars, func.c_str());
325     printer->Print("\n");
326   }
327   printer->Print("}\n\n");
328 
329   printer->Print(vars, "$ACCESS$ extension $ServiceQualifiedName$Provider {\n");
330   printer->Print("\n");
331   printer->Print(vars,
332                  "  var serviceName: Substring { return "
333                  "\"$PATH$$ServiceName$\" }\n");
334   printer->Print("\n");
335   printer->Print(
336       "  func handle(method name: Substring, context: "
337       "CallHandlerContext) -> GRPCServerHandlerProtocol? {\n");
338   printer->Print("    switch name {\n");
339   for (auto it = 0; it < service->method_count(); it++) {
340     auto method = service->method(it);
341     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
342                                     method->get_input_type_name());
343     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
344                                      method->get_output_type_name());
345     vars["MethodName"] = method->name();
346     auto body = GenerateServerExtensionBody(method.get());
347     printer->Print(vars, body.c_str());
348     printer->Print("\n");
349   }
350   printer->Print("    default: return nil;\n");
351   printer->Print("    }\n");
352   printer->Print("  }\n\n");
353   printer->Print("}\n\n");
354 
355   printer->Print(vars,
356                  "$ACCESS$ protocol "
357                  "$ServiceQualifiedName$ServerInterceptorFactoryProtocol {\n");
358   for (auto it = 0; it < service->method_count(); it++) {
359     auto method = service->method(it);
360     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
361                                     method->get_input_type_name());
362     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
363                                      method->get_output_type_name());
364     vars["MethodName"] = method->name();
365     printer->Print(
366         vars,
367         "  /// - Returns: Interceptors to use when handling '$MethodName$'.\n"
368         "  ///   Defaults to calling `self.makeInterceptors()`.\n");
369     printer->Print(vars,
370                    "  func make$MethodName$Interceptors() -> "
371                    "[ServerInterceptor<$Input$, $Output$>]\n\n");
372   }
373   printer->Print("}");
374 }
375 
Generate(grpc_generator::File * file,const grpc_generator::Service * service)376 grpc::string Generate(grpc_generator::File *file,
377                       const grpc_generator::Service *service) {
378   grpc::string output;
379   std::map<grpc::string, grpc::string> vars;
380   vars["PATH"] = file->package();
381   if (!file->package().empty()) { vars["PATH"].append("."); }
382   vars["ServiceQualifiedName"] =
383       WrapInNameSpace(service->namespace_parts(), service->name());
384   vars["ServiceName"] = service->name();
385   vars["ACCESS"] = service->is_internal() ? "internal" : "public";
386   auto printer = file->CreatePrinter(&output);
387   printer->Print(
388       vars,
389       "/// Usage: instantiate $ServiceQualifiedName$ServiceClient, then call "
390       "methods of this protocol to make API calls.\n");
391   GenerateClientProtocol(service, &*printer, &vars);
392   GenerateClientClass(&*printer, &vars);
393   printer->Print("\n");
394   GenerateServerProtocol(service, &*printer, &vars);
395   return output;
396 }
397 
GenerateHeader()398 grpc::string GenerateHeader() {
399   grpc::string code;
400   code +=
401       "/// The following code is generated by the Flatbuffers library which "
402       "might not be in sync with grpc-swift\n";
403   code +=
404       "/// in case of an issue please open github issue, though it would be "
405       "maintained\n";
406   code += "\n";
407   code += "// swiftlint:disable all\n";
408   code += "// swiftformat:disable all\n";
409   code += "\n";
410   code += "import Foundation\n";
411   code += "import GRPC\n";
412   code += "import NIO\n";
413   code += "import NIOHTTP1\n";
414   code += "import FlatBuffers\n";
415   code += "\n";
416   code +=
417       "public protocol GRPCFlatBufPayload: GRPCPayload, FlatBufferGRPCMessage "
418       "{}\n";
419 
420   code += "public extension GRPCFlatBufPayload {\n";
421   code += "  init(serializedByteBuffer: inout NIO.ByteBuffer) throws {\n";
422   code +=
423       "    self.init(byteBuffer: FlatBuffers.ByteBuffer(contiguousBytes: "
424       "serializedByteBuffer.readableBytesView, count: "
425       "serializedByteBuffer.readableBytes))\n";
426   code += "  }\n";
427 
428   code += "  func serialize(into buffer: inout NIO.ByteBuffer) throws {\n";
429   code +=
430       "    let buf = UnsafeRawBufferPointer(start: self.rawPointer, count: "
431       "Int(self.size))\n";
432   code += "    buffer.writeBytes(buf)\n";
433   code += "  }\n";
434   code += "}\n";
435   code += "extension Message: GRPCFlatBufPayload {}\n";
436   return code;
437 }
438 }  // namespace grpc_swift_generator
439