diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h index 074784b96..aff050a81 100644 --- a/src/google/protobuf/extension_set_inl.h +++ b/src/google/protobuf/extension_set_inl.h @@ -206,16 +206,22 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( const char* ptr, const Msg* containing_type, internal::InternalMetadata* metadata, internal::ParseContext* ctx) { std::string payload; - uint32 type_id = 0; - bool payload_read = false; + + uint32_t type_id; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (!ctx->Done(&ptr)) { uint32 tag = static_cast(*ptr++); if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint64 tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; ExtensionInfo extension; bool was_packed_on_wire; if (!FindExtension(2, type_id, containing_type, ctx, &extension, @@ -241,20 +247,26 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id != 0) { - ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, - containing_type, metadata, ctx); + + if (state == State::kHasType) { + ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, + containing_type, metadata, ctx); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); - type_id = 0; + state = State::kDone; } else { - int32 size = ReadSize(&ptr); + + std::string tmp; + int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - ptr = ctx->ReadString(ptr, size, &payload); + ptr = ctx->ReadString(ptr, size, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; + if (state == State::kNoTag) { + payload = std::move(tmp); + state = State::kHasPayload; + } } } else { ptr = ReadTag(ptr - 1, &tag); diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 16edf2ce3..88fb09169 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -659,9 +659,11 @@ struct WireFormat::MessageSetParser { const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { // Parse a MessageSetItem auto metadata = reflection->MutableInternalMetadata(msg); + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + std::string payload; - uint32 type_id = 0; - bool payload_read = false; + uint32_t type_id = 0; while (!ctx->Done(&ptr)) { // We use 64 bit tags in order to allow typeid's that span the whole // range of 32 bit numbers. @@ -670,8 +672,11 @@ struct WireFormat::MessageSetParser { uint64 tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; const FieldDescriptor* field; if (ctx->data().pool == nullptr) { field = reflection->FindKnownExtensionByNumber(type_id); @@ -698,17 +703,18 @@ struct WireFormat::MessageSetParser { GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } continue; } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id == 0) { - int32 size = ReadSize(&ptr); + + if (state == State::kNoTag) { + int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); ptr = ctx->ReadString(ptr, size, &payload); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; - } else { + state = State::kHasPayload; + } else if (state == State::kHasType) { // We're now parsing the payload const FieldDescriptor* field = nullptr; if (descriptor->IsExtensionNumber(type_id)) { @@ -722,7 +728,12 @@ struct WireFormat::MessageSetParser { ptr = WireFormat::_InternalParseAndMergeField( msg, ptr, ctx, static_cast(type_id) * 8 + 2, reflection, field); - type_id = 0; + state = State::kDone; + } else { + int32_t size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ptr = ctx->Skip(ptr, size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); } } else { // An unknown field in MessageSetItem. diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h index c742fe869..4130bc531 100644 --- a/src/google/protobuf/wire_format_lite.h +++ b/src/google/protobuf/wire_format_lite.h @@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { // we can parse it later. std::string message_data; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (true) { const uint32 tag = input->ReadTagNoLastTag(); if (tag == 0) return false; @@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { case WireFormatLite::kMessageSetTypeIdTag: { uint32 type_id; if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - - if (!message_data.empty()) { + if (state == State::kNoTag) { + last_type_id = type_id; + state = State::kHasType; + } else if (state == State::kHasPayload) { // We saw some message data before the type_id. Have to parse it // now. io::CodedInputStream sub_input( reinterpret_cast(message_data.data()), static_cast(message_data.size())); sub_input.SetRecursionLimit(input->RecursionBudget()); - if (!ms.ParseField(last_type_id, &sub_input)) { + if (!ms.ParseField(type_id, &sub_input)) { return false; } message_data.clear(); + state = State::kDone; } break; } case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { + if (state == State::kHasType) { + // Already saw type_id, so we can parse this directly. + if (!ms.ParseField(last_type_id, input)) { + return false; + } + state = State::kDone; + } else if (state == State::kNoTag) { // We haven't seen a type_id yet. Append this data to message_data. uint32 length; if (!input->ReadVarint32(&length)) return false; @@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { auto ptr = reinterpret_cast(&message_data[0]); ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); if (!input->ReadRaw(ptr, length)) return false; + state = State::kHasPayload; } else { - // Already saw type_id, so we can parse this directly. - if (!ms.ParseField(last_type_id, input)) { - return false; - } + if (!ms.SkipField(tag, input)) return false; } break;