• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "protocol_proto.h"
17 #include <new>
18 #include <iterator>
19 #include "hash.h"
20 #include "securec.h"
21 #include "version.h"
22 #include "db_common.h"
23 #include "log_print.h"
24 #include "macro_utils.h"
25 #include "endian_convert.h"
26 #include "header_converter.h"
27 
28 namespace DistributedDB {
29 namespace {
30 const uint16_t MAGIC_CODE = 0xAAAA;
31 const uint16_t PROTOCOL_VERSION = 0;
32 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
33 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
34 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
35 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
36 const uint8_t MAX_PADDING_LEN = 7;
37 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
38 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
39 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
40 // LabelExchange(Ack) Frame Field Length
41 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
42 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
43 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
44 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
45 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
46 // Local func to set and get frame Type from packet Type field
SetFrameType(uint8_t & inPacketType,FrameType inFrameType)47 void SetFrameType(uint8_t &inPacketType, FrameType inFrameType)
48 {
49     inPacketType &= 0x0F; // Use 0x0F to clear high for bits
50     inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
51 }
GetFrameType(uint8_t inPacketType)52 FrameType GetFrameType(uint8_t inPacketType)
53 {
54     uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0x0F to get high 4 bits
55     if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
56         return FrameType::INVALID_MAX_FRAME_TYPE;
57     }
58     return static_cast<FrameType>(frameType);
59 }
60 }
61 
62 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
63 
GetAppLayerFrameHeaderLength()64 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
65 {
66     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
67     return length;
68 }
69 
GetLengthBeforeSerializedData()70 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
71 {
72     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
73     return length;
74 }
75 
GetCommLayerFrameHeaderLength()76 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
77 {
78     uint32_t length = sizeof(CommPhyHeader);
79     return length;
80 }
81 
ToSerialBuffer(const Message * inMsg,int & outErrorNo,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader)82 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg, int &outErrorNo,
83     std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader)
84 {
85     if (inMsg == nullptr) {
86         outErrorNo = -E_INVALID_ARGS;
87         return nullptr;
88     }
89 
90     uint32_t serializeLen = 0;
91     if (!onlyMsgHeader) {
92         int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
93         if (errCode != E_OK) {
94             outErrorNo = errCode;
95             return nullptr;
96         }
97     }
98     uint32_t headSize = 0;
99     int errCode = GetExtendHeadDataSize(extendHandle, headSize);
100     if (errCode != E_OK) {
101         outErrorNo = errCode;
102         return nullptr;
103     }
104 
105     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
106     if (buffer == nullptr) {
107         outErrorNo = -E_OUT_OF_MEMORY;
108         return nullptr;
109     }
110     if (headSize > 0) {
111         buffer->SetExtendHeadLength(headSize);
112     }
113     // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
114     uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
115     errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
116     if (errCode != E_OK) {
117         LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
118         goto ERROR_HANDLE;
119     }
120     errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
121     if (errCode != E_OK) {
122         goto ERROR_HANDLE;
123     }
124 
125     // Serialize the MessageHeader and data if need
126     errCode = SerializeMessage(buffer, inMsg);
127     if (errCode != E_OK) {
128         LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
129         goto ERROR_HANDLE;
130     }
131     outErrorNo = E_OK;
132     return buffer;
133 ERROR_HANDLE:
134     outErrorNo = errCode;
135     delete buffer;
136     buffer = nullptr;
137     return nullptr;
138 }
139 
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)140 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
141 {
142     if (inBuff == nullptr) {
143         outErrorNo = -E_INVALID_ARGS;
144         return nullptr;
145     }
146     Message *outMsg = new (std::nothrow) Message();
147     if (outMsg == nullptr) {
148         outErrorNo = -E_OUT_OF_MEMORY;
149         return nullptr;
150     }
151     int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
152     if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
153         LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
154         outErrorNo = errCode;
155         delete outMsg;
156         outMsg = nullptr;
157         return nullptr;
158     }
159     // If messageId not register in this software version, we return errCode and the Message without an object.
160     outErrorNo = errCode;
161     return outMsg;
162 }
163 
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)164 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
165 {
166     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
167     if (buffer == nullptr) {
168         outErrorNo = -E_OUT_OF_MEMORY;
169         return nullptr;
170     }
171 
172     // Empty frame has no payload, only header
173     int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
174     if (errCode != E_OK) {
175         LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
176         outErrorNo = errCode;
177         delete buffer;
178         buffer = nullptr;
179         return nullptr;
180     }
181     outErrorNo = E_OK;
182     return buffer;
183 }
184 
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)185 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
186     int &outErrorNo)
187 {
188     std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
189     SerialBuffer *buffer = ToSerialBuffer(inMsg, outErrorNo, extendHandle, true);
190     if (buffer == nullptr) {
191         // outErrorNo had already been set in ToSerialBuffer
192         return nullptr;
193     }
194     int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
195     if (errCode != E_OK) {
196         LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
197         outErrorNo = errCode;
198         delete buffer;
199         buffer = nullptr;
200         return nullptr;
201     }
202     outErrorNo = E_OK;
203     return buffer;
204 }
205 
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)206 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
207     const std::set<LabelType> &inLabels, int &outErrorNo)
208 {
209     // Size of inLabels won't be too large.
210     // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
211     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
212         inLabels.size() * COMM_LABEL_LENGTH;
213     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
214     if (buffer == nullptr) {
215         outErrorNo = -E_OUT_OF_MEMORY;
216         return nullptr;
217     }
218     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
219     if (errCode != E_OK) {
220         LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
221         outErrorNo = errCode;
222         delete buffer;
223         buffer = nullptr;
224         return nullptr;
225     }
226 
227     auto payloadByteLen = buffer->GetWritableBytesForPayload();
228     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
229     *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
230     *fieldPtr++ = HostToNet(inDistinctValue);
231     *fieldPtr++ = HostToNet(inSequenceId);
232     *fieldPtr++ = HostToNet(static_cast<uint64_t>(inLabels.size()));
233     // Note: don't worry, memory length had been carefully calculated above
234     auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
235     for (auto &eachLabel : inLabels) {
236         for (auto &eachByte : eachLabel) {
237             *bytePtr++ = eachByte;
238         }
239     }
240     outErrorNo = E_OK;
241     return buffer;
242 }
243 
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)244 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
245 {
246     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
247     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
248     if (buffer == nullptr) {
249         outErrorNo = -E_OUT_OF_MEMORY;
250         return nullptr;
251     }
252     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
253     if (errCode != E_OK) {
254         LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
255         outErrorNo = errCode;
256         delete buffer;
257         buffer = nullptr;
258         return nullptr;
259     }
260 
261     auto payloadByteLen = buffer->GetWritableBytesForPayload();
262     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
263     *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
264     *fieldPtr++ = HostToNet(inDistinctValue);
265     *fieldPtr++ = HostToNet(inSequenceId);
266     outErrorNo = E_OK;
267     return buffer;
268 }
269 
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)270 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
271     std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
272 {
273     auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
274     if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
275         return E_OK;
276     }
277     uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
278     // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
279     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
280     uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
281     // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
282     uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
283     // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
284     // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
285     uint16_t quotient = lengthToSplit / maxFragmentLen;
286     uint32_t remainder = lengthToSplit % maxFragmentLen;
287     // Finally we get the fragCount for this frame
288     uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
289     // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
290     auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
291     FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
292     return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
293 }
294 
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)295 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
296 {
297     uint32_t frameLen = inResult.GetFrameLen();
298     uint16_t fragCount = inResult.GetFragCount();
299     uint16_t fragNo = inResult.GetFragNo();
300 
301     // Firstly: Check frameLen
302     if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
303         LOGE("[Proto][ParsePhyOpt] FrameLen=%u illegal.", frameLen);
304         return -E_PARSE_FAIL;
305     }
306 
307     // Secondly: Check fragCount and fragNo
308     uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
309     if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
310         LOGE("[Proto][ParsePhyOpt] FragCount=%u or fragNo=%u illegal.", fragCount, fragNo);
311         return -E_PARSE_FAIL;
312     }
313 
314     // Finally: Check length relation deeply
315     uint32_t quotient = lengthBeSplit / fragCount;
316     uint16_t remainder = lengthBeSplit % fragCount;
317     outFragLen = quotient;
318     outLastFragLen = quotient + remainder;
319     uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
320     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
321         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
322         LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%u, FragCount=%u, fragNo=%u, PaddingLen=%u, PacketLen=%u",
323             frameLen, fragCount, fragNo, inResult.GetPaddingLen(), inResult.GetPacketLen());
324         return -E_PARSE_FAIL;
325     }
326 
327     return E_OK;
328 }
329 
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)330 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
331     uint32_t fragOffset, uint32_t fragLength)
332 {
333     // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
334     // Firstly: Check the length relation of source, even this check is not supposed to fail
335     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
336         return -E_LENGTH_ERROR;
337     }
338     // Secondly: Check the length relation of destination, even this check is not supposed to fail
339     auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
340     if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
341         return -E_LENGTH_ERROR;
342     }
343     // Finally: Do Combination!
344     const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
345     uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
346     uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
347     errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
348     if (errCode != EOK) {
349         return -E_SECUREC_ERROR;
350     }
351     return E_OK;
352 }
353 
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)354 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
355 {
356     if (msgIdMapFunc_.count(msgId) != 0) {
357         return -E_ALREADY_REGISTER;
358     }
359     if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
360         return -E_INVALID_ARGS;
361     }
362     msgIdMapFunc_[msgId] = inFunc;
363     return E_OK;
364 }
365 
UnRegTransformFunction(uint32_t msgId)366 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
367 {
368     if (msgIdMapFunc_.count(msgId) != 0) {
369         msgIdMapFunc_.erase(msgId);
370     }
371 }
372 
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)373 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
374 {
375     if (inBuff == nullptr) {
376         return -E_INVALID_ARGS;
377     }
378     auto headerByteLen = inBuff->GetWritableBytesForHeader();
379     if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
380         return -E_INVALID_ARGS;
381     }
382     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
383 
384     CommDivergeHeader divergeHeader;
385     divergeHeader.version = PROTOCOL_VERSION;
386     divergeHeader.reserved = 0;
387     divergeHeader.payLoadLen = payloadByteLen.second;
388     // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
389     for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
390         divergeHeader.commLabel[i] = inCommLabel[i];
391     }
392     HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
393 
394     errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
395         headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
396     if (errCode != EOK) {
397         return -E_SECUREC_ERROR;
398     }
399     return E_OK;
400 }
401 
402 namespace {
FillPhyHeaderLenInfo(CommPhyHeader & header,uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen)403 void FillPhyHeaderLenInfo(CommPhyHeader &header, uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen)
404 {
405     header.packetLen = packetLen;
406     header.checkSum = sum;
407     header.packetType |= type;
408     header.paddingLen = paddingLen;
409 }
410 }
411 
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)412 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
413 {
414     if (inBuff == nullptr) {
415         return -E_INVALID_ARGS;
416     }
417     auto headerByteLen = inBuff->GetWritableBytesForHeader();
418     if (headerByteLen.second < sizeof(CommPhyHeader)) {
419         return -E_INVALID_ARGS;
420     }
421     auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
422     auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
423 
424     uint32_t packetLen = bufferByteLen.second;
425     uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
426     uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
427     if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
428         SetFrameType(packetType, inInfo.frameType);
429     } else {
430         return -E_INVALID_ARGS;
431     }
432 
433     CommPhyHeader phyHeader;
434     phyHeader.magic = MAGIC_CODE;
435     phyHeader.version = PROTOCOL_VERSION;
436     phyHeader.sourceId = inInfo.sourceId;
437     phyHeader.frameId = inInfo.frameId;
438     phyHeader.packetType = 0;
439     phyHeader.dbIntVer = DB_GLOBAL_VERSION;
440     FillPhyHeaderLenInfo(phyHeader, packetLen, 0, packetType, paddingLen); // Sum is calculated afterwards
441     HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
442 
443     errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
444     if (retCode != EOK) {
445         return -E_SECUREC_ERROR;
446     }
447 
448     uint64_t sumResult = 0;
449     int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
450         bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
451     if (errCode != E_OK) {
452         return -E_SUM_CALCULATE_FAIL;
453     }
454 
455     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
456     ptrPhyHeader->checkSum = HostToNet(sumResult);
457 
458     return E_OK;
459 }
460 
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)461 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
462     ParseResult &outResult)
463 {
464     if (bytes == nullptr || length > MAX_TOTAL_LEN) {
465         return -E_INVALID_ARGS;
466     }
467     int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
468     if (errCode != E_OK) {
469         LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
470         return errCode;
471     }
472 
473     if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
474         return E_OK; // Do nothing more for empty frame
475     }
476 
477     if (outResult.IsFragment()) {
478         errCode = ParseCommPhyOptHeader(bytes, length, outResult);
479         if (errCode != E_OK) {
480             LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
481             return errCode;
482         }
483     } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
484         errCode = ParseCommLayerPayload(bytes, length, outResult);
485         if (errCode != E_OK) {
486             LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
487             return errCode;
488         }
489     } else {
490         errCode = ParseCommDivergeHeader(bytes, length, outResult);
491         if (errCode != E_OK) {
492             LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
493             return errCode;
494         }
495     }
496     return E_OK;
497 }
498 
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)499 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
500 {
501     if (inBuff == nullptr || outResult.IsFragment()) {
502         return -E_INTERNAL_ERROR;
503     }
504     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
505     if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
506         int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
507         if (errCode != E_OK) {
508             LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
509             return errCode;
510         }
511     } else {
512         int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
513         if (errCode != E_OK) {
514             LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
515             return errCode;
516         }
517     }
518     return E_OK;
519 }
520 
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)521 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
522 {
523     static std::map<FrameType, std::string> frameTypeStr{
524         {FrameType::EMPTY, "EmptyFrame"},
525         {FrameType::APPLICATION_MESSAGE, "AppLayerFrame"},
526         {FrameType::COMMUNICATION_LABEL_EXCHANGE, "CommLayerFrame_LabelExchange"},
527         {FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK, "CommLayerFrame_LabelExchangeAck"}};
528 
529     if (length < sizeof(CommPhyHeader)) {
530         return;
531     }
532     auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
533     uint32_t frameId = NetToHost(phyHeader->frameId);
534     uint8_t pktType = NetToHost(phyHeader->packetType);
535     bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
536     FrameType frameType = GetFrameType(pktType);
537     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
538         LOGW("[Proto][Display] This is unrecognized frame, pktType=%u.", pktType);
539         return;
540     }
541     if (isFragment) {
542         if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
543             return;
544         }
545         auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
546         LOGI("[Proto][Display] This is %s, frameId=%u, frameLen=%u, fragCount=%u, fragNo=%u.",
547             frameTypeStr[frameType].c_str(), frameId, NetToHost(phyOpt->frameLen),
548             NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
549     } else {
550         LOGI("[Proto][Display] This is %s, frameId=%u.", frameTypeStr[frameType].c_str(), frameId);
551     }
552 }
553 
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)554 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
555 {
556     if (length % sizeof(uint64_t) != 0) {
557         LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight.", length);
558         return -E_LENGTH_ERROR;
559     }
560     int count = length / sizeof(uint64_t);
561     auto array = reinterpret_cast<const uint64_t *>(bytes);
562     outSum = 0;
563     for (int i = 0; i < count; i++) {
564         outSum ^= array[i];
565     }
566     return E_OK;
567 }
568 
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)569 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
570 {
571     uint32_t messageId = inMsg->GetMessageId();
572     if (msgIdMapFunc_.count(messageId) == 0) {
573         LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%u.", messageId);
574         return -E_NOT_REGISTER;
575     }
576 
577     TransformFunc function = msgIdMapFunc_[messageId];
578     uint32_t serializeLen = function.computeFunc(inMsg);
579     uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
580     // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
581     if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
582         LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%u, serializeLen=%u, alignedLen=%u.",
583             messageId, serializeLen, alignedLen);
584         return -E_LENGTH_ERROR;
585     }
586     // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
587     outLength = serializeLen;
588     return E_OK;
589 }
590 
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)591 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
592 {
593     auto payloadByteLen = inBuff->GetWritableBytesForPayload();
594     if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
595         LOGE("[Proto][Serialize] Length error, payload length=%u.", payloadByteLen.second);
596         return -E_LENGTH_ERROR;
597     }
598     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
599 
600     auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
601     messageHdr->version = inMsg->GetVersion();
602     messageHdr->messageType = inMsg->GetMessageType();
603     messageHdr->messageId = inMsg->GetMessageId();
604     messageHdr->sessionId = inMsg->GetSessionId();
605     messageHdr->sequenceId = inMsg->GetSequenceId();
606     messageHdr->errorNo = inMsg->GetErrorNo();
607     messageHdr->dataLen = dataLen;
608     HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
609 
610     if (dataLen == 0) {
611         // For zero dataLen, we don't need to serialize data part
612         return E_OK;
613     }
614     // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
615     uint32_t messageId = inMsg->GetMessageId();
616     TransformFunc function = msgIdMapFunc_[messageId];
617     int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
618     if (result != E_OK) {
619         LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
620         return -E_SERIALIZE_ERROR;
621     }
622     return E_OK;
623 }
624 
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)625 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
626 {
627     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
628     // Check version before parse field
629     if (payloadByteLen.second < sizeof(uint16_t)) {
630         return -E_LENGTH_ERROR;
631     }
632     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
633     if (!IsSupportMessageVersion(version)) {
634         LOGE("[Proto][DeSerialize] Version=%u not support.", version);
635         return -E_VERSION_NOT_SUPPORT;
636     }
637 
638     if (payloadByteLen.second < sizeof(MessageHeader)) {
639         LOGE("[Proto][DeSerialize] Length error, payload length=%u.", payloadByteLen.second);
640         return -E_LENGTH_ERROR;
641     }
642     auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
643     MessageHeader messageHdr;
644     HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
645     inMsg->SetVersion(version);
646     inMsg->SetMessageType(messageHdr.messageType);
647     inMsg->SetMessageId(messageHdr.messageId);
648     inMsg->SetSessionId(messageHdr.sessionId);
649     inMsg->SetSequenceId(messageHdr.sequenceId);
650     inMsg->SetErrorNo(messageHdr.errorNo);
651     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
652     if (dataLen != messageHdr.dataLen) {
653         LOGE("[Proto][DeSerialize] dataLen=%u, msgDataLen=%u.", dataLen, messageHdr.dataLen);
654         return -E_LENGTH_ERROR;
655     }
656     // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
657     if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
658         LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%u.", messageHdr.errorNo);
659         return E_OK;
660     }
661     if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
662         return E_OK;
663     }
664     uint32_t messageId = inMsg->GetMessageId();
665     if (msgIdMapFunc_.count(messageId) == 0) {
666         LOGE("[Proto][DeSerialize] Not register, messageId=%u.", messageId);
667         return -E_NOT_REGISTER;
668     }
669     TransformFunc function = msgIdMapFunc_[messageId];
670     int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
671     if (result != E_OK) {
672         LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
673         return -E_DESERIALIZE_ERROR;
674     }
675     return E_OK;
676 }
677 
IsSupportMessageVersion(uint16_t version)678 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
679 {
680     if (version != MSG_VERSION_BASE && version != MSG_VERSION_EXT) {
681         return false;
682     }
683     return true;
684 }
685 
IsFeedbackErrorMessage(uint32_t errorNo)686 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
687 {
688     if (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND) {
689         return true;
690     }
691     return false;
692 }
693 
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)694 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
695 {
696     // At least magic and version should exist
697     if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
698         LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
699         return -E_LENGTH_ERROR;
700     }
701     auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
702     uint16_t magic = NetToHost(*fieldPtr++);
703     uint16_t version = NetToHost(*fieldPtr++);
704 
705     if (magic != MAGIC_CODE) {
706         LOGE("[Proto][ParsePhyCheckVer] MagicCode=%u Error.", magic);
707         return -E_PARSE_FAIL;
708     }
709     if (version != PROTOCOL_VERSION) {
710         LOGE("[Proto][ParsePhyCheckVer] Version=%u Error.", version);
711         return -E_VERSION_NOT_SUPPORT;
712     }
713     return E_OK;
714 }
715 
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)716 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
717     const uint8_t *bytes, uint32_t length)
718 {
719     if (phyHeader.sourceId != Hash::HashFunc(srcTarget)) {
720         LOGE("[Proto][ParsePhyCheck] SourceId Error: inSourceId=%llu, srcTarget=%s{private}, hashId=%llu.",
721             ULL(phyHeader.sourceId), srcTarget.c_str(), ULL(Hash::HashFunc(srcTarget)));
722         return -E_PARSE_FAIL;
723     }
724     if (phyHeader.packetLen != length) {
725         LOGE("[Proto][ParsePhyCheck] PacketLen=%u Mismatch length=%u.", phyHeader.packetLen, length);
726         return -E_PARSE_FAIL;
727     }
728     if (phyHeader.paddingLen > MAX_PADDING_LEN) {
729         LOGE("[Proto][ParsePhyCheck] PaddingLen=%u Error.", phyHeader.paddingLen);
730         return -E_PARSE_FAIL;
731     }
732     if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
733         LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
734         return -E_PARSE_FAIL;
735     }
736     uint64_t sumResult = 0;
737     int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
738     if (errCode != E_OK) {
739         LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
740         return -E_SUM_CALCULATE_FAIL;
741     }
742     if (phyHeader.checkSum != sumResult) {
743         LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%llu, sumResult=%llu.",
744             ULL(phyHeader.checkSum), ULL(sumResult));
745         return -E_SUM_MISMATCH;
746     }
747     return E_OK;
748 }
749 
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)750 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
751     ParseResult &inResult)
752 {
753     int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
754     if (errCode != E_OK) {
755         LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
756         return errCode;
757     }
758 
759     if (length < sizeof(CommPhyHeader)) {
760         LOGE("[Proto][ParsePhy] Length of Bytes Error.");
761         return -E_PARSE_FAIL;
762     }
763     auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
764     CommPhyHeader phyHeader;
765     HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
766     errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
767     if (errCode != E_OK) {
768         LOGE("[Proto][ParsePhy] Check Field Fail.");
769         return errCode;
770     }
771 
772     inResult.SetFrameId(phyHeader.frameId);
773     inResult.SetSourceId(phyHeader.sourceId);
774     inResult.SetPacketLen(phyHeader.packetLen);
775     inResult.SetPaddingLen(phyHeader.paddingLen);
776     inResult.SetDbVersion(phyHeader.dbIntVer);
777     if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
778         inResult.SetFragmentFlag(true);
779     } // FragmentFlag default is false
780     FrameType frameType = GetFrameType(phyHeader.packetType);
781     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
782         LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%u.", phyHeader.packetType);
783         return -E_FRAME_TYPE_NOT_SUPPORT;
784     }
785     inResult.SetFrameTypeInfo(frameType);
786     return E_OK;
787 }
788 
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)789 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
790 {
791     if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
792         LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
793         return -E_LENGTH_ERROR;
794     }
795     auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
796     CommPhyOptHeader phyOptHeader;
797     HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
798 
799     // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
800     inResult.SetFrameLen(phyOptHeader.frameLen);
801     inResult.SetFragCount(phyOptHeader.fragCount);
802     inResult.SetFragNo(phyOptHeader.fragNo);
803     return E_OK;
804 }
805 
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)806 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
807 {
808     // Check version before parse field
809     if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
810         return -E_LENGTH_ERROR;
811     }
812     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
813     if (version != PROTOCOL_VERSION) {
814         LOGE("[Proto][ParseDiverge] Version=%u not support.", version);
815         return -E_VERSION_NOT_SUPPORT;
816     }
817 
818     if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
819         LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
820         return -E_PARSE_FAIL;
821     }
822     auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
823     CommDivergeHeader divergeHeader;
824     HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
825     if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
826         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
827         LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
828         return -E_PARSE_FAIL;
829     }
830     inResult.SetPayloadLen(divergeHeader.payLoadLen);
831     inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
832     return E_OK;
833 }
834 
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)835 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
836 {
837     if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
838         int errCode = ParseLabelExchangeAck(bytes, length, inResult);
839         if (errCode != E_OK) {
840             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
841             return errCode;
842         }
843     } else {
844         int errCode = ParseLabelExchange(bytes, length, inResult);
845         if (errCode != E_OK) {
846             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
847             return errCode;
848         }
849     }
850     return E_OK;
851 }
852 
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)853 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
854 {
855     // Check version at very first
856     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
857         return -E_LENGTH_ERROR;
858     }
859     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
860     uint64_t version = NetToHost(*fieldPtr++);
861     if (version != PROTOCOL_VERSION) {
862         LOGE("[Proto][ParseLabel] Version=%llu not support.", ULL(version));
863         return -E_VERSION_NOT_SUPPORT;
864     }
865 
866     // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
867     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
868         LOGE("[Proto][ParseLabel] Length of Bytes Error.");
869         return -E_LENGTH_ERROR;
870     }
871     uint64_t distinctValue = NetToHost(*fieldPtr++);
872     inResult.SetLabelExchangeDistinctValue(distinctValue);
873     uint64_t sequenceId = NetToHost(*fieldPtr++);
874     inResult.SetLabelExchangeSequenceId(sequenceId);
875     uint64_t commLabelCount = NetToHost(*fieldPtr++);
876     if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
877         LOGE("[Proto][ParseLabel] commLabelCount=%llu invalid.", ULL(commLabelCount));
878         return -E_PARSE_FAIL;
879     }
880     // commLabelCount is expected to be not very large
881     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
882         commLabelCount * COMM_LABEL_LENGTH) {
883         LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%llu", ULL(commLabelCount));
884         return -E_LENGTH_ERROR;
885     }
886 
887     // Get each commLabel
888     std::set<LabelType> commLabels;
889     auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
890     for (uint64_t i = 0; i < commLabelCount; i++) {
891         // the length is checked just above
892         LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
893         if (commLabels.count(commLabel) != 0) {
894             LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%s.", VEC_TO_STR(commLabel));
895         } else {
896             commLabels.insert(commLabel);
897         }
898     }
899     inResult.SetLatestCommLabels(commLabels);
900     return E_OK;
901 }
902 
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)903 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
904 {
905     // Check version at very first
906     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
907         return -E_LENGTH_ERROR;
908     }
909     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
910     uint64_t version = NetToHost(*fieldPtr++);
911     if (version != PROTOCOL_VERSION) {
912         LOGE("[Proto][ParseLabelAck] Version=%llu not support.", ULL(version));
913         return -E_VERSION_NOT_SUPPORT;
914     }
915 
916     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
917         LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
918         return -E_LENGTH_ERROR;
919     }
920     uint64_t distinctValue = NetToHost(*fieldPtr++);
921     inResult.SetLabelExchangeDistinctValue(distinctValue);
922     uint64_t sequenceId = NetToHost(*fieldPtr++);
923     inResult.SetLabelExchangeSequenceId(sequenceId);
924     return E_OK;
925 }
926 
927 // Note: framePhyHeader is in network endian
928 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)929 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
930     const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
931 {
932     // It can be guaranteed that fragCount >= 2 and also won't be too large
933     if (fragmentInfo.fragCount < 2) {
934         return -E_INVALID_ARGS;
935     }
936     outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
937     uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
938     uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
939     uint16_t fragNo = 0; // Fragment index start from 0
940     uint32_t byteOffset = 0;
941 
942     for (auto &entry : outPieces) {
943         // subtract 1 for index
944         uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
945         uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
946         uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
947 
948         // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
949         entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
950         if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
951             LOGE("[Proto][FrameFrag] Resize failed for length=%u", pieceTotalLen);
952             return -E_OUT_OF_MEMORY;
953         }
954 
955         CommPhyHeader pktPhyHeader;
956         HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
957 
958         // The sum value need to be recalculated, and the packet is fragmented.
959         // The alignedFragLen is always larger than pieceFragLen
960         FillPhyHeaderLenInfo(pktPhyHeader, pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen);
961         HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
962 
963         CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
964             fragmentInfo.fragCount, fragNo};
965         HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
966         int err;
967         FragmentPacket packet;
968         uint8_t *ptrPacket = &(entry.first[0]);
969         if (fragmentInfo.extendHeadSize > 0) {
970             packet = {ptrPacket, fragmentInfo.extendHeadSize};
971             err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
972             if (err != E_OK) {
973                 return err;
974             }
975             ptrPacket += fragmentInfo.extendHeadSize;
976         }
977         packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
978         err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
979             pieceFragLen, packet);
980         entry.second = fragmentInfo.extendHeadSize;
981         if (err != E_OK) {
982             LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%u, fragNo=%u", fragmentInfo.fragCount, fragNo);
983             return err;
984         }
985 
986         fragNo++;
987         byteOffset += pieceFragLen;
988     }
989 
990     return E_OK;
991 }
992 
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)993 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
994 {
995     if (headLen > outPacket.leftLength) {
996         LOGE("[Proto][FrameFrag] headLen less than leftLength");
997         return -E_INVALID_ARGS;
998     }
999     errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1000     if (retCode != EOK) {
1001         LOGE("memcpy error:%d", retCode);
1002         return -E_SECUREC_ERROR;
1003     }
1004     return E_OK;
1005 }
1006 
1007 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1008 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1009     const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1010 {
1011     if (outPacket.leftLength == 0) {
1012         return -E_INVALID_ARGS;
1013     }
1014     uint8_t *ptrPacket = outPacket.ptrPacket;
1015     uint32_t leftLength = outPacket.leftLength;
1016 
1017     // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1018     // So, there will be no redundant check during subtraction
1019     errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1020     if (retCode != EOK) {
1021         return -E_SECUREC_ERROR;
1022     }
1023     ptrPacket += sizeof(CommPhyHeader);
1024     leftLength -= sizeof(CommPhyHeader);
1025 
1026     retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1027     if (retCode != EOK) {
1028         return -E_SECUREC_ERROR;
1029     }
1030     ptrPacket += sizeof(CommPhyOptHeader);
1031     leftLength -= sizeof(CommPhyOptHeader);
1032 
1033     retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1034     if (retCode != EOK) {
1035         return -E_SECUREC_ERROR;
1036     }
1037 
1038     // Calculate sum and set sum field
1039     uint64_t sumResult = 0;
1040     int errCode  = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1041         outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1042     if (errCode != E_OK) {
1043         return -E_SUM_CALCULATE_FAIL;
1044     }
1045     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1046     ptrPhyHeader->checkSum = HostToNet(sumResult);
1047 
1048     return E_OK;
1049 }
1050 
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1051 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1052 {
1053     if (extendHandle != nullptr) {
1054         DBStatus status = extendHandle->GetHeadDataSize(headSize);
1055         if (status != DBStatus::OK) {
1056             LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1057             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1058         }
1059         if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1060             LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1061             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1062         }
1063         return E_OK;
1064     }
1065     return E_OK;
1066 }
1067 
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1068 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1069     uint32_t headSize)
1070 {
1071     if (extendHandle != nullptr && headSize > 0) {
1072         if (buffer == nullptr) {
1073             return -E_INVALID_ARGS;
1074         }
1075         DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1076             buffer->GetSize() + headSize);
1077         if (status != DBStatus::OK) {
1078             LOGI("[Proto][ToSerial] fill head data failed");
1079             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1080         }
1081         return E_OK;
1082     }
1083     return E_OK;
1084 }
1085 } // namespace DistributedDB
1086