1 /*
2 * Copyright (c) 2020, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "common/encoding.hpp"
30 #include "common/instance.hpp"
31 #include "common/message.hpp"
32 #include "common/random.hpp"
33 #include "net/checksum.hpp"
34 #include "net/icmp6.hpp"
35 #include "net/ip4_types.hpp"
36 #include "net/udp6.hpp"
37
38 #include "test_platform.h"
39 #include "test_util.hpp"
40
41 namespace ot {
42
CalculateChecksum(const void * aBuffer,uint16_t aLength)43 uint16_t CalculateChecksum(const void *aBuffer, uint16_t aLength)
44 {
45 // Calculates checksum over a given buffer data. This implementation
46 // is inspired by the algorithm from RFC-1071.
47
48 uint32_t sum = 0;
49 const uint8_t *bytes = reinterpret_cast<const uint8_t *>(aBuffer);
50
51 while (aLength >= sizeof(uint16_t))
52 {
53 sum += Encoding::BigEndian::ReadUint16(bytes);
54 bytes += sizeof(uint16_t);
55 aLength -= sizeof(uint16_t);
56 }
57
58 if (aLength > 0)
59 {
60 sum += (static_cast<uint32_t>(bytes[0]) << 8);
61 }
62
63 // Fold 32-bit sum to 16 bits.
64
65 while (sum >> 16)
66 {
67 sum = (sum & 0xffff) + (sum >> 16);
68 }
69
70 return static_cast<uint16_t>(sum & 0xffff);
71 }
72
CalculateChecksum(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)73 uint16_t CalculateChecksum(const Ip6::Address &aSource,
74 const Ip6::Address &aDestination,
75 uint8_t aIpProto,
76 const Message & aMessage)
77 {
78 // This method calculates the checksum over an IPv6 message.
79 constexpr uint16_t kMaxPayload = 1024;
80
81 OT_TOOL_PACKED_BEGIN
82 struct PseudoHeader
83 {
84 Ip6::Address mSource;
85 Ip6::Address mDestination;
86 uint32_t mPayloadLength;
87 uint32_t mProtocol;
88 } OT_TOOL_PACKED_END;
89
90 OT_TOOL_PACKED_BEGIN
91 struct ChecksumData
92 {
93 PseudoHeader mPseudoHeader;
94 uint8_t mPayload[kMaxPayload];
95 } OT_TOOL_PACKED_END;
96
97 ChecksumData data;
98 uint16_t payloadLength;
99
100 payloadLength = aMessage.GetLength() - aMessage.GetOffset();
101
102 data.mPseudoHeader.mSource = aSource;
103 data.mPseudoHeader.mDestination = aDestination;
104 data.mPseudoHeader.mProtocol = Encoding::BigEndian::HostSwap32(aIpProto);
105 data.mPseudoHeader.mPayloadLength = Encoding::BigEndian::HostSwap32(payloadLength);
106
107 SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
108
109 return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
110 }
111
CalculateChecksum(const Ip4::Address & aSource,const Ip4::Address & aDestination,uint8_t aIpProto,const Message & aMessage)112 uint16_t CalculateChecksum(const Ip4::Address &aSource,
113 const Ip4::Address &aDestination,
114 uint8_t aIpProto,
115 const Message & aMessage)
116 {
117 // This method calculates the checksum over an IPv4 message.
118 constexpr uint16_t kMaxPayload = 1024;
119
120 OT_TOOL_PACKED_BEGIN
121 struct PseudoHeader
122 {
123 Ip4::Address mSource;
124 Ip4::Address mDestination;
125 uint16_t mPayloadLength;
126 uint16_t mProtocol;
127 } OT_TOOL_PACKED_END;
128
129 OT_TOOL_PACKED_BEGIN
130 struct ChecksumData
131 {
132 PseudoHeader mPseudoHeader;
133 uint8_t mPayload[kMaxPayload];
134 } OT_TOOL_PACKED_END;
135
136 ChecksumData data;
137 uint16_t payloadLength;
138
139 payloadLength = aMessage.GetLength() - aMessage.GetOffset();
140
141 data.mPseudoHeader.mSource = aSource;
142 data.mPseudoHeader.mDestination = aDestination;
143 data.mPseudoHeader.mProtocol = Encoding::BigEndian::HostSwap16(aIpProto);
144 data.mPseudoHeader.mPayloadLength = Encoding::BigEndian::HostSwap16(payloadLength);
145
146 SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
147
148 return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
149 }
150
CorruptMessage(Message & aMessage)151 void CorruptMessage(Message &aMessage)
152 {
153 // Change a random bit in the message.
154
155 uint16_t byteOffset;
156 uint8_t bitOffset;
157 uint8_t byte;
158
159 byteOffset = Random::NonCrypto::GetUint16InRange(0, aMessage.GetLength());
160
161 SuccessOrQuit(aMessage.Read(byteOffset, byte));
162
163 bitOffset = Random::NonCrypto::GetUint8InRange(0, CHAR_BIT);
164
165 byte ^= (1 << bitOffset);
166
167 aMessage.Write(byteOffset, byte);
168 }
169
TestUdpMessageChecksum(void)170 void TestUdpMessageChecksum(void)
171 {
172 constexpr uint16_t kMinSize = sizeof(Ip6::Udp::Header);
173 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
174
175 const char *kSourceAddress = "fd00:1122:3344:5566:7788:99aa:bbcc:ddee";
176 const char *kDestAddress = "fd01:2345:6789:abcd:ef01:2345:6789:abcd";
177
178 Instance *instance = static_cast<Instance *>(testInitInstance());
179
180 VerifyOrQuit(instance != nullptr);
181
182 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
183 {
184 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Udp::Header));
185 Ip6::Udp::Header udpHeader;
186 Ip6::MessageInfo messageInfo;
187
188 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
189 SuccessOrQuit(message->SetLength(size));
190
191 // Write UDP header with a random payload.
192
193 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&udpHeader), sizeof(udpHeader));
194 udpHeader.SetChecksum(0);
195 message->Write(0, udpHeader);
196
197 if (size > sizeof(udpHeader))
198 {
199 uint8_t buffer[kMaxSize];
200 uint16_t payloadSize = size - sizeof(udpHeader);
201
202 Random::NonCrypto::FillBuffer(buffer, payloadSize);
203 message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
204 }
205
206 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
207 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
208
209 // Verify that the `Checksum::UpdateMessageChecksum` correctly
210 // updates the checksum field in the UDP header on the message.
211
212 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp);
213
214 SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
215 VerifyOrQuit(udpHeader.GetChecksum() != 0);
216
217 // Verify that the calculated UDP checksum is valid.
218
219 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp,
220 *message) == 0xffff);
221
222 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
223 // message and its calculated checksum.
224
225 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp));
226
227 // Corrupt the message and verify that checksum is no longer accepted.
228
229 CorruptMessage(*message);
230
231 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp) != kErrorNone,
232 "Checksum passed on corrupted message");
233
234 message->Free();
235 }
236 }
237
TestIcmp6MessageChecksum(void)238 void TestIcmp6MessageChecksum(void)
239 {
240 constexpr uint16_t kMinSize = sizeof(Ip6::Icmp::Header);
241 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
242
243 const char *kSourceAddress = "fd00:feef:dccd:baab:9889:7667:5444:3223";
244 const char *kDestAddress = "fd01:abab:beef:cafe:1234:5678:9abc:0";
245
246 Instance *instance = static_cast<Instance *>(testInitInstance());
247
248 VerifyOrQuit(instance != nullptr, "Null OpenThread instance\n");
249
250 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
251 {
252 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Icmp::Header));
253 Ip6::Icmp::Header icmp6Header;
254 Ip6::MessageInfo messageInfo;
255
256 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
257 SuccessOrQuit(message->SetLength(size));
258
259 // Write ICMP6 header with a random payload.
260
261 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&icmp6Header), sizeof(icmp6Header));
262 icmp6Header.SetChecksum(0);
263 message->Write(0, icmp6Header);
264
265 if (size > sizeof(icmp6Header))
266 {
267 uint8_t buffer[kMaxSize];
268 uint16_t payloadSize = size - sizeof(icmp6Header);
269
270 Random::NonCrypto::FillBuffer(buffer, payloadSize);
271 message->WriteBytes(sizeof(icmp6Header), &buffer[0], payloadSize);
272 }
273
274 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
275 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
276
277 // Verify that the `Checksum::UpdateMessageChecksum` correctly
278 // updates the checksum field in the ICMP6 header on the message.
279
280 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(),
281 Ip6::kProtoIcmp6);
282
283 SuccessOrQuit(message->Read(message->GetOffset(), icmp6Header));
284 VerifyOrQuit(icmp6Header.GetChecksum() != 0, "Failed to update checksum");
285
286 // Verify that the calculated ICMP6 checksum is valid.
287
288 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoIcmp6,
289 *message) == 0xffff);
290
291 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
292 // message and its calculated checksum.
293
294 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6));
295
296 // Corrupt the message and verify that checksum is no longer accepted.
297
298 CorruptMessage(*message);
299
300 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6) != kErrorNone,
301 "Checksum passed on corrupted message");
302
303 message->Free();
304 }
305 }
306
TestTcp4MessageChecksum(void)307 void TestTcp4MessageChecksum(void)
308 {
309 constexpr size_t kMinSize = sizeof(Ip4::Tcp::Header);
310 constexpr size_t kMaxSize = kBufferSize * 3 + 24;
311
312 const char *kSourceAddress = "12.34.56.78";
313 const char *kDestAddress = "87.65.43.21";
314
315 Ip4::Address sourceAddress;
316 Ip4::Address destAddress;
317
318 Instance *instance = static_cast<Instance *>(testInitInstance());
319
320 VerifyOrQuit(instance != nullptr);
321
322 SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
323 SuccessOrQuit(destAddress.FromString(kDestAddress));
324
325 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
326 {
327 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Tcp::Header));
328 Ip4::Tcp::Header tcpHeader;
329
330 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
331 SuccessOrQuit(message->SetLength(size));
332
333 // Write TCP header with a random payload.
334
335 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&tcpHeader), sizeof(tcpHeader));
336 message->Write(0, tcpHeader);
337
338 if (size > sizeof(tcpHeader))
339 {
340 uint8_t buffer[kMaxSize];
341 uint16_t payloadSize = size - sizeof(tcpHeader);
342
343 Random::NonCrypto::FillBuffer(buffer, payloadSize);
344 message->WriteBytes(sizeof(tcpHeader), &buffer[0], payloadSize);
345 }
346
347 // Verify that the `Checksum::UpdateMessageChecksum` correctly
348 // updates the checksum field in the UDP header on the message.
349
350 Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoTcp);
351
352 SuccessOrQuit(message->Read(message->GetOffset(), tcpHeader));
353 VerifyOrQuit(tcpHeader.GetChecksum() != 0);
354
355 // Verify that the calculated UDP checksum is valid.
356
357 VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoTcp, *message) == 0xffff);
358 message->Free();
359 }
360 }
361
TestUdp4MessageChecksum(void)362 void TestUdp4MessageChecksum(void)
363 {
364 constexpr uint16_t kMinSize = sizeof(Ip4::Udp::Header);
365 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
366
367 const char *kSourceAddress = "12.34.56.78";
368 const char *kDestAddress = "87.65.43.21";
369
370 Ip4::Address sourceAddress;
371 Ip4::Address destAddress;
372
373 Instance *instance = static_cast<Instance *>(testInitInstance());
374
375 SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
376 SuccessOrQuit(destAddress.FromString(kDestAddress));
377
378 VerifyOrQuit(instance != nullptr);
379
380 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
381 {
382 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Udp::Header));
383 Ip4::Udp::Header udpHeader;
384
385 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
386 SuccessOrQuit(message->SetLength(size));
387
388 // Write UDP header with a random payload.
389
390 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&udpHeader), sizeof(udpHeader));
391 udpHeader.SetChecksum(0);
392 message->Write(0, udpHeader);
393
394 if (size > sizeof(udpHeader))
395 {
396 uint8_t buffer[kMaxSize];
397 uint16_t payloadSize = size - sizeof(udpHeader);
398
399 Random::NonCrypto::FillBuffer(buffer, payloadSize);
400 message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
401 }
402
403 // Verify that the `Checksum::UpdateMessageChecksum` correctly
404 // updates the checksum field in the UDP header on the message.
405
406 Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoUdp);
407
408 SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
409 VerifyOrQuit(udpHeader.GetChecksum() != 0);
410
411 // Verify that the calculated UDP checksum is valid.
412
413 VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoUdp, *message) == 0xffff);
414 message->Free();
415 }
416 }
417
TestIcmp4MessageChecksum(void)418 void TestIcmp4MessageChecksum(void)
419 {
420 // A captured ICMP echo request (ping) message. Checksum field is set to zero.
421 const uint8_t kExampleIcmpMessage[] = "\x08\x00\x00\x00\x67\x2e\x00\x00\x62\xaf\xf1\x61\x00\x04\xfc\x24"
422 "\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17"
423 "\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27"
424 "\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37";
425 uint16_t kChecksumForExampleMessage = 0x5594;
426 Instance *instance = static_cast<Instance *>(testInitInstance());
427 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(kExampleIcmpMessage));
428
429 Ip4::Address source;
430 Ip4::Address dest;
431
432 uint8_t mPayload[sizeof(kExampleIcmpMessage)];
433 Ip4::Icmp::Header icmpHeader;
434
435 SuccessOrQuit(message->AppendBytes(kExampleIcmpMessage, sizeof(kExampleIcmpMessage)));
436
437 // Random IPv4 address, ICMP message checksum does not include a presudo header like TCP and UDP.
438 source.mFields.m32 = 0x12345678;
439 dest.mFields.m32 = 0x87654321;
440
441 Checksum::UpdateMessageChecksum(*message, source, dest, Ip4::kProtoIcmp);
442
443 SuccessOrQuit(message->Read(0, icmpHeader));
444 VerifyOrQuit(icmpHeader.GetChecksum() == kChecksumForExampleMessage);
445
446 SuccessOrQuit(message->Read(message->GetOffset(), mPayload, sizeof(mPayload)));
447 VerifyOrQuit(CalculateChecksum(mPayload, sizeof(mPayload)) == 0xffff);
448 }
449
450 class ChecksumTester
451 {
452 public:
TestExampleVector(void)453 static void TestExampleVector(void)
454 {
455 // Example from RFC 1071
456 const uint8_t kTestVector[] = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7};
457 const uint16_t kTestVectorChecksum = 0xddf2;
458
459 Checksum checksum;
460
461 VerifyOrQuit(checksum.GetValue() == 0, "Incorrect initial checksum value");
462
463 checksum.AddData(kTestVector, sizeof(kTestVector));
464 VerifyOrQuit(checksum.GetValue() == kTestVectorChecksum);
465 VerifyOrQuit(checksum.GetValue() == CalculateChecksum(kTestVector, sizeof(kTestVector)), );
466 }
467 };
468
469 } // namespace ot
470
main(void)471 int main(void)
472 {
473 ot::ChecksumTester::TestExampleVector();
474 ot::TestUdpMessageChecksum();
475 ot::TestIcmp6MessageChecksum();
476 ot::TestTcp4MessageChecksum();
477 ot::TestUdp4MessageChecksum();
478 ot::TestIcmp4MessageChecksum();
479 printf("All tests passed\n");
480 return 0;
481 }
482