• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// -*- c++ -*-
2// Protocol Buffers - Google's data interchange format
3// Copyright 2008 Google Inc.  All rights reserved.
4//
5// Use of this source code is governed by a BSD-style
6// license that can be found in the LICENSE file or at
7// https://developers.google.com/open-source/licenses/bsd
8
9// Author: kenton@google.com (Kenton Varda)
10//  Based on original Protocol Buffers design by
11//  Sanjay Ghemawat, Jeff Dean, and others.
12//
13// This file needs to be included as .inc as it depends on certain macros being
14// defined prior to its inclusion.
15
16#include <fcntl.h>
17#include <sys/stat.h>
18#include <sys/types.h>
19
20#include <cmath>
21#include <cstddef>
22#include <cstdint>
23#include <limits>
24#include <string>
25
26#ifndef _MSC_VER
27#include <unistd.h>
28#endif
29#include <fstream>
30#include <sstream>
31
32#include "google/protobuf/testing/file.h"
33#include "google/protobuf/testing/file.h"
34#include "google/protobuf/descriptor.pb.h"
35#include <gmock/gmock.h>
36#include "google/protobuf/testing/googletest.h"
37#include <gtest/gtest.h>
38#include "absl/log/absl_check.h"
39#include "absl/log/scoped_mock_log.h"
40#include "absl/strings/cord.h"
41#include "absl/strings/substitute.h"
42#include "google/protobuf/arena.h"
43#include "google/protobuf/descriptor.h"
44#include "google/protobuf/dynamic_message.h"
45#include "google/protobuf/generated_message_reflection.h"
46#include "google/protobuf/generated_message_tctable_impl.h"
47#include "google/protobuf/io/coded_stream.h"
48#include "google/protobuf/io/io_win32.h"
49#include "google/protobuf/io/zero_copy_stream.h"
50#include "google/protobuf/io/zero_copy_stream_impl.h"
51#include "google/protobuf/message.h"
52#include "google/protobuf/reflection_ops.h"
53#include "google/protobuf/test_util2.h"
54
55
56// Must be included last.
57#include "google/protobuf/port_def.inc"
58
59namespace google {
60namespace protobuf {
61
62#if defined(_WIN32)
63// DO NOT include <io.h>, instead create functions in io_win32.{h,cc} and import
64// them like we do below.
65using google::protobuf::io::win32::close;
66using google::protobuf::io::win32::open;
67#endif
68
69#ifndef O_BINARY
70#ifdef _O_BINARY
71#define O_BINARY _O_BINARY
72#else
73#define O_BINARY 0  // If this isn't defined, the platform doesn't need it.
74#endif
75#endif
76
77namespace {
78
79UNITTEST::NestedTestAllTypes InitNestedProto(int depth) {
80  UNITTEST::NestedTestAllTypes p;
81  auto* child = p.mutable_child();
82  for (int i = 0; i < depth; i++) {
83    child->mutable_payload()->set_optional_int32(i);
84    child = child->mutable_child();
85  }
86  // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1
87  child->mutable_payload()->set_optional_int32(-1);
88  return p;
89}
90}  // namespace
91
92TEST(MESSAGE_TEST_NAME, SerializeHelpers) {
93  // TODO:  Test more helpers?  They're all two-liners so it seems
94  //   like a waste of time.
95
96  UNITTEST::TestAllTypes message;
97  TestUtil::SetAllFields(&message);
98  std::stringstream stream;
99
100  std::string str1("foo");
101  std::string str2("bar");
102
103  EXPECT_TRUE(message.SerializeToString(&str1));
104  EXPECT_TRUE(message.AppendToString(&str2));
105  EXPECT_TRUE(message.SerializeToOstream(&stream));
106
107  EXPECT_EQ(str1.size() + 3, str2.size());
108  EXPECT_EQ("bar", str2.substr(0, 3));
109  // Don't use EXPECT_EQ because we don't want to dump raw binary data to
110  // stdout.
111  EXPECT_TRUE(str2.substr(3) == str1);
112
113  // GCC gives some sort of error if we try to just do stream.str() == str1.
114  std::string temp = stream.str();
115  EXPECT_TRUE(temp == str1);
116
117  EXPECT_TRUE(message.SerializeAsString() == str1);
118
119}
120
121TEST(MESSAGE_TEST_NAME, RoundTrip) {
122  UNITTEST::TestAllTypes message;
123  TestUtil::SetAllFields(&message);
124  TestUtil::ExpectAllFieldsSet(message);
125
126  UNITTEST::TestAllTypes copied, merged, parsed;
127  copied = message;
128  TestUtil::ExpectAllFieldsSet(copied);
129
130  merged.MergeFrom(message);
131  TestUtil::ExpectAllFieldsSet(merged);
132
133  std::string data;
134  ASSERT_TRUE(message.SerializeToString(&data));
135  ASSERT_TRUE(parsed.ParseFromString(data));
136  TestUtil::ExpectAllFieldsSet(parsed);
137}
138
139TEST(MESSAGE_TEST_NAME, SerializeToBrokenOstream) {
140  std::ofstream out;
141  UNITTEST::TestAllTypes message;
142  message.set_optional_int32(123);
143
144  EXPECT_FALSE(message.SerializeToOstream(&out));
145}
146
147TEST(MESSAGE_TEST_NAME, ParseFromFileDescriptor) {
148  std::string filename = absl::StrCat(TestTempDir(), "/golden_message");
149  UNITTEST::TestAllTypes expected_message;
150  TestUtil::SetAllFields(&expected_message);
151  ABSL_CHECK_OK(File::SetContents(
152      filename, expected_message.SerializeAsString(), true));
153
154  int file = open(filename.c_str(), O_RDONLY | O_BINARY);
155  ASSERT_GE(file, 0);
156
157  UNITTEST::TestAllTypes message;
158  EXPECT_TRUE(message.ParseFromFileDescriptor(file));
159  TestUtil::ExpectAllFieldsSet(message);
160
161  EXPECT_GE(close(file), 0);
162}
163
164TEST(MESSAGE_TEST_NAME, ParsePackedFromFileDescriptor) {
165  std::string filename = absl::StrCat(TestTempDir(), "/golden_message");
166  UNITTEST::TestPackedTypes expected_message;
167  TestUtil::SetPackedFields(&expected_message);
168  ABSL_CHECK_OK(File::SetContents(
169      filename, expected_message.SerializeAsString(), true));
170
171  int file = open(filename.c_str(), O_RDONLY | O_BINARY);
172  ASSERT_GE(file, 0);
173
174  UNITTEST::TestPackedTypes message;
175  EXPECT_TRUE(message.ParseFromFileDescriptor(file));
176  TestUtil::ExpectPackedFieldsSet(message);
177
178  EXPECT_GE(close(file), 0);
179}
180
181TEST(MESSAGE_TEST_NAME, ParseHelpers) {
182  // TODO:  Test more helpers?  They're all two-liners so it seems
183  //   like a waste of time.
184  std::string data;
185
186  {
187    // Set up.
188    UNITTEST::TestAllTypes message;
189    TestUtil::SetAllFields(&message);
190    message.SerializeToString(&data);
191  }
192
193  {
194    // Test ParseFromString.
195    UNITTEST::TestAllTypes message;
196    EXPECT_TRUE(message.ParseFromString(data));
197    TestUtil::ExpectAllFieldsSet(message);
198  }
199
200  {
201    // Test ParseFromIstream.
202    UNITTEST::TestAllTypes message;
203    std::stringstream stream(data);
204    EXPECT_TRUE(message.ParseFromIstream(&stream));
205    EXPECT_TRUE(stream.eof());
206    TestUtil::ExpectAllFieldsSet(message);
207  }
208
209  {
210    // Test ParseFromBoundedZeroCopyStream.
211    std::string data_with_junk(data);
212    data_with_junk.append("some junk on the end");
213    io::ArrayInputStream stream(data_with_junk.data(), data_with_junk.size());
214    UNITTEST::TestAllTypes message;
215    EXPECT_TRUE(message.ParseFromBoundedZeroCopyStream(&stream, data.size()));
216    TestUtil::ExpectAllFieldsSet(message);
217  }
218
219  {
220    // Test that ParseFromBoundedZeroCopyStream fails (but doesn't crash) if
221    // EOF is reached before the expected number of bytes.
222    io::ArrayInputStream stream(data.data(), data.size());
223    UNITTEST::TestAllTypes message;
224    EXPECT_FALSE(
225        message.ParseFromBoundedZeroCopyStream(&stream, data.size() + 1));
226  }
227
228  // Test bytes cord
229  {
230    UNITTEST::TestCord cord_message;
231    cord_message.set_optional_bytes_cord("bytes_cord");
232    EXPECT_TRUE(cord_message.SerializeToString(&data));
233    EXPECT_TRUE(cord_message.SerializeAsString() == data);
234  }
235  {
236    UNITTEST::TestCord cord_message;
237    EXPECT_TRUE(cord_message.ParseFromString(data));
238    EXPECT_EQ("bytes_cord", cord_message.optional_bytes_cord());
239  }
240}
241
242TEST(MESSAGE_TEST_NAME, ParseFailsIfNotInitialized) {
243  UNITTEST::TestRequired message;
244
245  {
246    absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected);
247    EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::StrCat(
248            "Can't parse message of type \"", UNITTEST_PACKAGE_NAME,
249            ".TestRequired\" because it is missing required fields: a, b, c")));
250    log.StartCapturingLogs();
251    EXPECT_FALSE(message.ParseFromString(""));
252  }
253}
254
255TEST(MESSAGE_TEST_NAME, ParseFailsIfSubmessageNotInitialized) {
256  UNITTEST::TestRequiredForeign source, message;
257  source.mutable_optional_message()->set_dummy2(100);
258  std::string serialized = source.SerializePartialAsString();
259
260  EXPECT_TRUE(message.ParsePartialFromString(serialized));
261  EXPECT_FALSE(message.IsInitialized());
262
263  {
264    absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected);
265    EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::StrCat(
266          "Can't parse message of type \"", UNITTEST_PACKAGE_NAME,
267          ".TestRequiredForeign\" because it is missing required fields: "
268          "optional_message.a, optional_message.b, optional_message.c")));
269    log.StartCapturingLogs();
270    EXPECT_FALSE(message.ParseFromString(source.SerializePartialAsString()));
271  }
272}
273
274TEST(MESSAGE_TEST_NAME, ParseFailsIfExtensionNotInitialized) {
275  UNITTEST::TestChildExtension source, message;
276  auto* r = source.mutable_optional_extension()->MutableExtension(
277      UNITTEST::TestRequired::single);
278  r->set_dummy2(100);
279  std::string serialized = source.SerializePartialAsString();
280
281  EXPECT_TRUE(message.ParsePartialFromString(serialized));
282  EXPECT_FALSE(message.IsInitialized());
283
284{
285    absl::ScopedMockLog log(absl::MockLogDefault::kDisallowUnexpected);
286    EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, absl::Substitute(
287                  "Can't parse message of type \"$0.TestChildExtension\" "
288                  "because it is missing required fields: "
289                  "optional_extension.($0.TestRequired.single).a, "
290                  "optional_extension.($0.TestRequired.single).b, "
291                  "optional_extension.($0.TestRequired.single).c",
292                  UNITTEST_PACKAGE_NAME)));
293    log.StartCapturingLogs();
294    EXPECT_FALSE(message.ParseFromString(source.SerializePartialAsString()));
295  }
296}
297
298TEST(MESSAGE_TEST_NAME, MergeFromUninitialized) {
299  UNITTEST::TestNestedRequiredForeign o, p, q;
300  UNITTEST::TestNestedRequiredForeign* child = o.mutable_child();
301  constexpr int kDepth = 2;
302  for (int i = 0; i < kDepth; i++) {
303    child->set_dummy(i);
304    child = child->mutable_child();
305  }
306  UNITTEST::TestRequiredForeign* payload = child->mutable_payload();
307  payload->mutable_optional_message()->set_a(1);
308  payload->mutable_optional_message()->set_dummy2(100);
309  payload->mutable_optional_message()->set_dummy4(200);
310  ASSERT_TRUE(p.ParsePartialFromString(o.SerializePartialAsString()));
311
312  q.mutable_child()->set_dummy(500);
313  q = p;
314  q.ParsePartialFromString(q.SerializePartialAsString());
315  EXPECT_TRUE(TestUtil::EqualsToSerialized(q, o.SerializePartialAsString()));
316  EXPECT_TRUE(TestUtil::EqualsToSerialized(q, p.SerializePartialAsString()));
317}
318
319TEST(MESSAGE_TEST_NAME, UninitializedAndTooDeep) {
320  UNITTEST::TestRequiredForeign original;
321  original.mutable_optional_message()->set_a(1);
322  original.mutable_optional_lazy_message()
323      ->mutable_child()
324      ->mutable_payload()
325      ->set_optional_int64(0);
326
327  std::string data;
328  ASSERT_TRUE(original.SerializePartialToString(&data));
329
330  UNITTEST::TestRequiredForeign pass;
331  ASSERT_TRUE(pass.ParsePartialFromString(data));
332  ASSERT_FALSE(pass.IsInitialized());
333
334  io::ArrayInputStream array_stream(data.data(), data.size());
335  io::CodedInputStream input_stream(&array_stream);
336  input_stream.SetRecursionLimit(2);
337
338  UNITTEST::TestRequiredForeign fail;
339  EXPECT_FALSE(fail.ParsePartialFromCodedStream(&input_stream));
340
341  UNITTEST::TestRequiredForeign fail_uninitialized;
342  EXPECT_FALSE(fail_uninitialized.ParseFromString(data));
343}
344
345TEST(MESSAGE_TEST_NAME, ExplicitLazyExceedRecursionLimit) {
346  UNITTEST::NestedTestAllTypes original, parsed;
347  // Build proto with recursion depth of 3.
348  original.mutable_lazy_child()
349      ->mutable_child()
350      ->mutable_payload()
351      ->set_optional_int32(-1);
352  std::string serialized;
353  ASSERT_TRUE(original.SerializeToString(&serialized));
354
355  // User annotated LazyField ([lazy = true]) is eagerly verified and should
356  // catch the recursion limit violation.
357  io::ArrayInputStream array_stream(serialized.data(), serialized.size());
358  io::CodedInputStream input_stream(&array_stream);
359  input_stream.SetRecursionLimit(2);
360  EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream));
361
362  // Lazy read results in parsing error which can be verified by not having
363  // expected value.
364  EXPECT_NE(parsed.lazy_child().child().payload().optional_int32(), -1);
365}
366
367TEST(MESSAGE_TEST_NAME, NestedLazyRecursionLimit) {
368  UNITTEST::NestedTestAllTypes original, parsed;
369  original.mutable_lazy_child()
370      ->mutable_lazy_child()
371      ->mutable_lazy_child()
372      ->mutable_payload()
373      ->set_optional_int32(-1);
374  std::string serialized;
375  ASSERT_TRUE(original.SerializeToString(&serialized));
376  ASSERT_TRUE(parsed.ParseFromString(serialized));
377
378  io::ArrayInputStream array_stream(serialized.data(), serialized.size());
379  io::CodedInputStream input_stream(&array_stream);
380  input_stream.SetRecursionLimit(2);
381  EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream));
382  EXPECT_TRUE(parsed.has_lazy_child());
383  EXPECT_TRUE(parsed.lazy_child().has_lazy_child());
384  EXPECT_TRUE(parsed.lazy_child().lazy_child().has_lazy_child());
385  EXPECT_FALSE(parsed.lazy_child().lazy_child().lazy_child().has_payload());
386}
387
388TEST(MESSAGE_TEST_NAME, UnparsedEmpty) {
389  // lazy_child, LEN=100 with no payload.
390  const char encoded[] = {'\042', 100};
391  UNITTEST::NestedTestAllTypes message;
392
393  EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded)));
394  EXPECT_TRUE(message.has_lazy_child());
395  EXPECT_EQ(message.lazy_child().ByteSizeLong(), 0);
396}
397
398TEST(MESSAGE_TEST_NAME, ParseFailNonCanonicalZeroTag) {
399  const char encoded[] = {"\n\x3\x80\0\0"};
400  UNITTEST::NestedTestAllTypes parsed;
401  EXPECT_FALSE(parsed.ParsePartialFromString(
402      absl::string_view{encoded, sizeof(encoded) - 1}));
403}
404
405TEST(MESSAGE_TEST_NAME, ParseFailNonCanonicalZeroField) {
406  const char encoded[] = {"\012\x6\205\0\0\0\0\0"};
407  UNITTEST::NestedTestAllTypes parsed;
408  EXPECT_FALSE(parsed.ParsePartialFromString(
409      absl::string_view{encoded, sizeof(encoded) - 1}));
410}
411
412TEST(MESSAGE_TEST_NAME, NestedExplicitLazyExceedRecursionLimit) {
413  UNITTEST::NestedTestAllTypes original, parsed;
414  // Build proto with recursion depth of 5, with nested annotated LazyField.
415  original.mutable_lazy_child()
416      ->mutable_child()
417      ->mutable_lazy_child()
418      ->mutable_child()
419      ->mutable_payload()
420      ->set_optional_int32(-1);
421  std::string serialized;
422  EXPECT_TRUE(original.SerializeToString(&serialized));
423
424  // User annotated LazyField ([lazy = true]) is eagerly verified and should
425  // catch the recursion limit violation.
426  io::ArrayInputStream array_stream(serialized.data(), serialized.size());
427  io::CodedInputStream input_stream(&array_stream);
428  input_stream.SetRecursionLimit(4);
429  EXPECT_FALSE(parsed.ParseFromCodedStream(&input_stream));
430
431  // Lazy read results in parsing error which can be verified by not having
432  // expected value.
433  EXPECT_NE(parsed.lazy_child()
434                .child()
435                .lazy_child()
436                .child()
437                .payload()
438                .optional_int32(),
439            -1);
440}
441
442TEST(MESSAGE_TEST_NAME, ParseFailsIfSubmessageTruncated) {
443  UNITTEST::NestedTestAllTypes o, p;
444  constexpr int kDepth = 5;
445  auto* child = o.mutable_child();
446  for (int i = 0; i < kDepth; i++) {
447    child = child->mutable_child();
448  }
449  TestUtil::SetAllFields(child->mutable_payload());
450
451  std::string serialized;
452  EXPECT_TRUE(o.SerializeToString(&serialized));
453
454  // Should parse correctly.
455  EXPECT_TRUE(p.ParseFromString(serialized));
456
457  constexpr int kMaxTruncate = 50;
458  ASSERT_GT(serialized.size(), kMaxTruncate);
459
460  for (int i = 1; i < kMaxTruncate; i += 3) {
461    EXPECT_FALSE(
462        p.ParseFromString(serialized.substr(0, serialized.size() - i)));
463  }
464}
465
466TEST(MESSAGE_TEST_NAME, ParseFailsIfWireMalformed) {
467  UNITTEST::NestedTestAllTypes o, p;
468  constexpr int kDepth = 5;
469  auto* child = o.mutable_child();
470  for (int i = 0; i < kDepth; i++) {
471    child = child->mutable_child();
472  }
473  // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1
474  child->mutable_payload()->set_optional_int32(-1);
475
476  std::string serialized;
477  EXPECT_TRUE(o.SerializeToString(&serialized));
478
479  // Should parse correctly.
480  EXPECT_TRUE(p.ParseFromString(serialized));
481
482  // Overwriting the last byte to 0xFF results in malformed wire.
483  serialized[serialized.size() - 1] = 0xFF;
484  EXPECT_FALSE(p.ParseFromString(serialized));
485}
486
487TEST(MESSAGE_TEST_NAME, ParseFailsIfOneofWireMalformed) {
488  UNITTEST::NestedTestAllTypes o, p;
489  constexpr int kDepth = 5;
490  auto* child = o.mutable_child();
491  for (int i = 0; i < kDepth; i++) {
492    child = child->mutable_child();
493  }
494  // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1
495  child->mutable_payload()->mutable_oneof_nested_message()->set_bb(-1);
496
497  std::string serialized;
498  EXPECT_TRUE(o.SerializeToString(&serialized));
499
500  // Should parse correctly.
501  EXPECT_TRUE(p.ParseFromString(serialized));
502
503  // Overwriting the last byte to 0xFF results in malformed wire.
504  serialized[serialized.size() - 1] = 0xFF;
505  EXPECT_FALSE(p.ParseFromString(serialized));
506}
507
508TEST(MESSAGE_TEST_NAME, ParseFailsIfExtensionWireMalformed) {
509  UNITTEST::TestChildExtension o, p;
510  auto* m = o.mutable_optional_extension()->MutableExtension(
511      UNITTEST::optional_nested_message_extension);
512
513  // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1
514  m->set_bb(-1);
515
516  std::string serialized;
517  EXPECT_TRUE(o.SerializeToString(&serialized));
518
519  // Should parse correctly.
520  EXPECT_TRUE(p.ParseFromString(serialized));
521
522  // Overwriting the last byte to 0xFF results in malformed wire.
523  serialized[serialized.size() - 1] = 0xFF;
524  EXPECT_FALSE(p.ParseFromString(serialized));
525}
526
527TEST(MESSAGE_TEST_NAME, ParseFailsIfGroupFieldMalformed) {
528  UNITTEST::TestMutualRecursionA original, parsed;
529  original.mutable_bb()
530      ->mutable_a()
531      ->mutable_subgroup()
532      ->mutable_sub_message()
533      ->mutable_b()
534      ->set_optional_int32(-1);
535
536  std::string data;
537  ASSERT_TRUE(original.SerializeToString(&data));
538  // Should parse correctly.
539  ASSERT_TRUE(parsed.ParseFromString(data));
540  // Overwriting the last byte of varint (-1) to 0xFF results in malformed wire.
541  data[data.size() - 2] = 0xFF;
542
543  EXPECT_FALSE(parsed.ParseFromString(data));
544}
545
546TEST(MESSAGE_TEST_NAME, ParseFailsIfRepeatedGroupFieldMalformed) {
547  UNITTEST::TestMutualRecursionA original, parsed;
548  original.mutable_bb()
549      ->mutable_a()
550      ->add_subgroupr()
551      ->mutable_payload()
552      ->set_optional_int64(-1);
553
554  std::string data;
555  ASSERT_TRUE(original.SerializeToString(&data));
556  // Should parse correctly.
557  ASSERT_TRUE(parsed.ParseFromString(data));
558  // Overwriting the last byte of varint (-1) to 0xFF results in malformed wire.
559  data[data.size() - 2] = 0xFF;
560
561  EXPECT_FALSE(parsed.ParseFromString(data));
562}
563
564TEST(MESSAGE_TEST_NAME, UninitializedAndMalformed) {
565  UNITTEST::TestRequiredForeign o, p1, p2;
566  o.mutable_optional_message()->set_a(-1);
567
568  // -1 becomes \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1
569  std::string serialized;
570  EXPECT_TRUE(o.SerializePartialToString(&serialized));
571
572  // Should parse correctly.
573  EXPECT_TRUE(p1.ParsePartialFromString(serialized));
574  EXPECT_FALSE(p1.IsInitialized());
575
576  // Overwriting the last byte to 0xFF results in malformed wire.
577  serialized[serialized.size() - 1] = 0xFF;
578  EXPECT_FALSE(p2.ParseFromString(serialized));
579  EXPECT_FALSE(p2.IsInitialized());
580}
581
582// Parsing proto must not access beyond the bound.
583TEST(MESSAGE_TEST_NAME, ParseStrictlyBoundedStream) {
584  UNITTEST::NestedTestAllTypes o, p;
585  constexpr int kDepth = 2;
586  o = InitNestedProto(kDepth);
587  TestUtil::SetAllFields(o.mutable_child()->mutable_payload());
588  o.mutable_child()->mutable_child()->mutable_payload()->set_optional_string(
589      std::string(1024, 'a'));
590
591  std::string data;
592  EXPECT_TRUE(o.SerializeToString(&data));
593
594  TestUtil::BoundedArrayInputStream stream(data.data(), data.size());
595  EXPECT_TRUE(p.ParseFromBoundedZeroCopyStream(&stream, data.size()));
596  TestUtil::ExpectAllFieldsSet(p.child().payload());
597}
598
599// Helper functions to touch any nested lazy field
600void TouchLazy(UNITTEST::NestedTestAllTypes* msg);
601void TouchLazy(UNITTEST::TestAllTypes* msg);
602void TouchLazy(UNITTEST::TestAllTypes::NestedMessage* msg) {}
603
604void TouchLazy(UNITTEST::TestAllTypes* msg) {
605  if (msg->has_optional_lazy_message()) {
606    TouchLazy(msg->mutable_optional_lazy_message());
607  }
608  if (msg->has_optional_unverified_lazy_message()) {
609    TouchLazy(msg->mutable_optional_unverified_lazy_message());
610  }
611  for (auto& child : *msg->mutable_repeated_lazy_message()) {
612    TouchLazy(&child);
613  }
614}
615
616void TouchLazy(UNITTEST::NestedTestAllTypes* msg) {
617  if (msg->has_child()) TouchLazy(msg->mutable_child());
618  if (msg->has_payload()) TouchLazy(msg->mutable_payload());
619  for (auto& child : *msg->mutable_repeated_child()) {
620    TouchLazy(&child);
621  }
622  if (msg->has_lazy_child()) TouchLazy(msg->mutable_lazy_child());
623  if (msg->has_eager_child()) TouchLazy(msg->mutable_eager_child());
624}
625
626TEST(MESSAGE_TEST_NAME, SuccessAfterParsingFailure) {
627  UNITTEST::NestedTestAllTypes o, p, q;
628  constexpr int kDepth = 5;
629  o = InitNestedProto(kDepth);
630  std::string serialized;
631  EXPECT_TRUE(o.SerializeToString(&serialized));
632
633  // Should parse correctly.
634  EXPECT_TRUE(p.ParseFromString(serialized));
635
636  // Overwriting the last byte to 0xFF results in malformed wire.
637  serialized[serialized.size() - 1] = 0xFF;
638  EXPECT_FALSE(p.ParseFromString(serialized));
639
640  // If the affected byte is inside a lazy message, we have no guarantee that it
641  // serializes into error free data because serialization needs to preserve
642  // const correctness on lazy fields: `touch` all lazy fields.
643  TouchLazy(&p);
644  EXPECT_TRUE(q.ParseFromString(p.SerializeAsString()));
645}
646
647TEST(MESSAGE_TEST_NAME, ExceedRecursionLimit) {
648  UNITTEST::NestedTestAllTypes o, p;
649  const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10;
650  o = InitNestedProto(kDepth);
651  std::string serialized;
652  EXPECT_TRUE(o.SerializeToString(&serialized));
653
654  // Recursion level deeper than the default.
655  EXPECT_FALSE(p.ParseFromString(serialized));
656}
657
658TEST(MESSAGE_TEST_NAME, SupportCustomRecursionLimitRead) {
659  UNITTEST::NestedTestAllTypes o, p;
660  const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10;
661  o = InitNestedProto(kDepth);
662  std::string serialized;
663  EXPECT_TRUE(o.SerializeToString(&serialized));
664
665  // Should pass with custom limit + reads.
666  io::ArrayInputStream raw_input(serialized.data(), serialized.size());
667  io::CodedInputStream input(&raw_input);
668  input.SetRecursionLimit(kDepth + 10);
669  EXPECT_TRUE(p.ParseFromCodedStream(&input));
670
671  EXPECT_EQ(p.child().payload().optional_int32(), 0);
672  EXPECT_EQ(p.child().child().payload().optional_int32(), 1);
673
674  // Verify p serializes successfully (survives VerifyConsistency).
675  std::string result;
676  EXPECT_TRUE(p.SerializeToString(&result));
677}
678
679TEST(MESSAGE_TEST_NAME, SupportCustomRecursionLimitWrite) {
680  UNITTEST::NestedTestAllTypes o, p;
681  const int kDepth = io::CodedInputStream::GetDefaultRecursionLimit() + 10;
682  o = InitNestedProto(kDepth);
683  std::string serialized;
684  EXPECT_TRUE(o.SerializeToString(&serialized));
685
686  // Should pass with custom limit + writes.
687  io::ArrayInputStream raw_input(serialized.data(), serialized.size());
688  io::CodedInputStream input(&raw_input);
689  input.SetRecursionLimit(kDepth + 10);
690  EXPECT_TRUE(p.ParseFromCodedStream(&input));
691
692  EXPECT_EQ(p.mutable_child()->mutable_payload()->optional_int32(), 0);
693  EXPECT_EQ(
694      p.mutable_child()->mutable_child()->mutable_payload()->optional_int32(),
695      1);
696}
697
698// While deep recursion is never guaranteed, this test aims to catch potential
699// issues with very deep recursion.
700TEST(MESSAGE_TEST_NAME, SupportDeepRecursionLimit) {
701  UNITTEST::NestedTestAllTypes o, p;
702  constexpr int kDepth = 1000;
703  auto* child = o.mutable_child();
704  for (int i = 0; i < kDepth; i++) {
705    child = child->mutable_child();
706  }
707  child->mutable_payload()->set_optional_int32(100);
708
709  std::string serialized;
710  EXPECT_TRUE(o.SerializeToString(&serialized));
711
712  io::ArrayInputStream raw_input(serialized.data(), serialized.size());
713  io::CodedInputStream input(&raw_input);
714  input.SetRecursionLimit(1100);
715  EXPECT_TRUE(p.ParseFromCodedStream(&input));
716}
717
718inline bool IsOptimizeForCodeSize(const Descriptor* descriptor) {
719  return descriptor->file()->options().optimize_for() == FileOptions::CODE_SIZE;
720}
721
722
723TEST(MESSAGE_TEST_NAME, Swap) {
724  UNITTEST::NestedTestAllTypes o;
725  constexpr int kDepth = 5;
726  auto* child = o.mutable_child();
727  for (int i = 0; i < kDepth; i++) {
728    child = child->mutable_child();
729  }
730  TestUtil::SetAllFields(child->mutable_payload());
731
732  std::string serialized;
733  EXPECT_TRUE(o.SerializeToString(&serialized));
734
735  {
736    Arena arena;
737    UNITTEST::NestedTestAllTypes* p1 =
738        Arena::Create<UNITTEST::NestedTestAllTypes>(&arena);
739
740    // Should parse correctly.
741    EXPECT_TRUE(p1->ParseFromString(serialized));
742
743    UNITTEST::NestedTestAllTypes* p2 =
744        Arena::Create<UNITTEST::NestedTestAllTypes>(&arena);
745
746    p1->Swap(p2);
747
748    EXPECT_EQ(o.SerializeAsString(), p2->SerializeAsString());
749  }
750}
751
752TEST(MESSAGE_TEST_NAME, BypassInitializationCheckOnParse) {
753  UNITTEST::TestRequired message;
754  io::ArrayInputStream raw_input(nullptr, 0);
755  io::CodedInputStream input(&raw_input);
756  EXPECT_TRUE(message.MergePartialFromCodedStream(&input));
757}
758
759TEST(MESSAGE_TEST_NAME, InitializationErrorString) {
760  UNITTEST::TestRequired message;
761  EXPECT_EQ("a, b, c", message.InitializationErrorString());
762}
763
764TEST(MESSAGE_TEST_NAME, DynamicCastMessage) {
765  UNITTEST::TestAllTypes test_all_types;
766
767  MessageLite* test_all_types_pointer = &test_all_types;
768  EXPECT_EQ(&test_all_types,
769            DynamicCastMessage<UNITTEST::TestAllTypes>(test_all_types_pointer));
770  EXPECT_EQ(nullptr,
771            DynamicCastMessage<UNITTEST::TestRequired>(test_all_types_pointer));
772
773  const MessageLite* test_all_types_pointer_const = &test_all_types;
774  EXPECT_EQ(&test_all_types, DynamicCastMessage<const UNITTEST::TestAllTypes>(
775                                 test_all_types_pointer_const));
776  EXPECT_EQ(nullptr, DynamicCastMessage<const UNITTEST::TestRequired>(
777                         test_all_types_pointer_const));
778
779  MessageLite* test_all_types_pointer_nullptr = nullptr;
780  EXPECT_EQ(nullptr, DynamicCastMessage<UNITTEST::TestAllTypes>(
781                         test_all_types_pointer_nullptr));
782
783  MessageLite& test_all_types_pointer_ref = test_all_types;
784  EXPECT_EQ(&test_all_types, &DynamicCastMessage<UNITTEST::TestAllTypes>(
785                                 test_all_types_pointer_ref));
786
787  const MessageLite& test_all_types_pointer_const_ref = test_all_types;
788  EXPECT_EQ(&test_all_types, &DynamicCastMessage<UNITTEST::TestAllTypes>(
789                                 test_all_types_pointer_const_ref));
790}
791
792TEST(MESSAGE_TEST_NAME, DynamicCastMessageInvalidReferenceType) {
793  UNITTEST::TestAllTypes test_all_types;
794  const MessageLite& test_all_types_pointer_const_ref = test_all_types;
795  ASSERT_DEATH(
796      DynamicCastMessage<UNITTEST::TestRequired>(
797          test_all_types_pointer_const_ref),
798      absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ",
799                   UNITTEST::TestRequired::default_instance().GetTypeName()));
800}
801
802TEST(MESSAGE_TEST_NAME, DownCastMessageValidType) {
803  UNITTEST::TestAllTypes test_all_types;
804
805  MessageLite* test_all_types_pointer = &test_all_types;
806  EXPECT_EQ(&test_all_types,
807            DownCastMessage<UNITTEST::TestAllTypes>(test_all_types_pointer));
808
809  const MessageLite* test_all_types_pointer_const = &test_all_types;
810  EXPECT_EQ(&test_all_types, DownCastMessage<const UNITTEST::TestAllTypes>(
811                                 test_all_types_pointer_const));
812
813  MessageLite* test_all_types_pointer_nullptr = nullptr;
814  EXPECT_EQ(nullptr, DownCastMessage<UNITTEST::TestAllTypes>(
815                         test_all_types_pointer_nullptr));
816
817  MessageLite& test_all_types_pointer_ref = test_all_types;
818  EXPECT_EQ(&test_all_types, &DownCastMessage<UNITTEST::TestAllTypes>(
819                                 test_all_types_pointer_ref));
820
821  const MessageLite& test_all_types_pointer_const_ref = test_all_types;
822  EXPECT_EQ(&test_all_types, &DownCastMessage<UNITTEST::TestAllTypes>(
823                                 test_all_types_pointer_const_ref));
824}
825
826TEST(MESSAGE_TEST_NAME, DownCastMessageInvalidPointerType) {
827  UNITTEST::TestAllTypes test_all_types;
828
829  MessageLite* test_all_types_pointer = &test_all_types;
830
831  ASSERT_DEBUG_DEATH(
832      DownCastMessage<UNITTEST::TestRequired>(test_all_types_pointer),
833      absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ",
834                   UNITTEST::TestRequired::default_instance().GetTypeName()));
835}
836
837TEST(MESSAGE_TEST_NAME, DownCastMessageInvalidReferenceType) {
838  UNITTEST::TestAllTypes test_all_types;
839
840  MessageLite& test_all_types_ref = test_all_types;
841
842  ASSERT_DEBUG_DEATH(
843      DownCastMessage<UNITTEST::TestRequired>(test_all_types_ref),
844      absl::StrCat("Cannot downcast ", test_all_types.GetTypeName(), " to ",
845                   UNITTEST::TestRequired::default_instance().GetTypeName()));
846}
847
848TEST(MESSAGE_TEST_NAME, MessageDebugStringMatchesBehindPointerAndLitePointer) {
849  UNITTEST::TestAllTypes test_all_types;
850  test_all_types.set_optional_string("foo");
851  Message* msg_full_pointer = &test_all_types;
852  MessageLite* msg_lite_pointer = &test_all_types;
853  ASSERT_EQ(test_all_types.DebugString(), msg_full_pointer->DebugString());
854  ASSERT_EQ(test_all_types.DebugString(), msg_lite_pointer->DebugString());
855}
856
857#if GTEST_HAS_DEATH_TEST  // death tests do not work on Windows yet.
858
859TEST(MESSAGE_TEST_NAME, SerializeFailsIfNotInitialized) {
860  UNITTEST::TestRequired message;
861  std::string data;
862  EXPECT_DEBUG_DEATH(
863      EXPECT_TRUE(message.SerializeToString(&data)),
864      absl::StrCat("Can't serialize message of type \"", UNITTEST_PACKAGE_NAME,
865                   ".TestRequired\" because "
866                   "it is missing required fields: a, b, c"));
867}
868
869TEST(MESSAGE_TEST_NAME, CheckInitialized) {
870  UNITTEST::TestRequired message;
871  EXPECT_DEATH(message.CheckInitialized(),
872               absl::StrCat("Message of type \"", UNITTEST_PACKAGE_NAME,
873                            ".TestRequired\" is missing required "
874                            "fields: a, b, c"));
875}
876
877#endif  // GTEST_HAS_DEATH_TEST
878
879namespace {
880// An input stream that repeats a std::string's content for a number of times.
881// It helps us create a really large input without consuming too much memory.
882// Used to test the parsing behavior when the input size exceeds 2G or close to
883// it.
884class RepeatedInputStream : public io::ZeroCopyInputStream {
885 public:
886  RepeatedInputStream(const std::string& data, size_t count)
887      : data_(data), count_(count), position_(0), total_byte_count_(0) {}
888
889  bool Next(const void** data, int* size) override {
890    if (position_ == data_.size()) {
891      if (--count_ == 0) {
892        return false;
893      }
894      position_ = 0;
895    }
896    *data = &data_[position_];
897    *size = static_cast<int>(data_.size() - position_);
898    position_ = data_.size();
899    total_byte_count_ += *size;
900    return true;
901  }
902
903  void BackUp(int count) override {
904    position_ -= static_cast<size_t>(count);
905    total_byte_count_ -= count;
906  }
907
908  bool Skip(int count) override {
909    while (count > 0) {
910      const void* data;
911      int size;
912      if (!Next(&data, &size)) {
913        break;
914      }
915      if (size >= count) {
916        BackUp(size - count);
917        return true;
918      } else {
919        count -= size;
920      }
921    }
922    return false;
923  }
924
925  int64_t ByteCount() const override { return total_byte_count_; }
926
927 private:
928  std::string data_;
929  size_t count_;     // The number of strings that haven't been consumed.
930  size_t position_;  // Position in the std::string for the next read.
931  int64_t total_byte_count_;
932};
933}  // namespace
934
935TEST(MESSAGE_TEST_NAME, TestParseMessagesCloseTo2G) {
936  constexpr int32_t kint32max = std::numeric_limits<int32_t>::max();
937
938  // Create a message with a large std::string field.
939  std::string value = std::string(64 * 1024 * 1024, 'x');
940  UNITTEST::TestAllTypes message;
941  message.set_optional_string(value);
942
943  // Repeat this message in the input stream to make the total input size
944  // close to 2G.
945  std::string data = message.SerializeAsString();
946  size_t count = static_cast<size_t>(kint32max) / data.size();
947  RepeatedInputStream input(data, count);
948
949  // The parsing should succeed.
950  UNITTEST::TestAllTypes result;
951  EXPECT_TRUE(result.ParseFromZeroCopyStream(&input));
952
953  // When there are multiple occurrences of a singular field, the last one
954  // should win.
955  EXPECT_EQ(value, result.optional_string());
956}
957
958TEST(MESSAGE_TEST_NAME, TestParseMessagesOver2G) {
959  constexpr int32_t kint32max = std::numeric_limits<int32_t>::max();
960
961  // Create a message with a large std::string field.
962  std::string value = std::string(64 * 1024 * 1024, 'x');
963  UNITTEST::TestAllTypes message;
964  message.set_optional_string(value);
965
966  // Repeat this message in the input stream to make the total input size
967  // larger than 2G.
968  std::string data = message.SerializeAsString();
969  size_t count = static_cast<size_t>(kint32max) / data.size() + 1;
970  RepeatedInputStream input(data, count);
971
972  // The parsing should fail.
973  UNITTEST::TestAllTypes result;
974  EXPECT_FALSE(result.ParseFromZeroCopyStream(&input));
975}
976
977TEST(MESSAGE_TEST_NAME, BypassInitializationCheckOnSerialize) {
978  UNITTEST::TestRequired message;
979  io::ArrayOutputStream raw_output(nullptr, 0);
980  io::CodedOutputStream output(&raw_output);
981  EXPECT_TRUE(message.SerializePartialToCodedStream(&output));
982}
983
984TEST(MESSAGE_TEST_NAME, FindInitializationErrors) {
985  UNITTEST::TestRequired message;
986  std::vector<std::string> errors;
987  message.FindInitializationErrors(&errors);
988  ASSERT_EQ(3, errors.size());
989  EXPECT_EQ("a", errors[0]);
990  EXPECT_EQ("b", errors[1]);
991  EXPECT_EQ("c", errors[2]);
992}
993
994TEST(MESSAGE_TEST_NAME, ReleaseMustUseResult) {
995  UNITTEST::TestAllTypes message;
996  auto* f = new UNITTEST::ForeignMessage();
997  f->set_c(1000);
998  message.set_allocated_optional_foreign_message(f);
999  auto* mf = message.mutable_optional_foreign_message();
1000  EXPECT_EQ(mf, f);
1001  std::unique_ptr<UNITTEST::ForeignMessage> rf(
1002      message.release_optional_foreign_message());
1003  EXPECT_NE(rf.get(), nullptr);
1004}
1005
1006TEST(MESSAGE_TEST_NAME, ParseFailsOnInvalidMessageEnd) {
1007  UNITTEST::TestAllTypes message;
1008
1009  // Control case.
1010  EXPECT_TRUE(message.ParseFromArray("", 0));
1011
1012  // The byte is a valid varint, but not a valid tag (zero).
1013  EXPECT_FALSE(message.ParseFromArray("\0", 1));
1014
1015  // The byte is a malformed varint.
1016  EXPECT_FALSE(message.ParseFromArray("\200", 1));
1017
1018  // The byte is an endgroup tag, but we aren't parsing a group.
1019  EXPECT_FALSE(message.ParseFromArray("\014", 1));
1020}
1021
1022// Regression test for b/23630858
1023TEST(MESSAGE_TEST_NAME, MessageIsStillValidAfterParseFails) {
1024  UNITTEST::TestAllTypes message;
1025
1026  // 9 0xFFs for the "optional_uint64" field.
1027  std::string invalid_data = "\x20\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF";
1028
1029  EXPECT_FALSE(message.ParseFromString(invalid_data));
1030  message.Clear();
1031  EXPECT_EQ(0, message.optional_uint64());
1032
1033  // invalid data for field "optional_string". Length prefix is 1 but no
1034  // payload.
1035  std::string invalid_string_data = "\x72\x01";
1036  {
1037    Arena arena;
1038    UNITTEST::TestAllTypes* arena_message =
1039        Arena::Create<UNITTEST::TestAllTypes>(&arena);
1040    EXPECT_FALSE(arena_message->ParseFromString(invalid_string_data));
1041    arena_message->Clear();
1042    EXPECT_EQ("", arena_message->optional_string());
1043  }
1044}
1045
1046TEST(MESSAGE_TEST_NAME, NonCanonicalTag) {
1047  UNITTEST::TestAllTypes message;
1048  // optional_lazy_message (27) LEN(3) with non canonical tag: (1).
1049  const char encoded[] = {'\332', 1, 3, '\210', 0, 0};
1050  EXPECT_TRUE(message.ParseFromArray(encoded, sizeof(encoded)));
1051}
1052
1053TEST(MESSAGE_TEST_NAME, Zero5BTag) {
1054  UNITTEST::TestAllTypes message;
1055  // optional_nested_message (18) LEN(6) with 5B but zero tag.
1056  const char encoded[] = {'\222', 1,      6,      '\200', '\200',
1057                          '\200', '\200', '\020', 0};
1058  EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded)));
1059}
1060
1061TEST(MESSAGE_TEST_NAME, Zero5BTagLazy) {
1062  UNITTEST::TestAllTypes message;
1063  // optional_lazy_message (27) LEN(6) with 5B but zero tag.
1064  const char encoded[] = {'\332', 1,      6,      '\200', '\200',
1065                          '\200', '\200', '\020', 0};
1066  EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded)));
1067}
1068
1069namespace {
1070void ExpectMessageMerged(const UNITTEST::TestAllTypes& message) {
1071  EXPECT_EQ(3, message.optional_int32());
1072  EXPECT_EQ(2, message.optional_int64());
1073  EXPECT_EQ("hello", message.optional_string());
1074}
1075
1076void AssignParsingMergeMessages(UNITTEST::TestAllTypes* msg1,
1077                                UNITTEST::TestAllTypes* msg2,
1078                                UNITTEST::TestAllTypes* msg3) {
1079  msg1->set_optional_int32(1);
1080  msg2->set_optional_int64(2);
1081  msg3->set_optional_int32(3);
1082  msg3->set_optional_string("hello");
1083}
1084}  // namespace
1085
1086// Test that if an optional or required message/group field appears multiple
1087// times in the input, they need to be merged.
1088TEST(MESSAGE_TEST_NAME, ParsingMerge) {
1089  UNITTEST::TestParsingMerge::RepeatedFieldsGenerator generator;
1090  UNITTEST::TestAllTypes* msg1;
1091  UNITTEST::TestAllTypes* msg2;
1092  UNITTEST::TestAllTypes* msg3;
1093
1094#define ASSIGN_REPEATED_FIELD(FIELD) \
1095  msg1 = generator.add_##FIELD();    \
1096  msg2 = generator.add_##FIELD();    \
1097  msg3 = generator.add_##FIELD();    \
1098  AssignParsingMergeMessages(msg1, msg2, msg3)
1099
1100  ASSIGN_REPEATED_FIELD(field1);
1101  ASSIGN_REPEATED_FIELD(field2);
1102  ASSIGN_REPEATED_FIELD(field3);
1103  ASSIGN_REPEATED_FIELD(ext1);
1104  ASSIGN_REPEATED_FIELD(ext2);
1105
1106#undef ASSIGN_REPEATED_FIELD
1107#define ASSIGN_REPEATED_GROUP(FIELD)                \
1108  msg1 = generator.add_##FIELD()->mutable_field1(); \
1109  msg2 = generator.add_##FIELD()->mutable_field1(); \
1110  msg3 = generator.add_##FIELD()->mutable_field1(); \
1111  AssignParsingMergeMessages(msg1, msg2, msg3)
1112
1113  ASSIGN_REPEATED_GROUP(group1);
1114  ASSIGN_REPEATED_GROUP(group2);
1115
1116#undef ASSIGN_REPEATED_GROUP
1117
1118  std::string buffer;
1119  generator.SerializeToString(&buffer);
1120  UNITTEST::TestParsingMerge parsing_merge;
1121  parsing_merge.ParseFromString(buffer);
1122
1123  // Required and optional fields should be merged.
1124  ExpectMessageMerged(parsing_merge.required_all_types());
1125  ExpectMessageMerged(parsing_merge.optional_all_types());
1126  ExpectMessageMerged(parsing_merge.optionalgroup().optional_group_all_types());
1127  ExpectMessageMerged(
1128      parsing_merge.GetExtension(UNITTEST::TestParsingMerge::optional_ext));
1129
1130  // Repeated fields should not be merged.
1131  EXPECT_EQ(3, parsing_merge.repeated_all_types_size());
1132  EXPECT_EQ(3, parsing_merge.repeatedgroup_size());
1133  EXPECT_EQ(
1134      3, parsing_merge.ExtensionSize(UNITTEST::TestParsingMerge::repeated_ext));
1135}
1136
1137TEST(MESSAGE_TEST_NAME, MergeFrom) {
1138  UNITTEST::TestAllTypes source, dest;
1139
1140  // Optional fields
1141  source.set_optional_int32(1);  // only source
1142  source.set_optional_int64(2);  // both source and dest
1143  dest.set_optional_int64(3);
1144  dest.set_optional_uint32(4);  // only dest
1145
1146  // Optional fields with defaults
1147  source.set_default_int32(13);  // only source
1148  source.set_default_int64(14);  // both source and dest
1149  dest.set_default_int64(15);
1150  dest.set_default_uint32(16);  // only dest
1151
1152  // Repeated fields
1153  source.add_repeated_int32(5);  // only source
1154  source.add_repeated_int32(6);
1155  source.add_repeated_int64(7);  // both source and dest
1156  source.add_repeated_int64(8);
1157  dest.add_repeated_int64(9);
1158  dest.add_repeated_int64(10);
1159  dest.add_repeated_uint32(11);  // only dest
1160  dest.add_repeated_uint32(12);
1161
1162  dest.MergeFrom(source);
1163
1164  // Optional fields: source overwrites dest if source is specified
1165  EXPECT_EQ(1, dest.optional_int32());   // only source: use source
1166  EXPECT_EQ(2, dest.optional_int64());   // source and dest: use source
1167  EXPECT_EQ(4, dest.optional_uint32());  // only dest: use dest
1168  EXPECT_EQ(0, dest.optional_uint64());  // neither: use default
1169
1170  // Optional fields with defaults
1171  EXPECT_EQ(13, dest.default_int32());   // only source: use source
1172  EXPECT_EQ(14, dest.default_int64());   // source and dest: use source
1173  EXPECT_EQ(16, dest.default_uint32());  // only dest: use dest
1174  EXPECT_EQ(44, dest.default_uint64());  // neither: use default
1175
1176  // Repeated fields: concatenate source onto the end of dest
1177  ASSERT_EQ(2, dest.repeated_int32_size());
1178  EXPECT_EQ(5, dest.repeated_int32(0));
1179  EXPECT_EQ(6, dest.repeated_int32(1));
1180  ASSERT_EQ(4, dest.repeated_int64_size());
1181  EXPECT_EQ(9, dest.repeated_int64(0));
1182  EXPECT_EQ(10, dest.repeated_int64(1));
1183  EXPECT_EQ(7, dest.repeated_int64(2));
1184  EXPECT_EQ(8, dest.repeated_int64(3));
1185  ASSERT_EQ(2, dest.repeated_uint32_size());
1186  EXPECT_EQ(11, dest.repeated_uint32(0));
1187  EXPECT_EQ(12, dest.repeated_uint32(1));
1188  ASSERT_EQ(0, dest.repeated_uint64_size());
1189}
1190
1191TEST(MESSAGE_TEST_NAME, IsInitialized) {
1192  UNITTEST::TestIsInitialized msg;
1193  EXPECT_TRUE(msg.IsInitialized());
1194  UNITTEST::TestIsInitialized::SubMessage* sub_message =
1195      msg.mutable_sub_message();
1196  EXPECT_TRUE(msg.IsInitialized());
1197  UNITTEST::TestIsInitialized::SubMessage::SubGroup* sub_group =
1198      sub_message->mutable_subgroup();
1199  EXPECT_FALSE(msg.IsInitialized());
1200  sub_group->set_i(1);
1201  EXPECT_TRUE(msg.IsInitialized());
1202}
1203
1204TEST(MESSAGE_TEST_NAME, IsInitializedSplitBytestream) {
1205  UNITTEST::TestRequired ab, c;
1206  ab.set_a(1);
1207  ab.set_b(2);
1208  c.set_c(3);
1209
1210  // The protobuf represented by the concatenated string has all required
1211  // fields (a,b,c) set.
1212  std::string bytes =
1213      ab.SerializePartialAsString() + c.SerializePartialAsString();
1214
1215  UNITTEST::TestRequired concatenated;
1216  EXPECT_TRUE(concatenated.ParsePartialFromString(bytes));
1217  EXPECT_TRUE(concatenated.IsInitialized());
1218
1219  UNITTEST::TestRequiredForeign fab, fc;
1220  fab.mutable_optional_message()->set_a(1);
1221  fab.mutable_optional_message()->set_b(2);
1222  fc.mutable_optional_message()->set_c(3);
1223
1224  bytes = fab.SerializePartialAsString() + fc.SerializePartialAsString();
1225
1226  UNITTEST::TestRequiredForeign fconcatenated;
1227  EXPECT_TRUE(fconcatenated.ParsePartialFromString(bytes));
1228  EXPECT_TRUE(fconcatenated.IsInitialized());
1229}
1230
1231TEST(MESSAGE_FACTORY_TEST_NAME, GeneratedFactoryLookup) {
1232  EXPECT_EQ(MessageFactory::generated_factory()->GetPrototype(
1233                UNITTEST::TestAllTypes::descriptor()),
1234            &UNITTEST::TestAllTypes::default_instance());
1235}
1236
1237TEST(MESSAGE_FACTORY_TEST_NAME, GeneratedFactoryUnknownType) {
1238  // Construct a new descriptor.
1239  DescriptorPool pool;
1240  FileDescriptorProto file;
1241  file.set_name("foo.proto");
1242  file.add_message_type()->set_name("Foo");
1243  const Descriptor* descriptor = pool.BuildFile(file)->message_type(0);
1244
1245  // Trying to construct it should return nullptr.
1246  EXPECT_TRUE(MessageFactory::generated_factory()->GetPrototype(descriptor) ==
1247              nullptr);
1248}
1249
1250TEST(MESSAGE_TEST_NAME, MOMIParserEdgeCases) {
1251  {
1252    UNITTEST::TestAllTypes msg;
1253    // Parser ends in last 16 bytes of buffer due to a 0.
1254    std::string data;
1255    // 12 bytes of data
1256    for (int i = 0; i < 4; i++) absl::StrAppend(&data, "\370\1\1");
1257    // 13 byte is terminator
1258    data += '\0';  // Terminator
1259    // followed by the rest of the stream
1260    // space is ascii 32 so no end group
1261    data += std::string(30, ' ');
1262    io::ArrayInputStream zcis(data.data(), data.size(), 17);
1263    io::CodedInputStream cis(&zcis);
1264    EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
1265    EXPECT_EQ(cis.CurrentPosition(), 3 * 4 + 1);
1266  }
1267  {
1268    // Parser ends in last 16 bytes of buffer due to a end-group.
1269    // Must use a message that is a group. Otherwise ending on a group end is
1270    // a failure.
1271    UNITTEST::TestAllTypes::OptionalGroup msg;
1272    std::string data;
1273    for (int i = 0; i < 3; i++) absl::StrAppend(&data, "\370\1\1");
1274    data += '\14';  // Octal end-group tag 12 (1 * 8 + 4(
1275    data += std::string(30, ' ');
1276    io::ArrayInputStream zcis(data.data(), data.size(), 17);
1277    io::CodedInputStream cis(&zcis);
1278    EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
1279    EXPECT_EQ(cis.CurrentPosition(), 3 * 3 + 1);
1280    EXPECT_TRUE(cis.LastTagWas(12));
1281  }
1282  {
1283    // Parser ends in last 16 bytes of buffer due to a end-group. But is inside
1284    // a length delimited field.
1285    // a failure.
1286    UNITTEST::TestAllTypes::OptionalGroup msg;
1287    std::string data = "\22\3foo";
1288    data += '\14';  // Octal end-group tag 12 (1 * 8 + 4(
1289    data += std::string(30, ' ');
1290    io::ArrayInputStream zcis(data.data(), data.size(), 17);
1291    io::CodedInputStream cis(&zcis);
1292    EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
1293    EXPECT_EQ(cis.CurrentPosition(), 6);
1294    EXPECT_TRUE(cis.LastTagWas(12));
1295  }
1296  {
1297    // Parser fails when ending on 0 if from ZeroCopyInputStream
1298    UNITTEST::TestAllTypes msg;
1299    std::string data;
1300    // 12 bytes of data
1301    for (int i = 0; i < 4; i++) absl::StrAppend(&data, "\370\1\1");
1302    // 13 byte is terminator
1303    data += '\0';  // Terminator
1304    data += std::string(30, ' ');
1305    io::ArrayInputStream zcis(data.data(), data.size(), 17);
1306    EXPECT_FALSE(msg.ParsePartialFromZeroCopyStream(&zcis));
1307  }
1308}
1309
1310
1311TEST(MESSAGE_TEST_NAME, CheckSerializationWhenInterleavedExtensions) {
1312  UNITTEST::TestExtensionRangeSerialize in_message;
1313
1314  in_message.set_foo_one(1);
1315  in_message.set_foo_two(2);
1316  in_message.set_foo_three(3);
1317  in_message.set_foo_four(4);
1318
1319  in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_one, 1);
1320  in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_two, 2);
1321  in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_three, 3);
1322  in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_four, 4);
1323  in_message.SetExtension(UNITTEST::TestExtensionRangeSerialize::bar_five, 5);
1324
1325  std::string buffer;
1326  in_message.SerializeToString(&buffer);
1327
1328  UNITTEST::TestExtensionRangeSerialize out_message;
1329  out_message.ParseFromString(buffer);
1330
1331  EXPECT_EQ(1, out_message.foo_one());
1332  EXPECT_EQ(2, out_message.foo_two());
1333  EXPECT_EQ(3, out_message.foo_three());
1334  EXPECT_EQ(4, out_message.foo_four());
1335
1336  EXPECT_EQ(1, out_message.GetExtension(
1337                   UNITTEST::TestExtensionRangeSerialize::bar_one));
1338  EXPECT_EQ(2, out_message.GetExtension(
1339                   UNITTEST::TestExtensionRangeSerialize::bar_two));
1340  EXPECT_EQ(3, out_message.GetExtension(
1341                   UNITTEST::TestExtensionRangeSerialize::bar_three));
1342  EXPECT_EQ(4, out_message.GetExtension(
1343                   UNITTEST::TestExtensionRangeSerialize::bar_four));
1344  EXPECT_EQ(5, out_message.GetExtension(
1345                   UNITTEST::TestExtensionRangeSerialize::bar_five));
1346}
1347
1348TEST(MESSAGE_TEST_NAME, PreservesFloatingPointNegative0) {
1349  UNITTEST::TestAllTypes in_message;
1350  in_message.set_optional_float(-0.0f);
1351  in_message.set_optional_double(-0.0);
1352  std::string serialized;
1353  EXPECT_TRUE(in_message.SerializeToString(&serialized));
1354  UNITTEST::TestAllTypes out_message;
1355  EXPECT_TRUE(out_message.ParseFromString(serialized));
1356  EXPECT_EQ(in_message.optional_float(), out_message.optional_float());
1357  EXPECT_EQ(std::signbit(in_message.optional_float()),
1358            std::signbit(out_message.optional_float()));
1359  EXPECT_EQ(in_message.optional_double(), out_message.optional_double());
1360  EXPECT_EQ(std::signbit(in_message.optional_double()),
1361            std::signbit(out_message.optional_double()));
1362}
1363
1364TEST(MESSAGE_TEST_NAME,
1365     RegressionTestForParseMessageReadingUninitializedLimit) {
1366  UNITTEST::TestAllTypes in_message;
1367  in_message.mutable_optional_nested_message();
1368  std::string serialized = in_message.SerializeAsString();
1369  // We expect this to have 3 bytes: two for the tag, and one for the zero size.
1370  // Break the size by making it overlong.
1371  ASSERT_EQ(serialized.size(), 3);
1372  serialized.back() = '\200';
1373  serialized += std::string(10, '\200');
1374  EXPECT_FALSE(in_message.ParseFromString(serialized));
1375}
1376
1377TEST(MESSAGE_TEST_NAME,
1378     RegressionTestForParseMessageWithSizeBeyondInputFailsToPopLimit) {
1379  UNITTEST::TestAllTypes in_message;
1380  in_message.mutable_optional_nested_message();
1381  std::string serialized = in_message.SerializeAsString();
1382  // We expect this to have 3 bytes: two for the tag, and one for the zero size.
1383  // Make the size a valid varint, but it overflows in the input.
1384  ASSERT_EQ(serialized.size(), 3);
1385  serialized.back() = 10;
1386  EXPECT_FALSE(in_message.ParseFromString(serialized));
1387}
1388
1389namespace {
1390const uint8_t* SkipTag(const uint8_t* buf) {
1391  while (*buf & 0x80) ++buf;
1392  ++buf;
1393  return buf;
1394}
1395
1396// Adds `non_canonical_bytes` bytes to the varint representation at the tail of
1397// the buffer.
1398// `buf` points to the start of the buffer, `p` points to one-past-the-end.
1399// Returns the new one-past-the-end pointer.
1400uint8_t* AddNonCanonicalBytes(const uint8_t* buf, uint8_t* p,
1401                              int non_canonical_bytes) {
1402  // varint can have a max of 10 bytes.
1403  while (non_canonical_bytes-- > 0 && p - buf < 10) {
1404    // Add a dummy byte at the end.
1405    p[-1] |= 0x80;
1406    p[0] = 0;
1407    ++p;
1408  }
1409  return p;
1410}
1411
1412std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) {
1413  uint8_t buf[100];
1414  uint8_t* p = buf;
1415
1416  p = internal::WireFormatLite::WriteBoolToArray(number, value, p);
1417  p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
1418  return std::string(buf, p);
1419}
1420
1421std::string EncodeEnumValue(int number, int value, int non_canonical_bytes,
1422                            bool use_packed) {
1423  uint8_t buf[100];
1424  uint8_t* p = buf;
1425
1426  if (use_packed) {
1427    p = internal::WireFormatLite::WriteEnumNoTagToArray(value, p);
1428    p = AddNonCanonicalBytes(buf, p, non_canonical_bytes);
1429
1430    std::string payload(buf, p);
1431    p = buf;
1432    p = internal::WireFormatLite::WriteStringToArray(number, payload, p);
1433    return std::string(buf, p);
1434
1435  } else {
1436    p = internal::WireFormatLite::WriteEnumToArray(number, value, p);
1437    p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
1438    return std::string(buf, p);
1439  }
1440}
1441
1442std::string EncodeOverlongEnum(int number, bool use_packed) {
1443  uint8_t buf[100];
1444  uint8_t* p = buf;
1445
1446  std::string overlong(16, static_cast<char>(0x80));
1447  if (use_packed) {
1448    p = internal::WireFormatLite::WriteStringToArray(number, overlong, p);
1449    return std::string(buf, p);
1450  } else {
1451    p = internal::WireFormatLite::WriteTagToArray(
1452        number, internal::WireFormatLite::WIRETYPE_VARINT, p);
1453    p = std::copy(overlong.begin(), overlong.end(), p);
1454    return std::string(buf, p);
1455  }
1456}
1457
1458std::string EncodeInt32Value(int number, int32_t value,
1459                             int non_canonical_bytes) {
1460  uint8_t buf[100];
1461  uint8_t* p = buf;
1462
1463  p = internal::WireFormatLite::WriteInt32ToArray(number, value, p);
1464  p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
1465  return std::string(buf, p);
1466}
1467
1468std::string EncodeInt64Value(int number, int64_t value, int non_canonical_bytes,
1469                             bool use_packed = false) {
1470  uint8_t buf[100];
1471  uint8_t* p = buf;
1472
1473  if (use_packed) {
1474    p = internal::WireFormatLite::WriteInt64NoTagToArray(value, p);
1475    p = AddNonCanonicalBytes(buf, p, non_canonical_bytes);
1476
1477    std::string payload(buf, p);
1478    p = buf;
1479    p = internal::WireFormatLite::WriteStringToArray(number, payload, p);
1480    return std::string(buf, p);
1481
1482  } else {
1483    p = internal::WireFormatLite::WriteInt64ToArray(number, value, p);
1484    p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
1485    return std::string(buf, p);
1486  }
1487}
1488
1489std::string EncodeOtherField() {
1490  UNITTEST::EnumParseTester obj;
1491  obj.set_other_field(1);
1492  return obj.SerializeAsString();
1493}
1494
1495template <typename T>
1496static std::vector<const FieldDescriptor*> GetFields() {
1497  auto* descriptor = T::descriptor();
1498  std::vector<const FieldDescriptor*> fields;
1499  for (int i = 0; i < descriptor->field_count(); ++i) {
1500    fields.push_back(descriptor->field(i));
1501  }
1502  for (int i = 0; i < descriptor->extension_count(); ++i) {
1503    fields.push_back(descriptor->extension(i));
1504  }
1505  return fields;
1506}
1507}  // namespace
1508
1509TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
1510  UNITTEST::EnumParseTester obj;
1511
1512  const auto other_field = EncodeOtherField();
1513
1514  // Encode an enum field for many different cases and verify that it can be
1515  // parsed as expected.
1516  // There are:
1517  //  - optional/repeated/packed fields
1518  //  - field tags that encode in 1/2/3 bytes
1519  //  - canonical and non-canonical encodings of the varint
1520  //  - last vs not last field
1521  //  - label combinations to trigger different parsers: sequential, small
1522  //  sequential, non-validated.
1523
1524  const std::vector<const FieldDescriptor*> fields =
1525      GetFields<UNITTEST::EnumParseTester>();
1526
1527  constexpr int kInvalidValue = 0x900913;
1528  auto* ref = obj.GetReflection();
1529  PROTOBUF_UNUSED auto* descriptor = obj.descriptor();
1530  for (bool use_packed : {false, true}) {
1531    SCOPED_TRACE(use_packed);
1532    for (bool use_tail_field : {false, true}) {
1533      SCOPED_TRACE(use_tail_field);
1534      for (int non_canonical_bytes = 0; non_canonical_bytes < 9;
1535           ++non_canonical_bytes) {
1536        SCOPED_TRACE(non_canonical_bytes);
1537        for (bool add_garbage_bits : {false, true}) {
1538          if (add_garbage_bits && non_canonical_bytes != 9) {
1539            // We only add garbage on the 10th byte.
1540            continue;
1541          }
1542          SCOPED_TRACE(add_garbage_bits);
1543          for (auto field : fields) {
1544            if (field->name() == "other_field") continue;
1545            if (!field->is_repeated() && use_packed) continue;
1546            SCOPED_TRACE(field->full_name());
1547            const auto* enum_desc = field->enum_type();
1548            for (int e = 0; e < enum_desc->value_count(); ++e) {
1549              const auto* value_desc = enum_desc->value(e);
1550              if (value_desc->number() < 0 && non_canonical_bytes > 0) {
1551                // Negative numbers only have a canonical representation.
1552                continue;
1553              }
1554              SCOPED_TRACE(value_desc->number());
1555              ABSL_CHECK_NE(value_desc->number(), kInvalidValue)
1556                  << "Invalid value is a real label.";
1557              auto encoded =
1558                  EncodeEnumValue(field->number(), value_desc->number(),
1559                                  non_canonical_bytes, use_packed);
1560              if (add_garbage_bits) {
1561                // These bits should be discarded even in the `false` case.
1562                encoded.back() |= 0b0111'1110;
1563              }
1564              if (use_tail_field) {
1565                // Make sure that fields after this one can be parsed too. ie
1566                // test that the "next" jump is correct too.
1567                encoded += other_field;
1568              }
1569
1570              EXPECT_TRUE(obj.ParseFromString(encoded));
1571              if (field->is_repeated()) {
1572                ASSERT_EQ(ref->FieldSize(obj, field), 1);
1573                EXPECT_EQ(ref->GetRepeatedEnumValue(obj, field, 0),
1574                          value_desc->number());
1575              } else {
1576                EXPECT_TRUE(ref->HasField(obj, field));
1577                EXPECT_EQ(ref->GetEnumValue(obj, field), value_desc->number());
1578              }
1579              auto& unknown = ref->GetUnknownFields(obj);
1580              ASSERT_EQ(unknown.field_count(), 0);
1581            }
1582
1583            {
1584              SCOPED_TRACE("Invalid value");
1585              // Try an invalid value, which should go to the unknown fields.
1586              EXPECT_TRUE(obj.ParseFromString(
1587                  EncodeEnumValue(field->number(), kInvalidValue,
1588                                  non_canonical_bytes, use_packed)));
1589              if (field->is_repeated()) {
1590                ASSERT_EQ(ref->FieldSize(obj, field), 0);
1591              } else {
1592                EXPECT_FALSE(ref->HasField(obj, field));
1593                EXPECT_EQ(ref->GetEnumValue(obj, field),
1594                          enum_desc->value(0)->number());
1595              }
1596              auto& unknown = ref->GetUnknownFields(obj);
1597              ASSERT_EQ(unknown.field_count(), 1);
1598              EXPECT_EQ(unknown.field(0).number(), field->number());
1599              EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT);
1600              EXPECT_EQ(unknown.field(0).varint(), kInvalidValue);
1601            }
1602            {
1603              SCOPED_TRACE("Overlong varint");
1604              // Try an overlong varint. It should fail parsing, but not trigger
1605              // any sanitizer warning.
1606              EXPECT_FALSE(obj.ParseFromString(
1607                  EncodeOverlongEnum(field->number(), use_packed)));
1608            }
1609          }
1610        }
1611      }
1612    }
1613  }
1614}
1615
1616TEST(MESSAGE_TEST_NAME, TestEnumParserForUnknownEnumValue) {
1617  DynamicMessageFactory factory;
1618  std::unique_ptr<Message> dynamic(
1619      factory.GetPrototype(UNITTEST::EnumParseTester::descriptor())->New());
1620
1621  UNITTEST::EnumParseTester non_dynamic;
1622
1623  // For unknown enum values, for consistency we must include the
1624  // int32_t enum value in the unknown field set, which might not be exactly the
1625  // same as the input.
1626  PROTOBUF_UNUSED auto* descriptor = non_dynamic.descriptor();
1627
1628  const std::vector<const FieldDescriptor*> fields =
1629      GetFields<UNITTEST::EnumParseTester>();
1630
1631  for (bool use_dynamic : {false, true}) {
1632    SCOPED_TRACE(use_dynamic);
1633    for (auto field : fields) {
1634      if (field->name() == "other_field") continue;
1635      SCOPED_TRACE(field->full_name());
1636      for (bool use_packed : {false, true}) {
1637        SCOPED_TRACE(use_packed);
1638        if (!field->is_repeated() && use_packed) continue;
1639
1640        // -2 is an invalid enum value on all the tests here.
1641        // We will encode -2 as a positive int64 that is equivalent to
1642        // int32_t{-2} when truncated.
1643        constexpr int64_t minus_2_non_canonical =
1644            static_cast<int64_t>(static_cast<uint32_t>(int32_t{-2}));
1645        static_assert(minus_2_non_canonical != -2, "");
1646        std::string encoded = EncodeInt64Value(
1647            field->number(), minus_2_non_canonical, 0, use_packed);
1648
1649        auto& obj = use_dynamic ? *dynamic : non_dynamic;
1650        ASSERT_TRUE(obj.ParseFromString(encoded));
1651
1652        auto& unknown = obj.GetReflection()->GetUnknownFields(obj);
1653        ASSERT_EQ(unknown.field_count(), 1);
1654        EXPECT_EQ(unknown.field(0).number(), field->number());
1655        EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT);
1656        EXPECT_EQ(unknown.field(0).varint(), int64_t{-2});
1657      }
1658    }
1659  }
1660}
1661
1662TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
1663  UNITTEST::BoolParseTester obj;
1664
1665  const auto other_field = EncodeOtherField();
1666
1667  // Encode a boolean field for many different cases and verify that it can be
1668  // parsed as expected.
1669  // There are:
1670  //  - optional/repeated/packed fields
1671  //  - field tags that encode in 1/2/3 bytes
1672  //  - canonical and non-canonical encodings of the varint
1673  //  - last vs not last field
1674
1675  const std::vector<const FieldDescriptor*> fields =
1676      GetFields<UNITTEST::BoolParseTester>();
1677
1678  auto* ref = obj.GetReflection();
1679  PROTOBUF_UNUSED auto* descriptor = obj.descriptor();
1680  for (bool use_tail_field : {false, true}) {
1681    SCOPED_TRACE(use_tail_field);
1682    for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
1683         ++non_canonical_bytes) {
1684      SCOPED_TRACE(non_canonical_bytes);
1685      for (bool add_garbage_bits : {false, true}) {
1686        if (add_garbage_bits && non_canonical_bytes != 9) {
1687          // We only add garbage on the 10th byte.
1688          continue;
1689        }
1690        SCOPED_TRACE(add_garbage_bits);
1691        for (auto field : fields) {
1692          if (field->name() == "other_field") continue;
1693          SCOPED_TRACE(field->full_name());
1694          for (bool value : {false, true}) {
1695            SCOPED_TRACE(value);
1696            auto encoded =
1697                EncodeBoolValue(field->number(), value, non_canonical_bytes);
1698            if (add_garbage_bits) {
1699              // These bits should be discarded even in the `false` case.
1700              encoded.back() |= 0b0111'1110;
1701            }
1702            if (use_tail_field) {
1703              // Make sure that fields after this one can be parsed too. ie test
1704              // that the "next" jump is correct too.
1705              encoded += other_field;
1706            }
1707
1708            EXPECT_TRUE(obj.ParseFromString(encoded));
1709            if (field->is_repeated()) {
1710              ASSERT_EQ(ref->FieldSize(obj, field), 1);
1711              EXPECT_EQ(ref->GetRepeatedBool(obj, field, 0), value);
1712            } else {
1713              EXPECT_TRUE(ref->HasField(obj, field));
1714              EXPECT_EQ(ref->GetBool(obj, field), value)
1715                  << testing::PrintToString(encoded);
1716            }
1717            auto& unknown = ref->GetUnknownFields(obj);
1718            ASSERT_EQ(unknown.field_count(), 0);
1719          }
1720        }
1721      }
1722    }
1723  }
1724}
1725
1726TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
1727  UNITTEST::Int32ParseTester obj;
1728
1729  const auto other_field = EncodeOtherField();
1730
1731  // Encode an int32 field for many different cases and verify that it can be
1732  // parsed as expected.
1733  // There are:
1734  //  - optional/repeated/packed fields
1735  //  - field tags that encode in 1/2/3 bytes
1736  //  - canonical and non-canonical encodings of the varint
1737  //  - last vs not last field
1738
1739  const std::vector<const FieldDescriptor*> fields =
1740      GetFields<UNITTEST::Int32ParseTester>();
1741
1742  auto* ref = obj.GetReflection();
1743  PROTOBUF_UNUSED auto* descriptor = obj.descriptor();
1744  for (bool use_tail_field : {false, true}) {
1745    SCOPED_TRACE(use_tail_field);
1746    for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
1747         ++non_canonical_bytes) {
1748      SCOPED_TRACE(non_canonical_bytes);
1749      for (bool add_garbage_bits : {false, true}) {
1750        if (add_garbage_bits && non_canonical_bytes != 9) {
1751          // We only add garbage on the 10th byte.
1752          continue;
1753        }
1754        SCOPED_TRACE(add_garbage_bits);
1755        for (auto field : fields) {
1756          if (field->name() == "other_field") continue;
1757          SCOPED_TRACE(field->full_name());
1758          for (int32_t value : {1, 0, -1, (std::numeric_limits<int32_t>::min)(),
1759                              (std::numeric_limits<int32_t>::max)()}) {
1760            SCOPED_TRACE(value);
1761            auto encoded =
1762                EncodeInt32Value(field->number(), value, non_canonical_bytes);
1763            if (add_garbage_bits) {
1764              // These bits should be discarded even in the `false` case.
1765              encoded.back() |= 0b0111'1110;
1766            }
1767            if (use_tail_field) {
1768              // Make sure that fields after this one can be parsed too. ie test
1769              // that the "next" jump is correct too.
1770              encoded += other_field;
1771            }
1772
1773            EXPECT_TRUE(obj.ParseFromString(encoded));
1774            if (field->is_repeated()) {
1775              ASSERT_EQ(ref->FieldSize(obj, field), 1);
1776              EXPECT_EQ(ref->GetRepeatedInt32(obj, field, 0), value);
1777            } else {
1778              EXPECT_TRUE(ref->HasField(obj, field));
1779              EXPECT_EQ(ref->GetInt32(obj, field), value)
1780                  << testing::PrintToString(encoded);
1781            }
1782            auto& unknown = ref->GetUnknownFields(obj);
1783            ASSERT_EQ(unknown.field_count(), 0);
1784          }
1785        }
1786      }
1787    }
1788  }
1789}
1790
1791TEST(MESSAGE_TEST_NAME, TestInt64Parsers) {
1792  UNITTEST::Int64ParseTester obj;
1793
1794  const auto other_field = EncodeOtherField();
1795
1796  // Encode an int64 field for many different cases and verify that it can be
1797  // parsed as expected.
1798  // There are:
1799  //  - optional/repeated/packed fields
1800  //  - field tags that encode in 1/2/3 bytes
1801  //  - canonical and non-canonical encodings of the varint
1802  //  - last vs not last field
1803
1804  const std::vector<const FieldDescriptor*> fields =
1805      GetFields<UNITTEST::Int64ParseTester>();
1806
1807  auto* ref = obj.GetReflection();
1808  PROTOBUF_UNUSED auto* descriptor = obj.descriptor();
1809  for (bool use_tail_field : {false, true}) {
1810    SCOPED_TRACE(use_tail_field);
1811    for (int non_canonical_bytes = 0; non_canonical_bytes < 10;
1812         ++non_canonical_bytes) {
1813      SCOPED_TRACE(non_canonical_bytes);
1814      for (bool add_garbage_bits : {false, true}) {
1815        if (add_garbage_bits && non_canonical_bytes != 9) {
1816          // We only add garbage on the 10th byte.
1817          continue;
1818        }
1819        SCOPED_TRACE(add_garbage_bits);
1820        for (auto field : fields) {
1821          if (field->name() == "other_field") continue;
1822          SCOPED_TRACE(field->full_name());
1823          for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1},
1824                                (std::numeric_limits<int64_t>::min)(),
1825                                (std::numeric_limits<int64_t>::max)()}) {
1826            SCOPED_TRACE(value);
1827            auto encoded =
1828                EncodeInt64Value(field->number(), value, non_canonical_bytes);
1829            if (add_garbage_bits) {
1830              // These bits should be discarded even in the `false` case.
1831              encoded.back() |= 0b0111'1110;
1832            }
1833            if (use_tail_field) {
1834              // Make sure that fields after this one can be parsed too. ie test
1835              // that the "next" jump is correct too.
1836              encoded += other_field;
1837            }
1838
1839            EXPECT_TRUE(obj.ParseFromString(encoded));
1840            if (field->is_repeated()) {
1841              ASSERT_EQ(ref->FieldSize(obj, field), 1);
1842              EXPECT_EQ(ref->GetRepeatedInt64(obj, field, 0), value);
1843            } else {
1844              EXPECT_TRUE(ref->HasField(obj, field));
1845              EXPECT_EQ(ref->GetInt64(obj, field), value)
1846                  << testing::PrintToString(encoded);
1847            }
1848            auto& unknown = ref->GetUnknownFields(obj);
1849            ASSERT_EQ(unknown.field_count(), 0);
1850          }
1851        }
1852      }
1853    }
1854  }
1855}
1856
1857TEST(MESSAGE_TEST_NAME, IsDefaultInstance) {
1858  UNITTEST::TestAllTypes msg;
1859  const auto& default_msg = UNITTEST::TestAllTypes::default_instance();
1860  const auto* r = msg.GetReflection();
1861  EXPECT_TRUE(r->IsDefaultInstance(default_msg));
1862  EXPECT_FALSE(r->IsDefaultInstance(msg));
1863}
1864
1865namespace {
1866std::string EncodeStringValue(int number, const std::string& value) {
1867  uint8_t buf[100];
1868  return std::string(
1869      buf, internal::WireFormatLite::WriteStringToArray(number, value, buf));
1870}
1871
1872class TestInputStream final : public io::ZeroCopyInputStream {
1873 public:
1874  explicit TestInputStream(absl::string_view payload, size_t break_pos)
1875      : payload_(payload), break_pos_(break_pos) {}
1876
1877  bool Next(const void** data, int* size) override {
1878    if (payload_.empty()) return false;
1879    const auto to_consume = payload_.substr(0, break_pos_);
1880    *data = to_consume.data();
1881    *size = to_consume.size();
1882    payload_.remove_prefix(to_consume.size());
1883    // The next time will consume the rest.
1884    break_pos_ = payload_.npos;
1885
1886    return true;
1887  }
1888
1889  void BackUp(int) override { ABSL_CHECK(false); }
1890  bool Skip(int) override {
1891    ABSL_CHECK(false);
1892    return false;
1893  }
1894  int64_t ByteCount() const override {
1895    ABSL_CHECK(false);
1896    return 0;
1897  }
1898
1899 private:
1900  absl::string_view payload_;
1901  size_t break_pos_;
1902};
1903}  // namespace
1904
1905TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) {
1906  google::protobuf::Arena arena;
1907
1908  const std::string sample =
1909      "abcdefghijklmnopqrstuvwxyz"
1910      "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
1911
1912  PROTOBUF_UNUSED const auto* const descriptor =
1913      UNITTEST::StringParseTester::descriptor();
1914
1915  const std::vector<const FieldDescriptor*> fields =
1916      GetFields<UNITTEST::StringParseTester>();
1917
1918  static const size_t sso_capacity = std::string().capacity();
1919  if (sso_capacity == 0) GTEST_SKIP();
1920  // SSO, !SSO, and off-by-one just in case
1921  for (size_t size :
1922       {sso_capacity - 1, sso_capacity, sso_capacity + 1, sso_capacity + 2}) {
1923    SCOPED_TRACE(size);
1924    const std::string value = sample.substr(0, size);
1925    for (auto field : fields) {
1926      SCOPED_TRACE(field->full_name());
1927      const auto encoded = EncodeStringValue(field->number(), sample) +
1928                           EncodeStringValue(field->number(), value);
1929      // Check for different breaks in the input stream to test cases where
1930      // the payload can be read and can't be read in one go.
1931      for (size_t i = 1; i <= encoded.size(); ++i) {
1932        TestInputStream input_stream(encoded, i);
1933
1934        auto& obj = *arena.Create<UNITTEST::StringParseTester>(&arena);
1935        auto* ref = obj.GetReflection();
1936        EXPECT_TRUE(obj.ParseFromZeroCopyStream(&input_stream));
1937        if (field->is_repeated()) {
1938          ASSERT_EQ(ref->FieldSize(obj, field), 2);
1939          EXPECT_EQ(ref->GetRepeatedString(obj, field, 0), sample);
1940          EXPECT_EQ(ref->GetRepeatedString(obj, field, 1), value);
1941        } else {
1942          EXPECT_EQ(ref->GetString(obj, field), value);
1943        }
1944      }
1945    }
1946  }
1947}
1948
1949TEST(MESSAGE_TEST_NAME, TestRegressionOnParseFailureNotSettingHasBits) {
1950  std::string single_field;
1951  // We use blocks because we want fully new instances of the proto. We are
1952  // testing .Clear(), so we can't use it to set up the test.
1953  {
1954    UNITTEST::TestAllTypes message;
1955    message.set_optional_int32(17);
1956    single_field = message.SerializeAsString();
1957  }
1958  const auto validate_message = [](auto& message) {
1959    if (!message.has_optional_int32()) {
1960      EXPECT_EQ(message.optional_int32(), 0);
1961    }
1962    message.Clear();
1963    EXPECT_FALSE(message.has_optional_int32());
1964    EXPECT_EQ(message.optional_int32(), 0);
1965  };
1966  {
1967    // Verify the setup is correct.
1968    UNITTEST::TestAllTypes message;
1969    EXPECT_FALSE(message.has_optional_int32());
1970    EXPECT_EQ(message.optional_int32(), 0);
1971    EXPECT_TRUE(message.ParseFromString(single_field));
1972    validate_message(message);
1973  }
1974  {
1975    // Run the regression.
1976    // These are the steps:
1977    // - The stream contains a fast field, and then a failure in MiniParse
1978    // - The parsing fails.
1979    // - We call clear.
1980    // - The fast field should be reset.
1981    UNITTEST::TestAllTypes message;
1982    EXPECT_FALSE(message.has_optional_int32());
1983    EXPECT_EQ(message.optional_int32(), 0);
1984    // The second tag will fail to parse because it has too many continuation
1985    // bits.
1986    auto with_error =
1987        absl::StrCat(single_field, std::string(100, static_cast<char>(0x80)));
1988    EXPECT_FALSE(message.ParseFromString(with_error));
1989    validate_message(message);
1990  }
1991}
1992
1993TEST(MESSAGE_TEST_NAME, TestRegressionOverwrittenLazyOneofDoesNotLeak) {
1994  UNITTEST::TestAllTypes message;
1995  auto* lazy = message.mutable_oneof_lazy_nested_message();
1996  // We need to add enough payload to make the lazy field overflow the SSO of
1997  // Cord. However, NestedMessage does not have enough fields for that. Just add
1998  // some unknown payload to it. Use something that the validator will allow to
1999  // stay as lazy.
2000  lazy->GetReflection()->MutableUnknownFields(lazy)->AddFixed64(10, 10);
2001  lazy->GetReflection()->MutableUnknownFields(lazy)->AddFixed64(11, 10);
2002  // Validate that the size is large enough.
2003  ASSERT_GT(lazy->ByteSizeLong(), 15);
2004
2005  // Append two instances of the oneof: first the lazy field, then any other to
2006  // cause a switch during parsing.
2007  std::string str;
2008  ASSERT_TRUE(message.AppendToString(&str));
2009  message.set_oneof_uint32(7);
2010  ASSERT_TRUE(message.AppendToString(&str));
2011
2012  EXPECT_TRUE(UNITTEST::TestAllTypes().ParseFromString(str));
2013  Arena arena;
2014  // This call had a bug where the LazyField was not destroyed in any way
2015  // causing the Cord inside it to leak its contents.
2016  EXPECT_TRUE(
2017      Arena::Create<UNITTEST::TestAllTypes>(&arena)->ParseFromString(str));
2018}
2019
2020}  // namespace protobuf
2021}  // namespace google
2022
2023#include "google/protobuf/port_undef.inc"
2024