• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "protocol_proto.h"
17 #include <iterator>
18 #include <mutex>
19 #include <new>
20 #include "db_common.h"
21 #include "endian_convert.h"
22 #include "hash.h"
23 #include "header_converter.h"
24 #include "log_print.h"
25 #include "macro_utils.h"
26 #include "securec.h"
27 #include "version.h"
28 
29 namespace DistributedDB {
30 namespace {
31 const uint16_t MAGIC_CODE = 0xAAAA;
32 const uint16_t PROTOCOL_VERSION = 0;
33 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
34 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
35 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
36 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
37 const uint8_t MAX_PADDING_LEN = 7;
38 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
39 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
40 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
41 // LabelExchange(Ack) Frame Field Length
42 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
43 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
44 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
45 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
46 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
47 // Local func to set and get frame Type from packet Type field
SetFrameType(FrameType inFrameType,uint8_t & inPacketType)48 void SetFrameType(FrameType inFrameType, uint8_t &inPacketType)
49 {
50     inPacketType &= 0x0F; // Use 0x0F to clear high four bits
51     inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
52 }
GetFrameType(uint8_t inPacketType)53 FrameType GetFrameType(uint8_t inPacketType)
54 {
55     uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0xF0 to get high 4 bits
56     if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
57         return FrameType::INVALID_MAX_FRAME_TYPE;
58     }
59     return static_cast<FrameType>(frameType);
60 }
IsSendLabelExchange(uint8_t inPacketType)61 bool IsSendLabelExchange(uint8_t inPacketType)
62 {
63     return ((inPacketType & 0x08) >> 3) == 0; // Use 0x08 and remove low 3 bit, it is Communication negotiation mark
64 }
SetSendLabelExchange(uint8_t & inPacketType,bool sendLabelExchange)65 void SetSendLabelExchange(uint8_t &inPacketType, bool sendLabelExchange)
66 {
67     if (!sendLabelExchange) {
68         inPacketType |= 0x08; // mark 0x08 when not support communication
69     }
70 }
71 }
72 
73 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
74 std::shared_mutex ProtocolProto::msgIdMutex_;
75 
GetAppLayerFrameHeaderLength()76 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
77 {
78     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
79     return length;
80 }
81 
GetLengthBeforeSerializedData()82 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
83 {
84     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
85     return length;
86 }
87 
GetCommLayerFrameHeaderLength()88 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
89 {
90     uint32_t length = sizeof(CommPhyHeader);
91     return length;
92 }
93 
ToSerialBuffer(const Message * inMsg,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader,int & outErrorNo)94 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg,
95     std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader, int &outErrorNo)
96 {
97     if (inMsg == nullptr) {
98         outErrorNo = -E_INVALID_ARGS;
99         return nullptr;
100     }
101 
102     uint32_t serializeLen = 0;
103     if (!onlyMsgHeader) {
104         int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
105         if (errCode != E_OK) {
106             outErrorNo = errCode;
107             return nullptr;
108         }
109     }
110     uint32_t headSize = 0;
111     int errCode = GetExtendHeadDataSize(extendHandle, headSize);
112     if (errCode != E_OK) {
113         outErrorNo = errCode;
114         return nullptr;
115     }
116 
117     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
118     if (buffer == nullptr) {
119         outErrorNo = -E_OUT_OF_MEMORY;
120         return nullptr;
121     }
122     if (headSize > 0) {
123         buffer->SetExtendHeadLength(headSize);
124     }
125     // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
126     uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
127     errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
128     if (errCode != E_OK) {
129         LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
130         goto ERROR_HANDLE;
131     }
132     errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
133     if (errCode != E_OK) {
134         goto ERROR_HANDLE;
135     }
136 
137     // Serialize the MessageHeader and data if need
138     errCode = SerializeMessage(buffer, inMsg);
139     if (errCode != E_OK) {
140         LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
141         goto ERROR_HANDLE;
142     }
143     outErrorNo = E_OK;
144     return buffer;
145 ERROR_HANDLE:
146     outErrorNo = errCode;
147     delete buffer;
148     buffer = nullptr;
149     return nullptr;
150 }
151 
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)152 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
153 {
154     if (inBuff == nullptr) {
155         outErrorNo = -E_INVALID_ARGS;
156         return nullptr;
157     }
158     Message *outMsg = new (std::nothrow) Message();
159     if (outMsg == nullptr) {
160         outErrorNo = -E_OUT_OF_MEMORY;
161         return nullptr;
162     }
163     int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
164     if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
165         LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
166         outErrorNo = errCode;
167         delete outMsg;
168         outMsg = nullptr;
169         return nullptr;
170     }
171     // If messageId not register in this software version, we return errCode and the Message without an object.
172     outErrorNo = errCode;
173     return outMsg;
174 }
175 
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)176 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
177 {
178     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
179     if (buffer == nullptr) {
180         outErrorNo = -E_OUT_OF_MEMORY;
181         return nullptr;
182     }
183 
184     // Empty frame has no payload, only header
185     int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
186     if (errCode != E_OK) {
187         LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
188         outErrorNo = errCode;
189         delete buffer;
190         buffer = nullptr;
191         return nullptr;
192     }
193     outErrorNo = E_OK;
194     return buffer;
195 }
196 
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)197 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
198     int &outErrorNo)
199 {
200     std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
201     SerialBuffer *buffer = ToSerialBuffer(inMsg, extendHandle, true, outErrorNo);
202     if (buffer == nullptr) {
203         // outErrorNo had already been set in ToSerialBuffer
204         return nullptr;
205     }
206     int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
207     if (errCode != E_OK) {
208         LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%.3s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
209         outErrorNo = errCode;
210         delete buffer;
211         buffer = nullptr;
212         return nullptr;
213     }
214     outErrorNo = E_OK;
215     return buffer;
216 }
217 
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)218 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
219     const std::set<LabelType> &inLabels, int &outErrorNo)
220 {
221     // Size of inLabels won't be too large.
222     // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
223     uint64_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
224         inLabels.size() * COMM_LABEL_LENGTH;
225     if (payloadLen > INT32_MAX) {
226         outErrorNo = -E_INVALID_ARGS;
227         return nullptr;
228     }
229     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
230     if (buffer == nullptr) {
231         outErrorNo = -E_OUT_OF_MEMORY;
232         return nullptr;
233     }
234     int errCode = buffer->AllocBufferByPayloadLength(static_cast<uint32_t>(payloadLen),
235         GetCommLayerFrameHeaderLength());
236     if (errCode != E_OK) {
237         LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
238         outErrorNo = errCode;
239         delete buffer;
240         buffer = nullptr;
241         return nullptr;
242     }
243 
244     auto payloadByteLen = buffer->GetWritableBytesForPayload();
245     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
246     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
247     fieldPtr++;
248     *fieldPtr = HostToNet(inDistinctValue);
249     fieldPtr++;
250     *fieldPtr = HostToNet(inSequenceId);
251     fieldPtr++;
252     *fieldPtr = HostToNet(static_cast<uint64_t>(inLabels.size()));
253     fieldPtr++;
254     // Note: don't worry, memory length had been carefully calculated above
255     auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
256     for (const auto &eachLabel : inLabels) {
257         for (const auto &eachByte : eachLabel) {
258             *bytePtr++ = eachByte;
259         }
260     }
261     outErrorNo = E_OK;
262     return buffer;
263 }
264 
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)265 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
266 {
267     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
268     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
269     if (buffer == nullptr) {
270         outErrorNo = -E_OUT_OF_MEMORY;
271         return nullptr;
272     }
273     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
274     if (errCode != E_OK) {
275         LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
276         outErrorNo = errCode;
277         delete buffer;
278         buffer = nullptr;
279         return nullptr;
280     }
281 
282     auto payloadByteLen = buffer->GetWritableBytesForPayload();
283     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
284     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
285     fieldPtr++;
286     *fieldPtr = HostToNet(inDistinctValue);
287     fieldPtr++;
288     *fieldPtr = HostToNet(inSequenceId);
289     fieldPtr++;
290     outErrorNo = E_OK;
291     return buffer;
292 }
293 
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)294 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
295     std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
296 {
297     auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
298     if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
299         return E_OK;
300     }
301     uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
302     // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
303     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
304     uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
305     // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
306     uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
307     // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
308     // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
309     uint16_t quotient = lengthToSplit / maxFragmentLen;
310     uint32_t remainder = lengthToSplit % maxFragmentLen;
311     // Finally we get the fragCount for this frame
312     uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
313     // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
314     auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
315     FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
316     return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
317 }
318 
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)319 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
320 {
321     uint32_t frameLen = inResult.GetFrameLen();
322     uint16_t fragCount = inResult.GetFragCount();
323     uint16_t fragNo = inResult.GetFragNo();
324 
325     // Firstly: Check frameLen
326     if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
327         LOGE("[Proto][ParsePhyOpt] FrameLen=%" PRIu32 " illegal.", frameLen);
328         return -E_PARSE_FAIL;
329     }
330 
331     // Secondly: Check fragCount and fragNo
332     uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
333     if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
334         LOGE("[Proto][ParsePhyOpt] FragCount=%" PRIu32 " or fragNo=%" PRIu32 " illegal.", fragCount, fragNo);
335         return -E_PARSE_FAIL;
336     }
337 
338     // Finally: Check length relation deeply
339     uint32_t quotient = lengthBeSplit / fragCount;
340     uint16_t remainder = lengthBeSplit % fragCount;
341     outFragLen = quotient;
342     outLastFragLen = quotient + remainder;
343     uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
344     if ((sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
345         inResult.GetPaddingLen()) != inResult.GetPacketLen()) {
346         LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%" PRIu32 ", FragCount=%" PRIu32 ", fragNo=%" PRIu32
347             ", PaddingLen=%" PRIu32 ", PacketLen=%" PRIu32, frameLen, fragCount, fragNo, inResult.GetPaddingLen(),
348             inResult.GetPacketLen());
349         return -E_PARSE_FAIL;
350     }
351 
352     return E_OK;
353 }
354 
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)355 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
356     uint32_t fragOffset, uint32_t fragLength)
357 {
358     // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
359     // Firstly: Check the length relation of source, even this check is not supposed to fail
360     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
361         return -E_LENGTH_ERROR;
362     }
363     // Secondly: Check the length relation of destination, even this check is not supposed to fail
364     auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
365     if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
366         return -E_LENGTH_ERROR;
367     }
368     // Finally: Do Combination!
369     const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
370     uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
371     uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
372     errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
373     if (errCode != EOK) {
374         return -E_SECUREC_ERROR;
375     }
376     return E_OK;
377 }
378 
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)379 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
380 {
381     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
382     if (msgIdMapFunc_.count(msgId) != 0) {
383         return -E_ALREADY_REGISTER;
384     }
385     if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
386         return -E_INVALID_ARGS;
387     }
388     msgIdMapFunc_[msgId] = inFunc;
389     return E_OK;
390 }
391 
UnRegTransformFunction(uint32_t msgId)392 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
393 {
394     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
395     if (msgIdMapFunc_.count(msgId) != 0) {
396         msgIdMapFunc_.erase(msgId);
397     }
398 }
399 
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)400 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
401 {
402     if (inBuff == nullptr) {
403         return -E_INVALID_ARGS;
404     }
405     auto headerByteLen = inBuff->GetWritableBytesForHeader();
406     if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
407         return -E_INVALID_ARGS;
408     }
409     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
410 
411     CommDivergeHeader divergeHeader;
412     divergeHeader.version = PROTOCOL_VERSION;
413     divergeHeader.reserved = 0;
414     divergeHeader.payLoadLen = payloadByteLen.second;
415     // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
416     for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
417         divergeHeader.commLabel[i] = inCommLabel[i];
418     }
419     HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
420 
421     errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
422         headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
423     if (errCode != EOK) {
424         return -E_SECUREC_ERROR;
425     }
426     return E_OK;
427 }
428 
429 namespace {
FillPhyHeaderLenInfo(uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen,CommPhyHeader & header)430 void FillPhyHeaderLenInfo(uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen, CommPhyHeader &header)
431 {
432     header.packetLen = packetLen;
433     header.checkSum = sum;
434     header.packetType |= type;
435     header.paddingLen = paddingLen;
436 }
437 }
438 
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)439 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
440 {
441     if (inBuff == nullptr) {
442         return -E_INVALID_ARGS;
443     }
444     auto headerByteLen = inBuff->GetWritableBytesForHeader();
445     if (headerByteLen.second < sizeof(CommPhyHeader)) {
446         return -E_INVALID_ARGS;
447     }
448     auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
449     auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
450 
451     uint32_t packetLen = bufferByteLen.second;
452     uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
453     uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
454     if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
455         SetFrameType(inInfo.frameType, packetType);
456     } else {
457         return -E_INVALID_ARGS;
458     }
459     SetSendLabelExchange(packetType, inInfo.sendLabelExchange);
460 
461     CommPhyHeader phyHeader;
462     phyHeader.magic = MAGIC_CODE;
463     phyHeader.version = PROTOCOL_VERSION;
464     phyHeader.sourceId = inInfo.sourceId;
465     phyHeader.frameId = inInfo.frameId;
466     phyHeader.packetType = 0;
467     phyHeader.dbIntVer = DB_GLOBAL_VERSION;
468     FillPhyHeaderLenInfo(packetLen, 0, packetType, paddingLen, phyHeader); // Sum is calculated afterwards
469     HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
470 
471     errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
472     if (retCode != EOK) {
473         return -E_SECUREC_ERROR;
474     }
475 
476     uint64_t sumResult = 0;
477     int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
478         bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
479     if (errCode != E_OK) {
480         return -E_SUM_CALCULATE_FAIL;
481     }
482 
483     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
484     ptrPhyHeader->checkSum = HostToNet(sumResult);
485 
486     return E_OK;
487 }
488 
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)489 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
490     ParseResult &outResult)
491 {
492     if (bytes == nullptr || length > MAX_TOTAL_LEN) {
493         return -E_INVALID_ARGS;
494     }
495     int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
496     if (errCode != E_OK) {
497         LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
498         return errCode;
499     }
500 
501     if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
502         return E_OK; // Do nothing more for empty frame
503     }
504 
505     if (outResult.IsFragment()) {
506         errCode = ParseCommPhyOptHeader(bytes, length, outResult);
507         if (errCode != E_OK) {
508             LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
509         }
510     } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
511         errCode = ParseCommLayerPayload(bytes, length, outResult);
512         if (errCode != E_OK) {
513             LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
514         }
515     } else {
516         errCode = ParseCommDivergeHeader(bytes, length, outResult);
517         if (errCode != E_OK) {
518             LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
519         }
520     }
521     return errCode;
522 }
523 
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)524 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
525 {
526     if (inBuff == nullptr || outResult.IsFragment()) {
527         return -E_INTERNAL_ERROR;
528     }
529     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
530     if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
531         int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
532         if (errCode != E_OK) {
533             LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
534             return errCode;
535         }
536     } else {
537         int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
538         if (errCode != E_OK) {
539             LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
540             return errCode;
541         }
542     }
543     return E_OK;
544 }
545 
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)546 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
547 {
548     static const char *frameTypeStr[] = {
549         "EmptyFrame",
550         "AppLayerFrame",
551         "CommLayerFrame_LabelExchange",
552         "CommLayerFrame_LabelExchangeAck"
553     };
554 
555     if (length < sizeof(CommPhyHeader)) {
556         return;
557     }
558     auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
559     uint32_t frameId = NetToHost(phyHeader->frameId);
560     uint8_t pktType = NetToHost(phyHeader->packetType);
561     bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
562     FrameType frameType = GetFrameType(pktType);
563     if (frameType >= FrameType::INVALID_MAX_FRAME_TYPE) {
564         LOGW("[Proto][Display] This is unrecognized frame, pktType=%" PRIu8 ".", pktType);
565         return;
566     }
567     if (isFragment) {
568         if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
569             return;
570         }
571         auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
572         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ", frameLen=%" PRIu32 ", fragCount=%" PRIu32
573             ", fragNo=%" PRIu32 ".", frameTypeStr[static_cast<int32_t>(frameType)],
574             frameId, NetToHost(phyOpt->frameLen),
575             NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
576     } else {
577         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ".",
578             frameTypeStr[static_cast<int32_t>(frameType)], frameId);
579     }
580 }
581 
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)582 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
583 {
584     if ((length > INT32_MAX) || (length % sizeof(uint64_t) != 0)) {
585         LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight or larget than int32_max.", length);
586         return -E_LENGTH_ERROR;
587     }
588     int count = length / sizeof(uint64_t);
589     auto array = reinterpret_cast<const uint64_t *>(bytes);
590     outSum = 0;
591     for (int i = 0; i < count; i++) {
592         outSum ^= array[i];
593     }
594     return E_OK;
595 }
596 
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)597 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
598 {
599     uint32_t messageId = inMsg->GetMessageId();
600     TransformFunc function;
601     if (GetTransformFunc(messageId, function) != E_OK) {
602         LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%" PRIu32 ".", messageId);
603         return -E_NOT_REGISTER;
604     }
605 
606     uint32_t serializeLen = function.computeFunc(inMsg);
607     uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
608     // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
609     if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
610         LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%" PRIu32 ", serializeLen=%" PRIu32
611             ", alignedLen=%" PRIu32 ".", messageId, serializeLen, alignedLen);
612         return -E_LENGTH_ERROR;
613     }
614     // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
615     outLength = serializeLen;
616     return E_OK;
617 }
618 
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)619 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
620 {
621     auto payloadByteLen = inBuff->GetWritableBytesForPayload();
622     if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
623         LOGE("[Proto][Serialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
624         return -E_LENGTH_ERROR;
625     }
626     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
627 
628     auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
629     messageHdr->version = inMsg->GetVersion();
630     messageHdr->messageType = inMsg->GetMessageType();
631     messageHdr->messageId = inMsg->GetMessageId();
632     messageHdr->sessionId = inMsg->GetSessionId();
633     messageHdr->sequenceId = inMsg->GetSequenceId();
634     messageHdr->errorNo = inMsg->GetErrorNo();
635     messageHdr->dataLen = dataLen;
636     HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
637 
638     if (dataLen == 0) {
639         // For zero dataLen, we don't need to serialize data part
640         return E_OK;
641     }
642     // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
643     TransformFunc function;
644     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
645         LOGE("[Proto][Serialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
646         return -E_NOT_REGISTER;
647     }
648     int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
649     if (result != E_OK) {
650         LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
651         return -E_SERIALIZE_ERROR;
652     }
653     return E_OK;
654 }
655 
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)656 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
657 {
658     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
659     // Check version before parse field
660     if (payloadByteLen.second < sizeof(uint16_t)) {
661         return -E_LENGTH_ERROR;
662     }
663     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
664     if (!IsSupportMessageVersion(version)) {
665         LOGE("[Proto][DeSerialize] Version=%" PRIu32 " not support.", version);
666         return -E_VERSION_NOT_SUPPORT;
667     }
668 
669     if (payloadByteLen.second < sizeof(MessageHeader)) {
670         LOGE("[Proto][DeSerialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
671         return -E_LENGTH_ERROR;
672     }
673     auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
674     MessageHeader messageHdr;
675     HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
676     inMsg->SetVersion(version);
677     inMsg->SetMessageType(messageHdr.messageType);
678     inMsg->SetMessageId(messageHdr.messageId);
679     inMsg->SetSessionId(messageHdr.sessionId);
680     inMsg->SetSequenceId(messageHdr.sequenceId);
681     inMsg->SetErrorNo(messageHdr.errorNo);
682     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
683     if (dataLen != messageHdr.dataLen) {
684         LOGE("[Proto][DeSerialize] dataLen=%" PRIu32 ", msgDataLen=%" PRIu32 ".", dataLen, messageHdr.dataLen);
685         return -E_LENGTH_ERROR;
686     }
687     // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
688     if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
689         LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%" PRIu32 ".", messageHdr.errorNo);
690         return E_OK;
691     }
692     if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
693         return E_OK;
694     }
695     TransformFunc function;
696     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
697         LOGE("[Proto][DeSerialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
698         return -E_NOT_REGISTER;
699     }
700     int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
701     if (result != E_OK) {
702         LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
703         return -E_DESERIALIZE_ERROR;
704     }
705     return E_OK;
706 }
707 
IsSupportMessageVersion(uint16_t version)708 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
709 {
710     return (version == MSG_VERSION_BASE || version == MSG_VERSION_EXT);
711 }
712 
IsFeedbackErrorMessage(uint32_t errorNo)713 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
714 {
715     return (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND);
716 }
717 
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)718 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
719 {
720     // At least magic and version should exist
721     if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
722         LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
723         return -E_LENGTH_ERROR;
724     }
725     auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
726     uint16_t magic = NetToHost(*fieldPtr++);
727     uint16_t version = NetToHost(*fieldPtr++);
728 
729     if (magic != MAGIC_CODE) {
730         LOGE("[Proto][ParsePhyCheckVer] MagicCode=%" PRIu32 " Error.", magic);
731         return -E_PARSE_FAIL;
732     }
733     if (version != PROTOCOL_VERSION) {
734         LOGE("[Proto][ParsePhyCheckVer] Version=%" PRIu32 " Error.", version);
735         return -E_VERSION_NOT_SUPPORT;
736     }
737     return E_OK;
738 }
739 
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)740 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
741     const uint8_t *bytes, uint32_t length)
742 {
743     if (phyHeader.packetLen != length) {
744         LOGE("[Proto][ParsePhyCheck] PacketLen=%" PRIu32 " Mismatch length=%" PRIu32 ".", phyHeader.packetLen, length);
745         return -E_PARSE_FAIL;
746     }
747     if (phyHeader.paddingLen > MAX_PADDING_LEN) {
748         LOGE("[Proto][ParsePhyCheck] PaddingLen=%" PRIu32 " Error.", phyHeader.paddingLen);
749         return -E_PARSE_FAIL;
750     }
751     if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
752         LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
753         return -E_PARSE_FAIL;
754     }
755     uint64_t sumResult = 0;
756     int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
757     if (errCode != E_OK) {
758         LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
759         return -E_SUM_CALCULATE_FAIL;
760     }
761     if (phyHeader.checkSum != sumResult) {
762         LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%" PRIu64 ", sumResult=%" PRIu64 ".",
763             ULL(phyHeader.checkSum), ULL(sumResult));
764         return -E_SUM_MISMATCH;
765     }
766     return E_OK;
767 }
768 
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)769 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
770     ParseResult &inResult)
771 {
772     int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
773     if (errCode != E_OK) {
774         LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
775         return errCode;
776     }
777 
778     if (length < sizeof(CommPhyHeader)) {
779         LOGE("[Proto][ParsePhy] Length of Bytes Error.");
780         return -E_PARSE_FAIL;
781     }
782     auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
783     CommPhyHeader phyHeader;
784     HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
785     errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
786     if (errCode != E_OK) {
787         LOGE("[Proto][ParsePhy] Check Field Fail.");
788         return errCode;
789     }
790 
791     inResult.SetFrameId(phyHeader.frameId);
792     inResult.SetSourceId(phyHeader.sourceId);
793     inResult.SetPacketLen(phyHeader.packetLen);
794     inResult.SetPaddingLen(phyHeader.paddingLen);
795     inResult.SetDbVersion(phyHeader.dbIntVer);
796     if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
797         inResult.SetFragmentFlag(true);
798     } // FragmentFlag default is false
799     FrameType frameType = GetFrameType(phyHeader.packetType);
800     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
801         LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%" PRIu32 ".", phyHeader.packetType);
802         return -E_FRAME_TYPE_NOT_SUPPORT;
803     }
804     inResult.SetFrameTypeInfo(frameType);
805     inResult.SetSendLabelExchange(IsSendLabelExchange(phyHeader.packetType));
806     return E_OK;
807 }
808 
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)809 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
810 {
811     if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
812         LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
813         return -E_LENGTH_ERROR;
814     }
815     auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
816     CommPhyOptHeader phyOptHeader;
817     HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
818 
819     // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
820     inResult.SetFrameLen(phyOptHeader.frameLen);
821     inResult.SetFragCount(phyOptHeader.fragCount);
822     inResult.SetFragNo(phyOptHeader.fragNo);
823     return E_OK;
824 }
825 
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)826 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
827 {
828     // Check version before parse field
829     if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
830         return -E_LENGTH_ERROR;
831     }
832     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
833     if (version != PROTOCOL_VERSION) {
834         LOGE("[Proto][ParseDiverge] Version=%" PRIu16 " not support.", version);
835         return -E_VERSION_NOT_SUPPORT;
836     }
837 
838     if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
839         LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
840         return -E_PARSE_FAIL;
841     }
842     auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
843     CommDivergeHeader divergeHeader;
844     HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
845     if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
846         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
847         LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
848         return -E_PARSE_FAIL;
849     }
850     inResult.SetPayloadLen(divergeHeader.payLoadLen);
851     inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
852     return E_OK;
853 }
854 
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)855 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
856 {
857     if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
858         int errCode = ParseLabelExchangeAck(bytes, length, inResult);
859         if (errCode != E_OK) {
860             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
861             return errCode;
862         }
863     } else {
864         int errCode = ParseLabelExchange(bytes, length, inResult);
865         if (errCode != E_OK) {
866             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
867             return errCode;
868         }
869     }
870     return E_OK;
871 }
872 
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)873 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
874 {
875     // Check version at very first
876     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
877         return -E_LENGTH_ERROR;
878     }
879     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
880     uint64_t version = NetToHost(*fieldPtr++);
881     if (version != PROTOCOL_VERSION) {
882         LOGE("[Proto][ParseLabel] Version=%" PRIu64 " not support.", ULL(version));
883         return -E_VERSION_NOT_SUPPORT;
884     }
885 
886     // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
887     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
888         LOGE("[Proto][ParseLabel] Length of Bytes Error.");
889         return -E_LENGTH_ERROR;
890     }
891     uint64_t distinctValue = NetToHost(*fieldPtr++);
892     inResult.SetLabelExchangeDistinctValue(distinctValue);
893     uint64_t sequenceId = NetToHost(*fieldPtr++);
894     inResult.SetLabelExchangeSequenceId(sequenceId);
895     uint64_t commLabelCount = NetToHost(*fieldPtr++);
896     if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
897         LOGE("[Proto][ParseLabel] commLabelCount=%" PRIu64 " invalid.", ULL(commLabelCount));
898         return -E_PARSE_FAIL;
899     }
900     // commLabelCount is expected to be not very large
901     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
902         commLabelCount * COMM_LABEL_LENGTH) {
903         LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%" PRIu64, ULL(commLabelCount));
904         return -E_LENGTH_ERROR;
905     }
906 
907     // Get each commLabel
908     std::set<LabelType> commLabels;
909     auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
910     for (uint64_t i = 0; i < commLabelCount; i++) {
911         // the length is checked just above
912         LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
913         if (commLabels.count(commLabel) != 0) {
914             LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%.3s.", VEC_TO_STR(commLabel));
915         } else {
916             commLabels.insert(commLabel);
917         }
918     }
919     inResult.SetLatestCommLabels(commLabels);
920     return E_OK;
921 }
922 
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)923 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
924 {
925     // Check version at very first
926     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
927         return -E_LENGTH_ERROR;
928     }
929     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
930     uint64_t version = NetToHost(*fieldPtr++);
931     if (version != PROTOCOL_VERSION) {
932         LOGE("[Proto][ParseLabelAck] Version=%" PRIu64 " not support.", ULL(version));
933         return -E_VERSION_NOT_SUPPORT;
934     }
935 
936     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
937         LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
938         return -E_LENGTH_ERROR;
939     }
940     uint64_t distinctValue = NetToHost(*fieldPtr++);
941     inResult.SetLabelExchangeDistinctValue(distinctValue);
942     uint64_t sequenceId = NetToHost(*fieldPtr++);
943     inResult.SetLabelExchangeSequenceId(sequenceId);
944     return E_OK;
945 }
946 
947 // Note: framePhyHeader is in network endian
948 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)949 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
950     const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
951 {
952     // It can be guaranteed that fragCount >= 2 and also won't be too large
953     if (fragmentInfo.fragCount < MIN_FRAGMENT_COUNT) {
954         return -E_INVALID_ARGS;
955     }
956     outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
957     uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
958     uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
959     uint16_t fragNo = 0; // Fragment index start from 0
960     uint32_t byteOffset = 0;
961 
962     for (auto &entry : outPieces) {
963         // subtract 1 for index
964         uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
965         uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
966         uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
967 
968         // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
969         entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
970         if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
971             LOGE("[Proto][FrameFrag] Resize failed for length=%" PRIu32, pieceTotalLen);
972             return -E_OUT_OF_MEMORY;
973         }
974 
975         CommPhyHeader pktPhyHeader;
976         HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
977 
978         // The sum value need to be recalculated, and the packet is fragmented.
979         // The alignedFragLen is always larger than pieceFragLen
980         FillPhyHeaderLenInfo(pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen, pktPhyHeader);
981         HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
982 
983         CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
984             fragmentInfo.fragCount, fragNo};
985         HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
986         int err;
987         FragmentPacket packet;
988         uint8_t *ptrPacket = &(entry.first[0]);
989         if (fragmentInfo.extendHeadSize > 0) {
990             packet = {ptrPacket, fragmentInfo.extendHeadSize};
991             err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
992             if (err != E_OK) {
993                 return err;
994             }
995             ptrPacket += fragmentInfo.extendHeadSize;
996         }
997         packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
998         err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
999             pieceFragLen, packet);
1000         entry.second = fragmentInfo.extendHeadSize;
1001         if (err != E_OK) {
1002             LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%" PRIu16 ", fragNo=%" PRIu16, fragmentInfo.fragCount,
1003                 fragNo);
1004             return err;
1005         }
1006 
1007         fragNo++;
1008         byteOffset += pieceFragLen;
1009     }
1010 
1011     return E_OK;
1012 }
1013 
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)1014 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
1015 {
1016     if (headLen > outPacket.leftLength) {
1017         LOGE("[Proto][FrameFrag] headLen less than leftLength");
1018         return -E_INVALID_ARGS;
1019     }
1020     errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1021     if (retCode != EOK) {
1022         LOGE("memcpy error:%d", retCode);
1023         return -E_SECUREC_ERROR;
1024     }
1025     return E_OK;
1026 }
1027 
1028 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1029 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1030     const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1031 {
1032     if (outPacket.leftLength == 0) {
1033         return -E_INVALID_ARGS;
1034     }
1035     uint8_t *ptrPacket = outPacket.ptrPacket;
1036     uint32_t leftLength = outPacket.leftLength;
1037 
1038     // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1039     // So, there will be no redundant check during subtraction
1040     errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1041     if (retCode != EOK) {
1042         return -E_SECUREC_ERROR;
1043     }
1044     ptrPacket += sizeof(CommPhyHeader);
1045     leftLength -= sizeof(CommPhyHeader);
1046 
1047     retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1048     if (retCode != EOK) {
1049         return -E_SECUREC_ERROR;
1050     }
1051     ptrPacket += sizeof(CommPhyOptHeader);
1052     leftLength -= sizeof(CommPhyOptHeader);
1053 
1054     retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1055     if (retCode != EOK) {
1056         return -E_SECUREC_ERROR;
1057     }
1058 
1059     // Calculate sum and set sum field
1060     uint64_t sumResult = 0;
1061     int errCode  = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1062         outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1063     if (errCode != E_OK) {
1064         return -E_SUM_CALCULATE_FAIL;
1065     }
1066     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1067     if (ptrPhyHeader == nullptr) {
1068         return -E_INVALID_ARGS;
1069     }
1070     ptrPhyHeader->checkSum = HostToNet(sumResult);
1071 
1072     return E_OK;
1073 }
1074 
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1075 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1076 {
1077     if (extendHandle != nullptr) {
1078         DBStatus status = extendHandle->GetHeadDataSize(headSize);
1079         if (status != DBStatus::OK) {
1080             LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1081             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1082         }
1083         if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1084             LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1085             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1086         }
1087         return E_OK;
1088     }
1089     return E_OK;
1090 }
1091 
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1092 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1093     uint32_t headSize)
1094 {
1095     if (extendHandle != nullptr && headSize > 0) {
1096         if (buffer == nullptr) {
1097             return -E_INVALID_ARGS;
1098         }
1099         DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1100             buffer->GetSize() + headSize);
1101         if (status != DBStatus::OK) {
1102             LOGI("[Proto][ToSerial] fill head data failed");
1103             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1104         }
1105     }
1106     return E_OK;
1107 }
1108 
GetTransformFunc(uint32_t messageId,DistributedDB::TransformFunc & function)1109 int ProtocolProto::GetTransformFunc(uint32_t messageId, DistributedDB::TransformFunc &function)
1110 {
1111     std::shared_lock<std::shared_mutex> autoLock(msgIdMutex_);
1112     const auto &entry = msgIdMapFunc_.find(messageId);
1113     if (entry == msgIdMapFunc_.end()) {
1114         return -E_NOT_REGISTER;
1115     }
1116     function = entry->second;
1117     return E_OK;
1118 }
1119 } // namespace DistributedDB
1120