1// -*- c++ -*- 2// Protocol Buffers - Google's data interchange format 3// Copyright 2008 Google Inc. All rights reserved. 4// 5// Use of this source code is governed by a BSD-style 6// license that can be found in the LICENSE file or at 7// https://developers.google.com/open-source/licenses/bsd 8 9// Author: kenton@google.com (Kenton Varda) 10// Based on original Protocol Buffers design by 11// Sanjay Ghemawat, Jeff Dean, and others. 12// 13// This file needs to be included as .inc as it depends on certain macros being 14// defined prior to its inclusion. 15 16#include <fcntl.h> 17#include <sys/stat.h> 18#include <sys/types.h> 19 20#include <cmath> 21#include <cstddef> 22#include <cstdint> 23#include <limits> 24#include <string> 25 26#ifndef _MSC_VER 27#include <unistd.h> 28#endif 29#include <fstream> 30#include <sstream> 31 32#include "google/protobuf/testing/file.h" 33#include "google/protobuf/testing/file.h" 34#include "google/protobuf/descriptor.pb.h" 35#include <gmock/gmock.h> 36#include "google/protobuf/testing/googletest.h" 37#include <gtest/gtest.h> 38#include "absl/log/absl_check.h" 39#include "absl/log/scoped_mock_log.h" 40#include "absl/strings/cord.h" 41#include "absl/strings/substitute.h" 42#include "google/protobuf/arena.h" 43#include "google/protobuf/descriptor.h" 44#include "google/protobuf/dynamic_message.h" 45#include "google/protobuf/generated_message_reflection.h" 46#include "google/protobuf/generated_message_tctable_impl.h" 47#include "google/protobuf/io/coded_stream.h" 48#include "google/protobuf/io/io_win32.h" 49#include "google/protobuf/io/zero_copy_stream.h" 50#include "google/protobuf/io/zero_copy_stream_impl.h" 51#include "google/protobuf/message.h" 52#include "google/protobuf/reflection_ops.h" 53#include "google/protobuf/test_util2.h" 54 55 56// Must be included last. 57#include "google/protobuf/port_def.inc" 58 59namespace google { 60namespace protobuf { 61 62#if defined(_WIN32) 63// DO NOT include <io.h>, instead create functions in io_win32.{h,cc} and import 64// them like we do below. 65using google::protobuf::io::win32::close; 66using google::protobuf::io::win32::open; 67#endif 68 69#ifndef O_BINARY 70#ifdef _O_BINARY 71#define O_BINARY _O_BINARY 72#else 73#define O_BINARY 0 // If this isn't defined, the platform doesn't need it. 74#endif 75#endif 76 77namespace { 78 79UNITTEST::NestedTestAllTypes InitNestedProto(int depth) { 80 UNITTEST::NestedTestAllTypes p; 81 auto* child = p.mutable_child(); 82 for (int i = 0; i < depth; i++) { 83 child->mutable_payload()->set_optional_int32(i); 84 child = child->mutable_child(); 85 } 86 // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1 87 child->mutable_payload()->set_optional_int32(-1); 88 return p; 89} 90} // namespace 91 92TEST(MESSAGE_TEST_NAME, SerializeHelpers) { 93 // TODO: Test more helpers? They're all two-liners so it seems 94 // like a waste of time. 95 96 UNITTEST::TestAllTypes message; 97 TestUtil::SetAllFields(&message); 98 std::stringstream stream; 99 100 std::string str1("foo"); 101 std::string str2("bar"); 102 103 EXPECT_TRUE(message.SerializeToString(&str1)); 104 EXPECT_TRUE(message.AppendToString(&str2)); 105 EXPECT_TRUE(message.SerializeToOstream(&stream)); 106 107 EXPECT_EQ(str1.size() + 3, str2.size()); 108 EXPECT_EQ("bar", str2.substr(0, 3)); 109 // Don't use EXPECT_EQ because we don't want to dump raw binary data to 110 // stdout. 111 EXPECT_TRUE(str2.substr(3) == str1); 112 113 // GCC gives some sort of error if we try to just do stream.str() == str1. 114 std::string temp = stream.str(); 115 EXPECT_TRUE(temp == str1); 116 117 EXPECT_TRUE(message.SerializeAsString() == str1); 118 119} 120 121TEST(MESSAGE_TEST_NAME, RoundTrip) { 122 UNITTEST::TestAllTypes message; 123 TestUtil::SetAllFields(&message); 124 TestUtil::ExpectAllFieldsSet(message); 125 126 UNITTEST::TestAllTypes copied, merged, parsed; 127 copied = message; 128 TestUtil::ExpectAllFieldsSet(copied); 129 130 merged.MergeFrom(message); 131 TestUtil::ExpectAllFieldsSet(merged); 132 133 std::string data; 134 ASSERT_TRUE(message.SerializeToString(&data)); 135 ASSERT_TRUE(parsed.ParseFromString(data)); 136 TestUtil::ExpectAllFieldsSet(parsed); 137} 138 139TEST(MESSAGE_TEST_NAME, SerializeToBrokenOstream) { 140 std::ofstream out; 141 UNITTEST::TestAllTypes message; 142 message.set_optional_int32(123); 143 144 EXPECT_FALSE(message.SerializeToOstream(&out)); 145} 146 147TEST(MESSAGE_TEST_NAME, ParseFromFileDescriptor) { 148 std::string filename = absl::StrCat(TestTempDir(), "/golden_message"); 149 UNITTEST::TestAllTypes expected_message; 150 TestUtil::SetAllFields(&expected_message); 151 ABSL_CHECK_OK(File::SetContents( 152 filename, expected_message.SerializeAsString(), true)); 153 154 int file = open(filename.c_str(), O_RDONLY | O_BINARY); 155 ASSERT_GE(file, 0); 156 157 UNITTEST::TestAllTypes message; 158 EXPECT_TRUE(message.ParseFromFileDescriptor(file)); 159 TestUtil::ExpectAllFieldsSet(message); 160 161 EXPECT_GE(close(file), 0); 162} 163 164TEST(MESSAGE_TEST_NAME, ParsePackedFromFileDescriptor) { 165 std::string filename = absl::StrCat(TestTempDir(), "/golden_message"); 166 UNITTEST::TestPackedTypes expected_message; 167 TestUtil::SetPackedFields(&expected_message); 168 ABSL_CHECK_OK(File::SetContents( 169 filename, expected_message.SerializeAsString(), true)); 170 171 int file = open(filename.c_str(), O_RDONLY | O_BINARY); 172 ASSERT_GE(file, 0); 173 174 UNITTEST::TestPackedTypes message; 175 EXPECT_TRUE(message.ParseFromFileDescriptor(file)); 176 TestUtil::ExpectPackedFieldsSet(message); 177 178 EXPECT_GE(close(file), 0); 179} 180 181TEST(MESSAGE_TEST_NAME, ParseHelpers) { 182 // TODO: Test more helpers? They're all two-liners so it seems 183 // like a waste of time. 184 std::string data; 185 186 { 187 // Set up. 188 UNITTEST::TestAllTypes message; 189 TestUtil::SetAllFields(&message); 190 message.SerializeToString(&data); 191 } 192 193 { 194 // Test ParseFromString. 195 UNITTEST::TestAllTypes message; 196 EXPECT_TRUE(message.ParseFromString(data)); 197 TestUtil::ExpectAllFieldsSet(message); 198 } 199 200 { 201 // Test ParseFromIstream. 202 UNITTEST::TestAllTypes message; 203 std::stringstream stream(data); 204 EXPECT_TRUE(message.ParseFromIstream(&stream)); 205 EXPECT_TRUE(stream.eof()); 206 TestUtil::ExpectAllFieldsSet(message); 207 } 208 209 { 210 // Test ParseFromBoundedZeroCopyStream. 211 std::string data_with_junk(data); 212 data_with_junk.append("some junk on the end"); 213 io::ArrayInputStream stream(data_with_junk.data(), data_with_junk.size()); 214 UNITTEST::TestAllTypes message; 215 EXPECT_TRUE(message.ParseFromBoundedZeroCopyStream(&stream, data.size())); 216 TestUtil::ExpectAllFieldsSet(message); 217 } 218 219 { 220 // Test that ParseFromBoundedZeroCopyStream fails (but doesn't crash) if 221 // EOF is reached before the expected number of bytes. 222 io::ArrayInputStream stream(data.data(), data.size()); 223 UNITTEST::TestAllTypes message; 224 EXPECT_FALSE( 225 message.ParseFromBoundedZeroCopyStream(&stream, data.size() + 1)); 226 } 227 228 // Test bytes cord 229 { 230 UNITTEST::TestCord cord_message; 231 cord_message.set_optional_bytes_cord("bytes_cord"); 232 EXPECT_TRUE(cord_message.SerializeToString(&data)); 233 EXPECT_TRUE(cord_message.SerializeAsString() == data); 234 } 235 { 236 UNITTEST::TestCord cord_message; 237 EXPECT_TRUE(cord_message.ParseFromString(data)); 238 EXPECT_EQ("bytes_cord", cord_message.optional_bytes_cord()); 239 } 240} 241 242TEST(MESSAGE_TEST_NAME, ParseFailsIfNotInitialized) { 243 UNITTEST::TestRequired message; 244 245 { 246 absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected); 247 EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::StrCat( 248 "Can't parse message of type \"", UNITTEST_PACKAGE_NAME, 249 ".TestRequired\" because it is missing required fields: a, b, c"))); 250 log.StartCapturingLogs(); 251 EXPECT_FALSE(message.ParseFromString("")); 252 } 253} 254 255TEST(MESSAGE_TEST_NAME, ParseFailsIfSubmessageNotInitialized) { 256 UNITTEST::TestRequiredForeign source, message; 257 source.mutable_optional_message()->set_dummy2(100); 258 std::string serialized = source.SerializePartialAsString(); 259 260 EXPECT_TRUE(message.ParsePartialFromString(serialized)); 261 EXPECT_FALSE(message.IsInitialized()); 262 263 { 264 absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected); 265 EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::StrCat( 266 "Can't parse message of type \"", UNITTEST_PACKAGE_NAME, 267 ".TestRequiredForeign\" because it is missing required fields: " 268 "optional_message.a, optional_message.b, optional_message.c"))); 269 log.StartCapturingLogs(); 270 EXPECT_FALSE(message.ParseFromString(source.SerializePartialAsString())); 271 } 272} 273 274TEST(MESSAGE_TEST_NAME, ParseFailsIfExtensionNotInitialized) { 275 UNITTEST::TestChildExtension source, message; 276 auto* r = source.mutable_optional_extension()->MutableExtension( 277 UNITTEST::TestRequired::single); 278 r->set_dummy2(100); 279 std::string serialized = source.SerializePartialAsString(); 280 281 EXPECT_TRUE(message.ParsePartialFromString(serialized)); 282 EXPECT_FALSE(message.IsInitialized()); 283 284{ 285 absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected); 286 EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::Substitute( 287 "Can't parse message of type \"$0.TestChildExtension\" " 288 "because it is missing required fields: " 289 "optional_extension.($0.TestRequired.single).a, " 290 "optional_extension.($0.TestRequired.single).b, " 291 "optional_extension.($0.TestRequired.single).c", 292 UNITTEST_PACKAGE_NAME))); 293 log.StartCapturingLogs(); 294 EXPECT_FALSE(message.ParseFromString(source.SerializePartialAsString())); 295 } 296} 297 298TEST(MESSAGE_TEST_NAME, MergeFromUninitialized) { 299 UNITTEST::TestNestedRequiredForeign o, p, q; 300 UNITTEST::TestNestedRequiredForeign* child = o.mutable_child(); 301 constexpr int kDepth = 2; 302 for (int i = 0; i < kDepth; i++) { 303 child->set_dummy(i); 304 child = child->mutable_child(); 305 } 306 UNITTEST::TestRequiredForeign* payload = child->mutable_payload(); 307 payload->mutable_optional_message()->set_a(1); 308 payload->mutable_optional_message()->set_dummy2(100); 309 payload->mutable_optional_message()->set_dummy4(200); 310 ASSERT_TRUE(p.ParsePartialFromString(o.SerializePartialAsString())); 311 312 q.mutable_child()->set_dummy(500); 313 q = p; 314 q.ParsePartialFromString(q.SerializePartialAsString()); 315 EXPECT_TRUE(TestUtil::EqualsToSerialized(q, o.SerializePartialAsString())); 316 EXPECT_TRUE(TestUtil::EqualsToSerialized(q, p.SerializePartialAsString())); 317} 318 319TEST(MESSAGE_TEST_NAME, UninitializedAndTooDeep) { 320 UNITTEST::TestRequiredForeign original; 321 original.mutable_optional_message()->set_a(1); 322 original.mutable_optional_lazy_message() 323 ->mutable_child() 324 ->mutable_payload() 325 ->set_optional_int64(0); 326 327 std::string data; 328 ASSERT_TRUE(original.SerializePartialToString(&data)); 329 330 UNITTEST::TestRequiredForeign pass; 331 ASSERT_TRUE(pass.ParsePartialFromString(data)); 332 ASSERT_FALSE(pass.IsInitialized()); 333 334 io::ArrayInputStream array_stream(data.data(), data.size()); 335 io::CodedInputStream input_stream(&array_stream); 336 input_stream.SetRecursionLimit(2); 337 338 UNITTEST::TestRequiredForeign fail; 339 EXPECT_FALSE(fail.ParsePartialFromCodedStream(&input_stream)); 340 341 UNITTEST::TestRequiredForeign fail_uninitialized; 342 EXPECT_FALSE(fail_uninitialized.ParseFromString(data)); 343} 344 345TEST(MESSAGE_TEST_NAME, ExplicitLazyExceedRecursionLimit) { 346 UNITTEST::NestedTestAllTypes original, parsed; 347 // Build proto with recursion depth of 3. 348 original.mutable_lazy_child() 349 ->mutable_child() 350 ->mutable_payload() 351 ->set_optional_int32(-1); 352 std::string serialized; 353 ASSERT_TRUE(original.SerializeToString(&serialized)); 354 355 // User annotated LazyField ([lazy = true]) is eagerly verified and should 356 // catch the recursion limit violation. 357 io::ArrayInputStream array_stream(serialized.data(), serialized.size()); 358 io::CodedInputStream input_stream(&array_stream); 359 input_stream.SetRecursionLimit(2); 360 EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream)); 361 362 // Lazy read results in parsing error which can be verified by not having 363 // expected value. 364 EXPECT_NE(parsed.lazy_child().child().payload().optional_int32(), -1); 365} 366 367TEST(MESSAGE_TEST_NAME, NestedLazyRecursionLimit) { 368 UNITTEST::NestedTestAllTypes original, parsed; 369 original.mutable_lazy_child() 370 ->mutable_lazy_child() 371 ->mutable_lazy_child() 372 ->mutable_payload() 373 ->set_optional_int32(-1); 374 std::string serialized; 375 ASSERT_TRUE(original.SerializeToString(&serialized)); 376 ASSERT_TRUE(parsed.ParseFromString(serialized)); 377 378 io::ArrayInputStream array_stream(serialized.data(), serialized.size()); 379 io::CodedInputStream input_stream(&array_stream); 380 input_stream.SetRecursionLimit(2); 381 EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream)); 382 EXPECT_TRUE(parsed.has_lazy_child()); 383 EXPECT_TRUE(parsed.lazy_child().has_lazy_child()); 384 EXPECT_TRUE(parsed.lazy_child().lazy_child().has_lazy_child()); 385 EXPECT_FALSE(parsed.lazy_child().lazy_child().lazy_child().has_payload()); 386} 387 388TEST(MESSAGE_TEST_NAME, UnparsedEmpty) { 389 // lazy_child, LEN=100 with no payload. 390 const char encoded[] = {'\042', 100}; 391 UNITTEST::NestedTestAllTypes message; 392 393 EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded))); 394 EXPECT_TRUE(message.has_lazy_child()); 395 EXPECT_EQ(message.lazy_child().ByteSizeLong(), 0); 396} 397 398TEST(MESSAGE_TEST_NAME, ParseFailNonCanonicalZeroTag) { 399 const char encoded[] = {"\n\x3\x80\0\0"}; 400 UNITTEST::NestedTestAllTypes parsed; 401 EXPECT_FALSE(parsed.ParsePartialFromString( 402 absl::string_view{encoded, sizeof(encoded) - 1})); 403} 404 405TEST(MESSAGE_TEST_NAME, ParseFailNonCanonicalZeroField) { 406 const char encoded[] = {"\012\x6\205\0\0\0\0\0"}; 407 UNITTEST::NestedTestAllTypes parsed; 408 EXPECT_FALSE(parsed.ParsePartialFromString( 409 absl::string_view{encoded, sizeof(encoded) - 1})); 410} 411 412TEST(MESSAGE_TEST_NAME, NestedExplicitLazyExceedRecursionLimit) { 413 UNITTEST::NestedTestAllTypes original, parsed; 414 // Build proto with recursion depth of 5, with nested annotated LazyField. 415 original.mutable_lazy_child() 416 ->mutable_child() 417 ->mutable_lazy_child() 418 ->mutable_child() 419 ->mutable_payload() 420 ->set_optional_int32(-1); 421 std::string serialized; 422 EXPECT_TRUE(original.SerializeToString(&serialized)); 423 424 // User annotated LazyField ([lazy = true]) is eagerly verified and should 425 // catch the recursion limit violation. 426 io::ArrayInputStream array_stream(serialized.data(), serialized.size()); 427 io::CodedInputStream input_stream(&array_stream); 428 input_stream.SetRecursionLimit(4); 429 EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream)); 430 431 // Lazy read results in parsing error which can be verified by not having 432 // expected value. 433 EXPECT_NE(parsed.lazy_child() 434 .child() 435 .lazy_child() 436 .child() 437 .payload() 438 .optional_int32(), 439 -1); 440} 441 442TEST(MESSAGE_TEST_NAME, ParseFailsIfSubmessageTruncated) { 443 UNITTEST::NestedTestAllTypes o, p; 444 constexpr int kDepth = 5; 445 auto* child = o.mutable_child(); 446 for (int i = 0; i < kDepth; i++) { 447 child = child->mutable_child(); 448 } 449 TestUtil::SetAllFields(child->mutable_payload()); 450 451 std::string serialized; 452 EXPECT_TRUE(o.SerializeToString(&serialized)); 453 454 // Should parse correctly. 455 EXPECT_TRUE(p.ParseFromString(serialized)); 456 457 constexpr int kMaxTruncate = 50; 458 ASSERT_GT(serialized.size(), kMaxTruncate); 459 460 for (int i = 1; i < kMaxTruncate; i += 3) { 461 EXPECT_FALSE( 462 p.ParseFromString(serialized.substr(0, serialized.size() - i))); 463 } 464} 465 466TEST(MESSAGE_TEST_NAME, ParseFailsIfWireMalformed) { 467 UNITTEST::NestedTestAllTypes o, p; 468 constexpr int kDepth = 5; 469 auto* child = o.mutable_child(); 470 for (int i = 0; i < kDepth; i++) { 471 child = child->mutable_child(); 472 } 473 // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1 474 child->mutable_payload()->set_optional_int32(-1); 475 476 std::string serialized; 477 EXPECT_TRUE(o.SerializeToString(&serialized)); 478 479 // Should parse correctly. 480 EXPECT_TRUE(p.ParseFromString(serialized)); 481 482 // Overwriting the last byte to 0xFF results in malformed wire. 483 serialized[serialized.size() - 1] = 0xFF; 484 EXPECT_FALSE(p.ParseFromString(serialized)); 485} 486 487TEST(MESSAGE_TEST_NAME, ParseFailsIfOneofWireMalformed) { 488 UNITTEST::NestedTestAllTypes o, p; 489 constexpr int kDepth = 5; 490 auto* child = o.mutable_child(); 491 for (int i = 0; i < kDepth; i++) { 492 child = child->mutable_child(); 493 } 494 // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1 495 child->mutable_payload()->mutable_oneof_nested_message()->set_bb(-1); 496 497 std::string serialized; 498 EXPECT_TRUE(o.SerializeToString(&serialized)); 499 500 // Should parse correctly. 501 EXPECT_TRUE(p.ParseFromString(serialized)); 502 503 // Overwriting the last byte to 0xFF results in malformed wire. 504 serialized[serialized.size() - 1] = 0xFF; 505 EXPECT_FALSE(p.ParseFromString(serialized)); 506} 507 508TEST(MESSAGE_TEST_NAME, ParseFailsIfExtensionWireMalformed) { 509 UNITTEST::TestChildExtension o, p; 510 auto* m = o.mutable_optional_extension()->MutableExtension( 511 UNITTEST::optional_nested_message_extension); 512 513 // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1 514 m->set_bb(-1); 515 516 std::string serialized; 517 EXPECT_TRUE(o.SerializeToString(&serialized)); 518 519 // Should parse correctly. 520 EXPECT_TRUE(p.ParseFromString(serialized)); 521 522 // Overwriting the last byte to 0xFF results in malformed wire. 523 serialized[serialized.size() - 1] = 0xFF; 524 EXPECT_FALSE(p.ParseFromString(serialized)); 525} 526 527TEST(MESSAGE_TEST_NAME, ParseFailsIfGroupFieldMalformed) { 528 UNITTEST::TestMutualRecursionA original, parsed; 529 original.mutable_bb() 530 ->mutable_a() 531 ->mutable_subgroup() 532 ->mutable_sub_message() 533 ->mutable_b() 534 ->set_optional_int32(-1); 535 536 std::string data; 537 ASSERT_TRUE(original.SerializeToString(&data)); 538 // Should parse correctly. 539 ASSERT_TRUE(parsed.ParseFromString(data)); 540 // Overwriting the last byte of varint (-1) to 0xFF results in malformed wire. 541 data[data.size() - 2] = 0xFF; 542 543 EXPECT_FALSE(parsed.ParseFromString(data)); 544} 545 546TEST(MESSAGE_TEST_NAME, ParseFailsIfRepeatedGroupFieldMalformed) { 547 UNITTEST::TestMutualRecursionA original, parsed; 548 original.mutable_bb() 549 ->mutable_a() 550 ->add_subgroupr() 551 ->mutable_payload() 552 ->set_optional_int64(-1); 553 554 std::string data; 555 ASSERT_TRUE(original.SerializeToString(&data)); 556 // Should parse correctly. 557 ASSERT_TRUE(parsed.ParseFromString(data)); 558 // Overwriting the last byte of varint (-1) to 0xFF results in malformed wire. 559 data[data.size() - 2] = 0xFF; 560 561 EXPECT_FALSE(parsed.ParseFromString(data)); 562} 563 564TEST(MESSAGE_TEST_NAME, UninitializedAndMalformed) { 565 UNITTEST::TestRequiredForeign o, p1, p2; 566 o.mutable_optional_message()->set_a(-1); 567 568 // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1 569 std::string serialized; 570 EXPECT_TRUE(o.SerializePartialToString(&serialized)); 571 572 // Should parse correctly. 573 EXPECT_TRUE(p1.ParsePartialFromString(serialized)); 574 EXPECT_FALSE(p1.IsInitialized()); 575 576 // Overwriting the last byte to 0xFF results in malformed wire. 577 serialized[serialized.size() - 1] = 0xFF; 578 EXPECT_FALSE(p2.ParseFromString(serialized)); 579 EXPECT_FALSE(p2.IsInitialized()); 580} 581 582// Parsing proto must not access beyond the bound. 583TEST(MESSAGE_TEST_NAME, ParseStrictlyBoundedStream) { 584 UNITTEST::NestedTestAllTypes o, p; 585 constexpr int kDepth = 2; 586 o = InitNestedProto(kDepth); 587 TestUtil::SetAllFields(o.mutable_child()->mutable_payload()); 588 o.mutable_child()->mutable_child()->mutable_payload()->set_optional_string( 589 std::string(1024, 'a')); 590 591 std::string data; 592 EXPECT_TRUE(o.SerializeToString(&data)); 593 594 TestUtil::BoundedArrayInputStream stream(data.data(), data.size()); 595 EXPECT_TRUE(p.ParseFromBoundedZeroCopyStream(&stream, data.size())); 596 TestUtil::ExpectAllFieldsSet(p.child().payload()); 597} 598 599// Helper functions to touch any nested lazy field 600void TouchLazy(UNITTEST::NestedTestAllTypes* msg); 601void TouchLazy(UNITTEST::TestAllTypes* msg); 602void TouchLazy(UNITTEST::TestAllTypes::NestedMessage* msg) {} 603 604void TouchLazy(UNITTEST::TestAllTypes* msg) { 605 if (msg->has_optional_lazy_message()) { 606 TouchLazy(msg->mutable_optional_lazy_message()); 607 } 608 if (msg->has_optional_unverified_lazy_message()) { 609 TouchLazy(msg->mutable_optional_unverified_lazy_message()); 610 } 611 for (auto& child : *msg->mutable_repeated_lazy_message()) { 612 TouchLazy(&child); 613 } 614} 615 616void TouchLazy(UNITTEST::NestedTestAllTypes* msg) { 617 if (msg->has_child()) TouchLazy(msg->mutable_child()); 618 if (msg->has_payload()) TouchLazy(msg->mutable_payload()); 619 for (auto& child : *msg->mutable_repeated_child()) { 620 TouchLazy(&child); 621 } 622 if (msg->has_lazy_child()) TouchLazy(msg->mutable_lazy_child()); 623 if (msg->has_eager_child()) TouchLazy(msg->mutable_eager_child()); 624} 625 626TEST(MESSAGE_TEST_NAME, SuccessAfterParsingFailure) { 627 UNITTEST::NestedTestAllTypes o, p, q; 628 constexpr int kDepth = 5; 629 o = InitNestedProto(kDepth); 630 std::string serialized; 631 EXPECT_TRUE(o.SerializeToString(&serialized)); 632 633 // Should parse correctly. 634 EXPECT_TRUE(p.ParseFromString(serialized)); 635 636 // Overwriting the last byte to 0xFF results in malformed wire. 637 serialized[serialized.size() - 1] = 0xFF; 638 EXPECT_FALSE(p.ParseFromString(serialized)); 639 640 // If the affected byte is inside a lazy message, we have no guarantee that it 641 // serializes into error free data because serialization needs to preserve 642 // const correctness on lazy fields: `touch` all lazy fields. 643 TouchLazy(&p); 644 EXPECT_TRUE(q.ParseFromString(p.SerializeAsString())); 645} 646 647TEST(MESSAGE_TEST_NAME, ExceedRecursionLimit) { 648 UNITTEST::NestedTestAllTypes o, p; 649 const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10; 650 o = InitNestedProto(kDepth); 651 std::string serialized; 652 EXPECT_TRUE(o.SerializeToString(&serialized)); 653 654 // Recursion level deeper than the default. 655 EXPECT_FALSE(p.ParseFromString(serialized)); 656} 657 658TEST(MESSAGE_TEST_NAME, SupportCustomRecursionLimitRead) { 659 UNITTEST::NestedTestAllTypes o, p; 660 const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10; 661 o = InitNestedProto(kDepth); 662 std::string serialized; 663 EXPECT_TRUE(o.SerializeToString(&serialized)); 664 665 // Should pass with custom limit + reads. 666 io::ArrayInputStream raw_input(serialized.data(), serialized.size()); 667 io::CodedInputStream input(&raw_input); 668 input.SetRecursionLimit(kDepth + 10); 669 EXPECT_TRUE(p.ParseFromCodedStream(&input)); 670 671 EXPECT_EQ(p.child().payload().optional_int32(), 0); 672 EXPECT_EQ(p.child().child().payload().optional_int32(), 1); 673 674 // Verify p serializes successfully (survives VerifyConsistency). 675 std::string result; 676 EXPECT_TRUE(p.SerializeToString(&result)); 677} 678 679TEST(MESSAGE_TEST_NAME, SupportCustomRecursionLimitWrite) { 680 UNITTEST::NestedTestAllTypes o, p; 681 const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10; 682 o = InitNestedProto(kDepth); 683 std::string serialized; 684 EXPECT_TRUE(o.SerializeToString(&serialized)); 685 686 // Should pass with custom limit + writes. 687 io::ArrayInputStream raw_input(serialized.data(), serialized.size()); 688 io::CodedInputStream input(&raw_input); 689 input.SetRecursionLimit(kDepth + 10); 690 EXPECT_TRUE(p.ParseFromCodedStream(&input)); 691 692 EXPECT_EQ(p.mutable_child()->mutable_payload()->optional_int32(), 0); 693 EXPECT_EQ( 694 p.mutable_child()->mutable_child()->mutable_payload()->optional_int32(), 695 1); 696} 697 698// While deep recursion is never guaranteed, this test aims to catch potential 699// issues with very deep recursion. 700TEST(MESSAGE_TEST_NAME, SupportDeepRecursionLimit) { 701 UNITTEST::NestedTestAllTypes o, p; 702 constexpr int kDepth = 1000; 703 auto* child = o.mutable_child(); 704 for (int i = 0; i < kDepth; i++) { 705 child = child->mutable_child(); 706 } 707 child->mutable_payload()->set_optional_int32(100); 708 709 std::string serialized; 710 EXPECT_TRUE(o.SerializeToString(&serialized)); 711 712 io::ArrayInputStream raw_input(serialized.data(), serialized.size()); 713 io::CodedInputStream input(&raw_input); 714 input.SetRecursionLimit(1100); 715 EXPECT_TRUE(p.ParseFromCodedStream(&input)); 716} 717 718inline bool IsOptimizeForCodeSize(const Descriptor* descriptor) { 719 return descriptor->file()->options().optimize_for() == FileOptions::CODE_SIZE; 720} 721 722 723TEST(MESSAGE_TEST_NAME, Swap) { 724 UNITTEST::NestedTestAllTypes o; 725 constexpr int kDepth = 5; 726 auto* child = o.mutable_child(); 727 for (int i = 0; i < kDepth; i++) { 728 child = child->mutable_child(); 729 } 730 TestUtil::SetAllFields(child->mutable_payload()); 731 732 std::string serialized; 733 EXPECT_TRUE(o.SerializeToString(&serialized)); 734 735 { 736 Arena arena; 737 UNITTEST::NestedTestAllTypes* p1 = 738 Arena::Create<UNITTEST::NestedTestAllTypes>(&arena); 739 740 // Should parse correctly. 741 EXPECT_TRUE(p1->ParseFromString(serialized)); 742 743 UNITTEST::NestedTestAllTypes* p2 = 744 Arena::Create<UNITTEST::NestedTestAllTypes>(&arena); 745 746 p1->Swap(p2); 747 748 EXPECT_EQ(o.SerializeAsString(), p2->SerializeAsString()); 749 } 750} 751 752TEST(MESSAGE_TEST_NAME, BypassInitializationCheckOnParse) { 753 UNITTEST::TestRequired message; 754 io::ArrayInputStream raw_input(nullptr, 0); 755 io::CodedInputStream input(&raw_input); 756 EXPECT_TRUE(message.MergePartialFromCodedStream(&input)); 757} 758 759TEST(MESSAGE_TEST_NAME, InitializationErrorString) { 760 UNITTEST::TestRequired message; 761 EXPECT_EQ("a, b, c", message.InitializationErrorString()); 762} 763 764TEST(MESSAGE_TEST_NAME, DynamicCastMessage) { 765 UNITTEST::TestAllTypes test_all_types; 766 767 MessageLite* test_all_types_pointer = &test_all_types; 768 EXPECT_EQ(&test_all_types, 769 DynamicCastMessage<UNITTEST::TestAllTypes>(test_all_types_pointer)); 770 EXPECT_EQ(nullptr, 771 DynamicCastMessage<UNITTEST::TestRequired>(test_all_types_pointer)); 772 773 const MessageLite* test_all_types_pointer_const = &test_all_types; 774 EXPECT_EQ(&test_all_types, DynamicCastMessage<const UNITTEST::TestAllTypes>( 775 test_all_types_pointer_const)); 776 EXPECT_EQ(nullptr, DynamicCastMessage<const UNITTEST::TestRequired>( 777 test_all_types_pointer_const)); 778 779 MessageLite* test_all_types_pointer_nullptr = nullptr; 780 EXPECT_EQ(nullptr, DynamicCastMessage<UNITTEST::TestAllTypes>( 781 test_all_types_pointer_nullptr)); 782 783 MessageLite& test_all_types_pointer_ref = test_all_types; 784 EXPECT_EQ(&test_all_types, &DynamicCastMessage<UNITTEST::TestAllTypes>( 785 test_all_types_pointer_ref)); 786 787 const MessageLite& test_all_types_pointer_const_ref = test_all_types; 788 EXPECT_EQ(&test_all_types, &DynamicCastMessage<UNITTEST::TestAllTypes>( 789 test_all_types_pointer_const_ref)); 790} 791 792TEST(MESSAGE_TEST_NAME, DynamicCastMessageInvalidReferenceType) { 793 UNITTEST::TestAllTypes test_all_types; 794 const MessageLite& test_all_types_pointer_const_ref = test_all_types; 795 ASSERT_DEATH( 796 DynamicCastMessage<UNITTEST::TestRequired>( 797 test_all_types_pointer_const_ref), 798 absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ", 799 UNITTEST::TestRequired::default_instance().GetTypeName())); 800} 801 802TEST(MESSAGE_TEST_NAME, DownCastMessageValidType) { 803 UNITTEST::TestAllTypes test_all_types; 804 805 MessageLite* test_all_types_pointer = &test_all_types; 806 EXPECT_EQ(&test_all_types, 807 DownCastMessage<UNITTEST::TestAllTypes>(test_all_types_pointer)); 808 809 const MessageLite* test_all_types_pointer_const = &test_all_types; 810 EXPECT_EQ(&test_all_types, DownCastMessage<const UNITTEST::TestAllTypes>( 811 test_all_types_pointer_const)); 812 813 MessageLite* test_all_types_pointer_nullptr = nullptr; 814 EXPECT_EQ(nullptr, DownCastMessage<UNITTEST::TestAllTypes>( 815 test_all_types_pointer_nullptr)); 816 817 MessageLite& test_all_types_pointer_ref = test_all_types; 818 EXPECT_EQ(&test_all_types, &DownCastMessage<UNITTEST::TestAllTypes>( 819 test_all_types_pointer_ref)); 820 821 const MessageLite& test_all_types_pointer_const_ref = test_all_types; 822 EXPECT_EQ(&test_all_types, &DownCastMessage<UNITTEST::TestAllTypes>( 823 test_all_types_pointer_const_ref)); 824} 825 826TEST(MESSAGE_TEST_NAME, DownCastMessageInvalidPointerType) { 827 UNITTEST::TestAllTypes test_all_types; 828 829 MessageLite* test_all_types_pointer = &test_all_types; 830 831 ASSERT_DEBUG_DEATH( 832 DownCastMessage<UNITTEST::TestRequired>(test_all_types_pointer), 833 absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ", 834 UNITTEST::TestRequired::default_instance().GetTypeName())); 835} 836 837TEST(MESSAGE_TEST_NAME, DownCastMessageInvalidReferenceType) { 838 UNITTEST::TestAllTypes test_all_types; 839 840 MessageLite& test_all_types_ref = test_all_types; 841 842 ASSERT_DEBUG_DEATH( 843 DownCastMessage<UNITTEST::TestRequired>(test_all_types_ref), 844 absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ", 845 UNITTEST::TestRequired::default_instance().GetTypeName())); 846} 847 848TEST(MESSAGE_TEST_NAME, MessageDebugStringMatchesBehindPointerAndLitePointer) { 849 UNITTEST::TestAllTypes test_all_types; 850 test_all_types.set_optional_string("foo"); 851 Message* msg_full_pointer = &test_all_types; 852 MessageLite* msg_lite_pointer = &test_all_types; 853 ASSERT_EQ(test_all_types.DebugString(), msg_full_pointer->DebugString()); 854 ASSERT_EQ(test_all_types.DebugString(), msg_lite_pointer->DebugString()); 855} 856 857#if GTEST_HAS_DEATH_TEST // death tests do not work on Windows yet. 858 859TEST(MESSAGE_TEST_NAME, SerializeFailsIfNotInitialized) { 860 UNITTEST::TestRequired message; 861 std::string data; 862 EXPECT_DEBUG_DEATH( 863 EXPECT_TRUE(message.SerializeToString(&data)), 864 absl::StrCat("Can't serialize message of type \"", UNITTEST_PACKAGE_NAME, 865 ".TestRequired\" because " 866 "it is missing required fields: a, b, c")); 867} 868 869TEST(MESSAGE_TEST_NAME, CheckInitialized) { 870 UNITTEST::TestRequired message; 871 EXPECT_DEATH(message.CheckInitialized(), 872 absl::StrCat("Message of type \"", UNITTEST_PACKAGE_NAME, 873 ".TestRequired\" is missing required " 874 "fields: a, b, c")); 875} 876 877#endif // GTEST_HAS_DEATH_TEST 878 879namespace { 880// An input stream that repeats a std::string's content for a number of times. 881// It helps us create a really large input without consuming too much memory. 882// Used to test the parsing behavior when the input size exceeds 2G or close to 883// it. 884class RepeatedInputStream : public io::ZeroCopyInputStream { 885 public: 886 RepeatedInputStream(const std::string& data, size_t count) 887 : data_(data), count_(count), position_(0), total_byte_count_(0) {} 888 889 bool Next(const void** data, int* size) override { 890 if (position_ == data_.size()) { 891 if (--count_ == 0) { 892 return false; 893 } 894 position_ = 0; 895 } 896 *data = &data_[position_]; 897 *size = static_cast<int>(data_.size() - position_); 898 position_ = data_.size(); 899 total_byte_count_ += *size; 900 return true; 901 } 902 903 void BackUp(int count) override { 904 position_ -= static_cast<size_t>(count); 905 total_byte_count_ -= count; 906 } 907 908 bool Skip(int count) override { 909 while (count > 0) { 910 const void* data; 911 int size; 912 if (!Next(&data, &size)) { 913 break; 914 } 915 if (size >= count) { 916 BackUp(size - count); 917 return true; 918 } else { 919 count -= size; 920 } 921 } 922 return false; 923 } 924 925 int64_t ByteCount() const override { return total_byte_count_; } 926 927 private: 928 std::string data_; 929 size_t count_; // The number of strings that haven't been consumed. 930 size_t position_; // Position in the std::string for the next read. 931 int64_t total_byte_count_; 932}; 933} // namespace 934 935TEST(MESSAGE_TEST_NAME, TestParseMessagesCloseTo2G) { 936 constexpr int32_t kint32max = std::numeric_limits<int32_t>::max(); 937 938 // Create a message with a large std::string field. 939 std::string value = std::string(64 * 1024 * 1024, 'x'); 940 UNITTEST::TestAllTypes message; 941 message.set_optional_string(value); 942 943 // Repeat this message in the input stream to make the total input size 944 // close to 2G. 945 std::string data = message.SerializeAsString(); 946 size_t count = static_cast<size_t>(kint32max) / data.size(); 947 RepeatedInputStream input(data, count); 948 949 // The parsing should succeed. 950 UNITTEST::TestAllTypes result; 951 EXPECT_TRUE(result.ParseFromZeroCopyStream(&input)); 952 953 // When there are multiple occurrences of a singular field, the last one 954 // should win. 955 EXPECT_EQ(value, result.optional_string()); 956} 957 958TEST(MESSAGE_TEST_NAME, TestParseMessagesOver2G) { 959 constexpr int32_t kint32max = std::numeric_limits<int32_t>::max(); 960 961 // Create a message with a large std::string field. 962 std::string value = std::string(64 * 1024 * 1024, 'x'); 963 UNITTEST::TestAllTypes message; 964 message.set_optional_string(value); 965 966 // Repeat this message in the input stream to make the total input size 967 // larger than 2G. 968 std::string data = message.SerializeAsString(); 969 size_t count = static_cast<size_t>(kint32max) / data.size() + 1; 970 RepeatedInputStream input(data, count); 971 972 // The parsing should fail. 973 UNITTEST::TestAllTypes result; 974 EXPECT_FALSE(result.ParseFromZeroCopyStream(&input)); 975} 976 977TEST(MESSAGE_TEST_NAME, BypassInitializationCheckOnSerialize) { 978 UNITTEST::TestRequired message; 979 io::ArrayOutputStream raw_output(nullptr, 0); 980 io::CodedOutputStream output(&raw_output); 981 EXPECT_TRUE(message.SerializePartialToCodedStream(&output)); 982} 983 984TEST(MESSAGE_TEST_NAME, FindInitializationErrors) { 985 UNITTEST::TestRequired message; 986 std::vector<std::string> errors; 987 message.FindInitializationErrors(&errors); 988 ASSERT_EQ(3, errors.size()); 989 EXPECT_EQ("a", errors[0]); 990 EXPECT_EQ("b", errors[1]); 991 EXPECT_EQ("c", errors[2]); 992} 993 994TEST(MESSAGE_TEST_NAME, ReleaseMustUseResult) { 995 UNITTEST::TestAllTypes message; 996 auto* f = new UNITTEST::ForeignMessage(); 997 f->set_c(1000); 998 message.set_allocated_optional_foreign_message(f); 999 auto* mf = message.mutable_optional_foreign_message(); 1000 EXPECT_EQ(mf, f); 1001 std::unique_ptr<UNITTEST::ForeignMessage> rf( 1002 message.release_optional_foreign_message()); 1003 EXPECT_NE(rf.get(), nullptr); 1004} 1005 1006TEST(MESSAGE_TEST_NAME, ParseFailsOnInvalidMessageEnd) { 1007 UNITTEST::TestAllTypes message; 1008 1009 // Control case. 1010 EXPECT_TRUE(message.ParseFromArray("", 0)); 1011 1012 // The byte is a valid varint, but not a valid tag (zero). 1013 EXPECT_FALSE(message.ParseFromArray("\0", 1)); 1014 1015 // The byte is a malformed varint. 1016 EXPECT_FALSE(message.ParseFromArray("\200", 1)); 1017 1018 // The byte is an endgroup tag, but we aren't parsing a group. 1019 EXPECT_FALSE(message.ParseFromArray("\014", 1)); 1020} 1021 1022// Regression test for b/23630858 1023TEST(MESSAGE_TEST_NAME, MessageIsStillValidAfterParseFails) { 1024 UNITTEST::TestAllTypes message; 1025 1026 // 9 0xFFs for the "optional_uint64" field. 1027 std::string invalid_data = "\x20\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"; 1028 1029 EXPECT_FALSE(message.ParseFromString(invalid_data)); 1030 message.Clear(); 1031 EXPECT_EQ(0, message.optional_uint64()); 1032 1033 // invalid data for field "optional_string". Length prefix is 1 but no 1034 // payload. 1035 std::string invalid_string_data = "\x72\x01"; 1036 { 1037 Arena arena; 1038 UNITTEST::TestAllTypes* arena_message = 1039 Arena::Create<UNITTEST::TestAllTypes>(&arena); 1040 EXPECT_FALSE(arena_message->ParseFromString(invalid_string_data)); 1041 arena_message->Clear(); 1042 EXPECT_EQ("", arena_message->optional_string()); 1043 } 1044} 1045 1046TEST(MESSAGE_TEST_NAME, NonCanonicalTag) { 1047 UNITTEST::TestAllTypes message; 1048 // optional_lazy_message (27) LEN(3) with non canonical tag: (1). 1049 const char encoded[] = {'\332', 1, 3, '\210', 0, 0}; 1050 EXPECT_TRUE(message.ParseFromArray(encoded, sizeof(encoded))); 1051} 1052 1053TEST(MESSAGE_TEST_NAME, Zero5BTag) { 1054 UNITTEST::TestAllTypes message; 1055 // optional_nested_message (18) LEN(6) with 5B but zero tag. 1056 const char encoded[] = {'\222', 1, 6, '\200', '\200', 1057 '\200', '\200', '\020', 0}; 1058 EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded))); 1059} 1060 1061TEST(MESSAGE_TEST_NAME, Zero5BTagLazy) { 1062 UNITTEST::TestAllTypes message; 1063 // optional_lazy_message (27) LEN(6) with 5B but zero tag. 1064 const char encoded[] = {'\332', 1, 6, '\200', '\200', 1065 '\200', '\200', '\020', 0}; 1066 EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded))); 1067} 1068 1069namespace { 1070void ExpectMessageMerged(const UNITTEST::TestAllTypes& message) { 1071 EXPECT_EQ(3, message.optional_int32()); 1072 EXPECT_EQ(2, message.optional_int64()); 1073 EXPECT_EQ("hello", message.optional_string()); 1074} 1075 1076void AssignParsingMergeMessages(UNITTEST::TestAllTypes* msg1, 1077 UNITTEST::TestAllTypes* msg2, 1078 UNITTEST::TestAllTypes* msg3) { 1079 msg1->set_optional_int32(1); 1080 msg2->set_optional_int64(2); 1081 msg3->set_optional_int32(3); 1082 msg3->set_optional_string("hello"); 1083} 1084} // namespace 1085 1086// Test that if an optional or required message/group field appears multiple 1087// times in the input, they need to be merged. 1088TEST(MESSAGE_TEST_NAME, ParsingMerge) { 1089 UNITTEST::TestParsingMerge::RepeatedFieldsGenerator generator; 1090 UNITTEST::TestAllTypes* msg1; 1091 UNITTEST::TestAllTypes* msg2; 1092 UNITTEST::TestAllTypes* msg3; 1093 1094#define ASSIGN_REPEATED_FIELD(FIELD) \ 1095 msg1 = generator.add_##FIELD(); \ 1096 msg2 = generator.add_##FIELD(); \ 1097 msg3 = generator.add_##FIELD(); \ 1098 AssignParsingMergeMessages(msg1, msg2, msg3) 1099 1100 ASSIGN_REPEATED_FIELD(field1); 1101 ASSIGN_REPEATED_FIELD(field2); 1102 ASSIGN_REPEATED_FIELD(field3); 1103 ASSIGN_REPEATED_FIELD(ext1); 1104 ASSIGN_REPEATED_FIELD(ext2); 1105 1106#undef ASSIGN_REPEATED_FIELD 1107#define ASSIGN_REPEATED_GROUP(FIELD) \ 1108 msg1 = generator.add_##FIELD()->mutable_field1(); \ 1109 msg2 = generator.add_##FIELD()->mutable_field1(); \ 1110 msg3 = generator.add_##FIELD()->mutable_field1(); \ 1111 AssignParsingMergeMessages(msg1, msg2, msg3) 1112 1113 ASSIGN_REPEATED_GROUP(group1); 1114 ASSIGN_REPEATED_GROUP(group2); 1115 1116#undef ASSIGN_REPEATED_GROUP 1117 1118 std::string buffer; 1119 generator.SerializeToString(&buffer); 1120 UNITTEST::TestParsingMerge parsing_merge; 1121 parsing_merge.ParseFromString(buffer); 1122 1123 // Required and optional fields should be merged. 1124 ExpectMessageMerged(parsing_merge.required_all_types()); 1125 ExpectMessageMerged(parsing_merge.optional_all_types()); 1126 ExpectMessageMerged(parsing_merge.optionalgroup().optional_group_all_types()); 1127 ExpectMessageMerged( 1128 parsing_merge.GetExtension(UNITTEST::TestParsingMerge::optional_ext)); 1129 1130 // Repeated fields should not be merged. 1131 EXPECT_EQ(3, parsing_merge.repeated_all_types_size()); 1132 EXPECT_EQ(3, parsing_merge.repeatedgroup_size()); 1133 EXPECT_EQ( 1134 3, parsing_merge.ExtensionSize(UNITTEST::TestParsingMerge::repeated_ext)); 1135} 1136 1137TEST(MESSAGE_TEST_NAME, MergeFrom) { 1138 UNITTEST::TestAllTypes source, dest; 1139 1140 // Optional fields 1141 source.set_optional_int32(1); // only source 1142 source.set_optional_int64(2); // both source and dest 1143 dest.set_optional_int64(3); 1144 dest.set_optional_uint32(4); // only dest 1145 1146 // Optional fields with defaults 1147 source.set_default_int32(13); // only source 1148 source.set_default_int64(14); // both source and dest 1149 dest.set_default_int64(15); 1150 dest.set_default_uint32(16); // only dest 1151 1152 // Repeated fields 1153 source.add_repeated_int32(5); // only source 1154 source.add_repeated_int32(6); 1155 source.add_repeated_int64(7); // both source and dest 1156 source.add_repeated_int64(8); 1157 dest.add_repeated_int64(9); 1158 dest.add_repeated_int64(10); 1159 dest.add_repeated_uint32(11); // only dest 1160 dest.add_repeated_uint32(12); 1161 1162 dest.MergeFrom(source); 1163 1164 // Optional fields: source overwrites dest if source is specified 1165 EXPECT_EQ(1, dest.optional_int32()); // only source: use source 1166 EXPECT_EQ(2, dest.optional_int64()); // source and dest: use source 1167 EXPECT_EQ(4, dest.optional_uint32()); // only dest: use dest 1168 EXPECT_EQ(0, dest.optional_uint64()); // neither: use default 1169 1170 // Optional fields with defaults 1171 EXPECT_EQ(13, dest.default_int32()); // only source: use source 1172 EXPECT_EQ(14, dest.default_int64()); // source and dest: use source 1173 EXPECT_EQ(16, dest.default_uint32()); // only dest: use dest 1174 EXPECT_EQ(44, dest.default_uint64()); // neither: use default 1175 1176 // Repeated fields: concatenate source onto the end of dest 1177 ASSERT_EQ(2, dest.repeated_int32_size()); 1178 EXPECT_EQ(5, dest.repeated_int32(0)); 1179 EXPECT_EQ(6, dest.repeated_int32(1)); 1180 ASSERT_EQ(4, dest.repeated_int64_size()); 1181 EXPECT_EQ(9, dest.repeated_int64(0)); 1182 EXPECT_EQ(10, dest.repeated_int64(1)); 1183 EXPECT_EQ(7, dest.repeated_int64(2)); 1184 EXPECT_EQ(8, dest.repeated_int64(3)); 1185 ASSERT_EQ(2, dest.repeated_uint32_size()); 1186 EXPECT_EQ(11, dest.repeated_uint32(0)); 1187 EXPECT_EQ(12, dest.repeated_uint32(1)); 1188 ASSERT_EQ(0, dest.repeated_uint64_size()); 1189} 1190 1191TEST(MESSAGE_TEST_NAME, IsInitialized) { 1192 UNITTEST::TestIsInitialized msg; 1193 EXPECT_TRUE(msg.IsInitialized()); 1194 UNITTEST::TestIsInitialized::SubMessage* sub_message = 1195 msg.mutable_sub_message(); 1196 EXPECT_TRUE(msg.IsInitialized()); 1197 UNITTEST::TestIsInitialized::SubMessage::SubGroup* sub_group = 1198 sub_message->mutable_subgroup(); 1199 EXPECT_FALSE(msg.IsInitialized()); 1200 sub_group->set_i(1); 1201 EXPECT_TRUE(msg.IsInitialized()); 1202} 1203 1204TEST(MESSAGE_TEST_NAME, IsInitializedSplitBytestream) { 1205 UNITTEST::TestRequired ab, c; 1206 ab.set_a(1); 1207 ab.set_b(2); 1208 c.set_c(3); 1209 1210 // The protobuf represented by the concatenated string has all required 1211 // fields (a,b,c) set. 1212 std::string bytes = 1213 ab.SerializePartialAsString() + c.SerializePartialAsString(); 1214 1215 UNITTEST::TestRequired concatenated; 1216 EXPECT_TRUE(concatenated.ParsePartialFromString(bytes)); 1217 EXPECT_TRUE(concatenated.IsInitialized()); 1218 1219 UNITTEST::TestRequiredForeign fab, fc; 1220 fab.mutable_optional_message()->set_a(1); 1221 fab.mutable_optional_message()->set_b(2); 1222 fc.mutable_optional_message()->set_c(3); 1223 1224 bytes = fab.SerializePartialAsString() + fc.SerializePartialAsString(); 1225 1226 UNITTEST::TestRequiredForeign fconcatenated; 1227 EXPECT_TRUE(fconcatenated.ParsePartialFromString(bytes)); 1228 EXPECT_TRUE(fconcatenated.IsInitialized()); 1229} 1230 1231TEST(MESSAGE_FACTORY_TEST_NAME, GeneratedFactoryLookup) { 1232 EXPECT_EQ(MessageFactory::generated_factory()->GetPrototype( 1233 UNITTEST::TestAllTypes::descriptor()), 1234 &UNITTEST::TestAllTypes::default_instance()); 1235} 1236 1237TEST(MESSAGE_FACTORY_TEST_NAME, GeneratedFactoryUnknownType) { 1238 // Construct a new descriptor. 1239 DescriptorPool pool; 1240 FileDescriptorProto file; 1241 file.set_name("foo.proto"); 1242 file.add_message_type()->set_name("Foo"); 1243 const Descriptor* descriptor = pool.BuildFile(file)->message_type(0); 1244 1245 // Trying to construct it should return nullptr. 1246 EXPECT_TRUE(MessageFactory::generated_factory()->GetPrototype(descriptor) == 1247 nullptr); 1248} 1249 1250TEST(MESSAGE_TEST_NAME, MOMIParserEdgeCases) { 1251 { 1252 UNITTEST::TestAllTypes msg; 1253 // Parser ends in last 16 bytes of buffer due to a 0. 1254 std::string data; 1255 // 12 bytes of data 1256 for (int i = 0; i < 4; i++) absl::StrAppend(&data, "\370\1\1"); 1257 // 13 byte is terminator 1258 data += '\0'; // Terminator 1259 // followed by the rest of the stream 1260 // space is ascii 32 so no end group 1261 data += std::string(30, ' '); 1262 io::ArrayInputStream zcis(data.data(), data.size(), 17); 1263 io::CodedInputStream cis(&zcis); 1264 EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis)); 1265 EXPECT_EQ(cis.CurrentPosition(), 3 * 4 + 1); 1266 } 1267 { 1268 // Parser ends in last 16 bytes of buffer due to a end-group. 1269 // Must use a message that is a group. Otherwise ending on a group end is 1270 // a failure. 1271 UNITTEST::TestAllTypes::OptionalGroup msg; 1272 std::string data; 1273 for (int i = 0; i < 3; i++) absl::StrAppend(&data, "\370\1\1"); 1274 data += '\14'; // Octal end-group tag 12 (1 * 8 + 4( 1275 data += std::string(30, ' '); 1276 io::ArrayInputStream zcis(data.data(), data.size(), 17); 1277 io::CodedInputStream cis(&zcis); 1278 EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis)); 1279 EXPECT_EQ(cis.CurrentPosition(), 3 * 3 + 1); 1280 EXPECT_TRUE(cis.LastTagWas(12)); 1281 } 1282 { 1283 // Parser ends in last 16 bytes of buffer due to a end-group. But is inside 1284 // a length delimited field. 1285 // a failure. 1286 UNITTEST::TestAllTypes::OptionalGroup msg; 1287 std::string data = "\22\3foo"; 1288 data += '\14'; // Octal end-group tag 12 (1 * 8 + 4( 1289 data += std::string(30, ' '); 1290 io::ArrayInputStream zcis(data.data(), data.size(), 17); 1291 io::CodedInputStream cis(&zcis); 1292 EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis)); 1293 EXPECT_EQ(cis.CurrentPosition(), 6); 1294 EXPECT_TRUE(cis.LastTagWas(12)); 1295 } 1296 { 1297 // Parser fails when ending on 0 if from ZeroCopyInputStream 1298 UNITTEST::TestAllTypes msg; 1299 std::string data; 1300 // 12 bytes of data 1301 for (int i = 0; i < 4; i++) absl::StrAppend(&data, "\370\1\1"); 1302 // 13 byte is terminator 1303 data += '\0'; // Terminator 1304 data += std::string(30, ' '); 1305 io::ArrayInputStream zcis(data.data(), data.size(), 17); 1306 EXPECT_FALSE(msg.ParsePartialFromZeroCopyStream(&zcis)); 1307 } 1308} 1309 1310 1311TEST(MESSAGE_TEST_NAME, CheckSerializationWhenInterleavedExtensions) { 1312 UNITTEST::TestExtensionRangeSerialize in_message; 1313 1314 in_message.set_foo_one(1); 1315 in_message.set_foo_two(2); 1316 in_message.set_foo_three(3); 1317 in_message.set_foo_four(4); 1318 1319 in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_one, 1); 1320 in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_two, 2); 1321 in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_three, 3); 1322 in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_four, 4); 1323 in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_five, 5); 1324 1325 std::string buffer; 1326 in_message.SerializeToString(&buffer); 1327 1328 UNITTEST::TestExtensionRangeSerialize out_message; 1329 out_message.ParseFromString(buffer); 1330 1331 EXPECT_EQ(1, out_message.foo_one()); 1332 EXPECT_EQ(2, out_message.foo_two()); 1333 EXPECT_EQ(3, out_message.foo_three()); 1334 EXPECT_EQ(4, out_message.foo_four()); 1335 1336 EXPECT_EQ(1, out_message.GetExtension( 1337 UNITTEST::TestExtensionRangeSerialize::bar_one)); 1338 EXPECT_EQ(2, out_message.GetExtension( 1339 UNITTEST::TestExtensionRangeSerialize::bar_two)); 1340 EXPECT_EQ(3, out_message.GetExtension( 1341 UNITTEST::TestExtensionRangeSerialize::bar_three)); 1342 EXPECT_EQ(4, out_message.GetExtension( 1343 UNITTEST::TestExtensionRangeSerialize::bar_four)); 1344 EXPECT_EQ(5, out_message.GetExtension( 1345 UNITTEST::TestExtensionRangeSerialize::bar_five)); 1346} 1347 1348TEST(MESSAGE_TEST_NAME, PreservesFloatingPointNegative0) { 1349 UNITTEST::TestAllTypes in_message; 1350 in_message.set_optional_float(-0.0f); 1351 in_message.set_optional_double(-0.0); 1352 std::string serialized; 1353 EXPECT_TRUE(in_message.SerializeToString(&serialized)); 1354 UNITTEST::TestAllTypes out_message; 1355 EXPECT_TRUE(out_message.ParseFromString(serialized)); 1356 EXPECT_EQ(in_message.optional_float(), out_message.optional_float()); 1357 EXPECT_EQ(std::signbit(in_message.optional_float()), 1358 std::signbit(out_message.optional_float())); 1359 EXPECT_EQ(in_message.optional_double(), out_message.optional_double()); 1360 EXPECT_EQ(std::signbit(in_message.optional_double()), 1361 std::signbit(out_message.optional_double())); 1362} 1363 1364TEST(MESSAGE_TEST_NAME, 1365 RegressionTestForParseMessageReadingUninitializedLimit) { 1366 UNITTEST::TestAllTypes in_message; 1367 in_message.mutable_optional_nested_message(); 1368 std::string serialized = in_message.SerializeAsString(); 1369 // We expect this to have 3 bytes: two for the tag, and one for the zero size. 1370 // Break the size by making it overlong. 1371 ASSERT_EQ(serialized.size(), 3); 1372 serialized.back() = '\200'; 1373 serialized += std::string(10, '\200'); 1374 EXPECT_FALSE(in_message.ParseFromString(serialized)); 1375} 1376 1377TEST(MESSAGE_TEST_NAME, 1378 RegressionTestForParseMessageWithSizeBeyondInputFailsToPopLimit) { 1379 UNITTEST::TestAllTypes in_message; 1380 in_message.mutable_optional_nested_message(); 1381 std::string serialized = in_message.SerializeAsString(); 1382 // We expect this to have 3 bytes: two for the tag, and one for the zero size. 1383 // Make the size a valid varint, but it overflows in the input. 1384 ASSERT_EQ(serialized.size(), 3); 1385 serialized.back() = 10; 1386 EXPECT_FALSE(in_message.ParseFromString(serialized)); 1387} 1388 1389namespace { 1390const uint8_t* SkipTag(const uint8_t* buf) { 1391 while (*buf & 0x80) ++buf; 1392 ++buf; 1393 return buf; 1394} 1395 1396// Adds `non_canonical_bytes` bytes to the varint representation at the tail of 1397// the buffer. 1398// `buf` points to the start of the buffer, `p` points to one-past-the-end. 1399// Returns the new one-past-the-end pointer. 1400uint8_t* AddNonCanonicalBytes(const uint8_t* buf, uint8_t* p, 1401 int non_canonical_bytes) { 1402 // varint can have a max of 10 bytes. 1403 while (non_canonical_bytes-- > 0 && p - buf < 10) { 1404 // Add a dummy byte at the end. 1405 p[-1] |= 0x80; 1406 p[0] = 0; 1407 ++p; 1408 } 1409 return p; 1410} 1411 1412std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) { 1413 uint8_t buf[100]; 1414 uint8_t* p = buf; 1415 1416 p = internal::WireFormatLite::WriteBoolToArray(number, value, p); 1417 p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); 1418 return std::string(buf, p); 1419} 1420 1421std::string EncodeEnumValue(int number, int value, int non_canonical_bytes, 1422 bool use_packed) { 1423 uint8_t buf[100]; 1424 uint8_t* p = buf; 1425 1426 if (use_packed) { 1427 p = internal::WireFormatLite::WriteEnumNoTagToArray(value, p); 1428 p = AddNonCanonicalBytes(buf, p, non_canonical_bytes); 1429 1430 std::string payload(buf, p); 1431 p = buf; 1432 p = internal::WireFormatLite::WriteStringToArray(number, payload, p); 1433 return std::string(buf, p); 1434 1435 } else { 1436 p = internal::WireFormatLite::WriteEnumToArray(number, value, p); 1437 p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); 1438 return std::string(buf, p); 1439 } 1440} 1441 1442std::string EncodeOverlongEnum(int number, bool use_packed) { 1443 uint8_t buf[100]; 1444 uint8_t* p = buf; 1445 1446 std::string overlong(16, static_cast<char>(0x80)); 1447 if (use_packed) { 1448 p = internal::WireFormatLite::WriteStringToArray(number, overlong, p); 1449 return std::string(buf, p); 1450 } else { 1451 p = internal::WireFormatLite::WriteTagToArray( 1452 number, internal::WireFormatLite::WIRETYPE_VARINT, p); 1453 p = std::copy(overlong.begin(), overlong.end(), p); 1454 return std::string(buf, p); 1455 } 1456} 1457 1458std::string EncodeInt32Value(int number, int32_t value, 1459 int non_canonical_bytes) { 1460 uint8_t buf[100]; 1461 uint8_t* p = buf; 1462 1463 p = internal::WireFormatLite::WriteInt32ToArray(number, value, p); 1464 p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); 1465 return std::string(buf, p); 1466} 1467 1468std::string EncodeInt64Value(int number, int64_t value, int non_canonical_bytes, 1469 bool use_packed = false) { 1470 uint8_t buf[100]; 1471 uint8_t* p = buf; 1472 1473 if (use_packed) { 1474 p = internal::WireFormatLite::WriteInt64NoTagToArray(value, p); 1475 p = AddNonCanonicalBytes(buf, p, non_canonical_bytes); 1476 1477 std::string payload(buf, p); 1478 p = buf; 1479 p = internal::WireFormatLite::WriteStringToArray(number, payload, p); 1480 return std::string(buf, p); 1481 1482 } else { 1483 p = internal::WireFormatLite::WriteInt64ToArray(number, value, p); 1484 p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); 1485 return std::string(buf, p); 1486 } 1487} 1488 1489std::string EncodeOtherField() { 1490 UNITTEST::EnumParseTester obj; 1491 obj.set_other_field(1); 1492 return obj.SerializeAsString(); 1493} 1494 1495template <typename T> 1496static std::vector<const FieldDescriptor*> GetFields() { 1497 auto* descriptor = T::descriptor(); 1498 std::vector<const FieldDescriptor*> fields; 1499 for (int i = 0; i < descriptor->field_count(); ++i) { 1500 fields.push_back(descriptor->field(i)); 1501 } 1502 for (int i = 0; i < descriptor->extension_count(); ++i) { 1503 fields.push_back(descriptor->extension(i)); 1504 } 1505 return fields; 1506} 1507} // namespace 1508 1509TEST(MESSAGE_TEST_NAME, TestEnumParsers) { 1510 UNITTEST::EnumParseTester obj; 1511 1512 const auto other_field = EncodeOtherField(); 1513 1514 // Encode an enum field for many different cases and verify that it can be 1515 // parsed as expected. 1516 // There are: 1517 // - optional/repeated/packed fields 1518 // - field tags that encode in 1/2/3 bytes 1519 // - canonical and non-canonical encodings of the varint 1520 // - last vs not last field 1521 // - label combinations to trigger different parsers: sequential, small 1522 // sequential, non-validated. 1523 1524 const std::vector<const FieldDescriptor*> fields = 1525 GetFields<UNITTEST::EnumParseTester>(); 1526 1527 constexpr int kInvalidValue = 0x900913; 1528 auto* ref = obj.GetReflection(); 1529 PROTOBUF_UNUSED auto* descriptor = obj.descriptor(); 1530 for (bool use_packed : {false, true}) { 1531 SCOPED_TRACE(use_packed); 1532 for (bool use_tail_field : {false, true}) { 1533 SCOPED_TRACE(use_tail_field); 1534 for (int non_canonical_bytes = 0; non_canonical_bytes < 9; 1535 ++non_canonical_bytes) { 1536 SCOPED_TRACE(non_canonical_bytes); 1537 for (bool add_garbage_bits : {false, true}) { 1538 if (add_garbage_bits && non_canonical_bytes != 9) { 1539 // We only add garbage on the 10th byte. 1540 continue; 1541 } 1542 SCOPED_TRACE(add_garbage_bits); 1543 for (auto field : fields) { 1544 if (field->name() == "other_field") continue; 1545 if (!field->is_repeated() && use_packed) continue; 1546 SCOPED_TRACE(field->full_name()); 1547 const auto* enum_desc = field->enum_type(); 1548 for (int e = 0; e < enum_desc->value_count(); ++e) { 1549 const auto* value_desc = enum_desc->value(e); 1550 if (value_desc->number() < 0 && non_canonical_bytes > 0) { 1551 // Negative numbers only have a canonical representation. 1552 continue; 1553 } 1554 SCOPED_TRACE(value_desc->number()); 1555 ABSL_CHECK_NE(value_desc->number(), kInvalidValue) 1556 << "Invalid value is a real label."; 1557 auto encoded = 1558 EncodeEnumValue(field->number(), value_desc->number(), 1559 non_canonical_bytes, use_packed); 1560 if (add_garbage_bits) { 1561 // These bits should be discarded even in the `false` case. 1562 encoded.back() |= 0b0111'1110; 1563 } 1564 if (use_tail_field) { 1565 // Make sure that fields after this one can be parsed too. ie 1566 // test that the "next" jump is correct too. 1567 encoded += other_field; 1568 } 1569 1570 EXPECT_TRUE(obj.ParseFromString(encoded)); 1571 if (field->is_repeated()) { 1572 ASSERT_EQ(ref->FieldSize(obj, field), 1); 1573 EXPECT_EQ(ref->GetRepeatedEnumValue(obj, field, 0), 1574 value_desc->number()); 1575 } else { 1576 EXPECT_TRUE(ref->HasField(obj, field)); 1577 EXPECT_EQ(ref->GetEnumValue(obj, field), value_desc->number()); 1578 } 1579 auto& unknown = ref->GetUnknownFields(obj); 1580 ASSERT_EQ(unknown.field_count(), 0); 1581 } 1582 1583 { 1584 SCOPED_TRACE("Invalid value"); 1585 // Try an invalid value, which should go to the unknown fields. 1586 EXPECT_TRUE(obj.ParseFromString( 1587 EncodeEnumValue(field->number(), kInvalidValue, 1588 non_canonical_bytes, use_packed))); 1589 if (field->is_repeated()) { 1590 ASSERT_EQ(ref->FieldSize(obj, field), 0); 1591 } else { 1592 EXPECT_FALSE(ref->HasField(obj, field)); 1593 EXPECT_EQ(ref->GetEnumValue(obj, field), 1594 enum_desc->value(0)->number()); 1595 } 1596 auto& unknown = ref->GetUnknownFields(obj); 1597 ASSERT_EQ(unknown.field_count(), 1); 1598 EXPECT_EQ(unknown.field(0).number(), field->number()); 1599 EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT); 1600 EXPECT_EQ(unknown.field(0).varint(), kInvalidValue); 1601 } 1602 { 1603 SCOPED_TRACE("Overlong varint"); 1604 // Try an overlong varint. It should fail parsing, but not trigger 1605 // any sanitizer warning. 1606 EXPECT_FALSE(obj.ParseFromString( 1607 EncodeOverlongEnum(field->number(), use_packed))); 1608 } 1609 } 1610 } 1611 } 1612 } 1613 } 1614} 1615 1616TEST(MESSAGE_TEST_NAME, TestEnumParserForUnknownEnumValue) { 1617 DynamicMessageFactory factory; 1618 std::unique_ptr<Message> dynamic( 1619 factory.GetPrototype(UNITTEST::EnumParseTester::descriptor())->New()); 1620 1621 UNITTEST::EnumParseTester non_dynamic; 1622 1623 // For unknown enum values, for consistency we must include the 1624 // int32_t enum value in the unknown field set, which might not be exactly the 1625 // same as the input. 1626 PROTOBUF_UNUSED auto* descriptor = non_dynamic.descriptor(); 1627 1628 const std::vector<const FieldDescriptor*> fields = 1629 GetFields<UNITTEST::EnumParseTester>(); 1630 1631 for (bool use_dynamic : {false, true}) { 1632 SCOPED_TRACE(use_dynamic); 1633 for (auto field : fields) { 1634 if (field->name() == "other_field") continue; 1635 SCOPED_TRACE(field->full_name()); 1636 for (bool use_packed : {false, true}) { 1637 SCOPED_TRACE(use_packed); 1638 if (!field->is_repeated() && use_packed) continue; 1639 1640 // -2 is an invalid enum value on all the tests here. 1641 // We will encode -2 as a positive int64 that is equivalent to 1642 // int32_t{-2} when truncated. 1643 constexpr int64_t minus_2_non_canonical = 1644 static_cast<int64_t>(static_cast<uint32_t>(int32_t{-2})); 1645 static_assert(minus_2_non_canonical != -2, ""); 1646 std::string encoded = EncodeInt64Value( 1647 field->number(), minus_2_non_canonical, 0, use_packed); 1648 1649 auto& obj = use_dynamic ? *dynamic : non_dynamic; 1650 ASSERT_TRUE(obj.ParseFromString(encoded)); 1651 1652 auto& unknown = obj.GetReflection()->GetUnknownFields(obj); 1653 ASSERT_EQ(unknown.field_count(), 1); 1654 EXPECT_EQ(unknown.field(0).number(), field->number()); 1655 EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT); 1656 EXPECT_EQ(unknown.field(0).varint(), int64_t{-2}); 1657 } 1658 } 1659 } 1660} 1661 1662TEST(MESSAGE_TEST_NAME, TestBoolParsers) { 1663 UNITTEST::BoolParseTester obj; 1664 1665 const auto other_field = EncodeOtherField(); 1666 1667 // Encode a boolean field for many different cases and verify that it can be 1668 // parsed as expected. 1669 // There are: 1670 // - optional/repeated/packed fields 1671 // - field tags that encode in 1/2/3 bytes 1672 // - canonical and non-canonical encodings of the varint 1673 // - last vs not last field 1674 1675 const std::vector<const FieldDescriptor*> fields = 1676 GetFields<UNITTEST::BoolParseTester>(); 1677 1678 auto* ref = obj.GetReflection(); 1679 PROTOBUF_UNUSED auto* descriptor = obj.descriptor(); 1680 for (bool use_tail_field : {false, true}) { 1681 SCOPED_TRACE(use_tail_field); 1682 for (int non_canonical_bytes = 0; non_canonical_bytes < 10; 1683 ++non_canonical_bytes) { 1684 SCOPED_TRACE(non_canonical_bytes); 1685 for (bool add_garbage_bits : {false, true}) { 1686 if (add_garbage_bits && non_canonical_bytes != 9) { 1687 // We only add garbage on the 10th byte. 1688 continue; 1689 } 1690 SCOPED_TRACE(add_garbage_bits); 1691 for (auto field : fields) { 1692 if (field->name() == "other_field") continue; 1693 SCOPED_TRACE(field->full_name()); 1694 for (bool value : {false, true}) { 1695 SCOPED_TRACE(value); 1696 auto encoded = 1697 EncodeBoolValue(field->number(), value, non_canonical_bytes); 1698 if (add_garbage_bits) { 1699 // These bits should be discarded even in the `false` case. 1700 encoded.back() |= 0b0111'1110; 1701 } 1702 if (use_tail_field) { 1703 // Make sure that fields after this one can be parsed too. ie test 1704 // that the "next" jump is correct too. 1705 encoded += other_field; 1706 } 1707 1708 EXPECT_TRUE(obj.ParseFromString(encoded)); 1709 if (field->is_repeated()) { 1710 ASSERT_EQ(ref->FieldSize(obj, field), 1); 1711 EXPECT_EQ(ref->GetRepeatedBool(obj, field, 0), value); 1712 } else { 1713 EXPECT_TRUE(ref->HasField(obj, field)); 1714 EXPECT_EQ(ref->GetBool(obj, field), value) 1715 << testing::PrintToString(encoded); 1716 } 1717 auto& unknown = ref->GetUnknownFields(obj); 1718 ASSERT_EQ(unknown.field_count(), 0); 1719 } 1720 } 1721 } 1722 } 1723 } 1724} 1725 1726TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { 1727 UNITTEST::Int32ParseTester obj; 1728 1729 const auto other_field = EncodeOtherField(); 1730 1731 // Encode an int32 field for many different cases and verify that it can be 1732 // parsed as expected. 1733 // There are: 1734 // - optional/repeated/packed fields 1735 // - field tags that encode in 1/2/3 bytes 1736 // - canonical and non-canonical encodings of the varint 1737 // - last vs not last field 1738 1739 const std::vector<const FieldDescriptor*> fields = 1740 GetFields<UNITTEST::Int32ParseTester>(); 1741 1742 auto* ref = obj.GetReflection(); 1743 PROTOBUF_UNUSED auto* descriptor = obj.descriptor(); 1744 for (bool use_tail_field : {false, true}) { 1745 SCOPED_TRACE(use_tail_field); 1746 for (int non_canonical_bytes = 0; non_canonical_bytes < 10; 1747 ++non_canonical_bytes) { 1748 SCOPED_TRACE(non_canonical_bytes); 1749 for (bool add_garbage_bits : {false, true}) { 1750 if (add_garbage_bits && non_canonical_bytes != 9) { 1751 // We only add garbage on the 10th byte. 1752 continue; 1753 } 1754 SCOPED_TRACE(add_garbage_bits); 1755 for (auto field : fields) { 1756 if (field->name() == "other_field") continue; 1757 SCOPED_TRACE(field->full_name()); 1758 for (int32_t value : {1, 0, -1, (std::numeric_limits<int32_t>::min)(), 1759 (std::numeric_limits<int32_t>::max)()}) { 1760 SCOPED_TRACE(value); 1761 auto encoded = 1762 EncodeInt32Value(field->number(), value, non_canonical_bytes); 1763 if (add_garbage_bits) { 1764 // These bits should be discarded even in the `false` case. 1765 encoded.back() |= 0b0111'1110; 1766 } 1767 if (use_tail_field) { 1768 // Make sure that fields after this one can be parsed too. ie test 1769 // that the "next" jump is correct too. 1770 encoded += other_field; 1771 } 1772 1773 EXPECT_TRUE(obj.ParseFromString(encoded)); 1774 if (field->is_repeated()) { 1775 ASSERT_EQ(ref->FieldSize(obj, field), 1); 1776 EXPECT_EQ(ref->GetRepeatedInt32(obj, field, 0), value); 1777 } else { 1778 EXPECT_TRUE(ref->HasField(obj, field)); 1779 EXPECT_EQ(ref->GetInt32(obj, field), value) 1780 << testing::PrintToString(encoded); 1781 } 1782 auto& unknown = ref->GetUnknownFields(obj); 1783 ASSERT_EQ(unknown.field_count(), 0); 1784 } 1785 } 1786 } 1787 } 1788 } 1789} 1790 1791TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { 1792 UNITTEST::Int64ParseTester obj; 1793 1794 const auto other_field = EncodeOtherField(); 1795 1796 // Encode an int64 field for many different cases and verify that it can be 1797 // parsed as expected. 1798 // There are: 1799 // - optional/repeated/packed fields 1800 // - field tags that encode in 1/2/3 bytes 1801 // - canonical and non-canonical encodings of the varint 1802 // - last vs not last field 1803 1804 const std::vector<const FieldDescriptor*> fields = 1805 GetFields<UNITTEST::Int64ParseTester>(); 1806 1807 auto* ref = obj.GetReflection(); 1808 PROTOBUF_UNUSED auto* descriptor = obj.descriptor(); 1809 for (bool use_tail_field : {false, true}) { 1810 SCOPED_TRACE(use_tail_field); 1811 for (int non_canonical_bytes = 0; non_canonical_bytes < 10; 1812 ++non_canonical_bytes) { 1813 SCOPED_TRACE(non_canonical_bytes); 1814 for (bool add_garbage_bits : {false, true}) { 1815 if (add_garbage_bits && non_canonical_bytes != 9) { 1816 // We only add garbage on the 10th byte. 1817 continue; 1818 } 1819 SCOPED_TRACE(add_garbage_bits); 1820 for (auto field : fields) { 1821 if (field->name() == "other_field") continue; 1822 SCOPED_TRACE(field->full_name()); 1823 for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1}, 1824 (std::numeric_limits<int64_t>::min)(), 1825 (std::numeric_limits<int64_t>::max)()}) { 1826 SCOPED_TRACE(value); 1827 auto encoded = 1828 EncodeInt64Value(field->number(), value, non_canonical_bytes); 1829 if (add_garbage_bits) { 1830 // These bits should be discarded even in the `false` case. 1831 encoded.back() |= 0b0111'1110; 1832 } 1833 if (use_tail_field) { 1834 // Make sure that fields after this one can be parsed too. ie test 1835 // that the "next" jump is correct too. 1836 encoded += other_field; 1837 } 1838 1839 EXPECT_TRUE(obj.ParseFromString(encoded)); 1840 if (field->is_repeated()) { 1841 ASSERT_EQ(ref->FieldSize(obj, field), 1); 1842 EXPECT_EQ(ref->GetRepeatedInt64(obj, field, 0), value); 1843 } else { 1844 EXPECT_TRUE(ref->HasField(obj, field)); 1845 EXPECT_EQ(ref->GetInt64(obj, field), value) 1846 << testing::PrintToString(encoded); 1847 } 1848 auto& unknown = ref->GetUnknownFields(obj); 1849 ASSERT_EQ(unknown.field_count(), 0); 1850 } 1851 } 1852 } 1853 } 1854 } 1855} 1856 1857TEST(MESSAGE_TEST_NAME, IsDefaultInstance) { 1858 UNITTEST::TestAllTypes msg; 1859 const auto& default_msg = UNITTEST::TestAllTypes::default_instance(); 1860 const auto* r = msg.GetReflection(); 1861 EXPECT_TRUE(r->IsDefaultInstance(default_msg)); 1862 EXPECT_FALSE(r->IsDefaultInstance(msg)); 1863} 1864 1865namespace { 1866std::string EncodeStringValue(int number, const std::string& value) { 1867 uint8_t buf[100]; 1868 return std::string( 1869 buf, internal::WireFormatLite::WriteStringToArray(number, value, buf)); 1870} 1871 1872class TestInputStream final : public io::ZeroCopyInputStream { 1873 public: 1874 explicit TestInputStream(absl::string_view payload, size_t break_pos) 1875 : payload_(payload), break_pos_(break_pos) {} 1876 1877 bool Next(const void** data, int* size) override { 1878 if (payload_.empty()) return false; 1879 const auto to_consume = payload_.substr(0, break_pos_); 1880 *data = to_consume.data(); 1881 *size = to_consume.size(); 1882 payload_.remove_prefix(to_consume.size()); 1883 // The next time will consume the rest. 1884 break_pos_ = payload_.npos; 1885 1886 return true; 1887 } 1888 1889 void BackUp(int) override { ABSL_CHECK(false); } 1890 bool Skip(int) override { 1891 ABSL_CHECK(false); 1892 return false; 1893 } 1894 int64_t ByteCount() const override { 1895 ABSL_CHECK(false); 1896 return 0; 1897 } 1898 1899 private: 1900 absl::string_view payload_; 1901 size_t break_pos_; 1902}; 1903} // namespace 1904 1905TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) { 1906 google::protobuf::Arena arena; 1907 1908 const std::string sample = 1909 "abcdefghijklmnopqrstuvwxyz" 1910 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; 1911 1912 PROTOBUF_UNUSED const auto* const descriptor = 1913 UNITTEST::StringParseTester::descriptor(); 1914 1915 const std::vector<const FieldDescriptor*> fields = 1916 GetFields<UNITTEST::StringParseTester>(); 1917 1918 static const size_t sso_capacity = std::string().capacity(); 1919 if (sso_capacity == 0) GTEST_SKIP(); 1920 // SSO, !SSO, and off-by-one just in case 1921 for (size_t size : 1922 {sso_capacity - 1, sso_capacity, sso_capacity + 1, sso_capacity + 2}) { 1923 SCOPED_TRACE(size); 1924 const std::string value = sample.substr(0, size); 1925 for (auto field : fields) { 1926 SCOPED_TRACE(field->full_name()); 1927 const auto encoded = EncodeStringValue(field->number(), sample) + 1928 EncodeStringValue(field->number(), value); 1929 // Check for different breaks in the input stream to test cases where 1930 // the payload can be read and can't be read in one go. 1931 for (size_t i = 1; i <= encoded.size(); ++i) { 1932 TestInputStream input_stream(encoded, i); 1933 1934 auto& obj = *arena.Create<UNITTEST::StringParseTester>(&arena); 1935 auto* ref = obj.GetReflection(); 1936 EXPECT_TRUE(obj.ParseFromZeroCopyStream(&input_stream)); 1937 if (field->is_repeated()) { 1938 ASSERT_EQ(ref->FieldSize(obj, field), 2); 1939 EXPECT_EQ(ref->GetRepeatedString(obj, field, 0), sample); 1940 EXPECT_EQ(ref->GetRepeatedString(obj, field, 1), value); 1941 } else { 1942 EXPECT_EQ(ref->GetString(obj, field), value); 1943 } 1944 } 1945 } 1946 } 1947} 1948 1949TEST(MESSAGE_TEST_NAME, TestRegressionOnParseFailureNotSettingHasBits) { 1950 std::string single_field; 1951 // We use blocks because we want fully new instances of the proto. We are 1952 // testing .Clear(), so we can't use it to set up the test. 1953 { 1954 UNITTEST::TestAllTypes message; 1955 message.set_optional_int32(17); 1956 single_field = message.SerializeAsString(); 1957 } 1958 const auto validate_message = [](auto& message) { 1959 if (!message.has_optional_int32()) { 1960 EXPECT_EQ(message.optional_int32(), 0); 1961 } 1962 message.Clear(); 1963 EXPECT_FALSE(message.has_optional_int32()); 1964 EXPECT_EQ(message.optional_int32(), 0); 1965 }; 1966 { 1967 // Verify the setup is correct. 1968 UNITTEST::TestAllTypes message; 1969 EXPECT_FALSE(message.has_optional_int32()); 1970 EXPECT_EQ(message.optional_int32(), 0); 1971 EXPECT_TRUE(message.ParseFromString(single_field)); 1972 validate_message(message); 1973 } 1974 { 1975 // Run the regression. 1976 // These are the steps: 1977 // - The stream contains a fast field, and then a failure in MiniParse 1978 // - The parsing fails. 1979 // - We call clear. 1980 // - The fast field should be reset. 1981 UNITTEST::TestAllTypes message; 1982 EXPECT_FALSE(message.has_optional_int32()); 1983 EXPECT_EQ(message.optional_int32(), 0); 1984 // The second tag will fail to parse because it has too many continuation 1985 // bits. 1986 auto with_error = 1987 absl::StrCat(single_field, std::string(100, static_cast<char>(0x80))); 1988 EXPECT_FALSE(message.ParseFromString(with_error)); 1989 validate_message(message); 1990 } 1991} 1992 1993TEST(MESSAGE_TEST_NAME, TestRegressionOverwrittenLazyOneofDoesNotLeak) { 1994 UNITTEST::TestAllTypes message; 1995 auto* lazy = message.mutable_oneof_lazy_nested_message(); 1996 // We need to add enough payload to make the lazy field overflow the SSO of 1997 // Cord. However, NestedMessage does not have enough fields for that. Just add 1998 // some unknown payload to it. Use something that the validator will allow to 1999 // stay as lazy. 2000 lazy->GetReflection()->MutableUnknownFields(lazy)->AddFixed64(10, 10); 2001 lazy->GetReflection()->MutableUnknownFields(lazy)->AddFixed64(11, 10); 2002 // Validate that the size is large enough. 2003 ASSERT_GT(lazy->ByteSizeLong(), 15); 2004 2005 // Append two instances of the oneof: first the lazy field, then any other to 2006 // cause a switch during parsing. 2007 std::string str; 2008 ASSERT_TRUE(message.AppendToString(&str)); 2009 message.set_oneof_uint32(7); 2010 ASSERT_TRUE(message.AppendToString(&str)); 2011 2012 EXPECT_TRUE(UNITTEST::TestAllTypes().ParseFromString(str)); 2013 Arena arena; 2014 // This call had a bug where the LazyField was not destroyed in any way 2015 // causing the Cord inside it to leak its contents. 2016 EXPECT_TRUE( 2017 Arena::Create<UNITTEST::TestAllTypes>(&arena)->ParseFromString(str)); 2018} 2019 2020} // namespace protobuf 2021} // namespace google 2022 2023#include "google/protobuf/port_undef.inc" 2024