1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "protocol_proto.h"
17 #include <new>
18 #include <iterator>
19 #include "hash.h"
20 #include "securec.h"
21 #include "version.h"
22 #include "db_common.h"
23 #include "log_print.h"
24 #include "macro_utils.h"
25 #include "endian_convert.h"
26 #include "header_converter.h"
27
28 namespace DistributedDB {
29 namespace {
30 const uint16_t MAGIC_CODE = 0xAAAA;
31 const uint16_t PROTOCOL_VERSION = 0;
32 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
33 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
34 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
35 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
36 const uint8_t MAX_PADDING_LEN = 7;
37 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
38 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
39 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
40 // LabelExchange(Ack) Frame Field Length
41 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
42 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
43 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
44 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
45 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
46 // Local func to set and get frame Type from packet Type field
SetFrameType(uint8_t & inPacketType,FrameType inFrameType)47 void SetFrameType(uint8_t &inPacketType, FrameType inFrameType)
48 {
49 inPacketType &= 0x0F; // Use 0x0F to clear high for bits
50 inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
51 }
GetFrameType(uint8_t inPacketType)52 FrameType GetFrameType(uint8_t inPacketType)
53 {
54 uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0x0F to get high 4 bits
55 if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
56 return FrameType::INVALID_MAX_FRAME_TYPE;
57 }
58 return static_cast<FrameType>(frameType);
59 }
60 }
61
62 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
63
GetAppLayerFrameHeaderLength()64 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
65 {
66 uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
67 return length;
68 }
69
GetLengthBeforeSerializedData()70 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
71 {
72 uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
73 return length;
74 }
75
GetCommLayerFrameHeaderLength()76 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
77 {
78 uint32_t length = sizeof(CommPhyHeader);
79 return length;
80 }
81
ToSerialBuffer(const Message * inMsg,int & outErrorNo,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader)82 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg, int &outErrorNo,
83 std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader)
84 {
85 if (inMsg == nullptr) {
86 outErrorNo = -E_INVALID_ARGS;
87 return nullptr;
88 }
89
90 uint32_t serializeLen = 0;
91 if (!onlyMsgHeader) {
92 int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
93 if (errCode != E_OK) {
94 outErrorNo = errCode;
95 return nullptr;
96 }
97 }
98 uint32_t headSize = 0;
99 int errCode = GetExtendHeadDataSize(extendHandle, headSize);
100 if (errCode != E_OK) {
101 outErrorNo = errCode;
102 return nullptr;
103 }
104
105 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
106 if (buffer == nullptr) {
107 outErrorNo = -E_OUT_OF_MEMORY;
108 return nullptr;
109 }
110 if (headSize > 0) {
111 buffer->SetExtendHeadLength(headSize);
112 }
113 // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
114 uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
115 errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
116 if (errCode != E_OK) {
117 LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
118 goto ERROR_HANDLE;
119 }
120 errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
121 if (errCode != E_OK) {
122 goto ERROR_HANDLE;
123 }
124
125 // Serialize the MessageHeader and data if need
126 errCode = SerializeMessage(buffer, inMsg);
127 if (errCode != E_OK) {
128 LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
129 goto ERROR_HANDLE;
130 }
131 outErrorNo = E_OK;
132 return buffer;
133 ERROR_HANDLE:
134 outErrorNo = errCode;
135 delete buffer;
136 buffer = nullptr;
137 return nullptr;
138 }
139
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)140 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
141 {
142 if (inBuff == nullptr) {
143 outErrorNo = -E_INVALID_ARGS;
144 return nullptr;
145 }
146 Message *outMsg = new (std::nothrow) Message();
147 if (outMsg == nullptr) {
148 outErrorNo = -E_OUT_OF_MEMORY;
149 return nullptr;
150 }
151 int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
152 if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
153 LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
154 outErrorNo = errCode;
155 delete outMsg;
156 outMsg = nullptr;
157 return nullptr;
158 }
159 // If messageId not register in this software version, we return errCode and the Message without an object.
160 outErrorNo = errCode;
161 return outMsg;
162 }
163
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)164 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
165 {
166 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
167 if (buffer == nullptr) {
168 outErrorNo = -E_OUT_OF_MEMORY;
169 return nullptr;
170 }
171
172 // Empty frame has no payload, only header
173 int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
174 if (errCode != E_OK) {
175 LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
176 outErrorNo = errCode;
177 delete buffer;
178 buffer = nullptr;
179 return nullptr;
180 }
181 outErrorNo = E_OK;
182 return buffer;
183 }
184
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)185 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
186 int &outErrorNo)
187 {
188 std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
189 SerialBuffer *buffer = ToSerialBuffer(inMsg, outErrorNo, extendHandle, true);
190 if (buffer == nullptr) {
191 // outErrorNo had already been set in ToSerialBuffer
192 return nullptr;
193 }
194 int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
195 if (errCode != E_OK) {
196 LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
197 outErrorNo = errCode;
198 delete buffer;
199 buffer = nullptr;
200 return nullptr;
201 }
202 outErrorNo = E_OK;
203 return buffer;
204 }
205
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)206 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
207 const std::set<LabelType> &inLabels, int &outErrorNo)
208 {
209 // Size of inLabels won't be too large.
210 // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
211 uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
212 inLabels.size() * COMM_LABEL_LENGTH;
213 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
214 if (buffer == nullptr) {
215 outErrorNo = -E_OUT_OF_MEMORY;
216 return nullptr;
217 }
218 int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
219 if (errCode != E_OK) {
220 LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
221 outErrorNo = errCode;
222 delete buffer;
223 buffer = nullptr;
224 return nullptr;
225 }
226
227 auto payloadByteLen = buffer->GetWritableBytesForPayload();
228 auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
229 *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
230 *fieldPtr++ = HostToNet(inDistinctValue);
231 *fieldPtr++ = HostToNet(inSequenceId);
232 *fieldPtr++ = HostToNet(static_cast<uint64_t>(inLabels.size()));
233 // Note: don't worry, memory length had been carefully calculated above
234 auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
235 for (auto &eachLabel : inLabels) {
236 for (auto &eachByte : eachLabel) {
237 *bytePtr++ = eachByte;
238 }
239 }
240 outErrorNo = E_OK;
241 return buffer;
242 }
243
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)244 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
245 {
246 uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
247 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
248 if (buffer == nullptr) {
249 outErrorNo = -E_OUT_OF_MEMORY;
250 return nullptr;
251 }
252 int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
253 if (errCode != E_OK) {
254 LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
255 outErrorNo = errCode;
256 delete buffer;
257 buffer = nullptr;
258 return nullptr;
259 }
260
261 auto payloadByteLen = buffer->GetWritableBytesForPayload();
262 auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
263 *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
264 *fieldPtr++ = HostToNet(inDistinctValue);
265 *fieldPtr++ = HostToNet(inSequenceId);
266 outErrorNo = E_OK;
267 return buffer;
268 }
269
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)270 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
271 std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
272 {
273 auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
274 if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
275 return E_OK;
276 }
277 uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
278 // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
279 auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
280 uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
281 // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
282 uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
283 // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
284 // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
285 uint16_t quotient = lengthToSplit / maxFragmentLen;
286 uint32_t remainder = lengthToSplit % maxFragmentLen;
287 // Finally we get the fragCount for this frame
288 uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
289 // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
290 auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
291 FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
292 return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
293 }
294
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)295 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
296 {
297 uint32_t frameLen = inResult.GetFrameLen();
298 uint16_t fragCount = inResult.GetFragCount();
299 uint16_t fragNo = inResult.GetFragNo();
300
301 // Firstly: Check frameLen
302 if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
303 LOGE("[Proto][ParsePhyOpt] FrameLen=%u illegal.", frameLen);
304 return -E_PARSE_FAIL;
305 }
306
307 // Secondly: Check fragCount and fragNo
308 uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
309 if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
310 LOGE("[Proto][ParsePhyOpt] FragCount=%u or fragNo=%u illegal.", fragCount, fragNo);
311 return -E_PARSE_FAIL;
312 }
313
314 // Finally: Check length relation deeply
315 uint32_t quotient = lengthBeSplit / fragCount;
316 uint16_t remainder = lengthBeSplit % fragCount;
317 outFragLen = quotient;
318 outLastFragLen = quotient + remainder;
319 uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
320 if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
321 inResult.GetPaddingLen() != inResult.GetPacketLen()) {
322 LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%u, FragCount=%u, fragNo=%u, PaddingLen=%u, PacketLen=%u",
323 frameLen, fragCount, fragNo, inResult.GetPaddingLen(), inResult.GetPacketLen());
324 return -E_PARSE_FAIL;
325 }
326
327 return E_OK;
328 }
329
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)330 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
331 uint32_t fragOffset, uint32_t fragLength)
332 {
333 // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
334 // Firstly: Check the length relation of source, even this check is not supposed to fail
335 if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
336 return -E_LENGTH_ERROR;
337 }
338 // Secondly: Check the length relation of destination, even this check is not supposed to fail
339 auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
340 if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
341 return -E_LENGTH_ERROR;
342 }
343 // Finally: Do Combination!
344 const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
345 uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
346 uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
347 errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
348 if (errCode != EOK) {
349 return -E_SECUREC_ERROR;
350 }
351 return E_OK;
352 }
353
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)354 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
355 {
356 if (msgIdMapFunc_.count(msgId) != 0) {
357 return -E_ALREADY_REGISTER;
358 }
359 if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
360 return -E_INVALID_ARGS;
361 }
362 msgIdMapFunc_[msgId] = inFunc;
363 return E_OK;
364 }
365
UnRegTransformFunction(uint32_t msgId)366 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
367 {
368 if (msgIdMapFunc_.count(msgId) != 0) {
369 msgIdMapFunc_.erase(msgId);
370 }
371 }
372
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)373 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
374 {
375 if (inBuff == nullptr) {
376 return -E_INVALID_ARGS;
377 }
378 auto headerByteLen = inBuff->GetWritableBytesForHeader();
379 if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
380 return -E_INVALID_ARGS;
381 }
382 auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
383
384 CommDivergeHeader divergeHeader;
385 divergeHeader.version = PROTOCOL_VERSION;
386 divergeHeader.reserved = 0;
387 divergeHeader.payLoadLen = payloadByteLen.second;
388 // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
389 for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
390 divergeHeader.commLabel[i] = inCommLabel[i];
391 }
392 HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
393
394 errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
395 headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
396 if (errCode != EOK) {
397 return -E_SECUREC_ERROR;
398 }
399 return E_OK;
400 }
401
402 namespace {
FillPhyHeaderLenInfo(CommPhyHeader & header,uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen)403 void FillPhyHeaderLenInfo(CommPhyHeader &header, uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen)
404 {
405 header.packetLen = packetLen;
406 header.checkSum = sum;
407 header.packetType |= type;
408 header.paddingLen = paddingLen;
409 }
410 }
411
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)412 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
413 {
414 if (inBuff == nullptr) {
415 return -E_INVALID_ARGS;
416 }
417 auto headerByteLen = inBuff->GetWritableBytesForHeader();
418 if (headerByteLen.second < sizeof(CommPhyHeader)) {
419 return -E_INVALID_ARGS;
420 }
421 auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
422 auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
423
424 uint32_t packetLen = bufferByteLen.second;
425 uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
426 uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
427 if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
428 SetFrameType(packetType, inInfo.frameType);
429 } else {
430 return -E_INVALID_ARGS;
431 }
432
433 CommPhyHeader phyHeader;
434 phyHeader.magic = MAGIC_CODE;
435 phyHeader.version = PROTOCOL_VERSION;
436 phyHeader.sourceId = inInfo.sourceId;
437 phyHeader.frameId = inInfo.frameId;
438 phyHeader.packetType = 0;
439 phyHeader.dbIntVer = DB_GLOBAL_VERSION;
440 FillPhyHeaderLenInfo(phyHeader, packetLen, 0, packetType, paddingLen); // Sum is calculated afterwards
441 HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
442
443 errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
444 if (retCode != EOK) {
445 return -E_SECUREC_ERROR;
446 }
447
448 uint64_t sumResult = 0;
449 int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
450 bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
451 if (errCode != E_OK) {
452 return -E_SUM_CALCULATE_FAIL;
453 }
454
455 auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
456 ptrPhyHeader->checkSum = HostToNet(sumResult);
457
458 return E_OK;
459 }
460
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)461 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
462 ParseResult &outResult)
463 {
464 if (bytes == nullptr || length > MAX_TOTAL_LEN) {
465 return -E_INVALID_ARGS;
466 }
467 int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
468 if (errCode != E_OK) {
469 LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
470 return errCode;
471 }
472
473 if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
474 return E_OK; // Do nothing more for empty frame
475 }
476
477 if (outResult.IsFragment()) {
478 errCode = ParseCommPhyOptHeader(bytes, length, outResult);
479 if (errCode != E_OK) {
480 LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
481 return errCode;
482 }
483 } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
484 errCode = ParseCommLayerPayload(bytes, length, outResult);
485 if (errCode != E_OK) {
486 LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
487 return errCode;
488 }
489 } else {
490 errCode = ParseCommDivergeHeader(bytes, length, outResult);
491 if (errCode != E_OK) {
492 LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
493 return errCode;
494 }
495 }
496 return E_OK;
497 }
498
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)499 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
500 {
501 if (inBuff == nullptr || outResult.IsFragment()) {
502 return -E_INTERNAL_ERROR;
503 }
504 auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
505 if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
506 int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
507 if (errCode != E_OK) {
508 LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
509 return errCode;
510 }
511 } else {
512 int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
513 if (errCode != E_OK) {
514 LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
515 return errCode;
516 }
517 }
518 return E_OK;
519 }
520
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)521 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
522 {
523 static std::map<FrameType, std::string> frameTypeStr{
524 {FrameType::EMPTY, "EmptyFrame"},
525 {FrameType::APPLICATION_MESSAGE, "AppLayerFrame"},
526 {FrameType::COMMUNICATION_LABEL_EXCHANGE, "CommLayerFrame_LabelExchange"},
527 {FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK, "CommLayerFrame_LabelExchangeAck"}};
528
529 if (length < sizeof(CommPhyHeader)) {
530 return;
531 }
532 auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
533 uint32_t frameId = NetToHost(phyHeader->frameId);
534 uint8_t pktType = NetToHost(phyHeader->packetType);
535 bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
536 FrameType frameType = GetFrameType(pktType);
537 if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
538 LOGW("[Proto][Display] This is unrecognized frame, pktType=%u.", pktType);
539 return;
540 }
541 if (isFragment) {
542 if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
543 return;
544 }
545 auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
546 LOGI("[Proto][Display] This is %s, frameId=%u, frameLen=%u, fragCount=%u, fragNo=%u.",
547 frameTypeStr[frameType].c_str(), frameId, NetToHost(phyOpt->frameLen),
548 NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
549 } else {
550 LOGI("[Proto][Display] This is %s, frameId=%u.", frameTypeStr[frameType].c_str(), frameId);
551 }
552 }
553
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)554 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
555 {
556 if (length % sizeof(uint64_t) != 0) {
557 LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight.", length);
558 return -E_LENGTH_ERROR;
559 }
560 int count = length / sizeof(uint64_t);
561 auto array = reinterpret_cast<const uint64_t *>(bytes);
562 outSum = 0;
563 for (int i = 0; i < count; i++) {
564 outSum ^= array[i];
565 }
566 return E_OK;
567 }
568
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)569 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
570 {
571 uint32_t messageId = inMsg->GetMessageId();
572 if (msgIdMapFunc_.count(messageId) == 0) {
573 LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%u.", messageId);
574 return -E_NOT_REGISTER;
575 }
576
577 TransformFunc function = msgIdMapFunc_[messageId];
578 uint32_t serializeLen = function.computeFunc(inMsg);
579 uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
580 // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
581 if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
582 LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%u, serializeLen=%u, alignedLen=%u.",
583 messageId, serializeLen, alignedLen);
584 return -E_LENGTH_ERROR;
585 }
586 // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
587 outLength = serializeLen;
588 return E_OK;
589 }
590
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)591 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
592 {
593 auto payloadByteLen = inBuff->GetWritableBytesForPayload();
594 if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
595 LOGE("[Proto][Serialize] Length error, payload length=%u.", payloadByteLen.second);
596 return -E_LENGTH_ERROR;
597 }
598 uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
599
600 auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
601 messageHdr->version = inMsg->GetVersion();
602 messageHdr->messageType = inMsg->GetMessageType();
603 messageHdr->messageId = inMsg->GetMessageId();
604 messageHdr->sessionId = inMsg->GetSessionId();
605 messageHdr->sequenceId = inMsg->GetSequenceId();
606 messageHdr->errorNo = inMsg->GetErrorNo();
607 messageHdr->dataLen = dataLen;
608 HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
609
610 if (dataLen == 0) {
611 // For zero dataLen, we don't need to serialize data part
612 return E_OK;
613 }
614 // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
615 uint32_t messageId = inMsg->GetMessageId();
616 TransformFunc function = msgIdMapFunc_[messageId];
617 int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
618 if (result != E_OK) {
619 LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
620 return -E_SERIALIZE_ERROR;
621 }
622 return E_OK;
623 }
624
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)625 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
626 {
627 auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
628 // Check version before parse field
629 if (payloadByteLen.second < sizeof(uint16_t)) {
630 return -E_LENGTH_ERROR;
631 }
632 uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
633 if (!IsSupportMessageVersion(version)) {
634 LOGE("[Proto][DeSerialize] Version=%u not support.", version);
635 return -E_VERSION_NOT_SUPPORT;
636 }
637
638 if (payloadByteLen.second < sizeof(MessageHeader)) {
639 LOGE("[Proto][DeSerialize] Length error, payload length=%u.", payloadByteLen.second);
640 return -E_LENGTH_ERROR;
641 }
642 auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
643 MessageHeader messageHdr;
644 HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
645 inMsg->SetVersion(version);
646 inMsg->SetMessageType(messageHdr.messageType);
647 inMsg->SetMessageId(messageHdr.messageId);
648 inMsg->SetSessionId(messageHdr.sessionId);
649 inMsg->SetSequenceId(messageHdr.sequenceId);
650 inMsg->SetErrorNo(messageHdr.errorNo);
651 uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
652 if (dataLen != messageHdr.dataLen) {
653 LOGE("[Proto][DeSerialize] dataLen=%u, msgDataLen=%u.", dataLen, messageHdr.dataLen);
654 return -E_LENGTH_ERROR;
655 }
656 // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
657 if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
658 LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%u.", messageHdr.errorNo);
659 return E_OK;
660 }
661 if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
662 return E_OK;
663 }
664 uint32_t messageId = inMsg->GetMessageId();
665 if (msgIdMapFunc_.count(messageId) == 0) {
666 LOGE("[Proto][DeSerialize] Not register, messageId=%u.", messageId);
667 return -E_NOT_REGISTER;
668 }
669 TransformFunc function = msgIdMapFunc_[messageId];
670 int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
671 if (result != E_OK) {
672 LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
673 return -E_DESERIALIZE_ERROR;
674 }
675 return E_OK;
676 }
677
IsSupportMessageVersion(uint16_t version)678 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
679 {
680 if (version != MSG_VERSION_BASE && version != MSG_VERSION_EXT) {
681 return false;
682 }
683 return true;
684 }
685
IsFeedbackErrorMessage(uint32_t errorNo)686 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
687 {
688 if (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND) {
689 return true;
690 }
691 return false;
692 }
693
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)694 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
695 {
696 // At least magic and version should exist
697 if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
698 LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
699 return -E_LENGTH_ERROR;
700 }
701 auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
702 uint16_t magic = NetToHost(*fieldPtr++);
703 uint16_t version = NetToHost(*fieldPtr++);
704
705 if (magic != MAGIC_CODE) {
706 LOGE("[Proto][ParsePhyCheckVer] MagicCode=%u Error.", magic);
707 return -E_PARSE_FAIL;
708 }
709 if (version != PROTOCOL_VERSION) {
710 LOGE("[Proto][ParsePhyCheckVer] Version=%u Error.", version);
711 return -E_VERSION_NOT_SUPPORT;
712 }
713 return E_OK;
714 }
715
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)716 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
717 const uint8_t *bytes, uint32_t length)
718 {
719 if (phyHeader.sourceId != Hash::HashFunc(srcTarget)) {
720 LOGE("[Proto][ParsePhyCheck] SourceId Error: inSourceId=%llu, srcTarget=%s{private}, hashId=%llu.",
721 ULL(phyHeader.sourceId), srcTarget.c_str(), ULL(Hash::HashFunc(srcTarget)));
722 return -E_PARSE_FAIL;
723 }
724 if (phyHeader.packetLen != length) {
725 LOGE("[Proto][ParsePhyCheck] PacketLen=%u Mismatch length=%u.", phyHeader.packetLen, length);
726 return -E_PARSE_FAIL;
727 }
728 if (phyHeader.paddingLen > MAX_PADDING_LEN) {
729 LOGE("[Proto][ParsePhyCheck] PaddingLen=%u Error.", phyHeader.paddingLen);
730 return -E_PARSE_FAIL;
731 }
732 if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
733 LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
734 return -E_PARSE_FAIL;
735 }
736 uint64_t sumResult = 0;
737 int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
738 if (errCode != E_OK) {
739 LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
740 return -E_SUM_CALCULATE_FAIL;
741 }
742 if (phyHeader.checkSum != sumResult) {
743 LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%llu, sumResult=%llu.",
744 ULL(phyHeader.checkSum), ULL(sumResult));
745 return -E_SUM_MISMATCH;
746 }
747 return E_OK;
748 }
749
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)750 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
751 ParseResult &inResult)
752 {
753 int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
754 if (errCode != E_OK) {
755 LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
756 return errCode;
757 }
758
759 if (length < sizeof(CommPhyHeader)) {
760 LOGE("[Proto][ParsePhy] Length of Bytes Error.");
761 return -E_PARSE_FAIL;
762 }
763 auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
764 CommPhyHeader phyHeader;
765 HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
766 errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
767 if (errCode != E_OK) {
768 LOGE("[Proto][ParsePhy] Check Field Fail.");
769 return errCode;
770 }
771
772 inResult.SetFrameId(phyHeader.frameId);
773 inResult.SetSourceId(phyHeader.sourceId);
774 inResult.SetPacketLen(phyHeader.packetLen);
775 inResult.SetPaddingLen(phyHeader.paddingLen);
776 inResult.SetDbVersion(phyHeader.dbIntVer);
777 if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
778 inResult.SetFragmentFlag(true);
779 } // FragmentFlag default is false
780 FrameType frameType = GetFrameType(phyHeader.packetType);
781 if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
782 LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%u.", phyHeader.packetType);
783 return -E_FRAME_TYPE_NOT_SUPPORT;
784 }
785 inResult.SetFrameTypeInfo(frameType);
786 return E_OK;
787 }
788
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)789 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
790 {
791 if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
792 LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
793 return -E_LENGTH_ERROR;
794 }
795 auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
796 CommPhyOptHeader phyOptHeader;
797 HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
798
799 // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
800 inResult.SetFrameLen(phyOptHeader.frameLen);
801 inResult.SetFragCount(phyOptHeader.fragCount);
802 inResult.SetFragNo(phyOptHeader.fragNo);
803 return E_OK;
804 }
805
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)806 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
807 {
808 // Check version before parse field
809 if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
810 return -E_LENGTH_ERROR;
811 }
812 uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
813 if (version != PROTOCOL_VERSION) {
814 LOGE("[Proto][ParseDiverge] Version=%u not support.", version);
815 return -E_VERSION_NOT_SUPPORT;
816 }
817
818 if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
819 LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
820 return -E_PARSE_FAIL;
821 }
822 auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
823 CommDivergeHeader divergeHeader;
824 HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
825 if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
826 inResult.GetPaddingLen() != inResult.GetPacketLen()) {
827 LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
828 return -E_PARSE_FAIL;
829 }
830 inResult.SetPayloadLen(divergeHeader.payLoadLen);
831 inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
832 return E_OK;
833 }
834
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)835 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
836 {
837 if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
838 int errCode = ParseLabelExchangeAck(bytes, length, inResult);
839 if (errCode != E_OK) {
840 LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
841 return errCode;
842 }
843 } else {
844 int errCode = ParseLabelExchange(bytes, length, inResult);
845 if (errCode != E_OK) {
846 LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
847 return errCode;
848 }
849 }
850 return E_OK;
851 }
852
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)853 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
854 {
855 // Check version at very first
856 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
857 return -E_LENGTH_ERROR;
858 }
859 auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
860 uint64_t version = NetToHost(*fieldPtr++);
861 if (version != PROTOCOL_VERSION) {
862 LOGE("[Proto][ParseLabel] Version=%llu not support.", ULL(version));
863 return -E_VERSION_NOT_SUPPORT;
864 }
865
866 // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
867 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
868 LOGE("[Proto][ParseLabel] Length of Bytes Error.");
869 return -E_LENGTH_ERROR;
870 }
871 uint64_t distinctValue = NetToHost(*fieldPtr++);
872 inResult.SetLabelExchangeDistinctValue(distinctValue);
873 uint64_t sequenceId = NetToHost(*fieldPtr++);
874 inResult.SetLabelExchangeSequenceId(sequenceId);
875 uint64_t commLabelCount = NetToHost(*fieldPtr++);
876 if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
877 LOGE("[Proto][ParseLabel] commLabelCount=%llu invalid.", ULL(commLabelCount));
878 return -E_PARSE_FAIL;
879 }
880 // commLabelCount is expected to be not very large
881 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
882 commLabelCount * COMM_LABEL_LENGTH) {
883 LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%llu", ULL(commLabelCount));
884 return -E_LENGTH_ERROR;
885 }
886
887 // Get each commLabel
888 std::set<LabelType> commLabels;
889 auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
890 for (uint64_t i = 0; i < commLabelCount; i++) {
891 // the length is checked just above
892 LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
893 if (commLabels.count(commLabel) != 0) {
894 LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%s.", VEC_TO_STR(commLabel));
895 } else {
896 commLabels.insert(commLabel);
897 }
898 }
899 inResult.SetLatestCommLabels(commLabels);
900 return E_OK;
901 }
902
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)903 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
904 {
905 // Check version at very first
906 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
907 return -E_LENGTH_ERROR;
908 }
909 auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
910 uint64_t version = NetToHost(*fieldPtr++);
911 if (version != PROTOCOL_VERSION) {
912 LOGE("[Proto][ParseLabelAck] Version=%llu not support.", ULL(version));
913 return -E_VERSION_NOT_SUPPORT;
914 }
915
916 if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
917 LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
918 return -E_LENGTH_ERROR;
919 }
920 uint64_t distinctValue = NetToHost(*fieldPtr++);
921 inResult.SetLabelExchangeDistinctValue(distinctValue);
922 uint64_t sequenceId = NetToHost(*fieldPtr++);
923 inResult.SetLabelExchangeSequenceId(sequenceId);
924 return E_OK;
925 }
926
927 // Note: framePhyHeader is in network endian
928 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)929 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
930 const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
931 {
932 // It can be guaranteed that fragCount >= 2 and also won't be too large
933 if (fragmentInfo.fragCount < 2) {
934 return -E_INVALID_ARGS;
935 }
936 outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
937 uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
938 uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
939 uint16_t fragNo = 0; // Fragment index start from 0
940 uint32_t byteOffset = 0;
941
942 for (auto &entry : outPieces) {
943 // subtract 1 for index
944 uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
945 uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
946 uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
947
948 // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
949 entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
950 if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
951 LOGE("[Proto][FrameFrag] Resize failed for length=%u", pieceTotalLen);
952 return -E_OUT_OF_MEMORY;
953 }
954
955 CommPhyHeader pktPhyHeader;
956 HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
957
958 // The sum value need to be recalculated, and the packet is fragmented.
959 // The alignedFragLen is always larger than pieceFragLen
960 FillPhyHeaderLenInfo(pktPhyHeader, pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen);
961 HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
962
963 CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
964 fragmentInfo.fragCount, fragNo};
965 HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
966 int err;
967 FragmentPacket packet;
968 uint8_t *ptrPacket = &(entry.first[0]);
969 if (fragmentInfo.extendHeadSize > 0) {
970 packet = {ptrPacket, fragmentInfo.extendHeadSize};
971 err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
972 if (err != E_OK) {
973 return err;
974 }
975 ptrPacket += fragmentInfo.extendHeadSize;
976 }
977 packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
978 err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
979 pieceFragLen, packet);
980 entry.second = fragmentInfo.extendHeadSize;
981 if (err != E_OK) {
982 LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%u, fragNo=%u", fragmentInfo.fragCount, fragNo);
983 return err;
984 }
985
986 fragNo++;
987 byteOffset += pieceFragLen;
988 }
989
990 return E_OK;
991 }
992
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)993 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
994 {
995 if (headLen > outPacket.leftLength) {
996 LOGE("[Proto][FrameFrag] headLen less than leftLength");
997 return -E_INVALID_ARGS;
998 }
999 errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1000 if (retCode != EOK) {
1001 LOGE("memcpy error:%d", retCode);
1002 return -E_SECUREC_ERROR;
1003 }
1004 return E_OK;
1005 }
1006
1007 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1008 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1009 const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1010 {
1011 if (outPacket.leftLength == 0) {
1012 return -E_INVALID_ARGS;
1013 }
1014 uint8_t *ptrPacket = outPacket.ptrPacket;
1015 uint32_t leftLength = outPacket.leftLength;
1016
1017 // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1018 // So, there will be no redundant check during subtraction
1019 errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1020 if (retCode != EOK) {
1021 return -E_SECUREC_ERROR;
1022 }
1023 ptrPacket += sizeof(CommPhyHeader);
1024 leftLength -= sizeof(CommPhyHeader);
1025
1026 retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1027 if (retCode != EOK) {
1028 return -E_SECUREC_ERROR;
1029 }
1030 ptrPacket += sizeof(CommPhyOptHeader);
1031 leftLength -= sizeof(CommPhyOptHeader);
1032
1033 retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1034 if (retCode != EOK) {
1035 return -E_SECUREC_ERROR;
1036 }
1037
1038 // Calculate sum and set sum field
1039 uint64_t sumResult = 0;
1040 int errCode = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1041 outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1042 if (errCode != E_OK) {
1043 return -E_SUM_CALCULATE_FAIL;
1044 }
1045 auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1046 ptrPhyHeader->checkSum = HostToNet(sumResult);
1047
1048 return E_OK;
1049 }
1050
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1051 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1052 {
1053 if (extendHandle != nullptr) {
1054 DBStatus status = extendHandle->GetHeadDataSize(headSize);
1055 if (status != DBStatus::OK) {
1056 LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1057 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1058 }
1059 if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1060 LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1061 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1062 }
1063 return E_OK;
1064 }
1065 return E_OK;
1066 }
1067
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1068 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1069 uint32_t headSize)
1070 {
1071 if (extendHandle != nullptr && headSize > 0) {
1072 if (buffer == nullptr) {
1073 return -E_INVALID_ARGS;
1074 }
1075 DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1076 buffer->GetSize() + headSize);
1077 if (status != DBStatus::OK) {
1078 LOGI("[Proto][ToSerial] fill head data failed");
1079 return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1080 }
1081 return E_OK;
1082 }
1083 return E_OK;
1084 }
1085 } // namespace DistributedDB
1086