• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "protocol_proto.h"
17 #include <iterator>
18 #include <mutex>
19 #include <new>
20 #include "db_common.h"
21 #include "endian_convert.h"
22 #include "hash.h"
23 #include "header_converter.h"
24 #include "log_print.h"
25 #include "macro_utils.h"
26 #include "securec.h"
27 #include "version.h"
28 
29 namespace DistributedDB {
30 namespace {
31 const uint16_t MAGIC_CODE = 0xAAAA;
32 const uint16_t PROTOCOL_VERSION = 0;
33 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
34 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
35 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
36 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
37 const uint8_t MAX_PADDING_LEN = 7;
38 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
39 const uint32_t MAX_FRAME_LEN = 128 * 1024 * 1024; // Max 128 MB, 1024 is scale
40 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
41 // LabelExchange(Ack) Frame Field Length
42 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
43 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
44 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
45 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
46 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
47 // Local func to set and get frame Type from packet Type field
SetFrameType(FrameType inFrameType,uint8_t & inPacketType)48 void SetFrameType(FrameType inFrameType, uint8_t &inPacketType)
49 {
50     inPacketType &= 0x0F; // Use 0x0F to clear high four bits
51     inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
52 }
53 
GetFrameType(uint8_t inPacketType)54 FrameType GetFrameType(uint8_t inPacketType)
55 {
56     uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0xF0 to get high 4 bits
57     if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
58         return FrameType::INVALID_MAX_FRAME_TYPE;
59     }
60     return static_cast<FrameType>(frameType);
61 }
62 
IsSendLabelExchange(uint8_t inPacketType)63 bool IsSendLabelExchange(uint8_t inPacketType)
64 {
65     return ((inPacketType & 0x08) >> 3) == 0; // Use 0x08 and remove low 3 bit, it is Communication negotiation mark
66 }
67 
SetSendLabelExchange(uint8_t & inPacketType,bool sendLabelExchange)68 void SetSendLabelExchange(uint8_t &inPacketType, bool sendLabelExchange)
69 {
70     if (!sendLabelExchange) {
71         inPacketType |= 0x08; // mark 0x08 when not support communication
72     }
73 }
74 }
75 
76 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
77 std::shared_mutex ProtocolProto::msgIdMutex_;
78 
GetAppLayerFrameHeaderLength()79 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
80 {
81     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
82     return length;
83 }
84 
GetLengthBeforeSerializedData()85 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
86 {
87     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
88     return length;
89 }
90 
GetCommLayerFrameHeaderLength()91 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
92 {
93     uint32_t length = sizeof(CommPhyHeader);
94     return length;
95 }
96 
ToSerialBuffer(const Message * inMsg,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader,int & outErrorNo)97 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg,
98     std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader, int &outErrorNo)
99 {
100     if (inMsg == nullptr) {
101         outErrorNo = -E_INVALID_ARGS;
102         return nullptr;
103     }
104 
105     uint32_t serializeLen = 0;
106     if (!onlyMsgHeader) {
107         int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
108         if (errCode != E_OK) {
109             outErrorNo = errCode;
110             return nullptr;
111         }
112     }
113     uint32_t headSize = 0;
114     int errCode = GetExtendHeadDataSize(extendHandle, headSize);
115     if (errCode != E_OK) {
116         outErrorNo = errCode;
117         return nullptr;
118     }
119 
120     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
121     if (buffer == nullptr) {
122         outErrorNo = -E_OUT_OF_MEMORY;
123         return nullptr;
124     }
125     if (headSize > 0) {
126         buffer->SetExtendHeadLength(headSize);
127     }
128     // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
129     uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
130     errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
131     if (errCode != E_OK) {
132         LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
133         goto ERROR_HANDLE;
134     }
135     errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
136     if (errCode != E_OK) {
137         goto ERROR_HANDLE;
138     }
139 
140     // Serialize the MessageHeader and data if need
141     errCode = SerializeMessage(buffer, inMsg);
142     if (errCode != E_OK) {
143         LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
144         goto ERROR_HANDLE;
145     }
146     outErrorNo = E_OK;
147     return buffer;
148 ERROR_HANDLE:
149     outErrorNo = errCode;
150     delete buffer;
151     buffer = nullptr;
152     return nullptr;
153 }
154 
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)155 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
156 {
157     if (inBuff == nullptr) {
158         outErrorNo = -E_INVALID_ARGS;
159         return nullptr;
160     }
161     Message *outMsg = new (std::nothrow) Message();
162     if (outMsg == nullptr) {
163         outErrorNo = -E_OUT_OF_MEMORY;
164         return nullptr;
165     }
166     int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
167     if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
168         LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
169         outErrorNo = errCode;
170         delete outMsg;
171         outMsg = nullptr;
172         return nullptr;
173     }
174     // If messageId not register in this software version, we return errCode and the Message without an object.
175     outErrorNo = errCode;
176     return outMsg;
177 }
178 
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)179 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
180 {
181     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
182     if (buffer == nullptr) {
183         outErrorNo = -E_OUT_OF_MEMORY;
184         return nullptr;
185     }
186 
187     // Empty frame has no payload, only header
188     int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
189     if (errCode != E_OK) {
190         LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
191         outErrorNo = errCode;
192         delete buffer;
193         buffer = nullptr;
194         return nullptr;
195     }
196     outErrorNo = E_OK;
197     return buffer;
198 }
199 
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)200 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
201     int &outErrorNo)
202 {
203     std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
204     SerialBuffer *buffer = ToSerialBuffer(inMsg, extendHandle, true, outErrorNo);
205     if (buffer == nullptr) {
206         // outErrorNo had already been set in ToSerialBuffer
207         return nullptr;
208     }
209     int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
210     if (errCode != E_OK) {
211         LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%.3s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
212         outErrorNo = errCode;
213         delete buffer;
214         buffer = nullptr;
215         return nullptr;
216     }
217     outErrorNo = E_OK;
218     return buffer;
219 }
220 
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)221 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
222     const std::set<LabelType> &inLabels, int &outErrorNo)
223 {
224     // Size of inLabels won't be too large.
225     // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
226     uint64_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
227         inLabels.size() * COMM_LABEL_LENGTH;
228     if (payloadLen > INT32_MAX) {
229         outErrorNo = -E_INVALID_ARGS;
230         return nullptr;
231     }
232     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
233     if (buffer == nullptr) {
234         outErrorNo = -E_OUT_OF_MEMORY;
235         return nullptr;
236     }
237     int errCode = buffer->AllocBufferByPayloadLength(static_cast<uint32_t>(payloadLen),
238         GetCommLayerFrameHeaderLength());
239     if (errCode != E_OK) {
240         LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
241         outErrorNo = errCode;
242         delete buffer;
243         buffer = nullptr;
244         return nullptr;
245     }
246 
247     auto payloadByteLen = buffer->GetWritableBytesForPayload();
248     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
249     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
250     fieldPtr++;
251     *fieldPtr = HostToNet(inDistinctValue);
252     fieldPtr++;
253     *fieldPtr = HostToNet(inSequenceId);
254     fieldPtr++;
255     *fieldPtr = HostToNet(static_cast<uint64_t>(inLabels.size()));
256     fieldPtr++;
257     // Note: don't worry, memory length had been carefully calculated above
258     auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
259     for (const auto &eachLabel : inLabels) {
260         for (const auto &eachByte : eachLabel) {
261             *bytePtr++ = eachByte;
262         }
263     }
264     outErrorNo = E_OK;
265     return buffer;
266 }
267 
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)268 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
269 {
270     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
271     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
272     if (buffer == nullptr) {
273         outErrorNo = -E_OUT_OF_MEMORY;
274         return nullptr;
275     }
276     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
277     if (errCode != E_OK) {
278         LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
279         outErrorNo = errCode;
280         delete buffer;
281         buffer = nullptr;
282         return nullptr;
283     }
284 
285     auto payloadByteLen = buffer->GetWritableBytesForPayload();
286     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
287     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
288     fieldPtr++;
289     *fieldPtr = HostToNet(inDistinctValue);
290     fieldPtr++;
291     *fieldPtr = HostToNet(inSequenceId);
292     fieldPtr++;
293     outErrorNo = E_OK;
294     return buffer;
295 }
296 
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)297 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
298     std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
299 {
300     auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
301     if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
302         return E_OK;
303     }
304     uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
305     // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
306     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
307     uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
308     // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
309     uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
310     // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
311     // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
312     uint16_t quotient = lengthToSplit / maxFragmentLen;
313     uint32_t remainder = lengthToSplit % maxFragmentLen;
314     // Finally we get the fragCount for this frame
315     uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
316     // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
317     auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
318     FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
319     return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
320 }
321 
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)322 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
323 {
324     uint32_t frameLen = inResult.GetFrameLen();
325     uint16_t fragCount = inResult.GetFragCount();
326     uint16_t fragNo = inResult.GetFragNo();
327 
328     // Firstly: Check frameLen
329     if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
330         LOGE("[Proto][ParsePhyOpt] FrameLen=%" PRIu32 " illegal.", frameLen);
331         return -E_PARSE_FAIL;
332     }
333 
334     // Secondly: Check fragCount and fragNo
335     uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
336     if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
337         LOGE("[Proto][ParsePhyOpt] FragCount=%" PRIu32 " or fragNo=%" PRIu32 " illegal.", fragCount, fragNo);
338         return -E_PARSE_FAIL;
339     }
340 
341     // Finally: Check length relation deeply
342     uint32_t quotient = lengthBeSplit / fragCount;
343     uint16_t remainder = lengthBeSplit % fragCount;
344     outFragLen = quotient;
345     outLastFragLen = quotient + remainder;
346     uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
347     if ((sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
348         inResult.GetPaddingLen()) != inResult.GetPacketLen()) {
349         LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%" PRIu32 ", FragCount=%" PRIu32 ", fragNo=%" PRIu32
350             ", PaddingLen=%" PRIu32 ", PacketLen=%" PRIu32, frameLen, fragCount, fragNo, inResult.GetPaddingLen(),
351             inResult.GetPacketLen());
352         return -E_PARSE_FAIL;
353     }
354 
355     return E_OK;
356 }
357 
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)358 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
359     uint32_t fragOffset, uint32_t fragLength)
360 {
361     // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
362     // Firstly: Check the length relation of source, even this check is not supposed to fail
363     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength ||
364         sizeof(CommPhyHeader) + static_cast<uint64_t>(fragOffset) + static_cast<uint64_t>(fragLength) > UINT32_MAX) {
365         return -E_LENGTH_ERROR;
366     }
367     // Secondly: Check the length relation of destination, even this check is not supposed to fail
368     auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
369     if (frameByteLen.first == nullptr || sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second ||
370         sizeof(CommPhyHeader) + fragOffset > frameByteLen.second) {
371         return -E_LENGTH_ERROR;
372     }
373     // Finally: Do Combination!
374     const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
375     uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
376     uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
377     errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
378     if (errCode != EOK) {
379         return -E_SECUREC_ERROR;
380     }
381     return E_OK;
382 }
383 
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)384 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
385 {
386     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
387     if (msgIdMapFunc_.count(msgId) != 0) {
388         return -E_ALREADY_REGISTER;
389     }
390     if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
391         return -E_INVALID_ARGS;
392     }
393     msgIdMapFunc_[msgId] = inFunc;
394     return E_OK;
395 }
396 
UnRegTransformFunction(uint32_t msgId)397 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
398 {
399     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
400     if (msgIdMapFunc_.count(msgId) != 0) {
401         msgIdMapFunc_.erase(msgId);
402     }
403 }
404 
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)405 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
406 {
407     if (inBuff == nullptr) {
408         return -E_INVALID_ARGS;
409     }
410     auto headerByteLen = inBuff->GetWritableBytesForHeader();
411     if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
412         return -E_INVALID_ARGS;
413     }
414     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
415 
416     CommDivergeHeader divergeHeader;
417     divergeHeader.version = PROTOCOL_VERSION;
418     divergeHeader.reserved = 0;
419     divergeHeader.payLoadLen = payloadByteLen.second;
420     // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
421     for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
422         divergeHeader.commLabel[i] = inCommLabel[i];
423     }
424     HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
425 
426     errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
427         headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
428     if (errCode != EOK) {
429         return -E_SECUREC_ERROR;
430     }
431     return E_OK;
432 }
433 
434 namespace {
FillPhyHeaderLenInfo(uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen,CommPhyHeader & header)435 void FillPhyHeaderLenInfo(uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen, CommPhyHeader &header)
436 {
437     header.packetLen = packetLen;
438     header.checkSum = sum;
439     header.packetType |= type;
440     header.paddingLen = paddingLen;
441 }
442 }
443 
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)444 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
445 {
446     if (inBuff == nullptr) {
447         return -E_INVALID_ARGS;
448     }
449     auto headerByteLen = inBuff->GetWritableBytesForHeader();
450     if (headerByteLen.second < sizeof(CommPhyHeader)) {
451         return -E_INVALID_ARGS;
452     }
453     auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
454     auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
455 
456     uint32_t packetLen = bufferByteLen.second;
457     uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
458     uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
459     if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
460         SetFrameType(inInfo.frameType, packetType);
461     } else {
462         return -E_INVALID_ARGS;
463     }
464     SetSendLabelExchange(packetType, inInfo.sendLabelExchange);
465 
466     CommPhyHeader phyHeader;
467     phyHeader.magic = MAGIC_CODE;
468     phyHeader.version = PROTOCOL_VERSION;
469     phyHeader.sourceId = inInfo.sourceId;
470     phyHeader.frameId = inInfo.frameId;
471     phyHeader.packetType = 0;
472     phyHeader.dbIntVer = DB_GLOBAL_VERSION;
473     FillPhyHeaderLenInfo(packetLen, 0, packetType, paddingLen, phyHeader); // Sum is calculated afterwards
474     HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
475 
476     errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
477     if (retCode != EOK) {
478         return -E_SECUREC_ERROR;
479     }
480 
481     uint64_t sumResult = 0;
482     int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
483         bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
484     if (errCode != E_OK) {
485         return -E_SUM_CALCULATE_FAIL;
486     }
487 
488     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
489     ptrPhyHeader->checkSum = HostToNet(sumResult);
490 
491     return E_OK;
492 }
493 
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)494 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
495     ParseResult &outResult)
496 {
497     if (bytes == nullptr || length > MAX_TOTAL_LEN) {
498         return -E_INVALID_ARGS;
499     }
500     int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
501     if (errCode != E_OK) {
502         LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
503         return errCode;
504     }
505 
506     if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
507         return E_OK; // Do nothing more for empty frame
508     }
509 
510     if (outResult.IsFragment()) {
511         errCode = ParseCommPhyOptHeader(bytes, length, outResult);
512         if (errCode != E_OK) {
513             LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
514         }
515     } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
516         errCode = ParseCommLayerPayload(bytes, length, outResult);
517         if (errCode != E_OK) {
518             LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
519         }
520     } else {
521         errCode = ParseCommDivergeHeader(bytes, length, outResult);
522         if (errCode != E_OK) {
523             LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
524         }
525     }
526     return errCode;
527 }
528 
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)529 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
530 {
531     if (inBuff == nullptr || outResult.IsFragment()) {
532         return -E_INTERNAL_ERROR;
533     }
534     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
535     if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
536         int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
537         if (errCode != E_OK) {
538             LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
539             return errCode;
540         }
541     } else {
542         int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
543         if (errCode != E_OK) {
544             LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
545             return errCode;
546         }
547     }
548     return E_OK;
549 }
550 
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)551 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
552 {
553     static const char *frameTypeStr[] = {
554         "EmptyFrame",
555         "AppLayerFrame",
556         "CommLayerFrame_LabelExchange",
557         "CommLayerFrame_LabelExchangeAck"
558     };
559 
560     if (length < sizeof(CommPhyHeader)) {
561         return;
562     }
563     auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
564     uint32_t frameId = NetToHost(phyHeader->frameId);
565     uint8_t pktType = NetToHost(phyHeader->packetType);
566     bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
567     FrameType frameType = GetFrameType(pktType);
568     if (frameType >= FrameType::INVALID_MAX_FRAME_TYPE) {
569         LOGW("[Proto][Display] This is unrecognized frame, pktType=%" PRIu8 ".", pktType);
570         return;
571     }
572     if (isFragment) {
573         if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
574             return;
575         }
576         auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
577         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ", frameLen=%" PRIu32 ", fragCount=%" PRIu32
578             ", fragNo=%" PRIu32 ".", frameTypeStr[static_cast<int32_t>(frameType)],
579             frameId, NetToHost(phyOpt->frameLen),
580             NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
581     } else {
582         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ".",
583             frameTypeStr[static_cast<int32_t>(frameType)], frameId);
584     }
585 }
586 
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)587 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
588 {
589     if ((length > INT32_MAX) || (length % sizeof(uint64_t) != 0)) {
590         LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight or larget than int32_max.", length);
591         return -E_LENGTH_ERROR;
592     }
593     int count = length / sizeof(uint64_t);
594     auto array = reinterpret_cast<const uint64_t *>(bytes);
595     outSum = 0;
596     for (int i = 0; i < count; i++) {
597         outSum ^= array[i];
598     }
599     return E_OK;
600 }
601 
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)602 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
603 {
604     uint32_t messageId = inMsg->GetMessageId();
605     TransformFunc function;
606     if (GetTransformFunc(messageId, function) != E_OK) {
607         LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%" PRIu32 ".", messageId);
608         return -E_NOT_REGISTER;
609     }
610 
611     uint32_t serializeLen = function.computeFunc(inMsg);
612     uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
613     // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
614     if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
615         LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%" PRIu32 ", serializeLen=%" PRIu32
616             ", alignedLen=%" PRIu32 ".", messageId, serializeLen, alignedLen);
617         return -E_LENGTH_ERROR;
618     }
619     // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
620     outLength = serializeLen;
621     return E_OK;
622 }
623 
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)624 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
625 {
626     auto payloadByteLen = inBuff->GetWritableBytesForPayload();
627     if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
628         LOGE("[Proto][Serialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
629         return -E_LENGTH_ERROR;
630     }
631     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
632 
633     auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
634     messageHdr->version = inMsg->GetVersion();
635     messageHdr->messageType = inMsg->GetMessageType();
636     messageHdr->messageId = inMsg->GetMessageId();
637     messageHdr->sessionId = inMsg->GetSessionId();
638     messageHdr->sequenceId = inMsg->GetSequenceId();
639     messageHdr->errorNo = inMsg->GetErrorNo();
640     messageHdr->dataLen = dataLen;
641     HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
642 
643     if (dataLen == 0) {
644         // For zero dataLen, we don't need to serialize data part
645         return E_OK;
646     }
647     // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
648     TransformFunc function;
649     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
650         LOGE("[Proto][Serialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
651         return -E_NOT_REGISTER;
652     }
653     int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
654     if (result != E_OK) {
655         LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
656         return -E_SERIALIZE_ERROR;
657     }
658     return E_OK;
659 }
660 
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)661 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
662 {
663     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
664     // Check version before parse field
665     if (payloadByteLen.second < sizeof(uint16_t)) {
666         return -E_LENGTH_ERROR;
667     }
668     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
669     if (!IsSupportMessageVersion(version)) {
670         LOGE("[Proto][DeSerialize] Version=%" PRIu32 " not support.", version);
671         return -E_VERSION_NOT_SUPPORT;
672     }
673 
674     if (payloadByteLen.second < sizeof(MessageHeader)) {
675         LOGE("[Proto][DeSerialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
676         return -E_LENGTH_ERROR;
677     }
678     auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
679     MessageHeader messageHdr;
680     HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
681     inMsg->SetVersion(version);
682     inMsg->SetMessageType(messageHdr.messageType);
683     inMsg->SetMessageId(messageHdr.messageId);
684     inMsg->SetSessionId(messageHdr.sessionId);
685     inMsg->SetSequenceId(messageHdr.sequenceId);
686     inMsg->SetErrorNo(messageHdr.errorNo);
687     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
688     if (dataLen != messageHdr.dataLen) {
689         LOGE("[Proto][DeSerialize] dataLen=%" PRIu32 ", msgDataLen=%" PRIu32 ".", dataLen, messageHdr.dataLen);
690         return -E_LENGTH_ERROR;
691     }
692     // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
693     if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
694         LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%" PRIu32 ".", messageHdr.errorNo);
695         return E_OK;
696     }
697     if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
698         return E_OK;
699     }
700     TransformFunc function;
701     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
702         LOGE("[Proto][DeSerialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
703         return -E_NOT_REGISTER;
704     }
705     int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
706     if (result != E_OK) {
707         LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
708         return -E_DESERIALIZE_ERROR;
709     }
710     return E_OK;
711 }
712 
IsSupportMessageVersion(uint16_t version)713 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
714 {
715     return (version == MSG_VERSION_BASE || version == MSG_VERSION_EXT);
716 }
717 
IsFeedbackErrorMessage(uint32_t errorNo)718 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
719 {
720     return (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND ||
721         errorNo == E_FEEDBACK_DB_CLOSING);
722 }
723 
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)724 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
725 {
726     // At least magic and version should exist
727     if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
728         LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
729         return -E_LENGTH_ERROR;
730     }
731     auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
732     uint16_t magic = NetToHost(*fieldPtr++);
733     uint16_t version = NetToHost(*fieldPtr++);
734 
735     if (magic != MAGIC_CODE) {
736         LOGE("[Proto][ParsePhyCheckVer] MagicCode=%" PRIu32 " Error.", magic);
737         return -E_PARSE_FAIL;
738     }
739     if (version != PROTOCOL_VERSION) {
740         LOGE("[Proto][ParsePhyCheckVer] Version=%" PRIu32 " Error.", version);
741         return -E_VERSION_NOT_SUPPORT;
742     }
743     return E_OK;
744 }
745 
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)746 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
747     const uint8_t *bytes, uint32_t length)
748 {
749     if (phyHeader.packetLen != length) {
750         LOGE("[Proto][ParsePhyCheck] PacketLen=%" PRIu32 " Mismatch length=%" PRIu32 ".", phyHeader.packetLen, length);
751         return -E_PARSE_FAIL;
752     }
753     if (phyHeader.paddingLen > MAX_PADDING_LEN) {
754         LOGE("[Proto][ParsePhyCheck] PaddingLen=%" PRIu32 " Error.", phyHeader.paddingLen);
755         return -E_PARSE_FAIL;
756     }
757     if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
758         LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
759         return -E_PARSE_FAIL;
760     }
761     uint64_t sumResult = 0;
762     int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
763     if (errCode != E_OK) {
764         LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
765         return -E_SUM_CALCULATE_FAIL;
766     }
767     if (phyHeader.checkSum != sumResult) {
768         LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%" PRIu64 ", sumResult=%" PRIu64 ".",
769             ULL(phyHeader.checkSum), ULL(sumResult));
770         return -E_SUM_MISMATCH;
771     }
772     return E_OK;
773 }
774 
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)775 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
776     ParseResult &inResult)
777 {
778     int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
779     if (errCode != E_OK) {
780         LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
781         return errCode;
782     }
783 
784     if (length < sizeof(CommPhyHeader)) {
785         LOGE("[Proto][ParsePhy] Length of Bytes Error.");
786         return -E_PARSE_FAIL;
787     }
788     auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
789     CommPhyHeader phyHeader;
790     HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
791     errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
792     if (errCode != E_OK) {
793         LOGE("[Proto][ParsePhy] Check Field Fail.");
794         return errCode;
795     }
796 
797     inResult.SetFrameId(phyHeader.frameId);
798     inResult.SetSourceId(phyHeader.sourceId);
799     inResult.SetPacketLen(phyHeader.packetLen);
800     inResult.SetPaddingLen(phyHeader.paddingLen);
801     inResult.SetDbVersion(phyHeader.dbIntVer);
802     if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
803         inResult.SetFragmentFlag(true);
804     } // FragmentFlag default is false
805     FrameType frameType = GetFrameType(phyHeader.packetType);
806     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
807         LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%" PRIu32 ".", phyHeader.packetType);
808         return -E_FRAME_TYPE_NOT_SUPPORT;
809     }
810     inResult.SetFrameTypeInfo(frameType);
811     inResult.SetSendLabelExchange(IsSendLabelExchange(phyHeader.packetType));
812     return E_OK;
813 }
814 
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)815 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
816 {
817     if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
818         LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
819         return -E_LENGTH_ERROR;
820     }
821     auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
822     CommPhyOptHeader phyOptHeader;
823     HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
824 
825     // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
826     inResult.SetFrameLen(phyOptHeader.frameLen);
827     inResult.SetFragCount(phyOptHeader.fragCount);
828     inResult.SetFragNo(phyOptHeader.fragNo);
829     return E_OK;
830 }
831 
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)832 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
833 {
834     // Check version before parse field
835     if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
836         return -E_LENGTH_ERROR;
837     }
838     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
839     if (version != PROTOCOL_VERSION) {
840         LOGE("[Proto][ParseDiverge] Version=%" PRIu16 " not support.", version);
841         return -E_VERSION_NOT_SUPPORT;
842     }
843 
844     if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
845         LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
846         return -E_PARSE_FAIL;
847     }
848     auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
849     CommDivergeHeader divergeHeader;
850     HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
851     if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
852         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
853         LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
854         return -E_PARSE_FAIL;
855     }
856     inResult.SetPayloadLen(divergeHeader.payLoadLen);
857     inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
858     return E_OK;
859 }
860 
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)861 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
862 {
863     if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
864         int errCode = ParseLabelExchangeAck(bytes, length, inResult);
865         if (errCode != E_OK) {
866             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
867             return errCode;
868         }
869     } else {
870         int errCode = ParseLabelExchange(bytes, length, inResult);
871         if (errCode != E_OK) {
872             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
873             return errCode;
874         }
875     }
876     return E_OK;
877 }
878 
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)879 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
880 {
881     // Check version at very first
882     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
883         return -E_LENGTH_ERROR;
884     }
885     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
886     uint64_t version = NetToHost(*fieldPtr++);
887     if (version != PROTOCOL_VERSION) {
888         LOGE("[Proto][ParseLabel] Version=%" PRIu64 " not support.", ULL(version));
889         return -E_VERSION_NOT_SUPPORT;
890     }
891 
892     // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
893     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
894         LOGE("[Proto][ParseLabel] Length of Bytes Error.");
895         return -E_LENGTH_ERROR;
896     }
897     uint64_t distinctValue = NetToHost(*fieldPtr++);
898     inResult.SetLabelExchangeDistinctValue(distinctValue);
899     uint64_t sequenceId = NetToHost(*fieldPtr++);
900     inResult.SetLabelExchangeSequenceId(sequenceId);
901     uint64_t commLabelCount = NetToHost(*fieldPtr++);
902     if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
903         LOGE("[Proto][ParseLabel] commLabelCount=%" PRIu64 " invalid.", ULL(commLabelCount));
904         return -E_PARSE_FAIL;
905     }
906     // commLabelCount is expected to be not very large
907     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
908         commLabelCount * COMM_LABEL_LENGTH) {
909         LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%" PRIu64, ULL(commLabelCount));
910         return -E_LENGTH_ERROR;
911     }
912 
913     // Get each commLabel
914     std::set<LabelType> commLabels;
915     auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
916     for (uint64_t i = 0; i < commLabelCount; i++) {
917         // the length is checked just above
918         LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
919         if (commLabels.count(commLabel) != 0) {
920             LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%.3s.", VEC_TO_STR(commLabel));
921         } else {
922             commLabels.insert(commLabel);
923         }
924     }
925     inResult.SetLatestCommLabels(commLabels);
926     return E_OK;
927 }
928 
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)929 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
930 {
931     // Check version at very first
932     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
933         return -E_LENGTH_ERROR;
934     }
935     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
936     uint64_t version = NetToHost(*fieldPtr++);
937     if (version != PROTOCOL_VERSION) {
938         LOGE("[Proto][ParseLabelAck] Version=%" PRIu64 " not support.", ULL(version));
939         return -E_VERSION_NOT_SUPPORT;
940     }
941 
942     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
943         LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
944         return -E_LENGTH_ERROR;
945     }
946     uint64_t distinctValue = NetToHost(*fieldPtr++);
947     inResult.SetLabelExchangeDistinctValue(distinctValue);
948     uint64_t sequenceId = NetToHost(*fieldPtr++);
949     inResult.SetLabelExchangeSequenceId(sequenceId);
950     return E_OK;
951 }
952 
953 // Note: framePhyHeader is in network endian
954 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)955 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
956     const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
957 {
958     // It can be guaranteed that fragCount >= 2 and also won't be too large
959     if (fragmentInfo.fragCount < MIN_FRAGMENT_COUNT) {
960         return -E_INVALID_ARGS;
961     }
962     outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
963     uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
964     uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
965     uint16_t fragNo = 0; // Fragment index start from 0
966     uint32_t byteOffset = 0;
967 
968     for (auto &entry : outPieces) {
969         // subtract 1 for index
970         uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
971         uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
972         uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
973 
974         // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
975         entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
976         if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
977             LOGE("[Proto][FrameFrag] Resize failed for length=%" PRIu32, pieceTotalLen);
978             return -E_OUT_OF_MEMORY;
979         }
980 
981         CommPhyHeader pktPhyHeader;
982         HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
983 
984         // The sum value need to be recalculated, and the packet is fragmented.
985         // The alignedFragLen is always larger than pieceFragLen
986         FillPhyHeaderLenInfo(pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen, pktPhyHeader);
987         HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
988 
989         CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
990             fragmentInfo.fragCount, fragNo};
991         HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
992         int err;
993         FragmentPacket packet;
994         uint8_t *ptrPacket = &(entry.first[0]);
995         if (fragmentInfo.extendHeadSize > 0) {
996             packet = {ptrPacket, fragmentInfo.extendHeadSize};
997             err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
998             if (err != E_OK) {
999                 return err;
1000             }
1001             ptrPacket += fragmentInfo.extendHeadSize;
1002         }
1003         packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
1004         err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
1005             pieceFragLen, packet);
1006         entry.second = fragmentInfo.extendHeadSize;
1007         if (err != E_OK) {
1008             LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%" PRIu16 ", fragNo=%" PRIu16, fragmentInfo.fragCount,
1009                 fragNo);
1010             return err;
1011         }
1012 
1013         fragNo++;
1014         byteOffset += pieceFragLen;
1015     }
1016 
1017     return E_OK;
1018 }
1019 
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)1020 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
1021 {
1022     if (headLen > outPacket.leftLength) {
1023         LOGE("[Proto][FrameFrag] headLen less than leftLength");
1024         return -E_INVALID_ARGS;
1025     }
1026     errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1027     if (retCode != EOK) {
1028         LOGE("memcpy error:%d", retCode);
1029         return -E_SECUREC_ERROR;
1030     }
1031     return E_OK;
1032 }
1033 
1034 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1035 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1036     const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1037 {
1038     if (outPacket.leftLength == 0) {
1039         return -E_INVALID_ARGS;
1040     }
1041     uint8_t *ptrPacket = outPacket.ptrPacket;
1042     uint32_t leftLength = outPacket.leftLength;
1043 
1044     // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1045     // So, there will be no redundant check during subtraction
1046     errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1047     if (retCode != EOK) {
1048         return -E_SECUREC_ERROR;
1049     }
1050     ptrPacket += sizeof(CommPhyHeader);
1051     leftLength -= sizeof(CommPhyHeader);
1052 
1053     retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1054     if (retCode != EOK) {
1055         return -E_SECUREC_ERROR;
1056     }
1057     ptrPacket += sizeof(CommPhyOptHeader);
1058     leftLength -= sizeof(CommPhyOptHeader);
1059 
1060     retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1061     if (retCode != EOK) {
1062         return -E_SECUREC_ERROR;
1063     }
1064 
1065     // Calculate sum and set sum field
1066     uint64_t sumResult = 0;
1067     int errCode  = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1068         outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1069     if (errCode != E_OK) {
1070         return -E_SUM_CALCULATE_FAIL;
1071     }
1072     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1073     if (ptrPhyHeader == nullptr) {
1074         return -E_INVALID_ARGS;
1075     }
1076     ptrPhyHeader->checkSum = HostToNet(sumResult);
1077 
1078     return E_OK;
1079 }
1080 
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1081 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1082 {
1083     if (extendHandle != nullptr) {
1084         DBStatus status = extendHandle->GetHeadDataSize(headSize);
1085         if (status != DBStatus::OK) {
1086             LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1087             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1088         }
1089         if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1090             LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1091             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1092         }
1093         return E_OK;
1094     }
1095     return E_OK;
1096 }
1097 
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1098 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1099     uint32_t headSize)
1100 {
1101     if (extendHandle != nullptr && headSize > 0) {
1102         if (buffer == nullptr) {
1103             return -E_INVALID_ARGS;
1104         }
1105         DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1106             buffer->GetSize() + headSize);
1107         if (status != DBStatus::OK) {
1108             LOGI("[Proto][ToSerial] fill head data failed");
1109             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1110         }
1111     }
1112     return E_OK;
1113 }
1114 
GetTransformFunc(uint32_t messageId,DistributedDB::TransformFunc & function)1115 int ProtocolProto::GetTransformFunc(uint32_t messageId, DistributedDB::TransformFunc &function)
1116 {
1117     std::shared_lock<std::shared_mutex> autoLock(msgIdMutex_);
1118     const auto &entry = msgIdMapFunc_.find(messageId);
1119     if (entry == msgIdMapFunc_.end()) {
1120         return -E_NOT_REGISTER;
1121     }
1122     function = entry->second;
1123     return E_OK;
1124 }
1125 } // namespace DistributedDB
1126