• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "protocol_proto.h"
17 #include <new>
18 #include <iterator>
19 #include "hash.h"
20 #include "securec.h"
21 #include "version.h"
22 #include "db_common.h"
23 #include "log_print.h"
24 #include "macro_utils.h"
25 #include "endian_convert.h"
26 #include "header_converter.h"
27 
28 namespace DistributedDB {
29 namespace {
30 const uint16_t MAGIC_CODE = 0xAAAA;
31 const uint16_t PROTOCOL_VERSION = 0;
32 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
33 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
34 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
35 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
36 const uint8_t MAX_PADDING_LEN = 7;
37 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
38 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
39 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
40 // LabelExchange(Ack) Frame Field Length
41 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
42 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
43 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
44 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
45 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
46 // Local func to set and get frame Type from packet Type field
SetFrameType(uint8_t & inPacketType,FrameType inFrameType)47 void SetFrameType(uint8_t &inPacketType, FrameType inFrameType)
48 {
49     inPacketType &= 0x0F; // Use 0x0F to clear high for bits
50     inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
51 }
GetFrameType(uint8_t inPacketType)52 FrameType GetFrameType(uint8_t inPacketType)
53 {
54     uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0xF0 to get high 4 bits
55     if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
56         return FrameType::INVALID_MAX_FRAME_TYPE;
57     }
58     return static_cast<FrameType>(frameType);
59 }
60 }
61 
62 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
63 std::shared_mutex ProtocolProto::msgIdMutex_;
64 
GetAppLayerFrameHeaderLength()65 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
66 {
67     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
68     return length;
69 }
70 
GetLengthBeforeSerializedData()71 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
72 {
73     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
74     return length;
75 }
76 
GetCommLayerFrameHeaderLength()77 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
78 {
79     uint32_t length = sizeof(CommPhyHeader);
80     return length;
81 }
82 
ToSerialBuffer(const Message * inMsg,int & outErrorNo,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader)83 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg, int &outErrorNo,
84     std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader)
85 {
86     if (inMsg == nullptr) {
87         outErrorNo = -E_INVALID_ARGS;
88         return nullptr;
89     }
90 
91     uint32_t serializeLen = 0;
92     if (!onlyMsgHeader) {
93         int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
94         if (errCode != E_OK) {
95             outErrorNo = errCode;
96             return nullptr;
97         }
98     }
99     uint32_t headSize = 0;
100     int errCode = GetExtendHeadDataSize(extendHandle, headSize);
101     if (errCode != E_OK) {
102         outErrorNo = errCode;
103         return nullptr;
104     }
105 
106     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
107     if (buffer == nullptr) {
108         outErrorNo = -E_OUT_OF_MEMORY;
109         return nullptr;
110     }
111     if (headSize > 0) {
112         buffer->SetExtendHeadLength(headSize);
113     }
114     // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
115     uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
116     errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
117     if (errCode != E_OK) {
118         LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
119         goto ERROR_HANDLE;
120     }
121     errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
122     if (errCode != E_OK) {
123         goto ERROR_HANDLE;
124     }
125 
126     // Serialize the MessageHeader and data if need
127     errCode = SerializeMessage(buffer, inMsg);
128     if (errCode != E_OK) {
129         LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
130         goto ERROR_HANDLE;
131     }
132     outErrorNo = E_OK;
133     return buffer;
134 ERROR_HANDLE:
135     outErrorNo = errCode;
136     delete buffer;
137     buffer = nullptr;
138     return nullptr;
139 }
140 
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)141 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
142 {
143     if (inBuff == nullptr) {
144         outErrorNo = -E_INVALID_ARGS;
145         return nullptr;
146     }
147     Message *outMsg = new (std::nothrow) Message();
148     if (outMsg == nullptr) {
149         outErrorNo = -E_OUT_OF_MEMORY;
150         return nullptr;
151     }
152     int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
153     if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
154         LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
155         outErrorNo = errCode;
156         delete outMsg;
157         outMsg = nullptr;
158         return nullptr;
159     }
160     // If messageId not register in this software version, we return errCode and the Message without an object.
161     outErrorNo = errCode;
162     return outMsg;
163 }
164 
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)165 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
166 {
167     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
168     if (buffer == nullptr) {
169         outErrorNo = -E_OUT_OF_MEMORY;
170         return nullptr;
171     }
172 
173     // Empty frame has no payload, only header
174     int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
175     if (errCode != E_OK) {
176         LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
177         outErrorNo = errCode;
178         delete buffer;
179         buffer = nullptr;
180         return nullptr;
181     }
182     outErrorNo = E_OK;
183     return buffer;
184 }
185 
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)186 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
187     int &outErrorNo)
188 {
189     std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
190     SerialBuffer *buffer = ToSerialBuffer(inMsg, outErrorNo, extendHandle, true);
191     if (buffer == nullptr) {
192         // outErrorNo had already been set in ToSerialBuffer
193         return nullptr;
194     }
195     int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
196     if (errCode != E_OK) {
197         LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
198         outErrorNo = errCode;
199         delete buffer;
200         buffer = nullptr;
201         return nullptr;
202     }
203     outErrorNo = E_OK;
204     return buffer;
205 }
206 
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)207 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
208     const std::set<LabelType> &inLabels, int &outErrorNo)
209 {
210     // Size of inLabels won't be too large.
211     // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
212     uint64_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
213         inLabels.size() * COMM_LABEL_LENGTH;
214     if (payloadLen > INT32_MAX) {
215         outErrorNo = -E_INVALID_ARGS;
216         return nullptr;
217     }
218     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
219     if (buffer == nullptr) {
220         outErrorNo = -E_OUT_OF_MEMORY;
221         return nullptr;
222     }
223     int errCode = buffer->AllocBufferByPayloadLength(static_cast<uint32_t>(payloadLen),
224         GetCommLayerFrameHeaderLength());
225     if (errCode != E_OK) {
226         LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
227         outErrorNo = errCode;
228         delete buffer;
229         buffer = nullptr;
230         return nullptr;
231     }
232 
233     auto payloadByteLen = buffer->GetWritableBytesForPayload();
234     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
235     *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
236     *fieldPtr++ = HostToNet(inDistinctValue);
237     *fieldPtr++ = HostToNet(inSequenceId);
238     *fieldPtr++ = HostToNet(static_cast<uint64_t>(inLabels.size()));
239     // Note: don't worry, memory length had been carefully calculated above
240     auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
241     for (const auto &eachLabel : inLabels) {
242         for (const auto &eachByte : eachLabel) {
243             *bytePtr++ = eachByte;
244         }
245     }
246     outErrorNo = E_OK;
247     return buffer;
248 }
249 
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)250 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
251 {
252     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
253     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
254     if (buffer == nullptr) {
255         outErrorNo = -E_OUT_OF_MEMORY;
256         return nullptr;
257     }
258     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
259     if (errCode != E_OK) {
260         LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
261         outErrorNo = errCode;
262         delete buffer;
263         buffer = nullptr;
264         return nullptr;
265     }
266 
267     auto payloadByteLen = buffer->GetWritableBytesForPayload();
268     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
269     *fieldPtr++ = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
270     *fieldPtr++ = HostToNet(inDistinctValue);
271     *fieldPtr++ = HostToNet(inSequenceId);
272     outErrorNo = E_OK;
273     return buffer;
274 }
275 
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)276 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
277     std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
278 {
279     auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
280     if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
281         return E_OK;
282     }
283     uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
284     // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
285     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
286     uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
287     // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
288     uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
289     // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
290     // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
291     uint16_t quotient = lengthToSplit / maxFragmentLen;
292     uint32_t remainder = lengthToSplit % maxFragmentLen;
293     // Finally we get the fragCount for this frame
294     uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
295     // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
296     auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
297     FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
298     return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
299 }
300 
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)301 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
302 {
303     uint32_t frameLen = inResult.GetFrameLen();
304     uint16_t fragCount = inResult.GetFragCount();
305     uint16_t fragNo = inResult.GetFragNo();
306 
307     // Firstly: Check frameLen
308     if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
309         LOGE("[Proto][ParsePhyOpt] FrameLen=%u illegal.", frameLen);
310         return -E_PARSE_FAIL;
311     }
312 
313     // Secondly: Check fragCount and fragNo
314     uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
315     if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
316         LOGE("[Proto][ParsePhyOpt] FragCount=%u or fragNo=%u illegal.", fragCount, fragNo);
317         return -E_PARSE_FAIL;
318     }
319 
320     // Finally: Check length relation deeply
321     uint32_t quotient = lengthBeSplit / fragCount;
322     uint16_t remainder = lengthBeSplit % fragCount;
323     outFragLen = quotient;
324     outLastFragLen = quotient + remainder;
325     uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
326     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
327         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
328         LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%u, FragCount=%u, fragNo=%u, PaddingLen=%u, PacketLen=%u",
329             frameLen, fragCount, fragNo, inResult.GetPaddingLen(), inResult.GetPacketLen());
330         return -E_PARSE_FAIL;
331     }
332 
333     return E_OK;
334 }
335 
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)336 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
337     uint32_t fragOffset, uint32_t fragLength)
338 {
339     // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
340     // Firstly: Check the length relation of source, even this check is not supposed to fail
341     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
342         return -E_LENGTH_ERROR;
343     }
344     // Secondly: Check the length relation of destination, even this check is not supposed to fail
345     auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
346     if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
347         return -E_LENGTH_ERROR;
348     }
349     // Finally: Do Combination!
350     const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
351     uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
352     uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
353     errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
354     if (errCode != EOK) {
355         return -E_SECUREC_ERROR;
356     }
357     return E_OK;
358 }
359 
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)360 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
361 {
362     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
363     if (msgIdMapFunc_.count(msgId) != 0) {
364         return -E_ALREADY_REGISTER;
365     }
366     if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
367         return -E_INVALID_ARGS;
368     }
369     msgIdMapFunc_[msgId] = inFunc;
370     return E_OK;
371 }
372 
UnRegTransformFunction(uint32_t msgId)373 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
374 {
375     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
376     if (msgIdMapFunc_.count(msgId) != 0) {
377         msgIdMapFunc_.erase(msgId);
378     }
379 }
380 
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)381 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
382 {
383     if (inBuff == nullptr) {
384         return -E_INVALID_ARGS;
385     }
386     auto headerByteLen = inBuff->GetWritableBytesForHeader();
387     if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
388         return -E_INVALID_ARGS;
389     }
390     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
391 
392     CommDivergeHeader divergeHeader;
393     divergeHeader.version = PROTOCOL_VERSION;
394     divergeHeader.reserved = 0;
395     divergeHeader.payLoadLen = payloadByteLen.second;
396     // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
397     for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
398         divergeHeader.commLabel[i] = inCommLabel[i];
399     }
400     HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
401 
402     errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
403         headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
404     if (errCode != EOK) {
405         return -E_SECUREC_ERROR;
406     }
407     return E_OK;
408 }
409 
410 namespace {
FillPhyHeaderLenInfo(CommPhyHeader & header,uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen)411 void FillPhyHeaderLenInfo(CommPhyHeader &header, uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen)
412 {
413     header.packetLen = packetLen;
414     header.checkSum = sum;
415     header.packetType |= type;
416     header.paddingLen = paddingLen;
417 }
418 }
419 
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)420 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
421 {
422     if (inBuff == nullptr) {
423         return -E_INVALID_ARGS;
424     }
425     auto headerByteLen = inBuff->GetWritableBytesForHeader();
426     if (headerByteLen.second < sizeof(CommPhyHeader)) {
427         return -E_INVALID_ARGS;
428     }
429     auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
430     auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
431 
432     uint32_t packetLen = bufferByteLen.second;
433     uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
434     uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
435     if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
436         SetFrameType(packetType, inInfo.frameType);
437     } else {
438         return -E_INVALID_ARGS;
439     }
440 
441     CommPhyHeader phyHeader;
442     phyHeader.magic = MAGIC_CODE;
443     phyHeader.version = PROTOCOL_VERSION;
444     phyHeader.sourceId = inInfo.sourceId;
445     phyHeader.frameId = inInfo.frameId;
446     phyHeader.packetType = 0;
447     phyHeader.dbIntVer = DB_GLOBAL_VERSION;
448     FillPhyHeaderLenInfo(phyHeader, packetLen, 0, packetType, paddingLen); // Sum is calculated afterwards
449     HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
450 
451     errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
452     if (retCode != EOK) {
453         return -E_SECUREC_ERROR;
454     }
455 
456     uint64_t sumResult = 0;
457     int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
458         bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
459     if (errCode != E_OK) {
460         return -E_SUM_CALCULATE_FAIL;
461     }
462 
463     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
464     ptrPhyHeader->checkSum = HostToNet(sumResult);
465 
466     return E_OK;
467 }
468 
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)469 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
470     ParseResult &outResult)
471 {
472     if (bytes == nullptr || length > MAX_TOTAL_LEN) {
473         return -E_INVALID_ARGS;
474     }
475     int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
476     if (errCode != E_OK) {
477         LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
478         return errCode;
479     }
480 
481     if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
482         return E_OK; // Do nothing more for empty frame
483     }
484 
485     if (outResult.IsFragment()) {
486         errCode = ParseCommPhyOptHeader(bytes, length, outResult);
487         if (errCode != E_OK) {
488             LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
489             return errCode;
490         }
491     } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
492         errCode = ParseCommLayerPayload(bytes, length, outResult);
493         if (errCode != E_OK) {
494             LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
495             return errCode;
496         }
497     } else {
498         errCode = ParseCommDivergeHeader(bytes, length, outResult);
499         if (errCode != E_OK) {
500             LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
501             return errCode;
502         }
503     }
504     return E_OK;
505 }
506 
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)507 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
508 {
509     if (inBuff == nullptr || outResult.IsFragment()) {
510         return -E_INTERNAL_ERROR;
511     }
512     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
513     if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
514         int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
515         if (errCode != E_OK) {
516             LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
517             return errCode;
518         }
519     } else {
520         int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
521         if (errCode != E_OK) {
522             LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
523             return errCode;
524         }
525     }
526     return E_OK;
527 }
528 
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)529 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
530 {
531     static std::map<FrameType, std::string> frameTypeStr{
532         {FrameType::EMPTY, "EmptyFrame"},
533         {FrameType::APPLICATION_MESSAGE, "AppLayerFrame"},
534         {FrameType::COMMUNICATION_LABEL_EXCHANGE, "CommLayerFrame_LabelExchange"},
535         {FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK, "CommLayerFrame_LabelExchangeAck"}};
536 
537     if (length < sizeof(CommPhyHeader)) {
538         return;
539     }
540     auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
541     uint32_t frameId = NetToHost(phyHeader->frameId);
542     uint8_t pktType = NetToHost(phyHeader->packetType);
543     bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
544     FrameType frameType = GetFrameType(pktType);
545     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
546         LOGW("[Proto][Display] This is unrecognized frame, pktType=%" PRIu8 ".", pktType);
547         return;
548     }
549     if (isFragment) {
550         if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
551             return;
552         }
553         auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
554         LOGI("[Proto][Display] This is %s, frameId=%u, frameLen=%u, fragCount=%u, fragNo=%u.",
555             frameTypeStr[frameType].c_str(), frameId, NetToHost(phyOpt->frameLen),
556             NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
557     } else {
558         LOGI("[Proto][Display] This is %s, frameId=%u.", frameTypeStr[frameType].c_str(), frameId);
559     }
560 }
561 
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)562 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
563 {
564     if (length % sizeof(uint64_t) != 0) {
565         LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight.", length);
566         return -E_LENGTH_ERROR;
567     }
568     int count = length / sizeof(uint64_t);
569     auto array = reinterpret_cast<const uint64_t *>(bytes);
570     outSum = 0;
571     for (int i = 0; i < count; i++) {
572         outSum ^= array[i];
573     }
574     return E_OK;
575 }
576 
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)577 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
578 {
579     uint32_t messageId = inMsg->GetMessageId();
580     TransformFunc function;
581     if (GetTransformFunc(messageId, function) != E_OK) {
582         LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%" PRIu32 ".", messageId);
583         return -E_NOT_REGISTER;
584     }
585 
586     uint32_t serializeLen = function.computeFunc(inMsg);
587     uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
588     // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
589     if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
590         LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%u, serializeLen=%u, alignedLen=%u.",
591             messageId, serializeLen, alignedLen);
592         return -E_LENGTH_ERROR;
593     }
594     // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
595     outLength = serializeLen;
596     return E_OK;
597 }
598 
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)599 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
600 {
601     auto payloadByteLen = inBuff->GetWritableBytesForPayload();
602     if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
603         LOGE("[Proto][Serialize] Length error, payload length=%u.", payloadByteLen.second);
604         return -E_LENGTH_ERROR;
605     }
606     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
607 
608     auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
609     messageHdr->version = inMsg->GetVersion();
610     messageHdr->messageType = inMsg->GetMessageType();
611     messageHdr->messageId = inMsg->GetMessageId();
612     messageHdr->sessionId = inMsg->GetSessionId();
613     messageHdr->sequenceId = inMsg->GetSequenceId();
614     messageHdr->errorNo = inMsg->GetErrorNo();
615     messageHdr->dataLen = dataLen;
616     HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
617 
618     if (dataLen == 0) {
619         // For zero dataLen, we don't need to serialize data part
620         return E_OK;
621     }
622     // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
623     TransformFunc function;
624     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
625         LOGE("[Proto][Serialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
626         return -E_NOT_REGISTER;
627     }
628     int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
629     if (result != E_OK) {
630         LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
631         return -E_SERIALIZE_ERROR;
632     }
633     return E_OK;
634 }
635 
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)636 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
637 {
638     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
639     // Check version before parse field
640     if (payloadByteLen.second < sizeof(uint16_t)) {
641         return -E_LENGTH_ERROR;
642     }
643     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
644     if (!IsSupportMessageVersion(version)) {
645         LOGE("[Proto][DeSerialize] Version=%u not support.", version);
646         return -E_VERSION_NOT_SUPPORT;
647     }
648 
649     if (payloadByteLen.second < sizeof(MessageHeader)) {
650         LOGE("[Proto][DeSerialize] Length error, payload length=%u.", payloadByteLen.second);
651         return -E_LENGTH_ERROR;
652     }
653     auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
654     MessageHeader messageHdr;
655     HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
656     inMsg->SetVersion(version);
657     inMsg->SetMessageType(messageHdr.messageType);
658     inMsg->SetMessageId(messageHdr.messageId);
659     inMsg->SetSessionId(messageHdr.sessionId);
660     inMsg->SetSequenceId(messageHdr.sequenceId);
661     inMsg->SetErrorNo(messageHdr.errorNo);
662     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
663     if (dataLen != messageHdr.dataLen) {
664         LOGE("[Proto][DeSerialize] dataLen=%u, msgDataLen=%u.", dataLen, messageHdr.dataLen);
665         return -E_LENGTH_ERROR;
666     }
667     // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
668     if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
669         LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%u.", messageHdr.errorNo);
670         return E_OK;
671     }
672     if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
673         return E_OK;
674     }
675     TransformFunc function;
676     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
677         LOGE("[Proto][DeSerialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
678         return -E_NOT_REGISTER;
679     }
680     int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
681     if (result != E_OK) {
682         LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
683         return -E_DESERIALIZE_ERROR;
684     }
685     return E_OK;
686 }
687 
IsSupportMessageVersion(uint16_t version)688 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
689 {
690     return (version == MSG_VERSION_BASE || version == MSG_VERSION_EXT);
691 }
692 
IsFeedbackErrorMessage(uint32_t errorNo)693 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
694 {
695     return (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND);
696 }
697 
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)698 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
699 {
700     // At least magic and version should exist
701     if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
702         LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
703         return -E_LENGTH_ERROR;
704     }
705     auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
706     uint16_t magic = NetToHost(*fieldPtr++);
707     uint16_t version = NetToHost(*fieldPtr++);
708 
709     if (magic != MAGIC_CODE) {
710         LOGE("[Proto][ParsePhyCheckVer] MagicCode=%u Error.", magic);
711         return -E_PARSE_FAIL;
712     }
713     if (version != PROTOCOL_VERSION) {
714         LOGE("[Proto][ParsePhyCheckVer] Version=%u Error.", version);
715         return -E_VERSION_NOT_SUPPORT;
716     }
717     return E_OK;
718 }
719 
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)720 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
721     const uint8_t *bytes, uint32_t length)
722 {
723     if (phyHeader.sourceId != Hash::HashFunc(srcTarget)) {
724         LOGE("[Proto][ParsePhyCheck] SourceId Error: inSourceId=%llu, srcTarget=%s{private}, hashId=%llu.",
725             ULL(phyHeader.sourceId), srcTarget.c_str(), ULL(Hash::HashFunc(srcTarget)));
726         return -E_PARSE_FAIL;
727     }
728     if (phyHeader.packetLen != length) {
729         LOGE("[Proto][ParsePhyCheck] PacketLen=%u Mismatch length=%u.", phyHeader.packetLen, length);
730         return -E_PARSE_FAIL;
731     }
732     if (phyHeader.paddingLen > MAX_PADDING_LEN) {
733         LOGE("[Proto][ParsePhyCheck] PaddingLen=%u Error.", phyHeader.paddingLen);
734         return -E_PARSE_FAIL;
735     }
736     if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
737         LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
738         return -E_PARSE_FAIL;
739     }
740     uint64_t sumResult = 0;
741     int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
742     if (errCode != E_OK) {
743         LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
744         return -E_SUM_CALCULATE_FAIL;
745     }
746     if (phyHeader.checkSum != sumResult) {
747         LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%llu, sumResult=%llu.",
748             ULL(phyHeader.checkSum), ULL(sumResult));
749         return -E_SUM_MISMATCH;
750     }
751     return E_OK;
752 }
753 
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)754 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
755     ParseResult &inResult)
756 {
757     int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
758     if (errCode != E_OK) {
759         LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
760         return errCode;
761     }
762 
763     if (length < sizeof(CommPhyHeader)) {
764         LOGE("[Proto][ParsePhy] Length of Bytes Error.");
765         return -E_PARSE_FAIL;
766     }
767     auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
768     CommPhyHeader phyHeader;
769     HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
770     errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
771     if (errCode != E_OK) {
772         LOGE("[Proto][ParsePhy] Check Field Fail.");
773         return errCode;
774     }
775 
776     inResult.SetFrameId(phyHeader.frameId);
777     inResult.SetSourceId(phyHeader.sourceId);
778     inResult.SetPacketLen(phyHeader.packetLen);
779     inResult.SetPaddingLen(phyHeader.paddingLen);
780     inResult.SetDbVersion(phyHeader.dbIntVer);
781     if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
782         inResult.SetFragmentFlag(true);
783     } // FragmentFlag default is false
784     FrameType frameType = GetFrameType(phyHeader.packetType);
785     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
786         LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%u.", phyHeader.packetType);
787         return -E_FRAME_TYPE_NOT_SUPPORT;
788     }
789     inResult.SetFrameTypeInfo(frameType);
790     return E_OK;
791 }
792 
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)793 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
794 {
795     if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
796         LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
797         return -E_LENGTH_ERROR;
798     }
799     auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
800     CommPhyOptHeader phyOptHeader;
801     HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
802 
803     // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
804     inResult.SetFrameLen(phyOptHeader.frameLen);
805     inResult.SetFragCount(phyOptHeader.fragCount);
806     inResult.SetFragNo(phyOptHeader.fragNo);
807     return E_OK;
808 }
809 
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)810 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
811 {
812     // Check version before parse field
813     if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
814         return -E_LENGTH_ERROR;
815     }
816     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
817     if (version != PROTOCOL_VERSION) {
818         LOGE("[Proto][ParseDiverge] Version=%" PRIu16 " not support.", version);
819         return -E_VERSION_NOT_SUPPORT;
820     }
821 
822     if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
823         LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
824         return -E_PARSE_FAIL;
825     }
826     auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
827     CommDivergeHeader divergeHeader;
828     HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
829     if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
830         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
831         LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
832         return -E_PARSE_FAIL;
833     }
834     inResult.SetPayloadLen(divergeHeader.payLoadLen);
835     inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
836     return E_OK;
837 }
838 
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)839 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
840 {
841     if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
842         int errCode = ParseLabelExchangeAck(bytes, length, inResult);
843         if (errCode != E_OK) {
844             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
845             return errCode;
846         }
847     } else {
848         int errCode = ParseLabelExchange(bytes, length, inResult);
849         if (errCode != E_OK) {
850             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
851             return errCode;
852         }
853     }
854     return E_OK;
855 }
856 
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)857 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
858 {
859     // Check version at very first
860     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
861         return -E_LENGTH_ERROR;
862     }
863     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
864     uint64_t version = NetToHost(*fieldPtr++);
865     if (version != PROTOCOL_VERSION) {
866         LOGE("[Proto][ParseLabel] Version=%llu not support.", ULL(version));
867         return -E_VERSION_NOT_SUPPORT;
868     }
869 
870     // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
871     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
872         LOGE("[Proto][ParseLabel] Length of Bytes Error.");
873         return -E_LENGTH_ERROR;
874     }
875     uint64_t distinctValue = NetToHost(*fieldPtr++);
876     inResult.SetLabelExchangeDistinctValue(distinctValue);
877     uint64_t sequenceId = NetToHost(*fieldPtr++);
878     inResult.SetLabelExchangeSequenceId(sequenceId);
879     uint64_t commLabelCount = NetToHost(*fieldPtr++);
880     if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
881         LOGE("[Proto][ParseLabel] commLabelCount=%llu invalid.", ULL(commLabelCount));
882         return -E_PARSE_FAIL;
883     }
884     // commLabelCount is expected to be not very large
885     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
886         commLabelCount * COMM_LABEL_LENGTH) {
887         LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%llu", ULL(commLabelCount));
888         return -E_LENGTH_ERROR;
889     }
890 
891     // Get each commLabel
892     std::set<LabelType> commLabels;
893     auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
894     for (uint64_t i = 0; i < commLabelCount; i++) {
895         // the length is checked just above
896         LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
897         if (commLabels.count(commLabel) != 0) {
898             LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%s.", VEC_TO_STR(commLabel));
899         } else {
900             commLabels.insert(commLabel);
901         }
902     }
903     inResult.SetLatestCommLabels(commLabels);
904     return E_OK;
905 }
906 
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)907 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
908 {
909     // Check version at very first
910     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
911         return -E_LENGTH_ERROR;
912     }
913     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
914     uint64_t version = NetToHost(*fieldPtr++);
915     if (version != PROTOCOL_VERSION) {
916         LOGE("[Proto][ParseLabelAck] Version=%llu not support.", ULL(version));
917         return -E_VERSION_NOT_SUPPORT;
918     }
919 
920     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
921         LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
922         return -E_LENGTH_ERROR;
923     }
924     uint64_t distinctValue = NetToHost(*fieldPtr++);
925     inResult.SetLabelExchangeDistinctValue(distinctValue);
926     uint64_t sequenceId = NetToHost(*fieldPtr++);
927     inResult.SetLabelExchangeSequenceId(sequenceId);
928     return E_OK;
929 }
930 
931 // Note: framePhyHeader is in network endian
932 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)933 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
934     const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
935 {
936     // It can be guaranteed that fragCount >= 2 and also won't be too large
937     if (fragmentInfo.fragCount < MIN_FRAGMENT_COUNT) {
938         return -E_INVALID_ARGS;
939     }
940     outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
941     uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
942     uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
943     uint16_t fragNo = 0; // Fragment index start from 0
944     uint32_t byteOffset = 0;
945 
946     for (auto &entry : outPieces) {
947         // subtract 1 for index
948         uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
949         uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
950         uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
951 
952         // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
953         entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
954         if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
955             LOGE("[Proto][FrameFrag] Resize failed for length=%u", pieceTotalLen);
956             return -E_OUT_OF_MEMORY;
957         }
958 
959         CommPhyHeader pktPhyHeader;
960         HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
961 
962         // The sum value need to be recalculated, and the packet is fragmented.
963         // The alignedFragLen is always larger than pieceFragLen
964         FillPhyHeaderLenInfo(pktPhyHeader, pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen);
965         HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
966 
967         CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
968             fragmentInfo.fragCount, fragNo};
969         HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
970         int err;
971         FragmentPacket packet;
972         uint8_t *ptrPacket = &(entry.first[0]);
973         if (fragmentInfo.extendHeadSize > 0) {
974             packet = {ptrPacket, fragmentInfo.extendHeadSize};
975             err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
976             if (err != E_OK) {
977                 return err;
978             }
979             ptrPacket += fragmentInfo.extendHeadSize;
980         }
981         packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
982         err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
983             pieceFragLen, packet);
984         entry.second = fragmentInfo.extendHeadSize;
985         if (err != E_OK) {
986             LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%" PRIu16 ", fragNo=%" PRIu16, fragmentInfo.fragCount,
987                 fragNo);
988             return err;
989         }
990 
991         fragNo++;
992         byteOffset += pieceFragLen;
993     }
994 
995     return E_OK;
996 }
997 
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)998 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
999 {
1000     if (headLen > outPacket.leftLength) {
1001         LOGE("[Proto][FrameFrag] headLen less than leftLength");
1002         return -E_INVALID_ARGS;
1003     }
1004     errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1005     if (retCode != EOK) {
1006         LOGE("memcpy error:%d", retCode);
1007         return -E_SECUREC_ERROR;
1008     }
1009     return E_OK;
1010 }
1011 
1012 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1013 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1014     const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1015 {
1016     if (outPacket.leftLength == 0) {
1017         return -E_INVALID_ARGS;
1018     }
1019     uint8_t *ptrPacket = outPacket.ptrPacket;
1020     uint32_t leftLength = outPacket.leftLength;
1021 
1022     // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1023     // So, there will be no redundant check during subtraction
1024     errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1025     if (retCode != EOK) {
1026         return -E_SECUREC_ERROR;
1027     }
1028     ptrPacket += sizeof(CommPhyHeader);
1029     leftLength -= sizeof(CommPhyHeader);
1030 
1031     retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1032     if (retCode != EOK) {
1033         return -E_SECUREC_ERROR;
1034     }
1035     ptrPacket += sizeof(CommPhyOptHeader);
1036     leftLength -= sizeof(CommPhyOptHeader);
1037 
1038     retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1039     if (retCode != EOK) {
1040         return -E_SECUREC_ERROR;
1041     }
1042 
1043     // Calculate sum and set sum field
1044     uint64_t sumResult = 0;
1045     int errCode  = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1046         outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1047     if (errCode != E_OK) {
1048         return -E_SUM_CALCULATE_FAIL;
1049     }
1050     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1051     if (ptrPhyHeader == nullptr) {
1052         return -E_INVALID_ARGS;
1053     }
1054     ptrPhyHeader->checkSum = HostToNet(sumResult);
1055 
1056     return E_OK;
1057 }
1058 
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1059 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1060 {
1061     if (extendHandle != nullptr) {
1062         DBStatus status = extendHandle->GetHeadDataSize(headSize);
1063         if (status != DBStatus::OK) {
1064             LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1065             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1066         }
1067         if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1068             LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1069             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1070         }
1071         return E_OK;
1072     }
1073     return E_OK;
1074 }
1075 
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1076 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1077     uint32_t headSize)
1078 {
1079     if (extendHandle != nullptr && headSize > 0) {
1080         if (buffer == nullptr) {
1081             return -E_INVALID_ARGS;
1082         }
1083         DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1084             buffer->GetSize() + headSize);
1085         if (status != DBStatus::OK) {
1086             LOGI("[Proto][ToSerial] fill head data failed");
1087             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1088         }
1089         return E_OK;
1090     }
1091     return E_OK;
1092 }
1093 
GetTransformFunc(uint32_t messageId,DistributedDB::TransformFunc & function)1094 int ProtocolProto::GetTransformFunc(uint32_t messageId, DistributedDB::TransformFunc &function)
1095 {
1096     std::shared_lock<std::shared_mutex> autoLock(msgIdMutex_);
1097     const auto &entry = msgIdMapFunc_.find(messageId);
1098     if (entry == msgIdMapFunc_.end()) {
1099         return -E_NOT_REGISTER;
1100     }
1101     function = entry->second;
1102     return E_OK;
1103 }
1104 } // namespace DistributedDB
1105