1 /*
2 * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "ipc_generator_impl.h"
17
IpcGeneratorImpl()18 IpcGeneratorImpl::IpcGeneratorImpl() {}
19
~IpcGeneratorImpl()20 IpcGeneratorImpl::~IpcGeneratorImpl() {}
21
22 namespace {
23 const std::string BASE_HEADER_STRING = R"(
24 #pragma once
25
26 #include "#HEAD_FILE_NAME#.pb.h"
27 #include "service_base.h"
28 #include <cstdint>
29 #include <mutex>
30
31 class SocketContext;
32 class UnixSocketClient;
33
34 #PROTOCOL_ENUM#
35
36 class #SERVICE_CLASS_NAME#:public ServiceBase
37 {
38 public:
39 #SERVICE_CLASS_NAME#();
40 bool ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size) override;
41 #RESPONSE_DEFINE#
42 };
43
44 class #CLIENT_CLASS_NAME#:public ServiceBase
45 {
46 public:
47 #CLIENT_CLASS_NAME#();
48
49 std::shared_ptr<UnixSocketClient> unixSocketClient_;
50 bool Connect(const std::string addrname);
51 bool ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size) override;
52 google::protobuf::Message *presponse;
53 uint32_t waitingFor;
54 #VIRTUAL_RESPONSE_FUNC#
55 };
56 )";
57
58 const std::string BASE_SOURCE_STRING = R"(
59 #include "#HEAD_FILE_NAME#.ipc.h"
60 #include "#HEAD_FILE_NAME#.pb.h"
61 #include "socket_context.h"
62 #include "unix_socket_client.h"
63 #include "unix_socket_server.h"
64 #include <unistd.h>
65
66 namespace {
67 constexpr uint32_t WAIT_FOR_EVER = 24 * 60 * 60 * 1000;
68 }
69
70 #SERVICE_CLASS_NAME#::#SERVICE_CLASS_NAME#()
71 {
72 serviceName_ = "#SERVICE_NAME#";
73 }
74
75 #RESPONSE_IMPLEMENT#
76
77 #SERVICE_PROTOCOL_PROC_FUNC#
78 bool #SERVICE_CLASS_NAME#::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
79 {
80 switch (pnum) {
81 #SERVICE_PROTOCOL_PROC#
82 }
83 return false;
84 }
85
86 #CLIENT_CLASS_NAME#::#CLIENT_CLASS_NAME#()
87 {
88 unixSocketClient_ = nullptr;
89 serviceName_ = "#SERVICE_NAME#";
90 }
91 bool #CLIENT_CLASS_NAME#::Connect(const std::string addrname)
92 {
93 if (unixSocketClient_ != nullptr) {
94 return false;
95 }
96 unixSocketClient_ = std::make_shared<UnixSocketClient>();
97 if (!unixSocketClient_->Connect(addrname, *this)) {
98 printf("Socket Connect failed\n");
99 unixSocketClient_ = nullptr;
100 return false;
101 }
102 return true;
103 }
104
105 #CLIENT_SEND_REQUEST_PROC_FUNC#
106
107 #CLIENT_SEND_PROTOCOL_PROC_FUNC#
108 bool #CLIENT_CLASS_NAME#::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
109 {
110 switch (pnum) {
111 #CLIENT_PROTOCOL_PROC#
112 }
113 if (waitingFor == pnum) {
114 waitingFor = -1;
115 mWait_.unlock();
116 }
117 return false;
118 }
119 )";
120
SwapName(std::string s)121 std::string SwapName(std::string s)
122 {
123 std::string ret = "";
124 bool b = true;
125 for (size_t i = 0; i < s.length(); i++) {
126 char c = s[i];
127 if (c == '_') {
128 b = true;
129 } else if (b && c >= 'a' && c <= 'z') {
130 ret += (c + 'A' - 'a');
131 b = false;
132 } else {
133 ret += c;
134 }
135 }
136 return ret;
137 }
ReplaceStr(const std::string & base,const std::string & _from,const std::string & _to)138 std::string ReplaceStr(const std::string& base, const std::string& _from, const std::string& _to)
139 {
140 std::string ret = base;
141 while (true) {
142 size_t pos = ret.find(_from, 0);
143 if (pos == std::string::npos) {
144 break;
145 }
146 ret.replace(pos, _from.length(), _to);
147 }
148 return ret;
149 }
150 } // namespace
151
SetNames(std::string fileName,std::string packageName)152 std::string IpcGeneratorImpl::SetNames(std::string fileName, std::string packageName)
153 {
154 fileName_ = fileName;
155 packageName_ = packageName + "::";
156 headFileName_ = "";
157
158 for (size_t i = 0; i < fileName.length(); i++) {
159 if (fileName.c_str()[i] == '.') {
160 break;
161 }
162 headFileName_ += fileName.c_str()[i];
163 }
164 baseName_ = SwapName(headFileName_);
165
166 serviceCount_ = 0;
167
168 serviceList_.clear();
169 enumMessageDict_.clear();
170
171 return headFileName_;
172 }
173
AddService(std::string serviceName)174 bool IpcGeneratorImpl::AddService(std::string serviceName)
175 {
176 for (int i = 0; i < serviceCount_; i++) {
177 if (serviceList_[i].serviceName_ == serviceName) {
178 return false;
179 }
180 }
181 serviceList_[serviceCount_].serviceName_ = serviceName;
182 serviceCount_++;
183 return true;
184 }
185
AddServiceMethod(std::string serviceName,std::string methodName,std::string requestName,std::string responseName)186 bool IpcGeneratorImpl::AddServiceMethod(std::string serviceName,
187 std::string methodName,
188 std::string requestName,
189 std::string responseName)
190 {
191 for (int i = 0; i < serviceCount_; i++) {
192 if (serviceList_[i].serviceName_ == serviceName) {
193 return serviceList_[i].AddMethod(methodName, requestName, responseName);
194 }
195 }
196 return false;
197 }
198
GenerateHeader(std::string & header_str)199 void IpcGeneratorImpl::GenerateHeader(std::string& header_str)
200 {
201 for (int i = 0; i < serviceCount_; i++) {
202 std::string server_class_name = serviceList_[i].serviceName_ + "Server";
203 header_str = ReplaceStr(header_str, "#SERVICE_CLASS_NAME#", server_class_name);
204
205 std::string tmp1 = "";
206 std::string tmp2 = "";
207 for (int j = 0; j < serviceList_[i].methodCount_; j++) {
208 tmp1 += "\tvirtual bool " + serviceList_[i].methodList_[j] + "(SocketContext &context," + packageName_ +
209 serviceList_[i].requestList_[j] + " &request," + packageName_ + serviceList_[i].responseList_[j] +
210 " &response);\n";
211
212 tmp2 += "\tbool SendResponse" + serviceList_[i].responseList_[j] + "(SocketContext &context," +
213 packageName_ + serviceList_[i].responseList_[j] + " &response);\n";
214 }
215 tmp1 += "\n" + tmp2;
216 header_str = ReplaceStr(header_str, "#RESPONSE_DEFINE#", tmp1);
217
218 std::string client_class_name = serviceList_[i].serviceName_ + "Client";
219 header_str = ReplaceStr(header_str, "#CLIENT_CLASS_NAME#", client_class_name);
220
221 tmp1 = "";
222 for (int j = 0; j < serviceList_[i].methodCount_; j++) {
223 tmp1 += "\tbool " + serviceList_[i].methodList_[j] + "(" + packageName_ + serviceList_[i].requestList_[j];
224 tmp1 += " &request," + packageName_ + serviceList_[i].responseList_[j];
225 tmp1 += " &response,uint32_t timeout_ms=5000);\n";
226 tmp1 += "\tbool " + serviceList_[i].methodList_[j] + "(" + packageName_ + serviceList_[i].requestList_[j];
227 tmp1 += " &request);\n";
228 }
229 tmp1 += "\n";
230 for (int j = 0; j < serviceList_[i].methodCount_; j++) {
231 tmp1 += "\tvirtual bool On" + serviceList_[i].responseList_[j] + "(SocketContext &context," + packageName_;
232 tmp1 += serviceList_[i].responseList_[j] + " &response);\n";
233 }
234
235 header_str = ReplaceStr(header_str, "#VIRTUAL_RESPONSE_FUNC#", tmp1);
236 }
237 }
238
GenHeader()239 std::string IpcGeneratorImpl::GenHeader()
240 {
241 std::string header_str = BASE_HEADER_STRING;
242 std::string tmp1;
243 header_str = ReplaceStr(header_str, "#HEAD_FILE_NAME#", headFileName_);
244 const int numTwo = 2;
245
246 if (serviceCount_ > 0) {
247 tmp1 = "enum {\n";
248 for (int i = 0; i < serviceCount_; i++) {
249 for (int j = 0; j < serviceList_[i].methodCount_; j++) {
250 tmp1 += "\tIpcProtocol" + baseName_ + serviceList_[i].requestList_[j];
251 tmp1 += "=" + std::to_string(j * numTwo) + ",\n";
252 tmp1 += "\tIpcProtocol" + baseName_ + serviceList_[i].responseList_[j];
253 tmp1 += "=" + std::to_string(j * numTwo + 1) + ",\n";
254 }
255 }
256 tmp1 += "};";
257 } else {
258 tmp1 = "";
259 }
260 header_str = ReplaceStr(header_str, "#PROTOCOL_ENUM#", tmp1);
261
262 GenerateHeader(header_str);
263 header_str = ReplaceStr(header_str, "\t", " ");
264 return header_str;
265 }
266
267 namespace {
268 const std::string SEND_RESPONSE_IMPL_STRING = R"(
269 bool #SERVER_CLASS_NAME#::SendResponse#RESPONSE_NAME#(SocketContext &context,
270 #PACKAGE_NAME##RESPONSE_NAME# &response) {
271 context.SendProtobuf(#ENUM_STR#, response);
272 return false;
273 }
274 )";
275 }
GenSendResponseImpl(int servicep,const std::string & server_class_name)276 std::string IpcGeneratorImpl::GenSendResponseImpl(int servicep, const std::string& server_class_name)
277 {
278 std::string ret = "";
279 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
280 std::string enum_str = "IpcProtocol" + baseName_ + serviceList_[servicep].responseList_[j];
281 std::string tmp = ReplaceStr(SEND_RESPONSE_IMPL_STRING, "#SERVER_CLASS_NAME#", server_class_name);
282 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
283 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
284 tmp = ReplaceStr(tmp, "#ENUM_STR#", enum_str);
285 ret += tmp;
286 }
287 return ret;
288 }
289 namespace {
290 const std::string ON_RESPONSE_IMPL_STRING = R"(
291 bool #CLIENT_CLASS_NAME#::On#RESPONSE_NAME#(SocketContext &context, #PACKAGE_NAME##RESPONSE_NAME# &response) {
292 return false;
293 }
294 )";
295 }
GenOnResponseImpl(int servicep,const std::string & client_class_name)296 std::string IpcGeneratorImpl::GenOnResponseImpl(int servicep, const std::string& client_class_name)
297 {
298 std::string ret = "";
299 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
300 std::string tmp = ReplaceStr(ON_RESPONSE_IMPL_STRING, "#CLIENT_CLASS_NAME#", client_class_name);
301 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
302 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
303 ret += tmp;
304 }
305 return ret;
306 }
307 namespace {
308 const std::string SERVICE_CALL_IMPL_STRING = R"(
309 bool #SERVER_CLASS_NAME#::#METHOD_NAME#(SocketContext &context,
310 #PACKAGE_NAME##REQUEST_NAME# &request,
311 #PACKAGE_NAME##RESPONSE_NAME# &response) {
312 return false;
313 }
314 )";
315 }
GenServiceCallImpl(int servicep,const std::string & server_class_name)316 std::string IpcGeneratorImpl::GenServiceCallImpl(int servicep, const std::string& server_class_name)
317 {
318 std::string ret = "";
319 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
320 std::string tmp = ReplaceStr(SERVICE_CALL_IMPL_STRING, "#SERVER_CLASS_NAME#", server_class_name);
321 tmp = ReplaceStr(tmp, "#SERVER_CLASS_NAME#", server_class_name);
322 tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
323 tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
324 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
325 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
326 ret += tmp;
327 }
328 return ret;
329 }
330 namespace {
331 const std::string CLIENT_PROC_IMPL_STRING = R"(
332 case IpcProtocol#BASE_NAME##REQUEST_NAME#:{
333 #PACKAGE_NAME##REQUEST_NAME# request;
334 #PACKAGE_NAME##RESPONSE_NAME# response;
335 request.ParseFromArray(buf, size);
336 if (#METHOD_NAME#(context, request, response)) {
337 context.SendProtobuf(IpcProtocol#BASE_NAME##RESPONSE_NAME#, response);
338 }
339 }
340 break;
341 )";
342 const std::string CLIENT_PROC_NOTIFYRESULT_STRING = R"(
343 case IpcProtocol#BASE_NAME##REQUEST_NAME#:{
344 #PACKAGE_NAME##REQUEST_NAME# request;
345 #PACKAGE_NAME##RESPONSE_NAME# response;
346 request.ParseFromArray(buf, size);
347 #METHOD_NAME#(context, request, response);
348 }
349 break;
350 )";
351 }
GenClientProcImpl(int servicep)352 std::string IpcGeneratorImpl::GenClientProcImpl(int servicep)
353 {
354 std::string ret = "";
355 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
356 std::string tmp = ReplaceStr(CLIENT_PROC_IMPL_STRING, "#BASE_NAME#", baseName_);
357 if (serviceList_[servicep].methodList_[j] == "NotifyResult") {
358 tmp = ReplaceStr(CLIENT_PROC_NOTIFYRESULT_STRING, "#BASE_NAME#", baseName_);
359 }
360 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
361 tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
362 tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
363 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
364 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
365 ret += tmp;
366 }
367 return ret;
368 }
369 namespace {
370 const std::string CLIENT_REQUEST_IMPL_STRING = R"(
371 bool #CLIENT_CLASS_NAME#::#METHOD_NAME#(#PACKAGE_NAME##REQUEST_NAME# &request,
372 #PACKAGE_NAME##RESPONSE_NAME# &response,
373 uint32_t timeout_ms)
374 {
375 mWait_.lock();
376 if (timeout_ms<=0) {
377 timeout_ms=WAIT_FOR_EVER;
378 }
379 waitingFor=IpcProtocol#BASE_NAME##RESPONSE_NAME#;
380 presponse=&response;
381 if (unixSocketClient_!=nullptr) {
382 unixSocketClient_->SendProtobuf(IpcProtocol#BASE_NAME##REQUEST_NAME#, request);
383 }
384 if (mWait_.try_lock_for(std::chrono::milliseconds(timeout_ms))) {
385 mWait_.unlock();
386 return true;
387 }
388 waitingFor=-1;
389 mWait_.unlock();
390 return false;
391 }
392 bool #CLIENT_CLASS_NAME#::#METHOD_NAME#(#PACKAGE_NAME##REQUEST_NAME# &request)
393 {
394 unixSocketClient_->SendProtobuf(IpcProtocol#BASE_NAME##REQUEST_NAME#, request);
395 return true;
396 }
397 )";
398 }
GenClientRequestImpl(int servicep,const std::string & client_class_name)399 std::string IpcGeneratorImpl::GenClientRequestImpl(int servicep, const std::string& client_class_name)
400 {
401 std::string ret = "";
402 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
403 std::string tmp = ReplaceStr(CLIENT_REQUEST_IMPL_STRING, "#CLIENT_CLASS_NAME#", client_class_name);
404 tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
405 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
406 tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
407 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
408 tmp = ReplaceStr(tmp, "#BASE_NAME#", baseName_);
409 ret += tmp;
410 }
411 return ret;
412 }
413 namespace {
414 const std::string SERVICE_PROC_IMPL_STRING = R"(
415 case IpcProtocol#BASE_NAME##RESPONSE_NAME#:
416 {
417 if (waitingFor==pnum) {
418 presponse->ParseFromArray(buf, size);
419 }
420 else {
421 #PACKAGE_NAME##RESPONSE_NAME# response#NUM#;
422 response#NUM#.ParseFromArray(buf, size);
423 On#RESPONSE_NAME#(context, response#NUM#);
424 }
425 }
426 break;
427 )";
428 }
GenServiceProcImpl(int servicep)429 std::string IpcGeneratorImpl::GenServiceProcImpl(int servicep)
430 {
431 std::string ret = "";
432 for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
433 std::string tmp = ReplaceStr(SERVICE_PROC_IMPL_STRING, "#BASE_NAME#", baseName_);
434 tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
435 tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
436 tmp = ReplaceStr(tmp, "#NUM#", std::to_string(j + 1));
437
438 ret += tmp;
439 }
440 return ret;
441 }
442
GenSource()443 std::string IpcGeneratorImpl::GenSource()
444 {
445 std::string source_str = BASE_SOURCE_STRING;
446
447 source_str = ReplaceStr(source_str, "#HEAD_FILE_NAME#", headFileName_);
448
449 for (int i = 0; i < serviceCount_; i++) {
450 std::string server_class_name = serviceList_[i].serviceName_ + "Server";
451 source_str = ReplaceStr(source_str, "#SERVICE_CLASS_NAME#", server_class_name);
452 source_str = ReplaceStr(source_str, "#SERVICE_NAME#", serviceList_[i].serviceName_);
453 std::string client_class_name = serviceList_[i].serviceName_ + "Client";
454 source_str = ReplaceStr(source_str, "#CLIENT_CLASS_NAME#", client_class_name);
455
456 source_str = ReplaceStr(source_str, "#RESPONSE_IMPLEMENT#", GenSendResponseImpl(i, server_class_name));
457 source_str = ReplaceStr(source_str, "#CLIENT_SEND_REQUEST_PROC_FUNC#", GenOnResponseImpl(i, client_class_name));
458
459 source_str = ReplaceStr(source_str, "#SERVICE_PROTOCOL_PROC_FUNC#", GenServiceCallImpl(i, server_class_name));
460 source_str = ReplaceStr(source_str, "#SERVICE_PROTOCOL_PROC#", GenClientProcImpl(i));
461 source_str = ReplaceStr(source_str, "#SERVICE_NAME#", serviceList_[i].serviceName_);
462
463 source_str = ReplaceStr(source_str, "#CLIENT_PROTOCOL_PROC#", GenServiceProcImpl(i));
464 source_str =
465 ReplaceStr(source_str, "#CLIENT_SEND_PROTOCOL_PROC_FUNC#", GenClientRequestImpl(i, client_class_name));
466 }
467
468 source_str = ReplaceStr(source_str, "\t", " ");
469 return source_str;
470 }
471