• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h
2index 074784b96..aff050a81 100644
3--- a/src/google/protobuf/extension_set_inl.h
4+++ b/src/google/protobuf/extension_set_inl.h
5@@ -206,16 +206,22 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
6     const char* ptr, const Msg* containing_type,
7     internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
8   std::string payload;
9-  uint32 type_id = 0;
10-  bool payload_read = false;
11+
12+  uint32_t type_id;
13+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
14+  State state = State::kNoTag;
15+
16   while (!ctx->Done(&ptr)) {
17     uint32 tag = static_cast<uint8>(*ptr++);
18     if (tag == WireFormatLite::kMessageSetTypeIdTag) {
19       uint64 tmp;
20       ptr = ParseBigVarint(ptr, &tmp);
21       GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
22-      type_id = tmp;
23-      if (payload_read) {
24+      if (state == State::kNoTag) {
25+        type_id = tmp;
26+        state = State::kHasType;
27+      } else if (state == State::kHasPayload) {
28+        type_id = tmp;
29         ExtensionInfo extension;
30         bool was_packed_on_wire;
31         if (!FindExtension(2, type_id, containing_type, ctx, &extension,
32@@ -241,20 +247,26 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
33           GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
34                                          tmp_ctx.EndedAtLimit());
35         }
36-        type_id = 0;
37+        state = State::kDone;
38       }
39     } else if (tag == WireFormatLite::kMessageSetMessageTag) {
40-      if (type_id != 0) {
41-        ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
42-                                    containing_type, metadata, ctx);
43+
44+      if (state == State::kHasType) {
45+        ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
46+                                    containing_type, metadata, ctx);
47         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
48-        type_id = 0;
49+        state = State::kDone;
50       } else {
51-        int32 size = ReadSize(&ptr);
52+
53+        std::string tmp;
54+        int32_t size = ReadSize(&ptr);
55         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
56-        ptr = ctx->ReadString(ptr, size, &payload);
57+        ptr = ctx->ReadString(ptr, size, &tmp);
58         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
59-        payload_read = true;
60+        if (state == State::kNoTag) {
61+          payload = std::move(tmp);
62+          state = State::kHasPayload;
63+        }
64       }
65     } else {
66       ptr = ReadTag(ptr - 1, &tag);
67diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc
68index 16edf2ce3..88fb09169 100644
69--- a/src/google/protobuf/wire_format.cc
70+++ b/src/google/protobuf/wire_format.cc
71@@ -659,9 +659,11 @@ struct WireFormat::MessageSetParser {
72   const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
73     // Parse a MessageSetItem
74     auto metadata = reflection->MutableInternalMetadata(msg);
75+    enum class State { kNoTag, kHasType, kHasPayload, kDone };
76+    State state = State::kNoTag;
77+
78     std::string payload;
79-    uint32 type_id = 0;
80-    bool payload_read = false;
81+    uint32_t type_id = 0;
82     while (!ctx->Done(&ptr)) {
83       // We use 64 bit tags in order to allow typeid's that span the whole
84       // range of 32 bit numbers.
85@@ -670,8 +672,11 @@ struct WireFormat::MessageSetParser {
86         uint64 tmp;
87         ptr = ParseBigVarint(ptr, &tmp);
88         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
89-        type_id = tmp;
90-        if (payload_read) {
91+        if (state == State::kNoTag) {
92+          type_id = tmp;
93+          state = State::kHasType;
94+        } else if (state == State::kHasPayload) {
95+          type_id = tmp;
96           const FieldDescriptor* field;
97           if (ctx->data().pool == nullptr) {
98             field = reflection->FindKnownExtensionByNumber(type_id);
99@@ -698,17 +703,18 @@ struct WireFormat::MessageSetParser {
100             GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
101                                            tmp_ctx.EndedAtLimit());
102           }
103-          type_id = 0;
104+          state = State::kDone;
105         }
106         continue;
107       } else if (tag == WireFormatLite::kMessageSetMessageTag) {
108-        if (type_id == 0) {
109-          int32 size = ReadSize(&ptr);
110+
111+        if (state == State::kNoTag) {
112+          int32_t size = ReadSize(&ptr);
113           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
114           ptr = ctx->ReadString(ptr, size, &payload);
115           GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
116-          payload_read = true;
117-        } else {
118+          state = State::kHasPayload;
119+        } else if (state == State::kHasType) {
120           // We're now parsing the payload
121           const FieldDescriptor* field = nullptr;
122           if (descriptor->IsExtensionNumber(type_id)) {
123@@ -722,7 +728,12 @@ struct WireFormat::MessageSetParser {
124           ptr = WireFormat::_InternalParseAndMergeField(
125               msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection,
126               field);
127-          type_id = 0;
128+          state = State::kDone;
129+        } else {
130+          int32_t size = ReadSize(&ptr);
131+          GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
132+          ptr = ctx->Skip(ptr, size);
133+          GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
134         }
135       } else {
136         // An unknown field in MessageSetItem.
137diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h
138index c742fe869..4130bc531 100644
139--- a/src/google/protobuf/wire_format_lite.h
140+++ b/src/google/protobuf/wire_format_lite.h
141@@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
142   // we can parse it later.
143   std::string message_data;
144
145+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
146+  State state = State::kNoTag;
147+
148   while (true) {
149     const uint32 tag = input->ReadTagNoLastTag();
150     if (tag == 0) return false;
151@@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
152       case WireFormatLite::kMessageSetTypeIdTag: {
153         uint32 type_id;
154         if (!input->ReadVarint32(&type_id)) return false;
155-        last_type_id = type_id;
156-
157-        if (!message_data.empty()) {
158+        if (state == State::kNoTag) {
159+          last_type_id = type_id;
160+          state = State::kHasType;
161+        } else if (state == State::kHasPayload) {
162           // We saw some message data before the type_id.  Have to parse it
163           // now.
164           io::CodedInputStream sub_input(
165               reinterpret_cast<const uint8*>(message_data.data()),
166               static_cast<int>(message_data.size()));
167           sub_input.SetRecursionLimit(input->RecursionBudget());
168-          if (!ms.ParseField(last_type_id, &sub_input)) {
169+          if (!ms.ParseField(type_id, &sub_input)) {
170             return false;
171           }
172           message_data.clear();
173+          state = State::kDone;
174         }
175
176         break;
177       }
178
179       case WireFormatLite::kMessageSetMessageTag: {
180-        if (last_type_id == 0) {
181+        if (state == State::kHasType) {
182+          // Already saw type_id, so we can parse this directly.
183+          if (!ms.ParseField(last_type_id, input)) {
184+            return false;
185+          }
186+          state = State::kDone;
187+        } else if (state == State::kNoTag) {
188           // We haven't seen a type_id yet.  Append this data to message_data.
189           uint32 length;
190           if (!input->ReadVarint32(&length)) return false;
191@@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
192           auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
193           ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
194           if (!input->ReadRaw(ptr, length)) return false;
195+          state = State::kHasPayload;
196         } else {
197-          // Already saw type_id, so we can parse this directly.
198-          if (!ms.ParseField(last_type_id, input)) {
199-            return false;
200-          }
201+          if (!ms.SkipField(tag, input)) return false;
202         }
203
204         break;
205