1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package messageset encodes and decodes the obsolete MessageSet wire format. 6package messageset 7 8import ( 9 "math" 10 11 "google.golang.org/protobuf/encoding/protowire" 12 "google.golang.org/protobuf/internal/errors" 13 pref "google.golang.org/protobuf/reflect/protoreflect" 14) 15 16// The MessageSet wire format is equivalent to a message defined as follows, 17// where each Item defines an extension field with a field number of 'type_id' 18// and content of 'message'. MessageSet extensions must be non-repeated message 19// fields. 20// 21// message MessageSet { 22// repeated group Item = 1 { 23// required int32 type_id = 2; 24// required string message = 3; 25// } 26// } 27const ( 28 FieldItem = protowire.Number(1) 29 FieldTypeID = protowire.Number(2) 30 FieldMessage = protowire.Number(3) 31) 32 33// ExtensionName is the field name for extensions of MessageSet. 34// 35// A valid MessageSet extension must be of the form: 36// message MyMessage { 37// extend proto2.bridge.MessageSet { 38// optional MyMessage message_set_extension = 1234; 39// } 40// ... 41// } 42const ExtensionName = "message_set_extension" 43 44// IsMessageSet returns whether the message uses the MessageSet wire format. 45func IsMessageSet(md pref.MessageDescriptor) bool { 46 xmd, ok := md.(interface{ IsMessageSet() bool }) 47 return ok && xmd.IsMessageSet() 48} 49 50// IsMessageSetExtension reports this field properly extends a MessageSet. 51func IsMessageSetExtension(fd pref.FieldDescriptor) bool { 52 switch { 53 case fd.Name() != ExtensionName: 54 return false 55 case !IsMessageSet(fd.ContainingMessage()): 56 return false 57 case fd.FullName().Parent() != fd.Message().FullName(): 58 return false 59 } 60 return true 61} 62 63// SizeField returns the size of a MessageSet item field containing an extension 64// with the given field number, not counting the contents of the message subfield. 65func SizeField(num protowire.Number) int { 66 return 2*protowire.SizeTag(FieldItem) + protowire.SizeTag(FieldTypeID) + protowire.SizeVarint(uint64(num)) 67} 68 69// Unmarshal parses a MessageSet. 70// 71// It calls fn with the type ID and value of each item in the MessageSet. 72// Unknown fields are discarded. 73// 74// If wantLen is true, the item values include the varint length prefix. 75// This is ugly, but simplifies the fast-path decoder in internal/impl. 76func Unmarshal(b []byte, wantLen bool, fn func(typeID protowire.Number, value []byte) error) error { 77 for len(b) > 0 { 78 num, wtyp, n := protowire.ConsumeTag(b) 79 if n < 0 { 80 return protowire.ParseError(n) 81 } 82 b = b[n:] 83 if num != FieldItem || wtyp != protowire.StartGroupType { 84 n := protowire.ConsumeFieldValue(num, wtyp, b) 85 if n < 0 { 86 return protowire.ParseError(n) 87 } 88 b = b[n:] 89 continue 90 } 91 typeID, value, n, err := ConsumeFieldValue(b, wantLen) 92 if err != nil { 93 return err 94 } 95 b = b[n:] 96 if typeID == 0 { 97 continue 98 } 99 if err := fn(typeID, value); err != nil { 100 return err 101 } 102 } 103 return nil 104} 105 106// ConsumeFieldValue parses b as a MessageSet item field value until and including 107// the trailing end group marker. It assumes the start group tag has already been parsed. 108// It returns the contents of the type_id and message subfields and the total 109// item length. 110// 111// If wantLen is true, the returned message value includes the length prefix. 112func ConsumeFieldValue(b []byte, wantLen bool) (typeid protowire.Number, message []byte, n int, err error) { 113 ilen := len(b) 114 for { 115 num, wtyp, n := protowire.ConsumeTag(b) 116 if n < 0 { 117 return 0, nil, 0, protowire.ParseError(n) 118 } 119 b = b[n:] 120 switch { 121 case num == FieldItem && wtyp == protowire.EndGroupType: 122 if wantLen && len(message) == 0 { 123 // The message field was missing, which should never happen. 124 // Be prepared for this case anyway. 125 message = protowire.AppendVarint(message, 0) 126 } 127 return typeid, message, ilen - len(b), nil 128 case num == FieldTypeID && wtyp == protowire.VarintType: 129 v, n := protowire.ConsumeVarint(b) 130 if n < 0 { 131 return 0, nil, 0, protowire.ParseError(n) 132 } 133 b = b[n:] 134 if v < 1 || v > math.MaxInt32 { 135 return 0, nil, 0, errors.New("invalid type_id in message set") 136 } 137 typeid = protowire.Number(v) 138 case num == FieldMessage && wtyp == protowire.BytesType: 139 m, n := protowire.ConsumeBytes(b) 140 if n < 0 { 141 return 0, nil, 0, protowire.ParseError(n) 142 } 143 if message == nil { 144 if wantLen { 145 message = b[:n:n] 146 } else { 147 message = m[:len(m):len(m)] 148 } 149 } else { 150 // This case should never happen in practice, but handle it for 151 // correctness: The MessageSet item contains multiple message 152 // fields, which need to be merged. 153 // 154 // In the case where we're returning the length, this becomes 155 // quite inefficient since we need to strip the length off 156 // the existing data and reconstruct it with the combined length. 157 if wantLen { 158 _, nn := protowire.ConsumeVarint(message) 159 m0 := message[nn:] 160 message = nil 161 message = protowire.AppendVarint(message, uint64(len(m0)+len(m))) 162 message = append(message, m0...) 163 message = append(message, m...) 164 } else { 165 message = append(message, m...) 166 } 167 } 168 b = b[n:] 169 default: 170 // We have no place to put it, so we just ignore unknown fields. 171 n := protowire.ConsumeFieldValue(num, wtyp, b) 172 if n < 0 { 173 return 0, nil, 0, protowire.ParseError(n) 174 } 175 b = b[n:] 176 } 177 } 178} 179 180// AppendFieldStart appends the start of a MessageSet item field containing 181// an extension with the given number. The caller must add the message 182// subfield (including the tag). 183func AppendFieldStart(b []byte, num protowire.Number) []byte { 184 b = protowire.AppendTag(b, FieldItem, protowire.StartGroupType) 185 b = protowire.AppendTag(b, FieldTypeID, protowire.VarintType) 186 b = protowire.AppendVarint(b, uint64(num)) 187 return b 188} 189 190// AppendFieldEnd appends the trailing end group marker for a MessageSet item field. 191func AppendFieldEnd(b []byte) []byte { 192 return protowire.AppendTag(b, FieldItem, protowire.EndGroupType) 193} 194 195// SizeUnknown returns the size of an unknown fields section in MessageSet format. 196// 197// See AppendUnknown. 198func SizeUnknown(unknown []byte) (size int) { 199 for len(unknown) > 0 { 200 num, typ, n := protowire.ConsumeTag(unknown) 201 if n < 0 || typ != protowire.BytesType { 202 return 0 203 } 204 unknown = unknown[n:] 205 _, n = protowire.ConsumeBytes(unknown) 206 if n < 0 { 207 return 0 208 } 209 unknown = unknown[n:] 210 size += SizeField(num) + protowire.SizeTag(FieldMessage) + n 211 } 212 return size 213} 214 215// AppendUnknown appends unknown fields to b in MessageSet format. 216// 217// For historic reasons, unresolved items in a MessageSet are stored in a 218// message's unknown fields section in non-MessageSet format. That is, an 219// unknown item with typeID T and value V appears in the unknown fields as 220// a field with number T and value V. 221// 222// This function converts the unknown fields back into MessageSet form. 223func AppendUnknown(b, unknown []byte) ([]byte, error) { 224 for len(unknown) > 0 { 225 num, typ, n := protowire.ConsumeTag(unknown) 226 if n < 0 || typ != protowire.BytesType { 227 return nil, errors.New("invalid data in message set unknown fields") 228 } 229 unknown = unknown[n:] 230 _, n = protowire.ConsumeBytes(unknown) 231 if n < 0 { 232 return nil, errors.New("invalid data in message set unknown fields") 233 } 234 b = AppendFieldStart(b, num) 235 b = protowire.AppendTag(b, FieldMessage, protowire.BytesType) 236 b = append(b, unknown[:n]...) 237 b = AppendFieldEnd(b) 238 unknown = unknown[n:] 239 } 240 return b, nil 241} 242