1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include <vector>
21
22 #include <google/protobuf/descriptor.h>
23 #include <google/protobuf/compiler/plugin.h>
24 #include <google/protobuf/compiler/code_generator.h>
25 #include <google/protobuf/io/printer.h>
26 #include <google/protobuf/io/zero_copy_stream.h>
27 #include <google/protobuf/stubs/strutil.h>
28
29 #include "nugget/protobuf/options.pb.h"
30
31 using ::google::protobuf::FileDescriptor;
32 using ::google::protobuf::JoinStrings;
33 using ::google::protobuf::MethodDescriptor;
34 using ::google::protobuf::ServiceDescriptor;
35 using ::google::protobuf::Split;
36 using ::google::protobuf::SplitStringUsing;
37 using ::google::protobuf::StripSuffixString;
38 using ::google::protobuf::compiler::CodeGenerator;
39 using ::google::protobuf::compiler::OutputDirectory;
40 using ::google::protobuf::io::Printer;
41 using ::google::protobuf::io::ZeroCopyOutputStream;
42
43 using ::nugget::protobuf::app_id;
44 using ::nugget::protobuf::request_buffer_size;
45 using ::nugget::protobuf::response_buffer_size;
46
47 namespace {
48
validateServiceOptions(const ServiceDescriptor & service)49 std::string validateServiceOptions(const ServiceDescriptor& service) {
50 if (!service.options().HasExtension(app_id)) {
51 return "nugget.protobuf.app_id is not defined for service " + service.name();
52 }
53 if (!service.options().HasExtension(request_buffer_size)) {
54 return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
55 }
56 if (!service.options().HasExtension(response_buffer_size)) {
57 return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
58 }
59 return "";
60 }
61
62 template <typename Descriptor>
Packages(const Descriptor & descriptor)63 std::vector<std::string> Packages(const Descriptor& descriptor) {
64 std::vector<std::string> namespaces;
65 SplitStringUsing(descriptor.full_name(), ".", &namespaces);
66 namespaces.pop_back(); // just take the package
67 return namespaces;
68 }
69
70 template <typename Descriptor>
FullyQualifiedIdentifier(const Descriptor & descriptor)71 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
72 const auto namespaces = Packages(descriptor);
73 if (namespaces.empty()) {
74 return "::" + descriptor.name();
75 } else {
76 std::string namespace_path;
77 JoinStrings(namespaces, "::", &namespace_path);
78 return "::" + namespace_path + "::" + descriptor.name();
79 }
80 }
81
82 template <typename Descriptor>
FullyQualifiedHeader(const Descriptor & descriptor)83 std::string FullyQualifiedHeader(const Descriptor& descriptor) {
84 const auto packages = Packages(descriptor);
85 const auto file = Split(descriptor.file()->name(), "/").back();
86 const auto header = StripSuffixString(file, ".proto") + ".pb.h";
87 if (packages.empty()) {
88 return header;
89 } else {
90 std::string package_path;
91 JoinStrings(packages, "/", &package_path);
92 return package_path + "/" + header;
93 }
94 }
95
96 template <typename Descriptor>
OpenNamespaces(Printer & printer,const Descriptor & descriptor)97 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
98 const auto namespaces = Packages(descriptor);
99 for (const auto& ns : namespaces) {
100 std::map<std::string, std::string> namespaceVars;
101 namespaceVars["namespace"] = ns;
102 printer.Print(namespaceVars, R"(
103 namespace $namespace$ {)");
104 }
105 }
106
107 template <typename Descriptor>
CloseNamespaces(Printer & printer,const Descriptor & descriptor)108 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
109 const auto namespaces = Packages(descriptor);
110 for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
111 std::map<std::string, std::string> namespaceVars;
112 namespaceVars["namespace"] = *it;
113 printer.Print(namespaceVars, R"(
114 } // namespace $namespace$)");
115 }
116 }
117
ForEachMethod(const ServiceDescriptor & service,std::function<void (std::map<std::string,std::string>)> handler)118 void ForEachMethod(const ServiceDescriptor& service,
119 std::function<void(std::map<std::string, std::string>)> handler) {
120 for (int i = 0; i < service.method_count(); ++i) {
121 const MethodDescriptor& method = *service.method(i);
122 std::map<std::string, std::string> vars;
123 vars["method_id"] = std::to_string(i);
124 vars["method_name"] = method.name();
125 vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
126 vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
127 handler(vars);
128 }
129 }
130
GenerateMockClient(Printer & printer,const ServiceDescriptor & service)131 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
132 std::map<std::string, std::string> vars;
133 vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
134 vars["service_header"] = service.name() + ".client.h";
135 vars["mock_class"] = "Mock" + service.name();
136 vars["class"] = service.name();
137
138 printer.Print(vars, R"(
139 #ifndef $include_guard$
140 #define $include_guard$
141
142 #include <gmock/gmock.h>
143
144 #include <$service_header$>)");
145
146 OpenNamespaces(printer, service);
147
148 printer.Print(vars, R"(
149 struct $mock_class$ : public I$class$ {)");
150
151 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
152 printer.Print(methodVars, R"(
153 MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
154 });
155
156 printer.Print(vars, R"(
157 };)");
158
159 CloseNamespaces(printer, service);
160
161 printer.Print(vars, R"(
162 #endif)");
163 }
164
GenerateClientHeader(Printer & printer,const ServiceDescriptor & service)165 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
166 std::map<std::string, std::string> vars;
167 vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
168 vars["protobuf_header"] = FullyQualifiedHeader(service);
169 vars["class"] = service.name();
170 vars["iface_class"] = "I" + service.name();
171 vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
172
173 printer.Print(vars, R"(
174 #ifndef $include_guard$
175 #define $include_guard$
176
177 #include <application.h>
178 #include <nos/AppClient.h>
179 #include <nos/NuggetClientInterface.h>
180
181 #include "$protobuf_header$")");
182
183 OpenNamespaces(printer, service);
184
185 // Pure virtual interface to make testing easier
186 printer.Print(vars, R"(
187 class $iface_class$ {
188 public:
189 virtual ~$iface_class$() = default;)");
190
191 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
192 printer.Print(methodVars, R"(
193 virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
194 });
195
196 printer.Print(vars, R"(
197 };)");
198
199 // Implementation of the interface for Nugget
200 printer.Print(vars, R"(
201 class $class$ : public $iface_class$ {
202 ::nos::AppClient _app;
203 public:
204 $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
205 ~$class$() override = default;)");
206
207 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
208 printer.Print(methodVars, R"(
209 uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
210 });
211
212 printer.Print(vars, R"(
213 };)");
214
215 CloseNamespaces(printer, service);
216
217 printer.Print(vars, R"(
218 #endif)");
219 }
220
GenerateClientSource(Printer & printer,const ServiceDescriptor & service)221 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
222 std::map<std::string, std::string> vars;
223 vars["generated_header"] = service.name() + ".client.h";
224 vars["class"] = service.name();
225
226 const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
227 const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
228 vars["max_request_size"] = std::to_string(max_request_size);
229 vars["max_response_size"] = std::to_string(max_response_size);
230
231 printer.Print(vars, R"(
232 #include <$generated_header$>
233
234 #include <application.h>)");
235
236 OpenNamespaces(printer, service);
237
238 // Methods
239 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
240 methodVars.insert(vars.begin(), vars.end());
241 printer.Print(methodVars, R"(
242 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
243 const size_t request_size = request.ByteSize();
244 if (request_size > $max_request_size$) {
245 return APP_ERROR_TOO_MUCH;
246 }
247 std::vector<uint8_t> buffer(request_size);
248 if (!request.SerializeToArray(buffer.data(), buffer.size())) {
249 return APP_ERROR_RPC;
250 }
251 std::vector<uint8_t> responseBuffer;
252 if (response != nullptr) {
253 responseBuffer.resize($max_response_size$);
254 }
255 const uint32_t appStatus = _app.Call($method_id$, buffer,
256 (response != nullptr) ? &responseBuffer : nullptr);
257 if (appStatus == APP_SUCCESS && response != nullptr) {
258 if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
259 return APP_ERROR_RPC;
260 }
261 }
262 return appStatus;
263 })");
264 });
265
266 CloseNamespaces(printer, service);
267 }
268
269 // Generator for C++ Nugget service client
270 class CppNuggetServiceClientGenerator : public CodeGenerator {
271 public:
272 CppNuggetServiceClientGenerator() = default;
273 ~CppNuggetServiceClientGenerator() override = default;
274
Generate(const FileDescriptor * file,const std::string & parameter,OutputDirectory * output_directory,std::string * error) const275 bool Generate(const FileDescriptor* file,
276 const std::string& parameter,
277 OutputDirectory* output_directory,
278 std::string* error) const override {
279 for (int i = 0; i < file->service_count(); ++i) {
280 const auto& service = *file->service(i);
281
282 *error = validateServiceOptions(service);
283 if (!error->empty()) {
284 return false;
285 }
286
287 if (parameter == "mock") {
288 std::unique_ptr<ZeroCopyOutputStream> output{
289 output_directory->Open("Mock" + service.name() + ".client.h")};
290 Printer printer(output.get(), '$');
291 GenerateMockClient(printer, service);
292 } else if (parameter == "header") {
293 std::unique_ptr<ZeroCopyOutputStream> output{
294 output_directory->Open(service.name() + ".client.h")};
295 Printer printer(output.get(), '$');
296 GenerateClientHeader(printer, service);
297 } else if (parameter == "source") {
298 std::unique_ptr<ZeroCopyOutputStream> output{
299 output_directory->Open(service.name() + ".client.cpp")};
300 Printer printer(output.get(), '$');
301 GenerateClientSource(printer, service);
302 } else {
303 *error = "Illegal parameter: must be mock|header|source";
304 return false;
305 }
306 }
307
308 return true;
309 }
310
311 private:
312 GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator);
313 };
314
315 } // namespace
316
main(int argc,char * argv[])317 int main(int argc, char* argv[]) {
318 GOOGLE_PROTOBUF_VERIFY_VERSION;
319 CppNuggetServiceClientGenerator generator;
320 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
321 }
322