• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "protocol_proto.h"
17 #include <iterator>
18 #include <mutex>
19 #include <new>
20 #include "db_common.h"
21 #include "endian_convert.h"
22 #include "hash.h"
23 #include "header_converter.h"
24 #include "log_print.h"
25 #include "macro_utils.h"
26 #include "securec.h"
27 #include "version.h"
28 
29 namespace DistributedDB {
30 namespace {
31 const uint16_t MAGIC_CODE = 0xAAAA;
32 const uint16_t PROTOCOL_VERSION = 0;
33 // Compatibility Final Method. 3 Correspond To Version 1.1.4(104)
34 const uint16_t DB_GLOBAL_VERSION = SOFTWARE_VERSION_CURRENT - SOFTWARE_VERSION_EARLIEST;
35 const uint8_t PACKET_TYPE_FRAGMENTED = BITX(0); // Use bit 0
36 const uint8_t PACKET_TYPE_NOT_FRAGMENTED = 0;
37 const uint8_t MAX_PADDING_LEN = 7;
38 const uint32_t LENGTH_BEFORE_SUM_RANGE = sizeof(uint64_t) + sizeof(uint64_t);
39 const uint32_t MAX_FRAME_LEN = 32 * 1024 * 1024; // Max 32 MB, 1024 is scale
40 const uint16_t MIN_FRAGMENT_COUNT = 2; // At least a frame will be splited into 2 parts
41 // LabelExchange(Ack) Frame Field Length
42 const uint32_t LABEL_VER_LEN = sizeof(uint64_t);
43 const uint32_t DISTINCT_VALUE_LEN = sizeof(uint64_t);
44 const uint32_t SEQUENCE_ID_LEN = sizeof(uint64_t);
45 // Note: COMM_LABEL_LENGTH is defined in communicator_type_define.h
46 const uint32_t COMM_LABEL_COUNT_LEN = sizeof(uint64_t);
47 // Local func to set and get frame Type from packet Type field
SetFrameType(FrameType inFrameType,uint8_t & inPacketType)48 void SetFrameType(FrameType inFrameType, uint8_t &inPacketType)
49 {
50     inPacketType &= 0x0F; // Use 0x0F to clear high four bits
51     inPacketType |= (static_cast<uint8_t>(inFrameType) << 4); // frame type is on high 4 bits
52 }
GetFrameType(uint8_t inPacketType)53 FrameType GetFrameType(uint8_t inPacketType)
54 {
55     uint8_t frameType = ((inPacketType & 0xF0) >> 4); // Use 0xF0 to get high 4 bits
56     if (frameType >= static_cast<uint8_t>(FrameType::INVALID_MAX_FRAME_TYPE)) {
57         return FrameType::INVALID_MAX_FRAME_TYPE;
58     }
59     return static_cast<FrameType>(frameType);
60 }
61 }
62 
63 std::map<uint32_t, TransformFunc> ProtocolProto::msgIdMapFunc_;
64 std::shared_mutex ProtocolProto::msgIdMutex_;
65 
GetAppLayerFrameHeaderLength()66 uint32_t ProtocolProto::GetAppLayerFrameHeaderLength()
67 {
68     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader);
69     return length;
70 }
71 
GetLengthBeforeSerializedData()72 uint32_t ProtocolProto::GetLengthBeforeSerializedData()
73 {
74     uint32_t length = sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + sizeof(MessageHeader);
75     return length;
76 }
77 
GetCommLayerFrameHeaderLength()78 uint32_t ProtocolProto::GetCommLayerFrameHeaderLength()
79 {
80     uint32_t length = sizeof(CommPhyHeader);
81     return length;
82 }
83 
ToSerialBuffer(const Message * inMsg,std::shared_ptr<ExtendHeaderHandle> & extendHandle,bool onlyMsgHeader,int & outErrorNo)84 SerialBuffer *ProtocolProto::ToSerialBuffer(const Message *inMsg,
85     std::shared_ptr<ExtendHeaderHandle> &extendHandle, bool onlyMsgHeader, int &outErrorNo)
86 {
87     if (inMsg == nullptr) {
88         outErrorNo = -E_INVALID_ARGS;
89         return nullptr;
90     }
91 
92     uint32_t serializeLen = 0;
93     if (!onlyMsgHeader) {
94         int errCode = CalculateDataSerializeLength(inMsg, serializeLen);
95         if (errCode != E_OK) {
96             outErrorNo = errCode;
97             return nullptr;
98         }
99     }
100     uint32_t headSize = 0;
101     int errCode = GetExtendHeadDataSize(extendHandle, headSize);
102     if (errCode != E_OK) {
103         outErrorNo = errCode;
104         return nullptr;
105     }
106 
107     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
108     if (buffer == nullptr) {
109         outErrorNo = -E_OUT_OF_MEMORY;
110         return nullptr;
111     }
112     if (headSize > 0) {
113         buffer->SetExtendHeadLength(headSize);
114     }
115     // serializeLen maybe not 8-bytes aligned, let SerialBuffer deal with the padding.
116     uint32_t payLoadLength = serializeLen + sizeof(MessageHeader);
117     errCode = buffer->AllocBufferByPayloadLength(payLoadLength, GetAppLayerFrameHeaderLength());
118     if (errCode != E_OK) {
119         LOGE("[Proto][ToSerial] Alloc Fail, errCode=%d.", errCode);
120         goto ERROR_HANDLE;
121     }
122     errCode = FillExtendHeadDataIfNeed(extendHandle, buffer, headSize);
123     if (errCode != E_OK) {
124         goto ERROR_HANDLE;
125     }
126 
127     // Serialize the MessageHeader and data if need
128     errCode = SerializeMessage(buffer, inMsg);
129     if (errCode != E_OK) {
130         LOGE("[Proto][ToSerial] Serialize Fail, errCode=%d.", errCode);
131         goto ERROR_HANDLE;
132     }
133     outErrorNo = E_OK;
134     return buffer;
135 ERROR_HANDLE:
136     outErrorNo = errCode;
137     delete buffer;
138     buffer = nullptr;
139     return nullptr;
140 }
141 
ToMessage(const SerialBuffer * inBuff,int & outErrorNo,bool onlyMsgHeader)142 Message *ProtocolProto::ToMessage(const SerialBuffer *inBuff, int &outErrorNo, bool onlyMsgHeader)
143 {
144     if (inBuff == nullptr) {
145         outErrorNo = -E_INVALID_ARGS;
146         return nullptr;
147     }
148     Message *outMsg = new (std::nothrow) Message();
149     if (outMsg == nullptr) {
150         outErrorNo = -E_OUT_OF_MEMORY;
151         return nullptr;
152     }
153     int errCode = DeSerializeMessage(inBuff, outMsg, onlyMsgHeader);
154     if (errCode != E_OK && errCode != -E_NOT_REGISTER) {
155         LOGE("[Proto][ToMessage] DeSerialize Fail, errCode=%d.", errCode);
156         outErrorNo = errCode;
157         delete outMsg;
158         outMsg = nullptr;
159         return nullptr;
160     }
161     // If messageId not register in this software version, we return errCode and the Message without an object.
162     outErrorNo = errCode;
163     return outMsg;
164 }
165 
BuildEmptyFrameForVersionNegotiate(int & outErrorNo)166 SerialBuffer *ProtocolProto::BuildEmptyFrameForVersionNegotiate(int &outErrorNo)
167 {
168     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
169     if (buffer == nullptr) {
170         outErrorNo = -E_OUT_OF_MEMORY;
171         return nullptr;
172     }
173 
174     // Empty frame has no payload, only header
175     int errCode = buffer->AllocBufferByPayloadLength(0, GetCommLayerFrameHeaderLength());
176     if (errCode != E_OK) {
177         LOGE("[Proto][BuildEmpty] Alloc Fail, errCode=%d.", errCode);
178         outErrorNo = errCode;
179         delete buffer;
180         buffer = nullptr;
181         return nullptr;
182     }
183     outErrorNo = E_OK;
184     return buffer;
185 }
186 
BuildFeedbackMessageFrame(const Message * inMsg,const LabelType & inLabel,int & outErrorNo)187 SerialBuffer *ProtocolProto::BuildFeedbackMessageFrame(const Message *inMsg, const LabelType &inLabel,
188     int &outErrorNo)
189 {
190     std::shared_ptr<ExtendHeaderHandle> extendHandle = nullptr;
191     SerialBuffer *buffer = ToSerialBuffer(inMsg, extendHandle, true, outErrorNo);
192     if (buffer == nullptr) {
193         // outErrorNo had already been set in ToSerialBuffer
194         return nullptr;
195     }
196     int errCode = ProtocolProto::SetDivergeHeader(buffer, inLabel);
197     if (errCode != E_OK) {
198         LOGE("[Proto][BuildFeedback] Set DivergeHeader fail, label=%.3s, errCode=%d.", VEC_TO_STR(inLabel), errCode);
199         outErrorNo = errCode;
200         delete buffer;
201         buffer = nullptr;
202         return nullptr;
203     }
204     outErrorNo = E_OK;
205     return buffer;
206 }
207 
BuildLabelExchange(uint64_t inDistinctValue,uint64_t inSequenceId,const std::set<LabelType> & inLabels,int & outErrorNo)208 SerialBuffer *ProtocolProto::BuildLabelExchange(uint64_t inDistinctValue, uint64_t inSequenceId,
209     const std::set<LabelType> &inLabels, int &outErrorNo)
210 {
211     // Size of inLabels won't be too large.
212     // The upper layer code(inside this communicator module) guarantee that size of each Label equals COMM_LABEL_LENGTH
213     uint64_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
214         inLabels.size() * COMM_LABEL_LENGTH;
215     if (payloadLen > INT32_MAX) {
216         outErrorNo = -E_INVALID_ARGS;
217         return nullptr;
218     }
219     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
220     if (buffer == nullptr) {
221         outErrorNo = -E_OUT_OF_MEMORY;
222         return nullptr;
223     }
224     int errCode = buffer->AllocBufferByPayloadLength(static_cast<uint32_t>(payloadLen),
225         GetCommLayerFrameHeaderLength());
226     if (errCode != E_OK) {
227         LOGE("[Proto][BuildLabel] Alloc Fail, errCode=%d.", errCode);
228         outErrorNo = errCode;
229         delete buffer;
230         buffer = nullptr;
231         return nullptr;
232     }
233 
234     auto payloadByteLen = buffer->GetWritableBytesForPayload();
235     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
236     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
237     fieldPtr++;
238     *fieldPtr = HostToNet(inDistinctValue);
239     fieldPtr++;
240     *fieldPtr = HostToNet(inSequenceId);
241     fieldPtr++;
242     *fieldPtr = HostToNet(static_cast<uint64_t>(inLabels.size()));
243     fieldPtr++;
244     // Note: don't worry, memory length had been carefully calculated above
245     auto bytePtr = reinterpret_cast<uint8_t *>(fieldPtr);
246     for (const auto &eachLabel : inLabels) {
247         for (const auto &eachByte : eachLabel) {
248             *bytePtr++ = eachByte;
249         }
250     }
251     outErrorNo = E_OK;
252     return buffer;
253 }
254 
BuildLabelExchangeAck(uint64_t inDistinctValue,uint64_t inSequenceId,int & outErrorNo)255 SerialBuffer *ProtocolProto::BuildLabelExchangeAck(uint64_t inDistinctValue, uint64_t inSequenceId, int &outErrorNo)
256 {
257     uint32_t payloadLen = LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN;
258     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
259     if (buffer == nullptr) {
260         outErrorNo = -E_OUT_OF_MEMORY;
261         return nullptr;
262     }
263     int errCode = buffer->AllocBufferByPayloadLength(payloadLen, GetCommLayerFrameHeaderLength());
264     if (errCode != E_OK) {
265         LOGE("[Proto][BuildLabelAck] Alloc Fail, errCode=%d.", errCode);
266         outErrorNo = errCode;
267         delete buffer;
268         buffer = nullptr;
269         return nullptr;
270     }
271 
272     auto payloadByteLen = buffer->GetWritableBytesForPayload();
273     auto fieldPtr = reinterpret_cast<uint64_t *>(payloadByteLen.first);
274     *fieldPtr = HostToNet(static_cast<uint64_t>(PROTOCOL_VERSION));
275     fieldPtr++;
276     *fieldPtr = HostToNet(inDistinctValue);
277     fieldPtr++;
278     *fieldPtr = HostToNet(inSequenceId);
279     fieldPtr++;
280     outErrorNo = E_OK;
281     return buffer;
282 }
283 
SplitFrameIntoPacketsIfNeed(const SerialBuffer * inBuff,uint32_t inMtuSize,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)284 int ProtocolProto::SplitFrameIntoPacketsIfNeed(const SerialBuffer *inBuff, uint32_t inMtuSize,
285     std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
286 {
287     auto bufferBytesLen = inBuff->GetReadOnlyBytesForEntireBuffer();
288     if ((bufferBytesLen.second + inBuff->GetExtendHeadLength()) <= inMtuSize) {
289         return E_OK;
290     }
291     uint32_t modifyMtuSize = inMtuSize - inBuff->GetExtendHeadLength();
292     // Do Fragmentaion! This function aims at calculate how many fragments to be split into.
293     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame(); // Padding not in the range of fragmentation.
294     uint32_t lengthToSplit = frameBytesLen.second - sizeof(CommPhyHeader); // The former is always larger than latter.
295     // The inMtuSize pass from CommunicatorAggregator is large enough to be subtract by the latter two.
296     uint32_t maxFragmentLen = modifyMtuSize - sizeof(CommPhyHeader) - sizeof(CommPhyOptHeader);
297     // It can be proved that lengthToSplit is always larger than maxFragmentLen, so quotient won't be zero.
298     // The maxFragmentLen won't be zero and in fact large enough to make sure no precision loss during division
299     uint16_t quotient = lengthToSplit / maxFragmentLen;
300     uint32_t remainder = lengthToSplit % maxFragmentLen;
301     // Finally we get the fragCount for this frame
302     uint16_t fragCount = ((remainder == 0) ? quotient : (quotient + 1));
303     // Get CommPhyHeader of this frame to be modified for each packets (Header in network endian)
304     auto oriPhyHeader = reinterpret_cast<const CommPhyHeader *>(frameBytesLen.first);
305     FrameFragmentInfo fragInfo = {inBuff->GetOringinalAddr(), inBuff->GetExtendHeadLength(), lengthToSplit, fragCount};
306     return FrameFragmentation(frameBytesLen.first + sizeof(CommPhyHeader), fragInfo, *oriPhyHeader, outPieces);
307 }
308 
AnalyzeSplitStructure(const ParseResult & inResult,uint32_t & outFragLen,uint32_t & outLastFragLen)309 int ProtocolProto::AnalyzeSplitStructure(const ParseResult &inResult, uint32_t &outFragLen, uint32_t &outLastFragLen)
310 {
311     uint32_t frameLen = inResult.GetFrameLen();
312     uint16_t fragCount = inResult.GetFragCount();
313     uint16_t fragNo = inResult.GetFragNo();
314 
315     // Firstly: Check frameLen
316     if (frameLen <= sizeof(CommPhyHeader) || frameLen > MAX_FRAME_LEN) {
317         LOGE("[Proto][ParsePhyOpt] FrameLen=%" PRIu32 " illegal.", frameLen);
318         return -E_PARSE_FAIL;
319     }
320 
321     // Secondly: Check fragCount and fragNo
322     uint32_t lengthBeSplit = frameLen - sizeof(CommPhyHeader);
323     if (fragCount == 0 || fragCount < MIN_FRAGMENT_COUNT || fragCount > lengthBeSplit || fragNo >= fragCount) {
324         LOGE("[Proto][ParsePhyOpt] FragCount=%" PRIu32 " or fragNo=%" PRIu32 " illegal.", fragCount, fragNo);
325         return -E_PARSE_FAIL;
326     }
327 
328     // Finally: Check length relation deeply
329     uint32_t quotient = lengthBeSplit / fragCount;
330     uint16_t remainder = lengthBeSplit % fragCount;
331     outFragLen = quotient;
332     outLastFragLen = quotient + remainder;
333     uint32_t thisFragLen = ((fragNo != fragCount - 1) ? outFragLen : outLastFragLen); // subtract by 1 for index
334     if ((sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + thisFragLen +
335         inResult.GetPaddingLen()) != inResult.GetPacketLen()) {
336         LOGE("[Proto][ParsePhyOpt] Length Error: FrameLen=%" PRIu32 ", FragCount=%" PRIu32 ", fragNo=%" PRIu32
337             ", PaddingLen=%" PRIu32 ", PacketLen=%" PRIu32, frameLen, fragCount, fragNo, inResult.GetPaddingLen(),
338             inResult.GetPacketLen());
339         return -E_PARSE_FAIL;
340     }
341 
342     return E_OK;
343 }
344 
CombinePacketIntoFrame(SerialBuffer * inFrame,const uint8_t * pktBytes,uint32_t pktLength,uint32_t fragOffset,uint32_t fragLength)345 int ProtocolProto::CombinePacketIntoFrame(SerialBuffer *inFrame, const uint8_t *pktBytes, uint32_t pktLength,
346     uint32_t fragOffset, uint32_t fragLength)
347 {
348     // inFrame is the destination, pktBytes and pktLength are the source, fragOffset and fragLength give the boundary
349     // Firstly: Check the length relation of source, even this check is not supposed to fail
350     if (sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader) + fragLength > pktLength) {
351         return -E_LENGTH_ERROR;
352     }
353     // Secondly: Check the length relation of destination, even this check is not supposed to fail
354     auto frameByteLen = inFrame->GetWritableBytesForEntireFrame();
355     if (sizeof(CommPhyHeader) + fragOffset + fragLength > frameByteLen.second) {
356         return -E_LENGTH_ERROR;
357     }
358     // Finally: Do Combination!
359     const uint8_t *srcByteHead = pktBytes + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
360     uint8_t *dstByteHead = frameByteLen.first + sizeof(CommPhyHeader) + fragOffset;
361     uint32_t dstLeftLen = frameByteLen.second - sizeof(CommPhyHeader) - fragOffset;
362     errno_t errCode = memcpy_s(dstByteHead, dstLeftLen, srcByteHead, fragLength);
363     if (errCode != EOK) {
364         return -E_SECUREC_ERROR;
365     }
366     return E_OK;
367 }
368 
RegTransformFunction(uint32_t msgId,const TransformFunc & inFunc)369 int ProtocolProto::RegTransformFunction(uint32_t msgId, const TransformFunc &inFunc)
370 {
371     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
372     if (msgIdMapFunc_.count(msgId) != 0) {
373         return -E_ALREADY_REGISTER;
374     }
375     if (!inFunc.computeFunc || !inFunc.serializeFunc || !inFunc.deserializeFunc) {
376         return -E_INVALID_ARGS;
377     }
378     msgIdMapFunc_[msgId] = inFunc;
379     return E_OK;
380 }
381 
UnRegTransformFunction(uint32_t msgId)382 void ProtocolProto::UnRegTransformFunction(uint32_t msgId)
383 {
384     std::unique_lock<std::shared_mutex> autoLock(msgIdMutex_);
385     if (msgIdMapFunc_.count(msgId) != 0) {
386         msgIdMapFunc_.erase(msgId);
387     }
388 }
389 
SetDivergeHeader(SerialBuffer * inBuff,const LabelType & inCommLabel)390 int ProtocolProto::SetDivergeHeader(SerialBuffer *inBuff, const LabelType &inCommLabel)
391 {
392     if (inBuff == nullptr) {
393         return -E_INVALID_ARGS;
394     }
395     auto headerByteLen = inBuff->GetWritableBytesForHeader();
396     if (headerByteLen.second != GetAppLayerFrameHeaderLength()) {
397         return -E_INVALID_ARGS;
398     }
399     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
400 
401     CommDivergeHeader divergeHeader;
402     divergeHeader.version = PROTOCOL_VERSION;
403     divergeHeader.reserved = 0;
404     divergeHeader.payLoadLen = payloadByteLen.second;
405     // The upper layer code(inside this communicator module) guarantee that size of inCommLabel equal COMM_LABEL_LENGTH
406     for (unsigned int i = 0; i < COMM_LABEL_LENGTH; i++) {
407         divergeHeader.commLabel[i] = inCommLabel[i];
408     }
409     HeaderConverter::ConvertHostToNet(divergeHeader, divergeHeader);
410 
411     errno_t errCode = memcpy_s(headerByteLen.first + sizeof(CommPhyHeader),
412         headerByteLen.second - sizeof(CommPhyHeader), &divergeHeader, sizeof(CommDivergeHeader));
413     if (errCode != EOK) {
414         return -E_SECUREC_ERROR;
415     }
416     return E_OK;
417 }
418 
419 namespace {
FillPhyHeaderLenInfo(uint32_t packetLen,uint64_t sum,uint8_t type,uint8_t paddingLen,CommPhyHeader & header)420 void FillPhyHeaderLenInfo(uint32_t packetLen, uint64_t sum, uint8_t type, uint8_t paddingLen, CommPhyHeader &header)
421 {
422     header.packetLen = packetLen;
423     header.checkSum = sum;
424     header.packetType |= type;
425     header.paddingLen = paddingLen;
426 }
427 }
428 
SetPhyHeader(SerialBuffer * inBuff,const PhyHeaderInfo & inInfo)429 int ProtocolProto::SetPhyHeader(SerialBuffer *inBuff, const PhyHeaderInfo &inInfo)
430 {
431     if (inBuff == nullptr) {
432         return -E_INVALID_ARGS;
433     }
434     auto headerByteLen = inBuff->GetWritableBytesForHeader();
435     if (headerByteLen.second < sizeof(CommPhyHeader)) {
436         return -E_INVALID_ARGS;
437     }
438     auto bufferByteLen = inBuff->GetReadOnlyBytesForEntireBuffer();
439     auto frameByteLen = inBuff->GetReadOnlyBytesForEntireFrame();
440 
441     uint32_t packetLen = bufferByteLen.second;
442     uint8_t paddingLen = static_cast<uint8_t>(bufferByteLen.second - frameByteLen.second);
443     uint8_t packetType = PACKET_TYPE_NOT_FRAGMENTED;
444     if (inInfo.frameType != FrameType::INVALID_MAX_FRAME_TYPE) {
445         SetFrameType(inInfo.frameType, packetType);
446     } else {
447         return -E_INVALID_ARGS;
448     }
449 
450     CommPhyHeader phyHeader;
451     phyHeader.magic = MAGIC_CODE;
452     phyHeader.version = PROTOCOL_VERSION;
453     phyHeader.sourceId = inInfo.sourceId;
454     phyHeader.frameId = inInfo.frameId;
455     phyHeader.packetType = 0;
456     phyHeader.dbIntVer = DB_GLOBAL_VERSION;
457     FillPhyHeaderLenInfo(packetLen, 0, packetType, paddingLen, phyHeader); // Sum is calculated afterwards
458     HeaderConverter::ConvertHostToNet(phyHeader, phyHeader);
459 
460     errno_t retCode = memcpy_s(headerByteLen.first, headerByteLen.second, &phyHeader, sizeof(CommPhyHeader));
461     if (retCode != EOK) {
462         return -E_SECUREC_ERROR;
463     }
464 
465     uint64_t sumResult = 0;
466     int errCode = CalculateXorSum(bufferByteLen.first + LENGTH_BEFORE_SUM_RANGE,
467         bufferByteLen.second - LENGTH_BEFORE_SUM_RANGE, sumResult);
468     if (errCode != E_OK) {
469         return -E_SUM_CALCULATE_FAIL;
470     }
471 
472     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(headerByteLen.first);
473     ptrPhyHeader->checkSum = HostToNet(sumResult);
474 
475     return E_OK;
476 }
477 
CheckAndParsePacket(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & outResult)478 int ProtocolProto::CheckAndParsePacket(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
479     ParseResult &outResult)
480 {
481     if (bytes == nullptr || length > MAX_TOTAL_LEN) {
482         return -E_INVALID_ARGS;
483     }
484     int errCode = ParseCommPhyHeader(srcTarget, bytes, length, outResult);
485     if (errCode != E_OK) {
486         LOGE("[Proto][ParsePacket] Parse PhyHeader Fail, errCode=%d.", errCode);
487         return errCode;
488     }
489 
490     if (outResult.GetFrameTypeInfo() == FrameType::EMPTY) {
491         return E_OK; // Do nothing more for empty frame
492     }
493 
494     if (outResult.IsFragment()) {
495         errCode = ParseCommPhyOptHeader(bytes, length, outResult);
496         if (errCode != E_OK) {
497             LOGE("[Proto][ParsePacket] Parse CommPhyOptHeader Fail, errCode=%d.", errCode);
498         }
499     } else if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
500         errCode = ParseCommLayerPayload(bytes, length, outResult);
501         if (errCode != E_OK) {
502             LOGE("[Proto][ParsePacket] Parse CommLayerPayload Fail, errCode=%d.", errCode);
503         }
504     } else {
505         errCode = ParseCommDivergeHeader(bytes, length, outResult);
506         if (errCode != E_OK) {
507             LOGE("[Proto][ParsePacket] Parse DivergeHeader Fail, errCode=%d.", errCode);
508         }
509     }
510     return errCode;
511 }
512 
CheckAndParseFrame(const SerialBuffer * inBuff,ParseResult & outResult)513 int ProtocolProto::CheckAndParseFrame(const SerialBuffer *inBuff, ParseResult &outResult)
514 {
515     if (inBuff == nullptr || outResult.IsFragment()) {
516         return -E_INTERNAL_ERROR;
517     }
518     auto frameBytesLen = inBuff->GetReadOnlyBytesForEntireFrame();
519     if (outResult.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) {
520         int errCode = ParseCommLayerPayload(frameBytesLen.first, frameBytesLen.second, outResult);
521         if (errCode != E_OK) {
522             LOGE("[Proto][ParseFrame] Parse CommLayerPayload Fail, errCode=%d.", errCode);
523             return errCode;
524         }
525     } else {
526         int errCode = ParseCommDivergeHeader(frameBytesLen.first, frameBytesLen.second, outResult);
527         if (errCode != E_OK) {
528             LOGE("[Proto][ParseFrame] Parse DivergeHeader Fail, errCode=%d.", errCode);
529             return errCode;
530         }
531     }
532     return E_OK;
533 }
534 
DisplayPacketInformation(const uint8_t * bytes,uint32_t length)535 void ProtocolProto::DisplayPacketInformation(const uint8_t *bytes, uint32_t length)
536 {
537     static const char *frameTypeStr[] = {
538         "EmptyFrame",
539         "AppLayerFrame",
540         "CommLayerFrame_LabelExchange",
541         "CommLayerFrame_LabelExchangeAck"
542     };
543 
544     if (length < sizeof(CommPhyHeader)) {
545         return;
546     }
547     auto phyHeader = reinterpret_cast<const CommPhyHeader *>(bytes);
548     uint32_t frameId = NetToHost(phyHeader->frameId);
549     uint8_t pktType = NetToHost(phyHeader->packetType);
550     bool isFragment = ((pktType & PACKET_TYPE_FRAGMENTED) != 0);
551     FrameType frameType = GetFrameType(pktType);
552     if (frameType >= FrameType::INVALID_MAX_FRAME_TYPE) {
553         LOGW("[Proto][Display] This is unrecognized frame, pktType=%" PRIu8 ".", pktType);
554         return;
555     }
556     if (isFragment) {
557         if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
558             return;
559         }
560         auto phyOpt = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
561         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ", frameLen=%" PRIu32 ", fragCount=%" PRIu32
562             ", fragNo=%" PRIu32 ".", frameTypeStr[static_cast<int32_t>(frameType)],
563             frameId, NetToHost(phyOpt->frameLen),
564             NetToHost(phyOpt->fragCount), NetToHost(phyOpt->fragNo));
565     } else {
566         LOGI("[Proto][Display] This is %s, frameId=%" PRIu32 ".",
567             frameTypeStr[static_cast<int32_t>(frameType)], frameId);
568     }
569 }
570 
CalculateXorSum(const uint8_t * bytes,uint32_t length,uint64_t & outSum)571 int ProtocolProto::CalculateXorSum(const uint8_t *bytes, uint32_t length, uint64_t &outSum)
572 {
573     if ((length > INT32_MAX) || (length % sizeof(uint64_t) != 0)) {
574         LOGE("[Proto][CalcuXorSum] Length=%d not multiple of eight or larget than int32_max.", length);
575         return -E_LENGTH_ERROR;
576     }
577     int count = length / sizeof(uint64_t);
578     auto array = reinterpret_cast<const uint64_t *>(bytes);
579     outSum = 0;
580     for (int i = 0; i < count; i++) {
581         outSum ^= array[i];
582     }
583     return E_OK;
584 }
585 
CalculateDataSerializeLength(const Message * inMsg,uint32_t & outLength)586 int ProtocolProto::CalculateDataSerializeLength(const Message *inMsg, uint32_t &outLength)
587 {
588     uint32_t messageId = inMsg->GetMessageId();
589     TransformFunc function;
590     if (GetTransformFunc(messageId, function) != E_OK) {
591         LOGE("[Proto][CalcuDataSerialLen] Not registered for messageId=%" PRIu32 ".", messageId);
592         return -E_NOT_REGISTER;
593     }
594 
595     uint32_t serializeLen = function.computeFunc(inMsg);
596     uint32_t alignedLen = BYTE_8_ALIGN(serializeLen);
597     // Currently not allowed the upper module to send a message without data. Regard serializeLen zero as abnormal.
598     if (serializeLen == 0 || alignedLen > MAX_FRAME_LEN - GetLengthBeforeSerializedData()) {
599         LOGE("[Proto][CalcuDataSerialLen] Length too large, msgId=%" PRIu32 ", serializeLen=%" PRIu32
600             ", alignedLen=%" PRIu32 ".", messageId, serializeLen, alignedLen);
601         return -E_LENGTH_ERROR;
602     }
603     // Attention: return the serializeLen nor the alignedLen. Let SerialBuffer to deal with the padding
604     outLength = serializeLen;
605     return E_OK;
606 }
607 
SerializeMessage(SerialBuffer * inBuff,const Message * inMsg)608 int ProtocolProto::SerializeMessage(SerialBuffer *inBuff, const Message *inMsg)
609 {
610     auto payloadByteLen = inBuff->GetWritableBytesForPayload();
611     if (payloadByteLen.second < sizeof(MessageHeader)) { // For equal, only msgHeader case
612         LOGE("[Proto][Serialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
613         return -E_LENGTH_ERROR;
614     }
615     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
616 
617     auto messageHdr = reinterpret_cast<MessageHeader *>(payloadByteLen.first);
618     messageHdr->version = inMsg->GetVersion();
619     messageHdr->messageType = inMsg->GetMessageType();
620     messageHdr->messageId = inMsg->GetMessageId();
621     messageHdr->sessionId = inMsg->GetSessionId();
622     messageHdr->sequenceId = inMsg->GetSequenceId();
623     messageHdr->errorNo = inMsg->GetErrorNo();
624     messageHdr->dataLen = dataLen;
625     HeaderConverter::ConvertHostToNet(*messageHdr, *messageHdr);
626 
627     if (dataLen == 0) {
628         // For zero dataLen, we don't need to serialize data part
629         return E_OK;
630     }
631     // If dataLen not zero, the TransformFunc of this messageId must exist, the caller's logic guarantee it
632     TransformFunc function;
633     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
634         LOGE("[Proto][Serialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
635         return -E_NOT_REGISTER;
636     }
637     int result = function.serializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
638     if (result != E_OK) {
639         LOGE("[Proto][Serialize] SerializeFunc Fail, result=%d.", result);
640         return -E_SERIALIZE_ERROR;
641     }
642     return E_OK;
643 }
644 
DeSerializeMessage(const SerialBuffer * inBuff,Message * inMsg,bool onlyMsgHeader)645 int ProtocolProto::DeSerializeMessage(const SerialBuffer *inBuff, Message *inMsg, bool onlyMsgHeader)
646 {
647     auto payloadByteLen = inBuff->GetReadOnlyBytesForPayload();
648     // Check version before parse field
649     if (payloadByteLen.second < sizeof(uint16_t)) {
650         return -E_LENGTH_ERROR;
651     }
652     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(payloadByteLen.first)));
653     if (!IsSupportMessageVersion(version)) {
654         LOGE("[Proto][DeSerialize] Version=%" PRIu32 " not support.", version);
655         return -E_VERSION_NOT_SUPPORT;
656     }
657 
658     if (payloadByteLen.second < sizeof(MessageHeader)) {
659         LOGE("[Proto][DeSerialize] Length error, payload length=%" PRIu32 ".", payloadByteLen.second);
660         return -E_LENGTH_ERROR;
661     }
662     auto oriMsgHeader = reinterpret_cast<const MessageHeader *>(payloadByteLen.first);
663     MessageHeader messageHdr;
664     HeaderConverter::ConvertNetToHost(*oriMsgHeader, messageHdr);
665     inMsg->SetVersion(version);
666     inMsg->SetMessageType(messageHdr.messageType);
667     inMsg->SetMessageId(messageHdr.messageId);
668     inMsg->SetSessionId(messageHdr.sessionId);
669     inMsg->SetSequenceId(messageHdr.sequenceId);
670     inMsg->SetErrorNo(messageHdr.errorNo);
671     uint32_t dataLen = payloadByteLen.second - sizeof(MessageHeader);
672     if (dataLen != messageHdr.dataLen) {
673         LOGE("[Proto][DeSerialize] dataLen=%" PRIu32 ", msgDataLen=%" PRIu32 ".", dataLen, messageHdr.dataLen);
674         return -E_LENGTH_ERROR;
675     }
676     // It is better to check FeedbackMessage first and check onlyMsgHeader flag later
677     if (IsFeedbackErrorMessage(messageHdr.errorNo)) {
678         LOGI("[Proto][DeSerialize] Feedback Message with errorNo=%" PRIu32 ".", messageHdr.errorNo);
679         return E_OK;
680     }
681     if (onlyMsgHeader || dataLen == 0) { // Do not need to deserialize data
682         return E_OK;
683     }
684     TransformFunc function;
685     if (GetTransformFunc(inMsg->GetMessageId(), function) != E_OK) {
686         LOGE("[Proto][DeSerialize] Not register, messageId=%" PRIu32 ".", inMsg->GetMessageId());
687         return -E_NOT_REGISTER;
688     }
689     int result = function.deserializeFunc(payloadByteLen.first + sizeof(MessageHeader), dataLen, inMsg);
690     if (result != E_OK) {
691         LOGE("[Proto][DeSerialize] DeserializeFunc Fail, result=%d.", result);
692         return -E_DESERIALIZE_ERROR;
693     }
694     return E_OK;
695 }
696 
IsSupportMessageVersion(uint16_t version)697 bool ProtocolProto::IsSupportMessageVersion(uint16_t version)
698 {
699     return (version == MSG_VERSION_BASE || version == MSG_VERSION_EXT);
700 }
701 
IsFeedbackErrorMessage(uint32_t errorNo)702 bool ProtocolProto::IsFeedbackErrorMessage(uint32_t errorNo)
703 {
704     return (errorNo == E_FEEDBACK_UNKNOWN_MESSAGE || errorNo == E_FEEDBACK_COMMUNICATOR_NOT_FOUND);
705 }
706 
ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t * bytes,uint32_t length)707 int ProtocolProto::ParseCommPhyHeaderCheckMagicAndVersion(const uint8_t *bytes, uint32_t length)
708 {
709     // At least magic and version should exist
710     if (length < sizeof(uint16_t) + sizeof(uint16_t)) {
711         LOGE("[Proto][ParsePhyCheckVer] Length of Bytes Error.");
712         return -E_LENGTH_ERROR;
713     }
714     auto fieldPtr = reinterpret_cast<const uint16_t *>(bytes);
715     uint16_t magic = NetToHost(*fieldPtr++);
716     uint16_t version = NetToHost(*fieldPtr++);
717 
718     if (magic != MAGIC_CODE) {
719         LOGE("[Proto][ParsePhyCheckVer] MagicCode=%" PRIu32 " Error.", magic);
720         return -E_PARSE_FAIL;
721     }
722     if (version != PROTOCOL_VERSION) {
723         LOGE("[Proto][ParsePhyCheckVer] Version=%" PRIu32 " Error.", version);
724         return -E_VERSION_NOT_SUPPORT;
725     }
726     return E_OK;
727 }
728 
ParseCommPhyHeaderCheckField(const std::string & srcTarget,const CommPhyHeader & phyHeader,const uint8_t * bytes,uint32_t length)729 int ProtocolProto::ParseCommPhyHeaderCheckField(const std::string &srcTarget, const CommPhyHeader &phyHeader,
730     const uint8_t *bytes, uint32_t length)
731 {
732     if (phyHeader.packetLen != length) {
733         LOGE("[Proto][ParsePhyCheck] PacketLen=%" PRIu32 " Mismatch length=%" PRIu32 ".", phyHeader.packetLen, length);
734         return -E_PARSE_FAIL;
735     }
736     if (phyHeader.paddingLen > MAX_PADDING_LEN) {
737         LOGE("[Proto][ParsePhyCheck] PaddingLen=%" PRIu32 " Error.", phyHeader.paddingLen);
738         return -E_PARSE_FAIL;
739     }
740     if (sizeof(CommPhyHeader) + phyHeader.paddingLen > phyHeader.packetLen) {
741         LOGE("[Proto][ParsePhyCheck] PaddingLen Add PhyHeader Greater Than PacketLen.");
742         return -E_PARSE_FAIL;
743     }
744     uint64_t sumResult = 0;
745     int errCode = CalculateXorSum(bytes + LENGTH_BEFORE_SUM_RANGE, length - LENGTH_BEFORE_SUM_RANGE, sumResult);
746     if (errCode != E_OK) {
747         LOGE("[Proto][ParsePhyCheck] Calculate Sum Fail.");
748         return -E_SUM_CALCULATE_FAIL;
749     }
750     if (phyHeader.checkSum != sumResult) {
751         LOGE("[Proto][ParsePhyCheck] Sum Mismatch, checkSum=%" PRIu64 ", sumResult=%" PRIu64 ".",
752             ULL(phyHeader.checkSum), ULL(sumResult));
753         return -E_SUM_MISMATCH;
754     }
755     return E_OK;
756 }
757 
ParseCommPhyHeader(const std::string & srcTarget,const uint8_t * bytes,uint32_t length,ParseResult & inResult)758 int ProtocolProto::ParseCommPhyHeader(const std::string &srcTarget, const uint8_t *bytes, uint32_t length,
759     ParseResult &inResult)
760 {
761     int errCode = ParseCommPhyHeaderCheckMagicAndVersion(bytes, length);
762     if (errCode != E_OK) {
763         LOGE("[Proto][ParsePhy] Check Magic And Version Fail.");
764         return errCode;
765     }
766 
767     if (length < sizeof(CommPhyHeader)) {
768         LOGE("[Proto][ParsePhy] Length of Bytes Error.");
769         return -E_PARSE_FAIL;
770     }
771     auto phyHeaderOri = reinterpret_cast<const CommPhyHeader *>(bytes);
772     CommPhyHeader phyHeader;
773     HeaderConverter::ConvertNetToHost(*phyHeaderOri, phyHeader);
774     errCode = ParseCommPhyHeaderCheckField(srcTarget, phyHeader, bytes, length);
775     if (errCode != E_OK) {
776         LOGE("[Proto][ParsePhy] Check Field Fail.");
777         return errCode;
778     }
779 
780     inResult.SetFrameId(phyHeader.frameId);
781     inResult.SetSourceId(phyHeader.sourceId);
782     inResult.SetPacketLen(phyHeader.packetLen);
783     inResult.SetPaddingLen(phyHeader.paddingLen);
784     inResult.SetDbVersion(phyHeader.dbIntVer);
785     if ((phyHeader.packetType & PACKET_TYPE_FRAGMENTED) != 0) {
786         inResult.SetFragmentFlag(true);
787     } // FragmentFlag default is false
788     FrameType frameType = GetFrameType(phyHeader.packetType);
789     if (frameType == FrameType::INVALID_MAX_FRAME_TYPE) {
790         LOGW("[Proto][ParsePhy] Unrecognized frame, pktType=%" PRIu32 ".", phyHeader.packetType);
791         return -E_FRAME_TYPE_NOT_SUPPORT;
792     }
793     inResult.SetFrameTypeInfo(frameType);
794     return E_OK;
795 }
796 
ParseCommPhyOptHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)797 int ProtocolProto::ParseCommPhyOptHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
798 {
799     if (length < sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader)) {
800         LOGE("[Proto][ParsePhyOpt] Length of Bytes Error.");
801         return -E_LENGTH_ERROR;
802     }
803     auto headerOri = reinterpret_cast<const CommPhyOptHeader *>(bytes + sizeof(CommPhyHeader));
804     CommPhyOptHeader phyOptHeader;
805     HeaderConverter::ConvertNetToHost(*headerOri, phyOptHeader);
806 
807     // Check of CommPhyOptHeader field will be done in the procedure of FrameCombiner
808     inResult.SetFrameLen(phyOptHeader.frameLen);
809     inResult.SetFragCount(phyOptHeader.fragCount);
810     inResult.SetFragNo(phyOptHeader.fragNo);
811     return E_OK;
812 }
813 
ParseCommDivergeHeader(const uint8_t * bytes,uint32_t length,ParseResult & inResult)814 int ProtocolProto::ParseCommDivergeHeader(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
815 {
816     // Check version before parse field
817     if (length < sizeof(CommPhyHeader) + sizeof(uint16_t)) {
818         return -E_LENGTH_ERROR;
819     }
820     uint16_t version = NetToHost(*(reinterpret_cast<const uint16_t *>(bytes + sizeof(CommPhyHeader))));
821     if (version != PROTOCOL_VERSION) {
822         LOGE("[Proto][ParseDiverge] Version=%" PRIu16 " not support.", version);
823         return -E_VERSION_NOT_SUPPORT;
824     }
825 
826     if (length < sizeof(CommPhyHeader) + sizeof(CommDivergeHeader)) {
827         LOGE("[Proto][ParseDiverge] Length of Bytes Error.");
828         return -E_PARSE_FAIL;
829     }
830     auto headerOri = reinterpret_cast<const CommDivergeHeader *>(bytes + sizeof(CommPhyHeader));
831     CommDivergeHeader divergeHeader;
832     HeaderConverter::ConvertNetToHost(*headerOri, divergeHeader);
833     if (sizeof(CommPhyHeader) + sizeof(CommDivergeHeader) + divergeHeader.payLoadLen +
834         inResult.GetPaddingLen() != inResult.GetPacketLen()) {
835         LOGE("[Proto][ParseDiverge] Total Length Mismatch.");
836         return -E_PARSE_FAIL;
837     }
838     inResult.SetPayloadLen(divergeHeader.payLoadLen);
839     inResult.SetCommLabel(LabelType(std::begin(divergeHeader.commLabel), std::end(divergeHeader.commLabel)));
840     return E_OK;
841 }
842 
ParseCommLayerPayload(const uint8_t * bytes,uint32_t length,ParseResult & inResult)843 int ProtocolProto::ParseCommLayerPayload(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
844 {
845     if (inResult.GetFrameTypeInfo() == FrameType::COMMUNICATION_LABEL_EXCHANGE_ACK) {
846         int errCode = ParseLabelExchangeAck(bytes, length, inResult);
847         if (errCode != E_OK) {
848             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
849             return errCode;
850         }
851     } else {
852         int errCode = ParseLabelExchange(bytes, length, inResult);
853         if (errCode != E_OK) {
854             LOGE("[Proto][ParseCommPayload] Total Length Mismatch.");
855             return errCode;
856         }
857     }
858     return E_OK;
859 }
860 
ParseLabelExchange(const uint8_t * bytes,uint32_t length,ParseResult & inResult)861 int ProtocolProto::ParseLabelExchange(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
862 {
863     // Check version at very first
864     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
865         return -E_LENGTH_ERROR;
866     }
867     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
868     uint64_t version = NetToHost(*fieldPtr++);
869     if (version != PROTOCOL_VERSION) {
870         LOGE("[Proto][ParseLabel] Version=%" PRIu64 " not support.", ULL(version));
871         return -E_VERSION_NOT_SUPPORT;
872     }
873 
874     // Version, DistinctValue, SequenceId and CommLabelCount field must be exist.
875     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN) {
876         LOGE("[Proto][ParseLabel] Length of Bytes Error.");
877         return -E_LENGTH_ERROR;
878     }
879     uint64_t distinctValue = NetToHost(*fieldPtr++);
880     inResult.SetLabelExchangeDistinctValue(distinctValue);
881     uint64_t sequenceId = NetToHost(*fieldPtr++);
882     inResult.SetLabelExchangeSequenceId(sequenceId);
883     uint64_t commLabelCount = NetToHost(*fieldPtr++);
884     if (length < commLabelCount || (UINT32_MAX / COMM_LABEL_LENGTH) < commLabelCount) {
885         LOGE("[Proto][ParseLabel] commLabelCount=%" PRIu64 " invalid.", ULL(commLabelCount));
886         return -E_PARSE_FAIL;
887     }
888     // commLabelCount is expected to be not very large
889     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN + COMM_LABEL_COUNT_LEN +
890         commLabelCount * COMM_LABEL_LENGTH) {
891         LOGE("[Proto][ParseLabel] Length of Bytes Error, commLabelCount=%" PRIu64, ULL(commLabelCount));
892         return -E_LENGTH_ERROR;
893     }
894 
895     // Get each commLabel
896     std::set<LabelType> commLabels;
897     auto bytePtr = reinterpret_cast<const uint8_t *>(fieldPtr);
898     for (uint64_t i = 0; i < commLabelCount; i++) {
899         // the length is checked just above
900         LabelType commLabel(bytePtr + i * COMM_LABEL_LENGTH, bytePtr + (i + 1) * COMM_LABEL_LENGTH);
901         if (commLabels.count(commLabel) != 0) {
902             LOGW("[Proto][ParseLabel] Duplicate Label Detected, commLabel=%.3s.", VEC_TO_STR(commLabel));
903         } else {
904             commLabels.insert(commLabel);
905         }
906     }
907     inResult.SetLatestCommLabels(commLabels);
908     return E_OK;
909 }
910 
ParseLabelExchangeAck(const uint8_t * bytes,uint32_t length,ParseResult & inResult)911 int ProtocolProto::ParseLabelExchangeAck(const uint8_t *bytes, uint32_t length, ParseResult &inResult)
912 {
913     // Check version at very first
914     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN) {
915         return -E_LENGTH_ERROR;
916     }
917     auto fieldPtr = reinterpret_cast<const uint64_t *>(bytes + sizeof(CommPhyHeader));
918     uint64_t version = NetToHost(*fieldPtr++);
919     if (version != PROTOCOL_VERSION) {
920         LOGE("[Proto][ParseLabelAck] Version=%" PRIu64 " not support.", ULL(version));
921         return -E_VERSION_NOT_SUPPORT;
922     }
923 
924     if (length < sizeof(CommPhyHeader) + LABEL_VER_LEN + DISTINCT_VALUE_LEN + SEQUENCE_ID_LEN) {
925         LOGE("[Proto][ParseLabelAck] Length of Bytes Error.");
926         return -E_LENGTH_ERROR;
927     }
928     uint64_t distinctValue = NetToHost(*fieldPtr++);
929     inResult.SetLabelExchangeDistinctValue(distinctValue);
930     uint64_t sequenceId = NetToHost(*fieldPtr++);
931     inResult.SetLabelExchangeSequenceId(sequenceId);
932     return E_OK;
933 }
934 
935 // Note: framePhyHeader is in network endian
936 // This function aims at calculating and preparing each part of each packets
FrameFragmentation(const uint8_t * splitStartBytes,const FrameFragmentInfo & fragmentInfo,const CommPhyHeader & framePhyHeader,std::vector<std::pair<std::vector<uint8_t>,uint32_t>> & outPieces)937 int ProtocolProto::FrameFragmentation(const uint8_t *splitStartBytes, const FrameFragmentInfo &fragmentInfo,
938     const CommPhyHeader &framePhyHeader, std::vector<std::pair<std::vector<uint8_t>, uint32_t>> &outPieces)
939 {
940     // It can be guaranteed that fragCount >= 2 and also won't be too large
941     if (fragmentInfo.fragCount < MIN_FRAGMENT_COUNT) {
942         return -E_INVALID_ARGS;
943     }
944     outPieces.resize(fragmentInfo.fragCount); // Note: should use resize other than reserve
945     uint32_t quotient = fragmentInfo.splitLength / fragmentInfo.fragCount;
946     uint16_t remainder = fragmentInfo.splitLength % fragmentInfo.fragCount;
947     uint16_t fragNo = 0; // Fragment index start from 0
948     uint32_t byteOffset = 0;
949 
950     for (auto &entry : outPieces) {
951         // subtract 1 for index
952         uint32_t pieceFragLen = (fragNo != fragmentInfo.fragCount - 1) ? quotient : (quotient + remainder);
953         uint32_t alignedFragLen = BYTE_8_ALIGN(pieceFragLen); // Add padding length
954         uint32_t pieceTotalLen = alignedFragLen + sizeof(CommPhyHeader) + sizeof(CommPhyOptHeader);
955 
956         // Since exception is disabled, we have to check the vector size to assure that memory is truly allocated
957         entry.first.resize(pieceTotalLen + fragmentInfo.extendHeadSize); // Note: should use resize other than reserve
958         if (entry.first.size() != (pieceTotalLen + fragmentInfo.extendHeadSize)) {
959             LOGE("[Proto][FrameFrag] Resize failed for length=%" PRIu32, pieceTotalLen);
960             return -E_OUT_OF_MEMORY;
961         }
962 
963         CommPhyHeader pktPhyHeader;
964         HeaderConverter::ConvertNetToHost(framePhyHeader, pktPhyHeader); // Restore to host endian
965 
966         // The sum value need to be recalculated, and the packet is fragmented.
967         // The alignedFragLen is always larger than pieceFragLen
968         FillPhyHeaderLenInfo(pieceTotalLen, 0, PACKET_TYPE_FRAGMENTED, alignedFragLen - pieceFragLen, pktPhyHeader);
969         HeaderConverter::ConvertHostToNet(pktPhyHeader, pktPhyHeader);
970 
971         CommPhyOptHeader pktPhyOptHeader = {static_cast<uint32_t>(fragmentInfo.splitLength + sizeof(CommPhyHeader)),
972             fragmentInfo.fragCount, fragNo};
973         HeaderConverter::ConvertHostToNet(pktPhyOptHeader, pktPhyOptHeader);
974         int err;
975         FragmentPacket packet;
976         uint8_t *ptrPacket = &(entry.first[0]);
977         if (fragmentInfo.extendHeadSize > 0) {
978             packet = {ptrPacket, fragmentInfo.extendHeadSize};
979             err = FillFragmentPacketExtendHead(fragmentInfo.oringinalBytesAddr, fragmentInfo.extendHeadSize, packet);
980             if (err != E_OK) {
981                 return err;
982             }
983             ptrPacket += fragmentInfo.extendHeadSize;
984         }
985         packet = {ptrPacket, static_cast<uint32_t>(entry.first.size()) - fragmentInfo.extendHeadSize};
986         err = FillFragmentPacket(pktPhyHeader, pktPhyOptHeader, splitStartBytes + byteOffset,
987             pieceFragLen, packet);
988         entry.second = fragmentInfo.extendHeadSize;
989         if (err != E_OK) {
990             LOGE("[Proto][FrameFrag] Fill packet fail, fragCount=%" PRIu16 ", fragNo=%" PRIu16, fragmentInfo.fragCount,
991                 fragNo);
992             return err;
993         }
994 
995         fragNo++;
996         byteOffset += pieceFragLen;
997     }
998 
999     return E_OK;
1000 }
1001 
FillFragmentPacketExtendHead(uint8_t * headBytesAddr,uint32_t headLen,FragmentPacket & outPacket)1002 int ProtocolProto::FillFragmentPacketExtendHead(uint8_t *headBytesAddr, uint32_t headLen, FragmentPacket &outPacket)
1003 {
1004     if (headLen > outPacket.leftLength) {
1005         LOGE("[Proto][FrameFrag] headLen less than leftLength");
1006         return -E_INVALID_ARGS;
1007     }
1008     errno_t retCode = memcpy_s(outPacket.ptrPacket, outPacket.leftLength, headBytesAddr, headLen);
1009     if (retCode != EOK) {
1010         LOGE("memcpy error:%d", retCode);
1011         return -E_SECUREC_ERROR;
1012     }
1013     return E_OK;
1014 }
1015 
1016 // Note: phyHeader and phyOptHeader is in network endian
FillFragmentPacket(const CommPhyHeader & phyHeader,const CommPhyOptHeader & phyOptHeader,const uint8_t * fragBytes,uint32_t fragLen,FragmentPacket & outPacket)1017 int ProtocolProto::FillFragmentPacket(const CommPhyHeader &phyHeader, const CommPhyOptHeader &phyOptHeader,
1018     const uint8_t *fragBytes, uint32_t fragLen, FragmentPacket &outPacket)
1019 {
1020     if (outPacket.leftLength == 0) {
1021         return -E_INVALID_ARGS;
1022     }
1023     uint8_t *ptrPacket = outPacket.ptrPacket;
1024     uint32_t leftLength = outPacket.leftLength;
1025 
1026     // leftLength is guaranteed to be no smaller than the sum of phyHeaderLen + phyOptHeaderLen + fragLen
1027     // So, there will be no redundant check during subtraction
1028     errno_t retCode = memcpy_s(ptrPacket, leftLength, &phyHeader, sizeof(CommPhyHeader));
1029     if (retCode != EOK) {
1030         return -E_SECUREC_ERROR;
1031     }
1032     ptrPacket += sizeof(CommPhyHeader);
1033     leftLength -= sizeof(CommPhyHeader);
1034 
1035     retCode = memcpy_s(ptrPacket, leftLength, &phyOptHeader, sizeof(CommPhyOptHeader));
1036     if (retCode != EOK) {
1037         return -E_SECUREC_ERROR;
1038     }
1039     ptrPacket += sizeof(CommPhyOptHeader);
1040     leftLength -= sizeof(CommPhyOptHeader);
1041 
1042     retCode = memcpy_s(ptrPacket, leftLength, fragBytes, fragLen);
1043     if (retCode != EOK) {
1044         return -E_SECUREC_ERROR;
1045     }
1046 
1047     // Calculate sum and set sum field
1048     uint64_t sumResult = 0;
1049     int errCode  = CalculateXorSum(outPacket.ptrPacket + LENGTH_BEFORE_SUM_RANGE,
1050         outPacket.leftLength - LENGTH_BEFORE_SUM_RANGE, sumResult);
1051     if (errCode != E_OK) {
1052         return -E_SUM_CALCULATE_FAIL;
1053     }
1054     auto ptrPhyHeader = reinterpret_cast<CommPhyHeader *>(outPacket.ptrPacket);
1055     if (ptrPhyHeader == nullptr) {
1056         return -E_INVALID_ARGS;
1057     }
1058     ptrPhyHeader->checkSum = HostToNet(sumResult);
1059 
1060     return E_OK;
1061 }
1062 
GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> & extendHandle,uint32_t & headSize)1063 int ProtocolProto::GetExtendHeadDataSize(std::shared_ptr<ExtendHeaderHandle> &extendHandle, uint32_t &headSize)
1064 {
1065     if (extendHandle != nullptr) {
1066         DBStatus status = extendHandle->GetHeadDataSize(headSize);
1067         if (status != DBStatus::OK) {
1068             LOGI("[Proto][ToSerial] get head data size failed,not permit to send");
1069             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1070         }
1071         if (headSize > SerialBuffer::MAX_EXTEND_HEAD_LENGTH || headSize != BYTE_8_ALIGN(headSize)) {
1072             LOGI("[Proto][ToSerial] head data size is larger than 512 or not 8 byte align");
1073             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1074         }
1075         return E_OK;
1076     }
1077     return E_OK;
1078 }
1079 
FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> & extendHandle,SerialBuffer * buffer,uint32_t headSize)1080 int ProtocolProto::FillExtendHeadDataIfNeed(std::shared_ptr<ExtendHeaderHandle> &extendHandle, SerialBuffer *buffer,
1081     uint32_t headSize)
1082 {
1083     if (extendHandle != nullptr && headSize > 0) {
1084         if (buffer == nullptr) {
1085             return -E_INVALID_ARGS;
1086         }
1087         DBStatus status = extendHandle->FillHeadData(buffer->GetOringinalAddr(), headSize,
1088             buffer->GetSize() + headSize);
1089         if (status != DBStatus::OK) {
1090             LOGI("[Proto][ToSerial] fill head data failed");
1091             return -E_FEEDBACK_COMMUNICATOR_NOT_FOUND;
1092         }
1093     }
1094     return E_OK;
1095 }
1096 
GetTransformFunc(uint32_t messageId,DistributedDB::TransformFunc & function)1097 int ProtocolProto::GetTransformFunc(uint32_t messageId, DistributedDB::TransformFunc &function)
1098 {
1099     std::shared_lock<std::shared_mutex> autoLock(msgIdMutex_);
1100     const auto &entry = msgIdMapFunc_.find(messageId);
1101     if (entry == msgIdMapFunc_.end()) {
1102         return -E_NOT_REGISTER;
1103     }
1104     function = entry->second;
1105     return E_OK;
1106 }
1107 } // namespace DistributedDB
1108